开发者指南 / 如何使用 Keras 配合 NNX 后端

如何使用 Keras 配合 NNX 后端

作者: Divyashree Sreepathihalli
创建日期 2025/08/07
最后修改日期 2025/08/07
描述: 如何使用 Keras 配合 NNX 后端。

在 Colab 中查看 GitHub 源码

Keras 和 Flax NNX 集成指南

本教程将引导您完成 Keras 与 Flax 的 NNX(Neural Networks JAX)模块系统的集成,展示它如何显著增强变量处理能力,并在 JAX 生态系统中实现高级训练功能。无论您喜欢 model.fit() 的简洁性,还是自定义训练循环的精细控制,此集成都能让您兼顾两者。让我们开始吧!

为何要集成 Keras 和 NNX?

Keras 以其用户友好性和高级 API 而闻名,使深度学习更加易于访问。另一方面,JAX 提供高性能数值计算,由于其 JIT 编译和自动微分功能,特别适合机器学习研究。NNX 是 Flax 基于 JAX 构建的函数式模块系统,提供显式状态管理和强大的函数式编程范例。

NNX 的设计宗旨是简洁。其特点是 Pythonic 方法,模块是标准的 Python 类,易于使用和熟悉。NNX 优先考虑用户友好性,并通过类型化的 Variable 集合提供对 JAX 转换的精细控制。

Keras 与 NNX 的集成使您能够两全其美:Keras 用于模型构建的简洁性和模块化,结合 NNX 和 JAX 用于变量管理和复杂训练循环的强大功能和显式控制。

入门:设置您的环境

!pip install -q -U keras
!pip install -q -U flax==0.11.0

启用 NNX 模式

要激活此集成,我们必须在导入 Keras 之前设置两个环境变量。这会告诉 Keras 使用 JAX 后端,并选择 opt-in 功能来切换到 NNX。

import os

os.environ["KERAS_BACKEND"] = "jax"
os.environ["KERAS_NNX_ENABLED"] = "true"
from flax import nnx
import keras
import jax.numpy as jnp

print("✅ Keras is now running on JAX with NNX enabled!")

核心集成:NNX 中的 Keras 变量

此集成的核心是新的 `keras.Variable`,它被设计为 Flax NNX 生态系统的原生组件。这意味着您可以自由混合 Keras 和 NNX 组件,并且 NNX 的跟踪和状态管理工具将理解您的 Keras 变量。让我们来证明这一点。我们将创建一个包含标准 `nnx.Linear` 层和 `keras.Variable` 的 `nnx.Module`。

from keras import Variable as KerasVariable


class MyNnxModel(nnx.Module):
    def __init__(self, rngs):
        self.linear = nnx.Linear(2, 3, rngs=rngs)
        self.custom_variable = KerasVariable(jnp.ones((1, 3)))

    def __call__(self, x):
        return self.linear(x) + self.custom_variable


# Instantiate the model
model = MyNnxModel(rngs=nnx.Rngs(0))

# --- Verification ---
# 1. Is the KerasVariable traced by NNX?
print(f"✅ Traced: {hasattr(model.custom_variable, '_trace_state')}")

# 2. Does NNX see the KerasVariable in the model's state?
print("✅ Variables:", nnx.variables(model))

# 3. Can we access its value directly?
print("✅ Value:", model.custom_variable.value)

这表明:KerasVariable 被 NNX 成功跟踪,就像任何原生的 `nnx.Variable` 一样。`nnx.variables()` 函数能够正确识别并列出我们的 `custom_variable` 作为模型状态的一部分。这证实了 Keras 状态和 NNX 状态可以完美和谐地共存。

两全其美:训练工作流

现在来到激动人心的部分:训练模型。此集成开启了两个强大的工作流。


工作流 1:经典的 Keras 体验 (model.fit)

import numpy as np
  1. 创建 Keras 模型
model = keras.Sequential(
    [keras.layers.Dense(units=1, input_shape=(10,), name="my_dense_layer")]
)

print("--- Initial Model Weights ---")
initial_weights = model.get_weights()
print(f"Initial Kernel: {initial_weights[0].T}")  # .T for better display
print(f"Initial Bias: {initial_weights[1]}")
  1. 创建虚拟数据
X_dummy = np.random.rand(100, 10)
y_dummy = np.random.rand(100, 1)
  1. 编译和拟合
model.compile(
    optimizer=keras.optimizers.SGD(learning_rate=0.01),
    loss="mean_squared_error",
)

print("\n--- Training with model.fit() ---")
history = model.fit(X_dummy, y_dummy, epochs=5, batch_size=32, verbose=1)
  1. 验证更改
print("\n--- Weights After Training ---")
updated_weights = model.get_weights()
print(f"Updated Kernel: {updated_weights[0].T}")
print(f"Updated Bias: {updated_weights[1]}")

# Verification
if not np.array_equal(initial_weights[1], updated_weights[1]):
    print("\n✅ SUCCESS: Model variables were updated during training.")
else:
    print("\n❌ FAILURE: Model variables were not updated.")

如您所见,您现有的 Keras 代码可以开箱即用,让您在 JAX 和 NNX 的底层支持下获得高级、高效的体验。


工作流 2:NNX 的强大功能:自定义训练循环

为了获得最大的灵活性,您可以将任何 Keras 层或模型视为 `nnx.Module`,并使用 Optax 等库编写自己的训练循环。这在您需要对梯度和更新过程进行精细控制时非常有用。

import numpy as np
import optax

X = np.linspace(-jnp.pi, jnp.pi, 100)[:, None]
Y = 0.8 * X + 0.1 + np.random.normal(0, 0.1, size=X.shape)


class MySimpleKerasModel(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # Define the layers of your model
        self.dense_layer = keras.layers.Dense(1)

    def call(self, inputs):
        # Define the forward pass
        # The 'inputs' argument will receive the input tensor when the model is
        #  called
        return self.dense_layer(inputs)


model = MySimpleKerasModel()
model(X)

tx = optax.sgd(1e-3)
trainable_var = nnx.All(keras.Variable, lambda path, x: x.trainable)
optimizer = nnx.Optimizer(model, tx, wrt=trainable_var)


@nnx.jit
def train_step(model, optimizer, batch):
    x, y = batch

    def loss_fn(model_):
        y_pred = model_(x)
        return jnp.mean((y - y_pred) ** 2)

    diff_state = nnx.DiffState(0, trainable_var)
    grads = nnx.grad(loss_fn, argnums=diff_state)(model)
    optimizer.update(model, grads)


@nnx.jit
def test_step(model, batch):
    x, y = batch
    y_pred = model(x)
    loss = jnp.mean((y - y_pred) ** 2)
    return {"loss": loss}


def dataset(batch_size=10):
    while True:
        idx = np.random.choice(len(X), size=batch_size)
        yield X[idx], Y[idx]


for step, batch in enumerate(dataset()):
    train_step(model, optimizer, batch)

    if step % 100 == 0:
        logs = test_step(model, (X, Y))
        print(f"step: {step}, loss: {logs['loss']}")

    if step >= 500:
        break

此示例演示了如何将 keras 模型对象无缝地传递给 `nnx.Optimizer` 并由 `nnx.grad` 进行区分。这种组合允许您将 Keras 组件集成到复杂的 JAX/NNX 工作流中。此方法与 sequential、functional、subclassed keras 模型甚至仅仅是层都能完美配合。

保存和加载

您在 Keras 生态系统中的投入是安全的。模型序列化等标准功能将按预期工作。

# Create a simple model
model = keras.Sequential([keras.layers.Dense(units=1, input_shape=(10,))])
dummy_input = np.random.rand(1, 10)

# Test call
print("Original model output:", model(dummy_input))

# Save and load
model.save("my_nnx_model.keras")
restored_model = keras.models.load_model("my_nnx_model.keras")

print("Restored model output:", restored_model(dummy_input))

# Verification
np.testing.assert_allclose(model(dummy_input), restored_model(dummy_input))
print("\n✅ SUCCESS: Restored model output matches original model output.")

实际应用:训练 Gemma

在尝试此 KerasHub 模型之前,请确保您已在 colab secrets 中设置好 Kaggle 凭据。colab 会拉取 `KAGGLE_KEY` 和 `KAGGLE_USERNAME` 来进行身份验证和下载模型。

import keras_hub

# Set a float16 policy for memory efficiency
keras.config.set_dtype_policy("float16")

# Load Gemma from KerasHub
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma_1.1_instruct_2b_en")

# --- 1. Inference / Generation ---
print("--- Gemma Generation ---")
output = gemma_lm.generate("Keras is a", max_length=30)
print(output)

# --- 2. Fine-tuning ---
print("\n--- Gemma Fine-tuning ---")
# Dummy data for demonstration
features = np.array(["The quick brown fox jumped.", "I forgot my homework."])
# The model.fit() API works seamlessly!
gemma_lm.fit(x=features, batch_size=2)
print("\n✅ Gemma fine-tuning step completed successfully!")

结论

Keras-NNX 集成代表了向前迈出的重要一步,它提供了一个统一的框架,既能快速原型化,又能进行高性能、可定制的研究。您现在可以:- 在 JAX 后端上使用熟悉的 Keras API(Sequential、Model、fit、save)。- 将 Keras 层和模型直接集成到 Flax NNX 模块和训练循环中。- 将 Keras 代码/模型与 NNX 生态系统(如 Qwix、Tunix 等)集成。- 利用整个 JAX 生态系统(例如,`nnx.jit`、`optax`)来使用您的 Keras 模型。- 与 KerasHub 的大型模型无缝协作。