Keras 2 API 文档 / 回调 API / TensorBoard

TensorBoard

[源]

TensorBoard

tf_keras.callbacks.TensorBoard(
    log_dir="logs",
    histogram_freq=0,
    write_graph=True,
    write_images=False,
    write_steps_per_second=False,
    update_freq="epoch",
    profile_batch=0,
    embeddings_freq=0,
    embeddings_metadata=None,
    **kwargs
)

启用 TensorBoard 的可视化功能。

TensorBoard 是 TensorFlow 提供的一个可视化工具。

此回调函数会将事件记录到 TensorBoard,包括

  • 指标摘要图
  • 训练图可视化
  • 权重直方图
  • 采样性能分析

当用于 Model.evaluate 或常规验证(on_test_end)时,除了 epoch 摘要,还会写入一个记录评估指标相对于 Model.optimizer.iterations 的摘要。指标名称将以 evaluation 为前缀,而 Model.optimizer.iterations 将是 TensorBoard 可视化中的步骤。

如果你使用 pip 安装了 TensorFlow,应该能够从命令行启动 TensorBoard

tensorboard --logdir=path_to_your_logs

你可以在这里找到关于 TensorBoard 的更多信息。

参数

  • log_dir: 保存 TensorBoard 解析的日志文件的目录路径。例如:log_dir = os.path.join(working_dir, 'logs') 此目录不应被任何其他回调函数重复使用。
  • histogram_freq: 计算模型层权重直方图的频率(以 epoch 为单位)。如果设置为 0,则不计算直方图。必须指定验证数据(或拆分)才能进行直方图可视化。
  • write_graph: 是否在 TensorBoard 中可视化模型图。当 write_graph 设置为 True 时,日志文件会变得相当大。
  • write_images: 是否将模型权重写入 TensorBoard 以图像形式可视化。
  • write_steps_per_second: 是否将每秒的训练步数记录到 TensorBoard 中。这支持按 epoch 和按 batch 频率记录。
  • update_freq: 'batch''epoch' 或整数。使用 'epoch' 时,会在每个 epoch 后将损失和指标写入 TensorBoard。如果使用整数(例如 1000),则所有指标和损失(包括通过 Model.compile 添加的自定义指标和损失)将每 1000 个 batch 记录到 TensorBoard。'batch'1 的同义词,表示它们将每 batch 写入。但请注意,过于频繁地写入 TensorBoard 可能会减慢训练速度,特别是与 tf.distribute.Strategy 一起使用时,因为它会产生额外的同步开销。不支持与 ParameterServerStrategy 一起使用。batch 级别的摘要写入也可以通过覆盖 train_step 来实现。更多详细信息请参见TensorBoard Scalars 教程 # noqa: E501。
  • profile_batch: 对 batch 进行性能分析以采样计算特征。profile_batch 必须是非负整数或整数元组。一对正整数表示要分析的 batch 范围。默认情况下,性能分析是禁用的。
  • embeddings_freq: 可视化 embedding 层的频率(以 epoch 为单位)。如果设置为 0,则不可视化 embedding。
  • embeddings_metadata: 字典,将 embedding 层名称映射到用于保存 embedding 层元数据的文件的文件名。如果所有 embedding 层使用同一个元数据文件,则可以传递单个文件名。

示例

基本用法

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
# Then run the tensorboard command to view the visualizations.

子类化 Model 中的自定义 batch 级别摘要

class MyModel(tf.keras.Model):

  def build(self, _):
    self.dense = tf.keras.layers.Dense(10)

  def call(self, x):
    outputs = self.dense(x)
    tf.summary.histogram('outputs', outputs)
    return outputs

model = MyModel()
model.compile('sgd', 'mse')

# Make sure to set `update_freq=N` to log a batch-level summary every N
# batches.  In addition to any [`tf.summary`](https://tensorflowcn.cn/api_docs/python/tf/summary) contained in `Model.call`,
# metrics added in `Model.compile` will be logged every N batches.
tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1)
model.fit(x_train, y_train, callbacks=[tb_callback])

函数式 API Model 中的自定义 batch 级别摘要

def my_summary(x):
  tf.summary.histogram('x', x)
  return x

inputs = tf.keras.Input(10)
x = tf.keras.layers.Dense(10)(inputs)
outputs = tf.keras.layers.Lambda(my_summary)(x)
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', 'mse')

# Make sure to set `update_freq=N` to log a batch-level summary every N
# batches. In addition to any [`tf.summary`](https://tensorflowcn.cn/api_docs/python/tf/summary) contained in `Model.call`,
# metrics added in `Model.compile` will be logged every N batches.
tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1)
model.fit(x_train, y_train, callbacks=[tb_callback])

性能分析

# Profile a single batch, e.g. the 5th batch.
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir='./logs', profile_batch=5)
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])

# Profile a range of batches, e.g. from 10 to 20.
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir='./logs', profile_batch=(10,20))
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])