Keras 3 API 文档 / 工具 / 模型绘图工具

模型绘图工具

[源代码]

plot_model 函数

keras.utils.plot_model(
    model,
    to_file="model.png",
    show_shapes=False,
    show_dtype=False,
    show_layer_names=False,
    rankdir="TB",
    expand_nested=False,
    dpi=200,
    show_layer_activations=False,
    show_trainable=False,
    **kwargs
)

将 Keras 模型转换为 dot 格式并保存到文件。

示例

inputs = ...
outputs = ...
model = keras.Model(inputs=inputs, outputs=outputs)

dot_img_file = '/tmp/model_1.png'
keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)

参数

  • model: Keras 模型实例
  • to_file: 绘图图像的文件名。
  • show_shapes: 是否显示形状信息。
  • show_dtype: 是否显示层的数据类型。
  • show_layer_names: 是否显示层名称。
  • rankdir: 传递给 PyDot 的 rankdir 参数,一个字符串,指定绘图的格式:"TB" 创建垂直绘图;"LR" 创建水平绘图。
  • expand_nested: 是否将嵌套的功能模型展开为集群。
  • dpi: 图像分辨率,单位为每英寸点数。
  • show_layer_activations: 显示层激活(仅适用于具有 activation 属性的层)。
  • show_trainable: 是否显示层是否可训练。

返回值

如果安装了 Jupyter,则为 Jupyter Notebook 图像对象。这使得可以在笔记本中内联显示模型图。


[源代码]

model_to_dot 函数

keras.utils.model_to_dot(
    model,
    show_shapes=False,
    show_dtype=False,
    show_layer_names=True,
    rankdir="TB",
    expand_nested=False,
    dpi=200,
    subgraph=False,
    show_layer_activations=False,
    show_trainable=False,
    **kwargs
)

将 Keras 模型转换为 dot 格式。

参数

  • model: Keras 模型实例。
  • show_shapes: 是否显示形状信息。
  • show_dtype: 是否显示层的数据类型。
  • show_layer_names: 是否显示层名称。
  • rankdir: 传递给 PyDot 的 rankdir 参数,一个字符串,指定绘图的格式:"TB" 创建垂直绘图;"LR" 创建水平绘图。
  • expand_nested: 是否将嵌套的功能模型展开为集群。
  • dpi: 图像分辨率,单位为每英寸点数。
  • subgraph: 是否返回 pydot.Cluster 实例。
  • show_layer_activations: 显示层激活(仅适用于具有 activation 属性的层)。
  • show_trainable: 是否显示层是否可训练。

返回值

表示 Keras 模型的 pydot.Dot 实例,或者如果 subgraph=True 则表示嵌套模型的 pydot.Cluster 实例。