KerasRS / API文档 / Retrieval Layers / SamplingProbabilityCorrection layer

SamplingProbabilityCorrection 层

[源代码]

SamplingProbabilityCorrection

keras_rs.layers.SamplingProbabilityCorrection(epsilon: float = 1e-06, **kwargs: Any)

采样概率校正。

校正 logits 以反映负样本的采样概率。

参数

  • epsilon: float。添加到采样概率的小浮点数,以避免对零取对数。默认为 1e-6。
  • **kwargs: 传递给基类的参数。

示例

# Create the layer.
sampling_probability_correction = (
    keras_rs.layers.SamplingProbabilityCorrection()
)

# Correct the logits based on the provided candidate sampling probability.
logits = sampling_probability_correction(logits, probabilities)

[源代码]

call 方法

SamplingProbabilityCorrection.call(logits: Any, candidate_sampling_probability: Any)

校正输入 logits 以考虑候选采样概率。

参数

  • logits: 需要校正的 logits 张量,通常为 [batch_size, num_candidates],但也可以有更多维度或为 1D 形式 [num_candidates]
  • candidate_sampling_probability: 采样概率,形状与 logits 相同。

返回

校正后的 logits,形状与输入 logits 相同。