Keras 3 API文档 / 回调API / BackupAndRestore

BackupAndRestore

[源代码]

BackupAndRestore

keras.callbacks.BackupAndRestore(
    backup_dir, save_freq="epoch", double_checkpoint=False, delete_checkpoint=True
)

用于备份和恢复训练状态的回调。

BackupAndRestore 回调旨在通过在每个 epoch 结束时将训练状态备份到临时检查点文件来恢复在 Model.fit 执行过程中断的训练。每次备份都会覆盖之前写入的检查点文件,因此在任何给定时间最多只有一个用于备份/恢复的检查点文件。

如果在完成之前重新启动训练,则训练状态(包括 Model 权重和 epoch 编号)将在新的 Model.fit 运行时恢复到最近保存的状态。Model.fit 运行完成后,将删除临时检查点文件。

请注意,用户负责在中断后恢复作业。此回调对于容错目的的备份和恢复机制很重要,并且期望用于从先前检查点恢复的模型与用于备份的模型相同。如果用户更改了传递给 compile 或 fit 的参数,则为容错保存的检查点可能会失效。

示例

>>> class InterruptingCallback(keras.callbacks.Callback):
...   def on_epoch_begin(self, epoch, logs=None):
...     if epoch == 4:
...       raise RuntimeError('Interrupting!')
>>> callback = keras.callbacks.BackupAndRestore(backup_dir="/tmp/backup")
>>> model = keras.models.Sequential([keras.layers.Dense(10)])
>>> model.compile(keras.optimizers.SGD(), loss='mse')
>>> model.build(input_shape=(None, 20))
>>> try:
...   model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10,
...             batch_size=1, callbacks=[callback, InterruptingCallback()],
...             verbose=0)
... except:
...   pass
>>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
...                     epochs=10, batch_size=1, callbacks=[callback],
...                     verbose=0)
>>> # Only 6 more epochs are run, since first training got interrupted at
>>> # zero-indexed epoch 4, second training will continue from 4 to 9.
>>> len(history.history['loss'])
>>> 6

参数

  • backup_dir: 字符串,用于存储恢复模型所需数据的目录的路径。该目录不能在其他地方重复使用以存储其他文件,例如,用于另一个训练运行的 BackupAndRestore 回调,或者用于同一训练运行的另一个回调(例如 ModelCheckpoint)。
  • save_freq: "epoch",整数,或 False。当设置为 "epoch" 时,回调会在每个 epoch 结束时保存检查点。当设置为整数时,回调会每 save_freq 个批次保存一次检查点。仅当使用抢占式检查点(即,与 save_before_preemption=True 一起使用)时,才将 save_freq 设置为 False
  • double_checkpoint: 布尔值。如果启用,BackupAndRestore 回调将保存最后两个训练状态(当前和上一个)。中断后,如果由于 IO 错误(例如文件损坏)无法加载当前状态,它将尝试恢复上一个状态。这种行为将消耗两倍的磁盘空间,但会提高容错能力。默认为 False
  • delete_checkpoint: 布尔值。此 BackupAndRestore 回调通过保存检查点来备份训练状态。如果 delete_checkpoint=True,则在训练完成后删除检查点。如果您希望保留检查点供将来使用,请使用 False。默认为 True