JaxLayer
类keras.layers.JaxLayer(
call_fn, init_fn=None, params=None, state=None, seed=None, **kwargs
)
封装 JAX 模型的 Keras 层。
当使用 JAX 作为 Keras 的后端时,此层允许在 Keras 中使用 JAX 组件。
此层接受 JAX 模型,形式为一个函数 call_fn
,它必须接受以下具有精确名称的参数
params
: 模型的训练参数。state
(可选): 模型的非训练状态。如果模型没有非训练状态,则可以省略。rng
(可选): 一个 jax.random.PRNGKey
实例。如果模型在训练或推理期间不需要随机数生成器 (RNG),则可以省略。inputs
: 模型的输入,一个 JAX 数组或一个数组的 PyTree
。training
(可选): 一个参数,指定是处于训练模式还是推理模式,训练模式下传入 True
。如果模型在训练模式和推理模式下的行为相同,则可以省略。inputs
参数是必需的。模型的输入必须通过单个参数提供。如果 JAX 模型接受多个输入作为单独的参数,则必须将它们组合成一个单一的结构,例如一个 tuple
或一个 dict
。
模型的 params
和 state
的初始化可以由此层处理,在这种情况下必须提供 init_fn
参数。这允许模型以正确的形状动态初始化。或者,如果形状已知,可以使用 params
参数和可选的 state
参数来创建一个已初始化的模型。
如果提供了 init_fn
函数,它必须接受以下具有精确名称的参数
rng
: 一个 jax.random.PRNGKey
实例。inputs
: 一个 JAX 数组或一个数组的 PyTree
,带有占位符值以提供输入的形状。training
(可选): 一个参数,指定是处于训练模式还是推理模式。始终向 init_fn
传入 True
。无论 call_fn
是否有 training
参数,都可以省略此参数。对于具有非训练状态的 JAX 模型
call_fn
必须有一个 state
参数call_fn
必须返回一个 tuple
,其中包含模型的输出和模型的新非训练状态init_fn
必须返回一个 tuple
,其中包含模型的初始训练参数和初始非训练状态。此代码显示了具有非训练状态的模型可能拥有的 call_fn
和 init_fn
函数签名组合。在此示例中,call_fn
中有一个 training
参数和一个 rng
参数。
def stateful_call(params, state, rng, inputs, training):
outputs = ...
new_state = ...
return outputs, new_state
def stateful_init(rng, inputs):
initial_params = ...
initial_state = ...
return initial_params, initial_state
对于没有非训练状态的 JAX 模型
call_fn
不得有 state
参数call_fn
必须只返回模型的输出init_fn
必须只返回模型的初始训练参数。此代码显示了没有非训练状态的模型可能拥有的 call_fn
和 init_fn
函数签名组合。在此示例中,call_fn
中没有 training
参数,也没有 rng
参数。
def stateless_call(params, inputs):
outputs = ...
return outputs
def stateless_init(rng, inputs):
initial_params = ...
return initial_params
如果一个模型具有与 JaxLayer
所需不同的函数签名,可以轻松编写一个包装方法来适配参数。此示例展示了一个模型,它将多个输入作为单独的参数,期望在 dict
中提供多个 RNG,并且有一个 deterministic
参数,其含义与 training
相反。为了符合要求,输入使用 tuple
组合成单个结构,RNG 被拆分并用于填充预期的 dict
,布尔标志被取反
def my_model_fn(params, rngs, input1, input2, deterministic):
...
if not deterministic:
dropout_rng = rngs["dropout"]
keep = jax.random.bernoulli(dropout_rng, dropout_rate, x.shape)
x = jax.numpy.where(keep, x / dropout_rate, 0)
...
...
return outputs
def my_model_wrapper_fn(params, rng, inputs, training):
input1, input2 = inputs
rng1, rng2 = jax.random.split(rng)
rngs = {"dropout": rng1, "preprocessing": rng2}
deterministic = not training
return my_model_fn(params, rngs, input1, input2, deterministic)
keras_layer = JaxLayer(my_model_wrapper_fn, params=initial_params)
JaxLayer
允许使用 Haiku 组件,形式为 haiku.Module
。这通过按照 Haiku 模式转换模块,然后在 call_fn
参数中传入 module.apply
,如果需要则在 init_fn
参数中传入 module.init
来实现。
如果模型具有非训练状态,应使用 haiku.transform_with_state
进行转换。如果模型没有非训练状态,应使用 haiku.transform
进行转换。此外,如果模块在“apply”中不使用 RNG,则可以选择使用 haiku.without_apply_rng
进行转换。
以下示例展示了如何从一个 Haiku 模块创建 JaxLayer
,该模块通过 hk.next_rng_key()
使用随机数生成器并接受一个训练位置参数
class MyHaikuModule(hk.Module):
def __call__(self, x, training):
x = hk.Conv2D(32, (3, 3))(x)
x = jax.nn.relu(x)
x = hk.AvgPool((1, 2, 2, 1), (1, 2, 2, 1), "VALID")(x)
x = hk.Flatten()(x)
x = hk.Linear(200)(x)
if training:
x = hk.dropout(rng=hk.next_rng_key(), rate=0.3, x=x)
x = jax.nn.relu(x)
x = hk.Linear(10)(x)
x = jax.nn.softmax(x)
return x
def my_haiku_module_fn(inputs, training):
module = MyHaikuModule()
return module(inputs, training)
transformed_module = hk.transform(my_haiku_module_fn)
keras_layer = JaxLayer(
call_fn=transformed_module.apply,
init_fn=transformed_module.init,
)
参数
None
,则必须提供 params
和/或 state
。PyTree
。这允许传入已训练参数或控制初始化。如果 params
和 state
都为 None
,则在构建时调用 init_fn
来初始化模型的训练参数。PyTree
。这允许传入学习到的状态或控制初始化。如果 params
和 state
都为 None
,并且 call_fn
接受 state
参数,则在构建时调用 init_fn
来初始化模型的非训练状态。keras.DTypePolicy
。可选。默认为默认策略。