SeedGenerator
类keras.random.SeedGenerator(seed=None, name=None, **kwargs)
在每次调用生成随机数的函数时生成可变的种子。
在 Keras 中,所有随机数生成器(例如 keras.random.normal()
)都是无状态的,这意味着如果您向它们传递一个整数种子(例如 seed=42
),它们将为重复调用返回相同的值。要为每次调用获取不同的值,必须使用提供随机生成器状态的 SeedGenerator
。
请注意,所有随机数生成器的默认种子为 None,这意味着将使用内部全局 SeedGenerator。如果您需要将 RNG 与全局状态分离,您可以提供一个具有确定性或随机初始状态的本地 StateGenerator
。
关于 JAX 后端的说明:请注意,对于使用 JAX 后端的 RNG 的 JIT 编译,需要使用本地 StateGenerator
作为种子参数,因为不支持使用全局状态。
示例
seed_gen = keras.random.SeedGenerator(seed=42)
values = keras.random.normal(shape=(2, 3), seed=seed_gen)
new_values = keras.random.normal(shape=(2, 3), seed=seed_gen)
在层中的用法
class Dropout(keras.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.seed_generator = keras.random.SeedGenerator(1337)
def call(self, x, training=False):
if training:
return keras.random.dropout(
x, rate=0.5, seed=self.seed_generator
)
return x