TopPSampler
类keras_nlp.samplers.TopPSampler(p=0.1, k=None, seed=None, **kwargs)
Top-P 采样器类。
此采样器实现 top-p 搜索算法。Top-p 搜索从输出概率的最小子集中选择标记,这些子集的总和大于 p
。换句话说,top-p 将首先按可能性对标记预测进行排序,并忽略所有在选定标记的累积概率超过 p
之后的标记,然后从剩余的标记中选择一个标记。
参数
p
值。k
个中的所有 logits 将被丢弃,并将对剩余的 logits 进行排序以找到 p
的截止点。设置此参数可以通过减少要排序的标记数量来显着加快采样速度。默认为 None
。None
。调用参数
{{call_args}}
示例
causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Pass by name to compile.
causal_lm.compile(sampler="top_p")
causal_lm.generate(["Keras is a"])
# Pass by object to compile.
sampler = keras_nlp.samplers.TopPSampler(p=0.1, k=1_000)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])