HardNegativeMining 类keras_rs.layers.HardNegativeMining(num_hard_negatives: int, **kwargs: Any)
过滤 logits 和 labels,返回硬负例。
输出将包含请求数量的硬负例的 logits 和 labels,以及正例候选。
参数
示例
# Create layer with the configured number of hard negatives to mine.
hard_negative_mining = keras_rs.layers.HardNegativeMining(
num_hard_negatives=10
)
# This will retrieve the top 10 negative candidates plus the positive
# candidate from `labels` for each row.
out_logits, out_labels = hard_negative_mining(in_logits, in_labels)
call 方法HardNegativeMining.call(logits: Any, labels: Any)
通过每查询硬负例挖掘来过滤 logits 和 labels。
结果将包括 num_hard_negatives 个负例的 logits 和 labels,以及正例候选。
参数
[batch_size, num_candidates],但也可以有更多维度或为 1D [num_candidates]。logits 相同。返回
一个包含两个张量的元组,最后一个维度从 num_candidates 替换为 num_hard_negatives + 1。
[..., num_hard_negatives + 1] 形状的 logits 张量。[..., num_hard_negatives + 1] 形状的 one-hot 编码的 labels 张量。