Keras 3 API 文档 / KerasNLP / 采样器 / GreedySampler

贪婪采样器(GreedySampler)

[源代码]

GreedySampler

keras_nlp.samplers.GreedySampler(**kwargs)

贪婪采样器类。

此采样器基于贪婪搜索实现,即始终选择概率最大的标记作为下一个标记。

示例

causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Pass by name to compile.
causal_lm.compile(sampler="greedy")
causal_lm.generate(["Keras is a"])

# Pass by object to compile.
sampler = keras_nlp.samplers.GreedySampler()
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])