KerasRS / API文档 / 检索层 / HardNegativeMining层

HardNegativeMining 层

[源代码]

HardNegativeMining

keras_rs.layers.HardNegativeMining(num_hard_negatives: int, **kwargs: Any)

过滤 logits 和 labels,返回硬负例。

输出将包含请求数量的硬负例的 logits 和 labels,以及正例候选。

参数

  • num_hard_negatives:要返回的硬负例的数量。
  • **kwargs: 传递给基类的参数。

示例

# Create layer with the configured number of hard negatives to mine.
hard_negative_mining = keras_rs.layers.HardNegativeMining(
    num_hard_negatives=10
)

# This will retrieve the top 10 negative candidates plus the positive
# candidate from `labels` for each row.
out_logits, out_labels = hard_negative_mining(in_logits, in_labels)

[源代码]

call 方法

HardNegativeMining.call(logits: Any, labels: Any)

通过每查询硬负例挖掘来过滤 logits 和 labels。

结果将包括 num_hard_negatives 个负例的 logits 和 labels,以及正例候选。

参数

  • logits:logits 张量,通常为 [batch_size, num_candidates],但也可以有更多维度或为 1D [num_candidates]
  • labels:one-hot 编码的 labels 张量,形状必须与 logits 相同。

返回

一个包含两个张量的元组,最后一个维度从 num_candidates 替换为 num_hard_negatives + 1

  • logits:[..., num_hard_negatives + 1] 形状的 logits 张量。
  • labels:[..., num_hard_negatives + 1] 形状的 one-hot 编码的 labels 张量。