Keras 3 API 文档 / KerasHub / 采样器 / 采样器基类

采样器基类

[源代码]

Sampler

keras_hub.samplers.Sampler(temperature=1.0)

采样器基类。

参数

  • temperature: float。可选。用于控制采样的随机性。温度越高,样本越多样化。默认值为 1.0

调用参数

{{call_args}}

此基类可以扩展以实现不同的自回归采样方法。为此,请覆盖 get_next_token() 方法,该方法根据所有可能词汇条目上的概率分布计算下一个标记。

示例

causal_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Greedy search with some tokens forbidden.
class CustomSampler(keras_hub.samplers.Sampler):
    def __init__(self, forbidden_tokens, **kwargs):
        super().__init__(**kwargs)
        self.forbidden_tokens = forbidden_tokens

    def get_next_token(self, probs):
        batch_size, vocab_size = keras.ops.shape(probs)
        for id in self.forbidden_tokens:
            update = keras.ops.zeros((batch_size, 1))
            probs = keras.ops.slice_update(probs, (0, id), update)
        return keras.ops.argmax(probs, axis=-1)

# 257 = "a" with a leading space, 262 = "the" with a leading space.
causal_lm.compile(sampler=CustomSampler(forbidden_tokens=[257, 262]))
causal_lm.summary()
causal_lm.generate(["That's strange"])

[源代码]

get_next_token 方法

Sampler.get_next_token(probabilities)

获取下一个标记。参数

  • probabilities: 张量,所有词汇标记上的下一个标记的概率分布。

根据给定的标记概率分布获取下一个标记。子类必须实现此方法。