Keras 3 API 文档 / RNG API / SeedGenerator 类

SeedGenerator 类

[源代码]

SeedGenerator

keras.random.SeedGenerator(seed=None, name=None, **kwargs)

在每次调用生成随机数的函数时生成可变的种子。

在 Keras 中,所有随机数生成器(例如 keras.random.normal())都是无状态的,这意味着如果您向它们传递一个整数种子(例如 seed=42),它们将为重复调用返回相同的值。要为每次调用获取不同的值,必须使用提供随机生成器状态的 SeedGenerator

请注意,所有随机数生成器的默认种子为 None,这意味着将使用内部全局 SeedGenerator。如果您需要将 RNG 与全局状态分离,您可以提供一个具有确定性或随机初始状态的本地 StateGenerator

关于 JAX 后端的说明:请注意,对于使用 JAX 后端的 RNG 的 JIT 编译,需要使用本地 StateGenerator 作为种子参数,因为不支持使用全局状态。

示例

seed_gen = keras.random.SeedGenerator(seed=42)
values = keras.random.normal(shape=(2, 3), seed=seed_gen)
new_values = keras.random.normal(shape=(2, 3), seed=seed_gen)

在层中的用法

class Dropout(keras.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.seed_generator = keras.random.SeedGenerator(1337)

    def call(self, x, training=False):
        if training:
            return keras.random.dropout(
                x, rate=0.5, seed=self.seed_generator
            )
        return x