Keras 2 API 文档 / 工具类 / 模型绘制工具

模型绘制工具

[源]

plot_model 函数

tf_keras.utils.plot_model(
    model,
    to_file="model.png",
    show_shapes=False,
    show_dtype=False,
    show_layer_names=True,
    rankdir="TB",
    expand_nested=False,
    dpi=96,
    layer_range=None,
    show_layer_activations=False,
    show_trainable=False,
)

将 TF-Keras 模型转换为 dot 格式并保存到文件。

示例

input = tf.keras.Input(shape=(100,), dtype='int32', name='input')
x = tf.keras.layers.Embedding(
    output_dim=512, input_dim=10000, input_length=100)(input)
x = tf.keras.layers.LSTM(32)(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)
model = tf.keras.Model(inputs=[input], outputs=[output])
dot_img_file = '/tmp/model_1.png'
tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)

参数

  • model: 一个 TF-Keras 模型实例
  • to_file: 绘制图像的文件名。
  • show_shapes: 是否显示形状信息。
  • show_dtype: 是否显示层数据类型。
  • show_layer_names: 是否显示层名称。
  • rankdir: 传递给 PyDot 的 rankdir 参数,一个指定绘制格式的字符串:'TB' 创建垂直布局图;'LR' 创建水平布局图。
  • expand_nested: 是否将嵌套模型展开为集群。
  • dpi: 每英寸点数。
  • layer_range: 包含两个 str 元素的 list 输入,它们是起始层名称和结束层名称(两端都包含),指示将生成图的层范围。它也接受正则表达式模式而不是确切名称。在这种情况下,起始谓词将是它匹配 layer_range[0] 的第一个元素,结束谓词将是它匹配 layer_range[1] 的最后一个元素。默认情况下为 None,这会考虑模型的所有层。请注意,传递的范围必须使生成的子图完整。
  • show_layer_activations: 显示层激活(仅适用于具有 activation 属性的层)。
  • show_trainable: 是否显示层是否可训练。当层可训练时显示 'T',不可训练时显示 'NT'。

抛出

  • ImportError: 如果 graphviz 或 pydot 不可用。
  • ValueError: 如果在模型构建之前调用 plot_model

返回

如果安装了 Jupyter,则返回一个 Jupyter notebook Image 对象。这允许在 notebook 中内联显示模型图。


[源]

model_to_dot 函数

tf_keras.utils.model_to_dot(
    model,
    show_shapes=False,
    show_dtype=False,
    show_layer_names=True,
    rankdir="TB",
    expand_nested=False,
    dpi=96,
    subgraph=False,
    layer_range=None,
    show_layer_activations=False,
    show_trainable=False,
)

将 TF-Keras 模型转换为 dot 格式。

参数

  • model: 一个 TF-Keras 模型实例。
  • show_shapes: 是否显示形状信息。
  • show_dtype: 是否显示层数据类型。
  • show_layer_names: 是否显示层名称。
  • rankdir: 传递给 PyDot 的 rankdir 参数,一个指定绘制格式的字符串:'TB' 创建垂直布局图;'LR' 创建水平布局图。
  • expand_nested: 是否将嵌套模型展开为集群。
  • dpi: 每英寸点数。
  • subgraph: 是否返回一个 pydot.Cluster 实例。
  • layer_range: 包含两个 str 元素的 list 输入,它们是起始层名称和结束层名称(两端都包含),指示将生成 pydot.Dot 的层范围。它也接受正则表达式模式而不是确切名称。在这种情况下,起始谓词将是它匹配 layer_range[0] 的第一个元素,结束谓词将是它匹配 layer_range[1] 的最后一个元素。默认情况下为 None,这会考虑模型的所有层。请注意,传递的范围必须使生成的子图完整。
  • show_layer_activations: 显示层激活(仅适用于具有 activation 属性的层)。
  • show_trainable: 是否显示层是否可训练。当层可训练时显示 'T',不可训练时显示 'NT'。

返回

一个代表 TF-Keras 模型的 pydot.Dot 实例,如果 subgraph=True,则为一个代表嵌套模型的 pydot.Cluster 实例。

抛出

  • ValueError: 如果在模型构建之前调用 model_to_dot
  • ImportError: 如果 pydot 不可用。