Keras 2 API文档 / 工具 / 模型绘图工具

模型绘图实用工具

[源代码]

plot_model 函数

tf_keras.utils.plot_model(
    model,
    to_file="model.png",
    show_shapes=False,
    show_dtype=False,
    show_layer_names=True,
    rankdir="TB",
    expand_nested=False,
    dpi=96,
    layer_range=None,
    show_layer_activations=False,
    show_trainable=False,
)

将TF-Keras模型转换为dot格式并保存到文件。

示例

input = tf.keras.Input(shape=(100,), dtype='int32', name='input')
x = tf.keras.layers.Embedding(
    output_dim=512, input_dim=10000, input_length=100)(input)
x = tf.keras.layers.LSTM(32)(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)
model = tf.keras.Model(inputs=[input], outputs=[output])
dot_img_file = '/tmp/model_1.png'
tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)

参数

  • model: 一个TF-Keras模型实例
  • to_file: 绘图文件的名称。
  • show_shapes: 是否显示形状信息。
  • show_dtype: 是否显示层数据类型。
  • show_layer_names: 是否显示层名称。
  • rankdir: 传递给PyDot的rankdir参数,一个指定绘图格式的字符串:'TB'创建垂直绘图;'LR'创建水平绘图。
  • expand_nested: 是否将嵌套模型展开为集群。
  • dpi: 每英寸点数。
  • layer_range: 输入一个包含两个str项的list,分别是起始层名称和结束层名称(均包含),用于指示生成绘图的层范围。它也接受正则表达式而不是精确名称。在这种情况下,起始谓词将是它匹配到的第一个元素layer_range[0],结束谓词将是它匹配到的最后一个元素layer_range[1]。默认为None,表示考虑模型的所有层。请注意,您必须传递一个有效的范围,使得生成的子图是完整的。
  • show_layer_activations: 显示层激活(仅对具有activation属性的层有效)。
  • show_trainable: 是否显示层是否可训练。当层可训练时显示'T',当层不可训练时显示'NT'。

引发

  • ImportError: 如果graphviz或pydot不可用。
  • ValueError: 如果在模型构建之前调用了plot_model

返回

一个Jupyter notebook的Image对象,如果安装了Jupyter。这使得模型绘图可以在notebook中内联显示。


[源代码]

model_to_dot 函数

tf_keras.utils.model_to_dot(
    model,
    show_shapes=False,
    show_dtype=False,
    show_layer_names=True,
    rankdir="TB",
    expand_nested=False,
    dpi=96,
    subgraph=False,
    layer_range=None,
    show_layer_activations=False,
    show_trainable=False,
)

将TF-Keras模型转换为dot格式。

参数

  • model: 一个TF-Keras模型实例。
  • show_shapes: 是否显示形状信息。
  • show_dtype: 是否显示层数据类型。
  • show_layer_names: 是否显示层名称。
  • rankdir: 传递给PyDot的rankdir参数,一个指定绘图格式的字符串:'TB'创建垂直绘图;'LR'创建水平绘图。
  • expand_nested: 是否将嵌套模型展开为集群。
  • dpi: 每英寸点数。
  • subgraph: 是否返回一个pydot.Cluster实例。
  • layer_range: 输入一个包含两个str项的list,分别是起始层名称和结束层名称(均包含),用于指示生成pydot.Dot的层范围。它也接受正则表达式而不是精确名称。在这种情况下,起始谓词将是它匹配到的第一个元素layer_range[0],结束谓词将是它匹配到的最后一个元素layer_range[1]。默认为None,表示考虑模型的所有层。请注意,您必须传递一个有效的范围,使得生成的子图是完整的。
  • show_layer_activations: 显示层激活(仅对具有activation属性的层有效)。
  • show_trainable: 是否显示层是否可训练。当层可训练时显示'T',当层不可训练时显示'NT'。

返回

一个代表TF-Keras模型或嵌套模型的pydot.Dot实例,如果subgraph=True则为pydot.Cluster实例。

引发

  • ValueError: 如果在模型构建之前调用了model_to_dot
  • ImportError: 如果pydot不可用。