Keras 3 API 文档 / KerasNLP / 采样器 / BeamSampler

BeamSampler

[源代码]

BeamSampler

keras_nlp.samplers.BeamSampler(num_beams=5, return_all_beams=False, **kwargs)

光束采样器类。

此采样器实现了光束搜索算法。在每个时间步,光束搜索都会保留累积概率最高的 num_beams 个光束(序列),并使用每个光束来预测候选的下一个标记。

参数

  • num_beams: int。每个时间步应保留的光束数量。num_beams 应为严格正数。
  • return_all_beams: bool。设置为 True 时,采样器将返回所有光束及其相应的概率得分。

调用参数

{{call_args}}

示例

causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Pass by name to compile.
causal_lm.compile(sampler="beam")
causal_lm.generate(["Keras is a"])

# Pass by object to compile.
sampler = keras_nlp.samplers.BeamSampler(num_beams=5)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])