Keras 2 API文档 / 回调API / ModelCheckpoint

ModelCheckpoint

[源代码]

ModelCheckpoint

tf_keras.callbacks.ModelCheckpoint(
    filepath,
    monitor: str = "val_loss",
    verbose: int = 0,
    save_best_only: bool = False,
    save_weights_only: bool = False,
    mode: str = "auto",
    save_freq="epoch",
    options=None,
    initial_value_threshold=None,
    **kwargs
)

在一定频率下保存TF-Keras模型或模型权重的回调。

ModelCheckpoint 回调用于与model.fit() 训练结合使用,以便在特定间隔保存模型或权重(在检查点文件中),这样就可以稍后加载模型或权重,以从保存的状态继续训练。

此回调提供了一些选项,包括:

  • 是否只保留迄今为止取得“最佳性能”的模型,或者无论性能如何,都在每个 epoch 结束时保存模型。
  • “最佳”的定义;要监控哪个量,以及是应该最大化还是最小化它。
  • 它应该保存的频率。目前,回调支持在每个 epoch 结束时保存,或在固定数量的训练批次后保存。
  • 是只保存权重,还是保存整个模型。

注意:如果您看到 WARNING:tensorflow:Can save best model only with <name> available, skipping,请参阅 monitor 参数的描述以了解如何正确设置。

示例

model.compile(loss=..., optimizer=...,
              metrics=['accuracy'])

EPOCHS = 10
checkpoint_filepath = '/tmp/checkpoint'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])

# The model weights (that are considered the best) are loaded into the
# model.
model.load_weights(checkpoint_filepath)

参数

  • filepath: string 或 PathLike,保存模型文件的路径。例如:filepath = os.path.join(working_dir, 'ckpt', file_name)。filepath 可以包含命名格式选项,这些选项将由 epoch 的值和 logs 中的键(在 on_epoch_end 中传递)填充。例如:如果 filepathweights.{epoch:02d}-{val_loss:.2f}.hdf5,那么模型检查点将以 epoch 编号和验证损失保存在文件名中。filepath 的目录不应被任何其他回调重用,以避免冲突。
  • monitor: 要监控的指标名称。通常指标由 Model.compile 方法设置。注意:

    • 在名称前加上 "val_" 以监控验证指标。
    • 使用 "loss" 或 "val_loss" 来监控模型的总损失。
    • 如果以字符串形式指定指标,例如 "accuracy",请传递相同的字符串(带或不带 "val_" 前缀)。
    • 如果您传递 metrics.Metric 对象,则 monitor 应设置为 metric.name
    • 如果您不确定指标名称,可以检查 history = model.fit() 返回的 history.history 字典的内容。
    • 多输出模型会在指标名称上设置额外的​​前缀。
    • __ verbose__: 详细程度模式,0 或 1。模式 0 为静默模式,模式 1 在回调执行操作时显示消息。
    • save_best_only: 如果 save_best_only=True,则仅在模型被认为是“最佳”时保存,并且根据监控的量,最新的最佳模型不会被覆盖。如果 filepath 不包含 {epoch} 等格式选项,则 filepath 将被每个新的更好模型覆盖。
    • mode: {'auto', 'min', 'max'} 中的一个。如果 save_best_only=True,则决定是否覆盖当前保存文件是基于被监控数量的最大化或最小化。对于 val_acc,这应该是 max;对于 val_loss,这应该是 min,依此类推。在 auto 模式下,如果监控的数量是 'acc' 或以 'fmeasure' 开头,则模式设置为 max,对于其他数量则设置为 min
    • save_weights_only: 如果为 True,则只保存模型的权重(model.save_weights(filepath)),否则保存整个模型(model.save(filepath))。
    • save_freq: 'epoch' 或 integer。当使用 'epoch' 时,回调会在每个 epoch 后保存模型。当使用 integer 时,回调会在这个数量的批次后保存模型。如果 Model 使用 steps_per_execution=N 编译,那么将在每 N 个批次检查一次保存条件。请注意,如果保存与 epoch 对齐,则监控的指标可能会不太可靠(它可能只反映 1 个批次,因为指标在每个 epoch 都会重置)。默认为 'epoch'
    • options: 可选的 tf.train.CheckpointOptions 对象(如果 save_weights_only 为 true)或可选的 tf.saved_model.SaveOptions 对象(如果 save_weights_only 为 false)。
    • initial_value_threshold: 要监控的指标的浮点型初始“最佳”值。仅当 save_best_value=True 时适用。只有当当前模型的性能优于此值时,才会覆盖已保存的模型权重。
    • **kwargs: 向后兼容的其他参数。可能的键是 period