Keras 3 API文档 / 回调API / TensorBoard

TensorBoard

[源代码]

TensorBoard

keras.callbacks.TensorBoard(
    log_dir="logs",
    histogram_freq=0,
    write_graph=True,
    write_images=False,
    write_steps_per_second=False,
    update_freq="epoch",
    profile_batch=0,
    embeddings_freq=0,
    embeddings_metadata=None,
)

为TensorBoard启用可视化。

TensorBoard是TensorFlow提供的可视化工具。要使用此回调,需要安装TensorFlow。

此回调记录TensorBoard的事件,包括:

  • 指标摘要图
  • 训练图可视化
  • 权重直方图
  • 采样剖析

当在model.evaluate()或常规验证中使用时,除了每个epoch的摘要之外,还会生成一个摘要,该摘要记录评估指标与model.optimizer.iterations的关系。指标名称将以evaluation作为前缀,其中model.optimizer.iterations是TensorBoard中可视化的步数。

如果您已使用 pip 安装 TensorFlow,则应该能够从命令行启动 TensorBoard:

tensorboard --logdir=path_to_your_logs

您可以在这里找到更多关于TensorBoard的信息。

参数

  • log_dir: 用于保存TensorBoard解析的日志文件的目录路径。例如,log_dir = os.path.join(working_dir, 'logs')。此目录不应被任何其他回调重复使用。
  • histogram_freq: 计算模型层权重直方图的频率(以epoch为单位)。如果设置为0,则不计算直方图。需要指定验证数据(或拆分)才能进行直方图可视化。
  • write_graph: (目前不支持) 是否在TensorBoard中可视化图。请注意,当write_graph设置为True时,日志文件可能会变得很大。
  • write_images: 是否将模型权重写入TensorBoard以图像形式可视化。
  • write_steps_per_second: 是否将每秒训练步数记录到TensorBoard。这支持按epoch和按batch进行频率记录。
  • update_freq: "batch""epoch"或整数。当使用"epoch"时,在每个epoch后将损失和指标写入TensorBoard。如果使用整数,例如1000,则每1000个batch会将所有指标和损失(包括通过Model.compile添加的自定义指标)写入TensorBoard。"batch"是1的同义词,意味着它们将在每个batch写入。请注意,过于频繁地写入TensorBoard可能会减慢训练速度,尤其是在使用分布式策略时,因为它会产生额外的同步开销。通过train_step覆盖也可以进行batch级别的摘要写入。有关更多详细信息,请参阅TensorBoardScalars教程
  • profile_batch: 剖析batch以采样计算特性。profile_batch必须是非负整数或整数元组。一对正整数表示要剖析的batch范围。默认情况下,剖析是禁用的。
  • embeddings_freq: 可视化嵌入层(embedding layers)的频率(以epoch为单位)。如果设置为0,则不可视化嵌入。
  • embeddings_metadata: 一个字典,将嵌入层名称映射到用于保存嵌入层元数据的文件的文件名。如果为所有嵌入层使用相同的元数据文件,则可以传递单个文件名。

示例

tensorboard_callback = keras.callbacks.TensorBoard(log_dir="./logs")
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
# Then run the tensorboard command to view the visualizations.

子类化模型中的自定义batch级别摘要

class MyModel(keras.Model):

    def build(self, _):
        self.dense = keras.layers.Dense(10)

    def call(self, x):
        outputs = self.dense(x)
        tf.summary.histogram('outputs', outputs)
        return outputs

model = MyModel()
model.compile('sgd', 'mse')

# Make sure to set `update_freq=N` to log a batch-level summary every N
# batches.  In addition to any [`tf.summary`](https://tensorflowcn.cn/api_docs/python/tf/summary) contained in `model.call()`,
# metrics added in `Model.compile` will be logged every N batches.
tb_callback = keras.callbacks.TensorBoard('./logs', update_freq=1)
model.fit(x_train, y_train, callbacks=[tb_callback])

函数式API模型中的自定义batch级别摘要

def my_summary(x):
    tf.summary.histogram('x', x)
    return x

inputs = keras.Input(10)
x = keras.layers.Dense(10)(inputs)
outputs = keras.layers.Lambda(my_summary)(x)
model = keras.Model(inputs, outputs)
model.compile('sgd', 'mse')

# Make sure to set `update_freq=N` to log a batch-level summary every N
# batches. In addition to any [`tf.summary`](https://tensorflowcn.cn/api_docs/python/tf/summary) contained in `Model.call`,
# metrics added in `Model.compile` will be logged every N batches.
tb_callback = keras.callbacks.TensorBoard('./logs', update_freq=1)
model.fit(x_train, y_train, callbacks=[tb_callback])

剖析

# Profile a single batch, e.g. the 5th batch.
tensorboard_callback = keras.callbacks.TensorBoard(
    log_dir='./logs', profile_batch=5)
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])

# Profile a range of batches, e.g. from 10 to 20.
tensorboard_callback = keras.callbacks.TensorBoard(
    log_dir='./logs', profile_batch=(10,20))
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])