Keras 3 API 文档 / 回调函数 API / ModelCheckpoint

ModelCheckpoint

[来源]

ModelCheckpoint

keras.callbacks.ModelCheckpoint(
    filepath,
    monitor="val_loss",
    verbose=0,
    save_best_only=False,
    save_weights_only=False,
    mode="auto",
    save_freq="epoch",
    initial_value_threshold=None,
)

用于按一定频率保存 Keras 模型或模型权重的回调函数。

ModelCheckpoint 回调函数与使用 model.fit() 的训练过程结合使用,以便按一定间隔保存模型或权重(在检查点文件中),从而可以在以后加载模型或权重以从保存的状态继续训练。

此回调函数提供的几个选项包括:

  • 是仅保留迄今为止达到“最佳性能”的模型,还是无论性能如何都在每个 epoch 结束时保存模型。
  • “最佳”的定义;要监控哪个指标以及是应该最大化还是最小化。
  • 保存的频率。目前,此回调函数支持在每个 epoch 结束时保存,或在固定数量的训练批次后保存。
  • 是仅保存权重,还是保存整个模型。

示例

model.compile(loss=..., optimizer=...,
              metrics=['accuracy'])

EPOCHS = 10
checkpoint_filepath = '/tmp/ckpt/checkpoint.model.keras'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

# Model is saved at the end of every epoch, if it's the best seen so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])

# The model (that are considered the best) can be loaded as -
keras.models.load_model(checkpoint_filepath)

# Alternatively, one could checkpoint just the model weights as -
checkpoint_filepath = '/tmp/ckpt/checkpoint.weights.h5'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])

# The model weights (that are considered the best) can be loaded as -
model.load_weights(checkpoint_filepath)

参数

  • filepath: 字符串或 PathLike 对象,保存模型文件的路径。filepath 可以包含命名格式选项,这些选项将填充 epoch 的值和 logs 中的键(在 on_epoch_end 中传递)。当 save_weights_only=True 时,filepath 名称需要以 ".weights.h5" 结尾;当保存整个模型时(默认),则应以 ".keras"".h5" 结尾。例如:如果 filepath"{epoch:02d}-{val_loss:.2f}.keras""{epoch:02d}-{val_loss:.2f}.weights.h5",那么模型检查点将以 epoch 编号和验证损失值作为文件名的一部分进行保存。filepath 所在的目录不应被任何其他回调函数重复使用,以避免冲突。
  • monitor: 要监控的指标名称。通常指标通过 Model.compile 方法设置。注意:
    • 在名称前加上 "val_" 以监控验证指标。
    • 使用 "loss""val_loss" 监控模型的总损失。
    • 如果您将指标指定为字符串,例如 "accuracy",请传递相同的字符串(无论是否带有 "val_" 前缀)。
    • 如果您传递 metrics.Metric 对象,monitor 应设置为 metric.name
    • 如果您不确定指标名称,可以查看 history = model.fit() 返回的 history.history 字典的内容。
    • 多输出模型在指标名称上设置了额外的名称前缀。
  • verbose: 详细模式,0 或 1。模式 0 为静默模式,模式 1 在回调函数采取行动时显示消息。
  • save_best_only: 如果 save_best_only=True,则仅在模型被认为是“最佳”时才保存,并且根据监控的指标,最新的最佳模型不会被覆盖。如果 filepath 不包含像 {epoch} 这样的格式选项,那么每次出现新的更好的模型时,filepath 将被覆盖。
  • mode: {"auto", "min", "max"} 之一。如果 save_best_only=True,是否覆盖当前保存文件的决定基于对监控指标的最大化或最小化。对于 val_acc,这应该是 "max";对于 val_loss,这应该是 "min",等等。在 "auto" 模式下,如果监控的指标是 "acc" 或以 "fmeasure" 开头,则模式设置为 "max";对于其余指标,则设置为 "min"
  • save_weights_only: 如果为 True,则仅保存模型的权重(model.save_weights(filepath)),否则保存整个模型(model.save(filepath))。
  • save_freq: "epoch" 或整数。使用 "epoch" 时,回调函数在每个 epoch 后保存模型。使用整数时,回调函数在此指定数量的批次结束时保存模型。如果使用 steps_per_execution=N 编译了 Model,则每隔 N 个批次检查一次保存条件。请注意,如果保存与 epoch 不对齐,监控的指标可能可靠性较低(它可能只反映最少 1 个批次的情况,因为指标在每个 epoch 都会重置)。默认为 "epoch"
  • initial_value_threshold: 要监控指标的浮点型初始“最佳”值。仅当 save_best_only=True 时适用。仅当当前模型的性能优于此值时,才覆盖已保存的模型权重。