作者: Samaneh Saadat
创建日期 2025/08/20
最后修改日期 2025/08/20
描述: 关于在使用 JAX 后端进行模型训练时如何保存 Orbax 检查点的指南。
Orbax 是 JAX 生态系统用户推荐的默认检查点库。它是一个高级检查点库,提供了检查点管理以及可组合和可扩展的序列化功能。本指南解释了在 JAX 后端训练模型时如何执行 Orbax 检查点。
请注意,在使用 Keras 分布式 API 进行多主机训练时,您应该使用 Orbax 检查点,因为当前的默认 Keras 检查点不支持多主机。
让我们开始安装 Orbax 检查点库
!pip install -q -U orbax-checkpoint
我们需要将 Keras 后端设置为 JAX,因为本指南是为 JAX 后端设计的。然后我们导入 Keras 和其他所需的库,包括 Orbax 检查点库。
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
import numpy as np
import orbax.checkpoint as ocp
我们需要创建两个主要的实用工具来管理 Keras 中的 Orbax 检查点
KerasOrbaxCheckpointManager: 这是 orbax.checkpoint.CheckpointManager 的 Keras 模型包装器。KerasOrbaxCheckpointManager 使用 Model 的 get_state_tree 和 set_state_tree API 来保存和恢复模型变量。OrbaxCheckpointCallback: 这是一个 Keras 回调函数,它使用 KerasOrbaxCheckpointManager 在训练期间自动保存和恢复模型状态。在 Keras 中使用 Orbax 检查点与将这些实用工具复制到您自己的代码库并将 OrbaxCheckpointCallback 传递给 fit 方法一样简单。
class KerasOrbaxCheckpointManager(ocp.CheckpointManager):
"""A wrapper over Orbax CheckpointManager for Keras with the JAX
backend."""
def __init__(
self,
model,
checkpoint_dir,
max_to_keep=5,
steps_per_epoch=1,
**kwargs,
):
"""Initialize the Keras Orbax Checkpoint Manager.
Args:
model: The Keras model to checkpoint.
checkpoint_dir: Directory path where checkpoints will be saved.
max_to_keep: Maximum number of checkpoints to keep in the directory.
Default is 5.
steps_per_epoch: Number of steps per epoch. Default is 1.
**kwargs: Additional keyword arguments to pass to Orbax's
CheckpointManagerOptions.
"""
options = ocp.CheckpointManagerOptions(
max_to_keep=max_to_keep, enable_async_checkpointing=False, **kwargs
)
self._model = model
self._steps_per_epoch = steps_per_epoch
self._checkpoint_dir = checkpoint_dir
super().__init__(checkpoint_dir, options=options)
def _get_state(self):
"""Gets the model state and metrics.
This method retrieves the complete state tree from the model and separates
the metrics variables from the rest of the state.
Returns:
A tuple containing:
- state: A dictionary containing the model's state (weights, optimizer state, etc.)
- metrics: The model's metrics variables, if any
"""
state = self._model.get_state_tree().copy()
metrics = state.pop("metrics_variables", None)
return state, metrics
def save_state(self, epoch):
"""Saves the model to the checkpoint directory.
Args:
epoch: The epoch number at which the state is saved.
"""
state, metrics_value = self._get_state()
self.save(
epoch * self._steps_per_epoch,
args=ocp.args.StandardSave(item=state),
metrics=metrics_value,
)
def restore_state(self, step=None):
"""Restores the model from the checkpoint directory.
Args:
step: The step number to restore the state from. Default=None
restores the latest step.
"""
step = step or self.latest_step()
if step is None:
return
# Restore the model state only, not metrics.
state, _ = self._get_state()
restored_state = self.restore(step, args=ocp.args.StandardRestore(item=state))
self._model.set_state_tree(restored_state)
class OrbaxCheckpointCallback(keras.callbacks.Callback):
"""A callback for checkpointing and restoring state using Orbax."""
def __init__(
self,
model,
checkpoint_dir,
max_to_keep=5,
steps_per_epoch=1,
**kwargs,
):
"""Initialize the Orbax checkpoint callback.
Args:
model: The Keras model to checkpoint.
checkpoint_dir: Directory path where checkpoints will be saved.
max_to_keep: Maximum number of checkpoints to keep in the directory.
Default is 5.
steps_per_epoch: Number of steps per epoch. Default is 1.
**kwargs: Additional keyword arguments to pass to Orbax's
CheckpointManagerOptions.
"""
if keras.config.backend() != "jax":
raise ValueError(
f"`OrbaxCheckpointCallback` is only supported on a "
f"`jax` backend. Provided backend is {keras.config.backend()}."
)
self._checkpoint_manager = KerasOrbaxCheckpointManager(
model, checkpoint_dir, max_to_keep, steps_per_epoch, **kwargs
)
def on_train_begin(self, logs=None):
if not self.model.built or not self.model.optimizer.built:
raise ValueError(
"To use `OrbaxCheckpointCallback`, your model and "
"optimizer must be built before you call `fit()`."
)
latest_epoch = self._checkpoint_manager.latest_step()
if latest_epoch is not None:
print("Load Orbax checkpoint on_train_begin.")
self._checkpoint_manager.restore_state(step=latest_epoch)
def on_epoch_end(self, epoch, logs=None):
print("Save Orbax checkpoint on_epoch_end.")
self._checkpoint_manager.save_state(epoch)
让我们看看如何使用 OrbaxCheckpointCallback 在训练过程中保存 Orbax 检查点。首先,让我们定义一个简单的模型和一个玩具训练数据集。
def get_model():
# Create a simple model.
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1, name="dense")(inputs)
model = keras.Model(inputs, outputs)
model.compile(optimizer=keras.optimizers.Adam(), loss="mean_squared_error")
return model
model = get_model()
x_train = np.random.random((128, 32))
y_train = np.random.random((128, 1))
然后,我们创建一个 Orbax 检查点回调函数,并将其传递给 fit 方法中的 callbacks 参数。
orbax_callback = OrbaxCheckpointCallback(
model,
checkpoint_dir="/tmp/ckpt",
max_to_keep=1,
steps_per_epoch=1,
)
history = model.fit(
x_train,
y_train,
batch_size=32,
epochs=3,
verbose=0,
validation_split=0.2,
callbacks=[orbax_callback],
)
Load Orbax checkpoint on_train_begin.
Save Orbax checkpoint on_epoch_end.
Save Orbax checkpoint on_epoch_end.
Save Orbax checkpoint on_epoch_end.
现在,如果您查看 Orbax 检查点目录,您可以看到作为 Orbax 检查点一部分的所有已保存文件。
!ls -R /tmp/ckpt
/tmp/ckpt:
2
/tmp/ckpt/2:
_CHECKPOINT_METADATA default
/tmp/ckpt/2/default:
array_metadatas d manifest.ocdbt _METADATA ocdbt.process_0 _sharding
/tmp/ckpt/2/default/array_metadatas:
process_0
/tmp/ckpt/2/default/d:
18ec9a2094133d1aa1a3d7513dae3e8d
/tmp/ckpt/2/default/ocdbt.process_0:
d manifest.ocdbt
/tmp/ckpt/2/default/ocdbt.process_0/d:
08372fc5734e445753b38235cb522988 c8af54d085d2d516444bd71f32a3787c
4601db15b67650f7c8818bfc8afeb9f5 cfe1e3ea313d637df6f6d2b2c66ca17a
a6ca20e04d8fe161ed95f6f71e8fe113