BackupAndRestore
类tf_keras.callbacks.BackupAndRestore(
backup_dir, save_freq="epoch", delete_checkpoint=True, save_before_preemption=False
)
用于备份和恢复训练状态的回调函数。
BackupAndRestore
回调旨在通过在每次 epoch 结束时将训练状态备份到临时检查点文件(借助 tf.train.CheckpointManager
),从而从 Model.fit
执行中断中恢复训练。每次备份都会覆盖先前写入的检查点文件,因此在任何给定时间,最多只有一个用于备份/恢复目的的此类检查点文件。
如果在训练完成之前重新开始,训练状态(包括 Model
权重和 epoch 数)将在新的 Model.fit
运行开始时恢复到最近保存的状态。Model.fit
运行完成后,临时检查点文件将被删除。
请注意,用户负责在中断后恢复作业。此回调对于容错的备份和恢复机制非常重要,并且从先前检查点恢复的模型应与用于备份的模型相同。如果用户更改传递给 compile 或 fit 的参数,为容错而保存的检查点可能会失效。
注意
Model.fit
会在训练中断并重新开始的未完成 epoch 中重新执行任何部分工作(因此中断之前完成的工作不会影响最终的模型状态)。Model.fit
与 tf.distribute
一起使用时,它支持 tf.distribute.MirroredStrategy
、tf.distribute.MultiWorkerMirroredStrategy
、tf.distribute.TPUStrategy
和 tf.distribute.experimental.ParameterServerStrategy
。示例
>>> class InterruptingCallback(tf.keras.callbacks.Callback):
... def on_epoch_begin(self, epoch, logs=None):
... if epoch == 4:
... raise RuntimeError('Interrupting!')
>>> callback = tf.keras.callbacks.BackupAndRestore(backup_dir="/tmp/backup")
>>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
>>> model.compile(tf.keras.optimizers.SGD(), loss='mse')
>>> try:
... model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10,
... batch_size=1, callbacks=[callback, InterruptingCallback()],
... verbose=0)
... except:
... pass
>>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
... epochs=10, batch_size=1, callbacks=[callback],
... verbose=0)
>>> # Only 6 more epochs are run, since first training got interrupted at
>>> # zero-indexed epoch 4, second training will continue from 4 to 9.
>>> len(history.history['loss'])
6
除了在每个 epoch 结束时或每 N 步保存的选项之外,如果您正在 Google Cloud Platform 或 Google Borg 上使用 tf.distribute.MultiWorkerMirroredStrategy
进行分布式训练,您还可以使用 save_before_preemption
参数来在 worker 被其他作业抢占和训练中断之前启用保存检查点。有关更多详细信息,请参阅 tf.distribute.experimental.PreemptionCheckpointHandler
。
参数
backup_dir = os.path.join(working_dir, 'backup')
。这是系统存储临时文件以从意外终止的作业中恢复模型的目录。此目录不能用于存储其他文件,例如不能由另一个训练运行的 BackupAndRestore
回调或同一训练的其他回调(例如 ModelCheckpoint
)使用。'epoch'
、整数或 False
。设置为 'epoch'
时,回调在每个 epoch 结束时保存检查点。设置为整数时,回调每 save_freq
批次保存检查点。如果仅使用抢占检查点(即 save_before_preemption=True
),则将 save_freq
设置为 False
。BackupAndRestore
回调通过保存检查点来备份训练状态。如果 delete_checkpoint=True
,则训练完成后将删除检查点。如果您想保留检查点以备将来使用,请设置为 False
。tf.distribute.MultiWorkerMirroredStrategy
。