FlaxLayer

[source]

FlaxLayer

keras.layers.FlaxLayer(module, method=None, variables=None, **kwargs)

Keras 层,用于封装 Flax 模块。

当使用 JAX 作为 Keras 的后端时,此层允许在 Keras 中以 flax.linen.Module 实例的形式使用 Flax 组件。

用于前向传递的模块方法可以通过 method 参数指定,默认情况下为 __call__。此方法必须接受以下参数,且参数名称必须完全一致:

  • self:如果该方法绑定到模块(对于默认的 __call__ 是这种情况),否则为 module 以传递模块。
  • inputs:模型的输入,JAX 数组或数组的 PyTree
  • training (可选):一个参数,用于指定我们处于训练模式还是推理模式,训练模式下会传入 True

FlaxLayer 自动处理模型的非可训练状态和所需的 RNG。请注意,flax.linen.Module.apply()mutable 参数设置为 DenyList(["params"]),因此假设 "params" 集合之外的所有变量都是非可训练权重。

此示例展示了如何从具有默认 __call__ 方法且没有 training 参数的 Flax Module 创建 FlaxLayer

class MyFlaxModule(flax.linen.Module):
    @flax.linen.compact
    def __call__(self, inputs):
        x = inputs
        x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x)
        x = flax.linen.relu(x)
        x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = flax.linen.Dense(features=200)(x)
        x = flax.linen.relu(x)
        x = flax.linen.Dense(features=10)(x)
        x = flax.linen.softmax(x)
        return x

flax_module = MyFlaxModule()
keras_layer = FlaxLayer(flax_module)

此示例展示了如何封装模块方法以符合所需的签名。这允许拥有多个输入参数和一个名称和值不同的 training 参数。此外,此示例还展示了如何使用未绑定到模块的函数。

class MyFlaxModule(flax.linen.Module):
    @flax.linen.compact
    def forward(self, input1, input2, deterministic):
        ...
        return outputs

def my_flax_module_wrapper(module, inputs, training):
    input1, input2 = inputs
    return module.forward(input1, input2, not training)

flax_module = MyFlaxModule()
keras_layer = FlaxLayer(
    module=flax_module,
    method=my_flax_module_wrapper,
)

参数

  • moduleflax.linen.Module 或子类的实例。
  • method:调用模型的方法。这通常是 Module 中的一个方法。如果未提供,则使用 __call__ 方法。method 也可以是未在 Module 中定义的函数,在这种情况下,它必须将 Module 作为第一个参数。它用于 Module.initModule.apply。详细信息请参阅 flax.linen.Module.apply()method 参数文档。
  • variables:一个 dict,其中包含模块的所有变量,格式与 flax.linen.Module.init() 返回的格式相同。它应该包含一个 "params" 键,如果适用,还应包含其他键,用于存放非可训练状态的变量集合。这允许传递已训练的参数和已学习的非可训练状态,或控制初始化。如果传递 None,则会在构建时调用模块的 init 函数来初始化模型的变量。