Keras 3 API 文档 / 随机数生成 API / SeedGenerator 类

SeedGenerator 类

[源代码]

SeedGenerator

keras.random.SeedGenerator(seed=None, name=None, **kwargs)

每次调用使用 RNG 的函数时生成可变种子。

在 Keras 中,所有使用 RNG 的方法(例如 keras.random.normal())都是无状态的,这意味着如果您向它们传递一个整数种子(例如 seed=42),它们在每次调用时都会返回相同的值。为了在每次调用时获得不同的值,您必须使用 SeedGenerator 而不是作为种子参数。SeedGenerator 对象是有状态的。

示例

seed_gen = keras.random.SeedGenerator(seed=42)
values = keras.random.normal(shape=(2, 3), seed=seed_gen)
new_values = keras.random.normal(shape=(2, 3), seed=seed_gen)

在层中使用

class Dropout(keras.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.seed_generator = keras.random.SeedGenerator(1337)

    def call(self, x, training=False):
        if training:
            return keras.random.dropout(
                x, rate=0.5, seed=self.seed_generator
            )
        return x