Callback
类tf_keras.callbacks.Callback()
用于构建新回调函数的抽象基类。
可以将回调函数传递给 Keras 方法,例如 fit
、evaluate
和 predict
,以便在模型训练和推理生命周期的各个阶段进行挂钩。
要创建自定义回调函数,请继承 keras.callbacks.Callback
并重写与感兴趣的阶段相关联的方法。有关更多信息,请参阅自定义回调函数。
示例
>>> training_finished = False
>>> class MyCallback(tf.keras.callbacks.Callback):
... def on_train_end(self, logs=None):
... global training_finished
... training_finished = True
>>> model = tf.keras.Sequential([
... tf.keras.layers.Dense(1, input_shape=(1,))])
>>> model.compile(loss='mean_squared_error')
>>> model.fit(tf.constant([[1.0]]), tf.constant([[1.0]]),
... callbacks=[MyCallback()])
>>> assert training_finished == True
如果您想在自定义训练循环中使用 Callback
对象
callbacks.CallbackList
中,以便可以一次性调用它们。on_*
方法。例如示例
callbacks = tf.keras.callbacks.CallbackList([...])
callbacks.append(...)
callbacks.on_train_begin(...)
for epoch in range(EPOCHS):
callbacks.on_epoch_begin(epoch)
for i, data in dataset.enumerate():
callbacks.on_train_batch_begin(i)
batch_logs = model.train_step(data)
callbacks.on_train_batch_end(i, batch_logs)
epoch_logs = ...
callbacks.on_epoch_end(epoch, epoch_logs)
final_logs=...
callbacks.on_train_end(final_logs)
属性
keras.models.Model
的实例。正在训练的模型的引用。回调函数方法作为参数接收的 logs
字典将包含与当前批次或 epoch 相关的数量的键(请参阅特定方法文档字符串)。