KerasRS / API 文档 / 检索层 / RemoveAccidentalHits 层

RemoveAccidentalHits 层

[源代码]

RemoveAccidentalHits

keras_rs.layers.RemoveAccidentalHits(
    activity_regularizer=None,
    trainable=True,
    dtype=None,
    autocast=True,
    name=None,
    **kwargs
)

将意外负样本的 logits 置零。

将同一行中与正样本具有相同 ID 的负候选样本的 logits 置零。

示例

# Create layer with the configured number of hard negatives to mine.
remove_accidental_hits = keras_rs.layers.RemoveAccidentalHits()

# This will zero the logits of negative candidates that have the same ID as
# the positive candidate from `labels` so as to not negatively impact the
# true positive.
logits = remove_accidental_hits(logits, labels, candidate_ids)

[源代码]

call 方法

RemoveAccidentalHits.call(logits: Any, labels: Any, candidate_ids: Any)

将选定的 logits 置零。

对于批量中的每一行,将同一行中与正样本具有相同 ID 的负候选样本的 logits 置零。

参数

  • logits: logits 张量,通常形状为 [batch_size, num_candidates],但也可能具有更多维度或为 1D 的 [num_candidates]
  • labels: one-hot 编码的标签张量,必须与 logits 具有相同的形状。
  • candidate_ids: 候选标识符张量,形状可以为 [num_candidates][batch_size, num_candidates],或者具有更多维度,只要其最后维度与 labels 的最后维度匹配即可。

返回值

修改后的 logits,其形状与输入 logits 相同。