ContrastiveSampler 类keras_hub.samplers.ContrastiveSampler(k=5, alpha=0.6, **kwargs)
对比采样器类。
此采样器实现了对比搜索算法。简而言之,采样器选择具有最大“分数”的 token 作为下一个 token。“分数”是 token 的概率与对先前 token 的最大相似度之间的加权和。通过使用此联合分数,对比采样器可以减少重复已见 token 的行为。
参数
k 值。下一个 token 将从 k 个 token 中选择。alpha 值越大,分数在多大程度上依赖于相似度而不是 token 概率。调用参数
{{call_args}}
示例
causal_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Pass by name to compile.
causal_lm.compile(sampler="contrastive")
causal_lm.generate(["Keras is a"])
# Pass by object to compile.
sampler = keras_hub.samplers.ContrastiveSampler(k=5)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])