ModelCheckpoint 类keras.callbacks.ModelCheckpoint(
filepath,
monitor="val_loss",
verbose=0,
save_best_only=False,
save_weights_only=False,
mode="auto",
save_freq="epoch",
initial_value_threshold=None,
)
在一定频率下保存 Keras 模型或模型权重的回调。
ModelCheckpoint 回调与使用 model.fit() 进行的训练结合使用,以便在某个时间间隔保存模型或权重(在检查点文件中),这样模型或权重就可以在之后加载,以从保存的状态继续训练。
此回调提供的一些选项包括:
示例
model.compile(loss=..., optimizer=...,
metrics=['accuracy'])
EPOCHS = 10
checkpoint_filepath = '/tmp/ckpt/checkpoint.model.keras'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
monitor='val_accuracy',
mode='max',
save_best_only=True)
# Model is saved at the end of every epoch, if it's the best seen so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
# The model (that are considered the best) can be loaded as -
keras.models.load_model(checkpoint_filepath)
# Alternatively, one could checkpoint just the model weights as -
checkpoint_filepath = '/tmp/ckpt/checkpoint.weights.h5'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='val_accuracy',
mode='max',
save_best_only=True)
# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
# The model weights (that are considered the best) can be loaded as -
model.load_weights(checkpoint_filepath)
参数
PathLike,模型文件的保存路径。filepath 可以包含命名的格式化选项,这些选项将由 epoch 的值和 logs 中的键(在 on_epoch_end 中传递)填充。当 save_weights_only=True 时,filepath 名称需要以 ".weights.h5" 结尾,或者在检查点保存整个模型时(默认)需要以 ".keras" 或 ".h5" 结尾。例如:如果 filepath 是 "{epoch:02d}-{val_loss:.2f}.keras" 或 "{epoch:02d}-{val_loss:.2f}.weights.h5"`,则模型检查点将以 epoch 号和验证损失作为文件名保存。filepath 的目录不应被任何其他回调复用,以避免冲突。Model.compile 方法设置的。注意:"val_" 以监控验证度量。"loss" 或 "val_loss" 来监控模型的总损失。"accuracy",则传递相同的字符串(带或不带 "val_" 前缀)。metrics.Metric 对象,则 monitor 应设置为 metric.name。history = model.fit() 返回的 history.history 字典的内容。save_best_only=True,则仅在模型被视为“最佳”时保存,并且根据监控的数量,最新的最佳模型不会被覆盖。如果 filepath 不包含 {epoch} 等格式化选项,那么 filepath 将会被每个新的更优模型覆盖。"auto", "min", "max"} 中的一个。如果 save_best_only=True,则决定是否覆盖当前保存文件的依据是监控数量的最小化或最大化。对于 val_acc,这应该是 "max";对于 val_loss,这应该是 "min",依此类推。在 "auto" 模式下,方向会自动从监控数量的名称推断出来。True,则只保存模型的权重(model.save_weights(filepath)),否则保存整个模型(model.save(filepath))。"epoch" 或整数。使用 "epoch" 时,回调在每个 epoch 后保存模型。使用整数时,回调在此数量的批次后保存模型。如果 Model 使用 steps_per_execution=N 进行编译,则每 N 个批次检查一次保存标准。请注意,如果保存与 epoch 不对齐,监控的度量可能不太可靠(它可能只反映 1 个批次,因为度量会在每个 epoch 重置)。默认为 "epoch"。save_best_value=True 时适用。仅当当前模型的性能优于此值时,才会覆盖已保存的模型权重。