SamplingProbabilityCorrection
类keras_rs.layers.SamplingProbabilityCorrection(epsilon: float = 1e-06, **kwargs: Any)
采样概率校正。
校正 logits 以反映负样本的采样概率。
参数
示例
# Create the layer.
sampling_probability_correction = (
keras_rs.layers.SamplingProbabilityCorrection()
)
# Correct the logits based on the provided candidate sampling probability.
logits = sampling_probability_correction(logits, probabilities)
call
方法SamplingProbabilityCorrection.call(logits: Any, candidate_sampling_probability: Any)
校正输入 logits 以考虑候选样本的采样概率。
参数
[batch_size, num_candidates]
,但可以有更多维度或是一维的 [num_candidates]
。logits
形状相同的采样概率。返回值
校正后的 logits,形状与输入 logits 相同。