HardNegativeMining
类keras_rs.layers.HardNegativeMining(num_hard_negatives: int, **kwargs: Any)
过滤 logits 和 labels 以返回难负例。
输出将包含指定数量的难负例以及正例候选的 logits 和 labels。
参数
示例
# Create layer with the configured number of hard negatives to mine.
hard_negative_mining = keras_rs.layers.HardNegativeMining(
num_hard_negatives=10
)
# This will retrieve the top 10 negative candidates plus the positive
# candidate from `labels` for each row.
out_logits, out_labels = hard_negative_mining(in_logits, in_labels)
call
方法HardNegativeMining.call(logits: Any, labels: Any)
使用按查询的难负例挖掘过滤 logits 和 labels。
结果将包含 num_hard_negatives
个负例以及正例候选的 logits 和 labels。
参数
[batch_size, num_candidates]
,但也可能包含更多维度或为一维 [num_candidates]
。logits
形状相同。返回值
一个包含两个张量的元组,其中最后一维 num_candidates
被替换为 num_hard_negatives + 1
。
[..., num_hard_negatives + 1]
形状的 logits 张量。[..., num_hard_negatives + 1]
形状的 one-hot labels 张量。