KerasHub: 预训练模型 / API 文档 / 采样器 / 对比采样器 (ContrastiveSampler)

对比采样器 (ContrastiveSampler)

[源代码]

ContrastiveSampler

keras_hub.samplers.ContrastiveSampler(k=5, alpha=0.6, **kwargs)

对比采样器类。

此采样器实现了对比搜索算法。简而言之,采样器选择“得分”最高的标记作为下一个标记。“得分”是标记概率和与先前标记的最大相似度之间的加权和。通过使用这种联合得分,对比采样器减少了重复看到标记的行为。

参数

  • k: int, top-k 的 k 值。下一个标记将从 k 个标记中选择。
  • alpha: float, 在联合得分计算中,负最大相似度的权重。alpha 的值越大,得分越依赖相似度而不是标记概率。

调用参数

{{call_args}}

示例

causal_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Pass by name to compile.
causal_lm.compile(sampler="contrastive")
causal_lm.generate(["Keras is a"])

# Pass by object to compile.
sampler = keras_hub.samplers.ContrastiveSampler(k=5)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])