作者: Divyashree Sreepathihalli
创建日期 2025/08/07
最后修改日期 2025/08/07
描述: 如何使用 Keras 配合 NNX 后端。
本教程将引导您完成 Keras 与 Flax 的 NNX(Neural Networks JAX)模块系统的集成,展示它如何显著增强变量处理能力,并在 JAX 生态系统中实现高级训练功能。无论您喜欢 model.fit() 的简洁性,还是自定义训练循环的精细控制,此集成都能让您兼顾两者。让我们开始吧!
Keras 以其用户友好性和高级 API 而闻名,使深度学习更加易于访问。另一方面,JAX 提供高性能数值计算,由于其 JIT 编译和自动微分功能,特别适合机器学习研究。NNX 是 Flax 基于 JAX 构建的函数式模块系统,提供显式状态管理和强大的函数式编程范例。
NNX 的设计宗旨是简洁。其特点是 Pythonic 方法,模块是标准的 Python 类,易于使用和熟悉。NNX 优先考虑用户友好性,并通过类型化的 Variable 集合提供对 JAX 转换的精细控制。
Keras 与 NNX 的集成使您能够两全其美:Keras 用于模型构建的简洁性和模块化,结合 NNX 和 JAX 用于变量管理和复杂训练循环的强大功能和显式控制。
!pip install -q -U keras
!pip install -q -U flax==0.11.0
要激活此集成,我们必须在导入 Keras 之前设置两个环境变量。这会告诉 Keras 使用 JAX 后端,并选择 opt-in 功能来切换到 NNX。
import os
os.environ["KERAS_BACKEND"] = "jax"
os.environ["KERAS_NNX_ENABLED"] = "true"
from flax import nnx
import keras
import jax.numpy as jnp
print("✅ Keras is now running on JAX with NNX enabled!")
此集成的核心是新的 `keras.Variable`,它被设计为 Flax NNX 生态系统的原生组件。这意味着您可以自由混合 Keras 和 NNX 组件,并且 NNX 的跟踪和状态管理工具将理解您的 Keras 变量。让我们来证明这一点。我们将创建一个包含标准 `nnx.Linear` 层和 `keras.Variable` 的 `nnx.Module`。
from keras import Variable as KerasVariable
class MyNnxModel(nnx.Module):
def __init__(self, rngs):
self.linear = nnx.Linear(2, 3, rngs=rngs)
self.custom_variable = KerasVariable(jnp.ones((1, 3)))
def __call__(self, x):
return self.linear(x) + self.custom_variable
# Instantiate the model
model = MyNnxModel(rngs=nnx.Rngs(0))
# --- Verification ---
# 1. Is the KerasVariable traced by NNX?
print(f"✅ Traced: {hasattr(model.custom_variable, '_trace_state')}")
# 2. Does NNX see the KerasVariable in the model's state?
print("✅ Variables:", nnx.variables(model))
# 3. Can we access its value directly?
print("✅ Value:", model.custom_variable.value)
这表明:KerasVariable 被 NNX 成功跟踪,就像任何原生的 `nnx.Variable` 一样。`nnx.variables()` 函数能够正确识别并列出我们的 `custom_variable` 作为模型状态的一部分。这证实了 Keras 状态和 NNX 状态可以完美和谐地共存。
现在来到激动人心的部分:训练模型。此集成开启了两个强大的工作流。
import numpy as np
model = keras.Sequential(
[keras.layers.Dense(units=1, input_shape=(10,), name="my_dense_layer")]
)
print("--- Initial Model Weights ---")
initial_weights = model.get_weights()
print(f"Initial Kernel: {initial_weights[0].T}") # .T for better display
print(f"Initial Bias: {initial_weights[1]}")
X_dummy = np.random.rand(100, 10)
y_dummy = np.random.rand(100, 1)
model.compile(
optimizer=keras.optimizers.SGD(learning_rate=0.01),
loss="mean_squared_error",
)
print("\n--- Training with model.fit() ---")
history = model.fit(X_dummy, y_dummy, epochs=5, batch_size=32, verbose=1)
print("\n--- Weights After Training ---")
updated_weights = model.get_weights()
print(f"Updated Kernel: {updated_weights[0].T}")
print(f"Updated Bias: {updated_weights[1]}")
# Verification
if not np.array_equal(initial_weights[1], updated_weights[1]):
print("\n✅ SUCCESS: Model variables were updated during training.")
else:
print("\n❌ FAILURE: Model variables were not updated.")
如您所见,您现有的 Keras 代码可以开箱即用,让您在 JAX 和 NNX 的底层支持下获得高级、高效的体验。
为了获得最大的灵活性,您可以将任何 Keras 层或模型视为 `nnx.Module`,并使用 Optax 等库编写自己的训练循环。这在您需要对梯度和更新过程进行精细控制时非常有用。
import numpy as np
import optax
X = np.linspace(-jnp.pi, jnp.pi, 100)[:, None]
Y = 0.8 * X + 0.1 + np.random.normal(0, 0.1, size=X.shape)
class MySimpleKerasModel(keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Define the layers of your model
self.dense_layer = keras.layers.Dense(1)
def call(self, inputs):
# Define the forward pass
# The 'inputs' argument will receive the input tensor when the model is
# called
return self.dense_layer(inputs)
model = MySimpleKerasModel()
model(X)
tx = optax.sgd(1e-3)
trainable_var = nnx.All(keras.Variable, lambda path, x: x.trainable)
optimizer = nnx.Optimizer(model, tx, wrt=trainable_var)
@nnx.jit
def train_step(model, optimizer, batch):
x, y = batch
def loss_fn(model_):
y_pred = model_(x)
return jnp.mean((y - y_pred) ** 2)
diff_state = nnx.DiffState(0, trainable_var)
grads = nnx.grad(loss_fn, argnums=diff_state)(model)
optimizer.update(model, grads)
@nnx.jit
def test_step(model, batch):
x, y = batch
y_pred = model(x)
loss = jnp.mean((y - y_pred) ** 2)
return {"loss": loss}
def dataset(batch_size=10):
while True:
idx = np.random.choice(len(X), size=batch_size)
yield X[idx], Y[idx]
for step, batch in enumerate(dataset()):
train_step(model, optimizer, batch)
if step % 100 == 0:
logs = test_step(model, (X, Y))
print(f"step: {step}, loss: {logs['loss']}")
if step >= 500:
break
此示例演示了如何将 keras 模型对象无缝地传递给 `nnx.Optimizer` 并由 `nnx.grad` 进行区分。这种组合允许您将 Keras 组件集成到复杂的 JAX/NNX 工作流中。此方法与 sequential、functional、subclassed keras 模型甚至仅仅是层都能完美配合。
您在 Keras 生态系统中的投入是安全的。模型序列化等标准功能将按预期工作。
# Create a simple model
model = keras.Sequential([keras.layers.Dense(units=1, input_shape=(10,))])
dummy_input = np.random.rand(1, 10)
# Test call
print("Original model output:", model(dummy_input))
# Save and load
model.save("my_nnx_model.keras")
restored_model = keras.models.load_model("my_nnx_model.keras")
print("Restored model output:", restored_model(dummy_input))
# Verification
np.testing.assert_allclose(model(dummy_input), restored_model(dummy_input))
print("\n✅ SUCCESS: Restored model output matches original model output.")
在尝试此 KerasHub 模型之前,请确保您已在 colab secrets 中设置好 Kaggle 凭据。colab 会拉取 `KAGGLE_KEY` 和 `KAGGLE_USERNAME` 来进行身份验证和下载模型。
import keras_hub
# Set a float16 policy for memory efficiency
keras.config.set_dtype_policy("float16")
# Load Gemma from KerasHub
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma_1.1_instruct_2b_en")
# --- 1. Inference / Generation ---
print("--- Gemma Generation ---")
output = gemma_lm.generate("Keras is a", max_length=30)
print(output)
# --- 2. Fine-tuning ---
print("\n--- Gemma Fine-tuning ---")
# Dummy data for demonstration
features = np.array(["The quick brown fox jumped.", "I forgot my homework."])
# The model.fit() API works seamlessly!
gemma_lm.fit(x=features, batch_size=2)
print("\n✅ Gemma fine-tuning step completed successfully!")
Keras-NNX 集成代表了向前迈出的重要一步,它提供了一个统一的框架,既能快速原型化,又能进行高性能、可定制的研究。您现在可以:- 在 JAX 后端上使用熟悉的 Keras API(Sequential、Model、fit、save)。- 将 Keras 层和模型直接集成到 Flax NNX 模块和训练循环中。- 将 Keras 代码/模型与 NNX 生态系统(如 Qwix、Tunix 等)集成。- 利用整个 JAX 生态系统(例如,`nnx.jit`、`optax`)来使用您的 Keras 模型。- 与 KerasHub 的大型模型无缝协作。