BeamSampler
类keras_hub.samplers.BeamSampler(num_beams=5, return_all_beams=False, **kwargs)
Beam 采样器类。
此采样器实现了束搜索算法。在每个时间步,束搜索保留累积概率最高的 num_beams
个束(序列),并使用每个束来预测候选的下一个 token。
参数
num_beams
必须是严格正数。True
时,采样器将返回所有束及其各自的概率分数。调用参数
{{call_args}}
示例
causal_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Pass by name to compile.
causal_lm.compile(sampler="beam")
causal_lm.generate(["Keras is a"])
# Pass by object to compile.
sampler = keras_hub.samplers.BeamSampler(num_beams=5)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])