KerasRS / API 文档 / 检索层 / HardNegativeMining 层

HardNegativeMining 层

[源代码]

HardNegativeMining

keras_rs.layers.HardNegativeMining(num_hard_negatives: int, **kwargs: Any)

过滤 logits 和 labels 以返回难负例。

输出将包含指定数量的难负例以及正例候选的 logits 和 labels。

参数

  • num_hard_negatives: 要返回的难负例数量。
  • **kwargs: 传递给基类的参数。

示例

# Create layer with the configured number of hard negatives to mine.
hard_negative_mining = keras_rs.layers.HardNegativeMining(
    num_hard_negatives=10
)

# This will retrieve the top 10 negative candidates plus the positive
# candidate from `labels` for each row.
out_logits, out_labels = hard_negative_mining(in_logits, in_labels)

[源代码]

call 方法

HardNegativeMining.call(logits: Any, labels: Any)

使用按查询的难负例挖掘过滤 logits 和 labels。

结果将包含 num_hard_negatives 个负例以及正例候选的 logits 和 labels。

参数

  • logits: logits 张量,通常为 [batch_size, num_candidates],但也可能包含更多维度或为一维 [num_candidates]
  • labels: one-hot labels 张量,必须与 logits 形状相同。

返回值

一个包含两个张量的元组,其中最后一维 num_candidates 被替换为 num_hard_negatives + 1

  • logits: [..., num_hard_negatives + 1] 形状的 logits 张量。
  • labels: [..., num_hard_negatives + 1] 形状的 one-hot labels 张量。