Top-K 采样器

[源代码]

TopKSampler

keras_hub.samplers.TopKSampler(k=5, seed=None, **kwargs)

Top-K 采样器类。

此采样器实现了 Top-K 搜索算法。 简而言之,Top-K 算法从概率最高的 K 个词元中随机选择一个词元,选择机会由概率决定。

参数

  • k: int,top-k 的 k 值。
  • seed: int。 随机种子。 默认为 None

调用参数

{{call_args}}

示例

causal_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Pass by name to compile.
causal_lm.compile(sampler="top_k")
causal_lm.generate(["Keras is a"])

# Pass by object to compile.
sampler = keras_hub.samplers.TopKSampler(k=5, temperature=0.7)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])