开发者指南 / Keras 中的 Orbax 检查点

Keras 中的 Orbax 检查点

作者: Samaneh Saadat
创建日期 2025/08/20
最后修改日期 2025/08/20
描述: 关于在使用 JAX 后端进行模型训练时如何保存 Orbax 检查点的指南。

在 Colab 中查看 GitHub 源代码


简介

Orbax 是 JAX 生态系统用户推荐的默认检查点库。它是一个高级检查点库,提供了检查点管理以及可组合和可扩展的序列化功能。本指南解释了在 JAX 后端训练模型时如何执行 Orbax 检查点。

请注意,在使用 Keras 分布式 API 进行多主机训练时,您应该使用 Orbax 检查点,因为当前的默认 Keras 检查点不支持多主机。


设置

让我们开始安装 Orbax 检查点库

!pip install -q -U orbax-checkpoint

我们需要将 Keras 后端设置为 JAX,因为本指南是为 JAX 后端设计的。然后我们导入 Keras 和其他所需的库,包括 Orbax 检查点库。

import os

os.environ["KERAS_BACKEND"] = "jax"

import keras
import numpy as np
import orbax.checkpoint as ocp

Orbax 回调函数

我们需要创建两个主要的实用工具来管理 Keras 中的 Orbax 检查点

  1. KerasOrbaxCheckpointManager: 这是 orbax.checkpoint.CheckpointManager 的 Keras 模型包装器。KerasOrbaxCheckpointManager 使用 Modelget_state_treeset_state_tree API 来保存和恢复模型变量。
  2. OrbaxCheckpointCallback: 这是一个 Keras 回调函数,它使用 KerasOrbaxCheckpointManager 在训练期间自动保存和恢复模型状态。

在 Keras 中使用 Orbax 检查点与将这些实用工具复制到您自己的代码库并将 OrbaxCheckpointCallback 传递给 fit 方法一样简单。

class KerasOrbaxCheckpointManager(ocp.CheckpointManager):
    """A wrapper over Orbax CheckpointManager for Keras with the JAX
    backend."""

    def __init__(
        self,
        model,
        checkpoint_dir,
        max_to_keep=5,
        steps_per_epoch=1,
        **kwargs,
    ):
        """Initialize the Keras Orbax Checkpoint Manager.

        Args:
            model: The Keras model to checkpoint.
            checkpoint_dir: Directory path where checkpoints will be saved.
            max_to_keep: Maximum number of checkpoints to keep in the directory.
                Default is 5.
            steps_per_epoch: Number of steps per epoch. Default is 1.
            **kwargs: Additional keyword arguments to pass to Orbax's
                CheckpointManagerOptions.
        """
        options = ocp.CheckpointManagerOptions(
            max_to_keep=max_to_keep, enable_async_checkpointing=False, **kwargs
        )
        self._model = model
        self._steps_per_epoch = steps_per_epoch
        self._checkpoint_dir = checkpoint_dir
        super().__init__(checkpoint_dir, options=options)

    def _get_state(self):
        """Gets the model state and metrics.

        This method retrieves the complete state tree from the model and separates
        the metrics variables from the rest of the state.

        Returns:
            A tuple containing:
                - state: A dictionary containing the model's state (weights, optimizer state, etc.)
                - metrics: The model's metrics variables, if any
        """
        state = self._model.get_state_tree().copy()
        metrics = state.pop("metrics_variables", None)
        return state, metrics

    def save_state(self, epoch):
        """Saves the model to the checkpoint directory.

        Args:
          epoch: The epoch number at which the state is saved.
        """
        state, metrics_value = self._get_state()
        self.save(
            epoch * self._steps_per_epoch,
            args=ocp.args.StandardSave(item=state),
            metrics=metrics_value,
        )

    def restore_state(self, step=None):
        """Restores the model from the checkpoint directory.

        Args:
          step: The step number to restore the state from. Default=None
            restores the latest step.
        """
        step = step or self.latest_step()
        if step is None:
            return
        # Restore the model state only, not metrics.
        state, _ = self._get_state()
        restored_state = self.restore(step, args=ocp.args.StandardRestore(item=state))
        self._model.set_state_tree(restored_state)


class OrbaxCheckpointCallback(keras.callbacks.Callback):
    """A callback for checkpointing and restoring state using Orbax."""

    def __init__(
        self,
        model,
        checkpoint_dir,
        max_to_keep=5,
        steps_per_epoch=1,
        **kwargs,
    ):
        """Initialize the Orbax checkpoint callback.

        Args:
            model: The Keras model to checkpoint.
            checkpoint_dir: Directory path where checkpoints will be saved.
            max_to_keep: Maximum number of checkpoints to keep in the directory.
                Default is 5.
            steps_per_epoch: Number of steps per epoch. Default is 1.
            **kwargs: Additional keyword arguments to pass to Orbax's
                CheckpointManagerOptions.
        """
        if keras.config.backend() != "jax":
            raise ValueError(
                f"`OrbaxCheckpointCallback` is only supported on a "
                f"`jax` backend. Provided backend is {keras.config.backend()}."
            )
        self._checkpoint_manager = KerasOrbaxCheckpointManager(
            model, checkpoint_dir, max_to_keep, steps_per_epoch, **kwargs
        )

    def on_train_begin(self, logs=None):
        if not self.model.built or not self.model.optimizer.built:
            raise ValueError(
                "To use `OrbaxCheckpointCallback`, your model and "
                "optimizer must be built before you call `fit()`."
            )
        latest_epoch = self._checkpoint_manager.latest_step()
        if latest_epoch is not None:
            print("Load Orbax checkpoint on_train_begin.")
            self._checkpoint_manager.restore_state(step=latest_epoch)

    def on_epoch_end(self, epoch, logs=None):
        print("Save Orbax checkpoint on_epoch_end.")
        self._checkpoint_manager.save_state(epoch)

Orbax 检查点示例

让我们看看如何使用 OrbaxCheckpointCallback 在训练过程中保存 Orbax 检查点。首先,让我们定义一个简单的模型和一个玩具训练数据集。

def get_model():
    # Create a simple model.
    inputs = keras.Input(shape=(32,))
    outputs = keras.layers.Dense(1, name="dense")(inputs)
    model = keras.Model(inputs, outputs)
    model.compile(optimizer=keras.optimizers.Adam(), loss="mean_squared_error")
    return model


model = get_model()

x_train = np.random.random((128, 32))
y_train = np.random.random((128, 1))

然后,我们创建一个 Orbax 检查点回调函数,并将其传递给 fit 方法中的 callbacks 参数。

orbax_callback = OrbaxCheckpointCallback(
    model,
    checkpoint_dir="/tmp/ckpt",
    max_to_keep=1,
    steps_per_epoch=1,
)
history = model.fit(
    x_train,
    y_train,
    batch_size=32,
    epochs=3,
    verbose=0,
    validation_split=0.2,
    callbacks=[orbax_callback],
)
Load Orbax checkpoint on_train_begin.

Save Orbax checkpoint on_epoch_end.
Save Orbax checkpoint on_epoch_end.
Save Orbax checkpoint on_epoch_end.

现在,如果您查看 Orbax 检查点目录,您可以看到作为 Orbax 检查点一部分的所有已保存文件。

!ls -R /tmp/ckpt
/tmp/ckpt:
2

/tmp/ckpt/2:
_CHECKPOINT_METADATA  default

/tmp/ckpt/2/default:
array_metadatas  d  manifest.ocdbt  _METADATA  ocdbt.process_0  _sharding

/tmp/ckpt/2/default/array_metadatas:
process_0

/tmp/ckpt/2/default/d:
18ec9a2094133d1aa1a3d7513dae3e8d

/tmp/ckpt/2/default/ocdbt.process_0:
d  manifest.ocdbt

/tmp/ckpt/2/default/ocdbt.process_0/d:
08372fc5734e445753b38235cb522988  c8af54d085d2d516444bd71f32a3787c
4601db15b67650f7c8818bfc8afeb9f5  cfe1e3ea313d637df6f6d2b2c66ca17a
a6ca20e04d8fe161ed95f6f71e8fe113