Keras 2 API 文档 / Callbacks API / BackupAndRestore

BackupAndRestore

[源代码]

BackupAndRestore

tf_keras.callbacks.BackupAndRestore(
    backup_dir, save_freq="epoch", delete_checkpoint=True, save_before_preemption=False
)

用于备份和恢复训练状态的回调函数。

BackupAndRestore 回调函数的目的是通过在每个 epoch 结束时将训练状态备份到临时检查点文件(借助 tf.train.CheckpointManager),从而从 `Model.fit` 执行过程中中断的中间环节恢复训练。每次备份都会覆盖之前写入的检查点文件,因此在任何给定时间最多只有一个用于备份/恢复目的的检查点文件。

如果训练在完成前重新启动,训练状态(包括 `Model` 权重和 epoch 编号)将在新的 `Model.fit` 运行时恢复到最近保存的状态。在 `Model.fit` 运行完成后,临时检查点文件将被删除。

请注意,用户负责在中断后恢复任务。此回调函数对于实现容错的备份和恢复机制很重要,并且用于从先前检查点恢复的模型预计与用于备份的模型相同。如果用户更改了传递给 `compile` 或 `fit` 的参数,则为容错保存的检查点可能会失效。

注意

  1. 此回调函数与禁用 eager 执行不兼容。
  2. 检查点在每个 epoch 结束时保存。恢复后,`Model.fit` 会重做在训练重新启动的未完成 epoch 中进行的任何部分工作(因此中断之前完成的工作不会影响最终模型状态)。
  3. 这适用于单工作节点和多工作节点模式。当 `Model.fit` 与 tf.distribute 一起使用时,它支持 tf.distribute.MirroredStrategytf.distribute.MultiWorkerMirroredStrategytf.distribute.TPUStrategytf.distribute.experimental.ParameterServerStrategy

示例

>>> class InterruptingCallback(tf.keras.callbacks.Callback):
...   def on_epoch_begin(self, epoch, logs=None):
...     if epoch == 4:
...       raise RuntimeError('Interrupting!')
>>> callback = tf.keras.callbacks.BackupAndRestore(backup_dir="/tmp/backup")
>>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
>>> model.compile(tf.keras.optimizers.SGD(), loss='mse')
>>> try:
...   model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10,
...             batch_size=1, callbacks=[callback, InterruptingCallback()],
...             verbose=0)
... except:
...   pass
>>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
...                     epochs=10, batch_size=1, callbacks=[callback],
...                     verbose=0)
>>> # Only 6 more epochs are run, since first training got interrupted at
>>> # zero-indexed epoch 4, second training will continue from 4 to 9.
>>> len(history.history['loss'])
6

除了在每个 epoch 或每 N 步结束时保存的选项外,如果您正在使用 tf.distribute.MultiWorkerMirroredStrategy 在 Google Cloud Platform 或 Google Borg 上进行分布式训练,您还可以使用 `save_before_preemption` 参数来启用在工作节点被其他作业抢占而训练被中断之前保存检查点。有关更多详细信息,请参阅 tf.distribute.experimental.PreemptionCheckpointHandler

参数

  • backup_dir:字符串,用于存储检查点的路径。例如:backup_dir = os.path.join(working_dir, 'backup')。这是系统用于存储临时文件以从意外终止的任务中恢复模型的目录。该目录不能用于存储其他文件,例如,不能被另一个训练运行的 BackupAndRestore 回调函数使用,也不能被同一训练的其他回调函数(例如 ModelCheckpoint)使用。
  • save_freq'epoch'、整数或 False。设置为 'epoch' 时,回调函数在每个 epoch 结束时保存检查点。设置为整数时,回调函数每 save_freq 个 batch 保存一次检查点。如果仅使用抢占式检查点(将 save_before_preemption=True),则将 save_freq 设置为 False
  • delete_checkpoint:布尔值,默认为 True。此 BackupAndRestore 回调函数通过保存检查点来备份训练状态。如果 delete_checkpoint=True,则在训练完成后删除该检查点。如果您希望保留检查点供将来使用,请使用 False
  • save_before_preemption:一个布尔值,指示是否启用针对抢占/维护事件的自动检查点保存。目前仅支持在 Google Cloud Platform 或 Google Borg 上的 tf.distribute.MultiWorkerMirroredStrategy