Keras 3 API 文档 / 回调 API / 模型检查点

模型检查点

[源代码]

ModelCheckpoint

keras.callbacks.ModelCheckpoint(
    filepath,
    monitor="val_loss",
    verbose=0,
    save_best_only=False,
    save_weights_only=False,
    mode="auto",
    save_freq="epoch",
    initial_value_threshold=None,
)

用于以一定频率保存 Keras 模型或模型权重的回调。

ModelCheckpoint 回调与使用 model.fit() 进行训练结合使用,以便以一定间隔保存模型或权重(在检查点文件中),以便稍后加载模型或权重以从保存的状态继续训练。

此回调提供的一些选项包括

  • 是否只保留迄今为止已获得“最佳性能”的模型,或者是否无论性能如何,在每个 epoch 结束时都保存模型。
  • “最佳”的定义;要监控的量以及是应最大化还是最小化。
  • 应保存的频率。当前,回调支持在每个 epoch 结束时保存,或在固定数量的训练批次后保存。
  • 是否仅保存权重,或者保存整个模型。

示例

model.compile(loss=..., optimizer=...,
              metrics=['accuracy'])

EPOCHS = 10
checkpoint_filepath = '/tmp/ckpt/checkpoint.model.keras'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

# Model is saved at the end of every epoch, if it's the best seen so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])

# The model (that are considered the best) can be loaded as -
keras.models.load_model(checkpoint_filepath)

# Alternatively, one could checkpoint just the model weights as -
checkpoint_filepath = '/tmp/ckpt/checkpoint.weights.h5'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])

# The model weights (that are considered the best) can be loaded as -
model.load_weights(checkpoint_filepath)

参数

  • filepath:字符串或PathLike,保存模型文件的路径。filepath 可以包含命名格式选项,这些选项将填充epoch的值和logs中的键(在on_epoch_end中传递)。当save_weights_only=True时,filepath名称需要以".weights.h5"结尾,或者当检查点保存整个模型时,应以".keras"结尾(默认值)。例如:如果filepath"{epoch:02d}-{val_loss:.2f}.keras",则模型检查点将使用 epoch 编号和文件名中的验证损失进行保存。filepath 的目录不应被任何其他回调重用,以避免冲突。
  • monitor:要监控的指标名称。通常,指标由Model.compile方法设置。注意
    • "val_"为前缀来监控验证指标。
    • 使用"loss""val_loss"监控模型的总损失。
    • 如果将指标指定为字符串,例如"accuracy",请传递相同的字符串(带有或不带"val_"前缀)。
    • 如果传递metrics.Metric对象,则monitor应设置为metric.name
    • 如果不确定指标名称,可以检查history = model.fit()返回的history.history字典的内容。
    • 多输出模型在指标名称上设置其他前缀。
  • verbose:详细模式,0 或 1。模式 0 为静默,模式 1 在回调采取操作时显示消息。
  • save_best_only:如果save_best_only=True,则仅在模型被认为是“最佳”模型时才保存,并且根据监控数量,最新的最佳模型将不会被覆盖。如果filepath不包含诸如{epoch}之类的格式选项,则每个新的更好的模型都会覆盖filepath
  • mode"auto""min""max"之一。如果save_best_only=True,则覆盖当前保存文件的决定是根据监控数量的最大化或最小化做出的。对于val_acc,这应该是"max",对于val_loss,这应该是"min",等等。在"auto"模式下,如果监控的量是"acc"或以"fmeasure"开头,则模式设置为"max",对于其余量,则设置为"min"
  • save_weights_only:如果为True,则仅保存模型的权重(model.save_weights(filepath)),否则保存完整的模型(model.save(filepath))。
  • save_freq"epoch"或整数。使用"epoch"时,回调在每个 epoch 后保存模型。使用整数时,回调在此批次结束时保存模型。如果Model使用steps_per_execution=N编译,则保存条件将每 N 个批次检查一次。请注意,如果保存与 epoch 不对齐,则监控的指标可能会不太可靠(它可能反映了少至 1 个批次,因为指标在每个 epoch 重置)。默认为"epoch"
  • initial_value_threshold:要监控的指标的浮点初始“最佳”值。仅在save_best_value=True时适用。仅当当前模型的性能优于此值时,才会覆盖已保存的模型权重。