FlaxLayer
类keras.layers.FlaxLayer(module, method=None, variables=None, **kwargs)
包装 Flax 模块的 Keras 层。
当使用 JAX 作为 Keras 后端时,此层允许以 flax.linen.Module
实例的形式在 Keras 中使用 Flax 组件。
用于前向传递的模块方法可以通过 method
参数指定,默认情况下为 __call__
。此方法必须接受以下参数,且名称必须完全一致:
__call__
的情况),则为 self
;否则为 module
以传递模块。inputs
:模型的输入,一个 JAX 数组或数组的 PyTree
。training
(可选):一个参数,指定我们是否处于训练模式或推理模式,在训练模式下传递 True
。FlaxLayer
自动处理模型的非可训练状态和所需的 RNG。请注意,flax.linen.Module.apply()
的 mutable
参数设置为 DenyList(["params"])
,因此假定“params”集合之外的所有变量都是非可训练权重。
此示例展示了如何使用默认的 __call__
方法和无训练参数从 Flax Module
创建 FlaxLayer
class MyFlaxModule(flax.linen.Module):
@flax.linen.compact
def __call__(self, inputs):
x = inputs
x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x)
x = flax.linen.relu(x)
x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = flax.linen.Dense(features=200)(x)
x = flax.linen.relu(x)
x = flax.linen.Dense(features=10)(x)
x = flax.linen.softmax(x)
return x
flax_module = MyFlaxModule()
keras_layer = FlaxLayer(flax_module)
此示例展示了如何包装模块方法以符合所需的签名。这允许有多个输入参数和一个具有不同名称和值的训练参数。这还展示了如何使用不绑定到模块的函数。
class MyFlaxModule(flax.linen.Module):
@flax.linen.compact
def forward(self, input1, input2, deterministic):
...
return outputs
def my_flax_module_wrapper(module, inputs, training):
input1, input2 = inputs
return module.forward(input1, input2, not training)
flax_module = MyFlaxModule()
keras_layer = FlaxLayer(
module=flax_module,
method=my_flax_module_wrapper,
)
参数
flax.linen.Module
或其子类的实例。Module
中的一个方法。如果未提供,则使用 __call__
方法。method
也可以是未在 Module
中定义的函数,在这种情况下,它必须将 Module
作为第一个参数。它用于 Module.init
和 Module.apply
。详细信息已记录在 flax.linen.Module.apply()
的 method
参数中。dict
,包含模块的所有变量,格式与 flax.linen.Module.init()
返回的格式相同。它应该包含一个“params”键,如果适用,还应包含其他键,用于非可训练状态的变量集合。这允许传递训练好的参数和学习到的非可训练状态,或控制初始化。如果传递 None
,则在构建时调用模块的 init
函数来初始化模型的变量。