Keras 3 API 文档 / 实用工具 / 模型绘图实用工具

模型绘图实用工具

[源代码]

plot_model 函数

keras.utils.plot_model(
    model,
    to_file="model.png",
    show_shapes=False,
    show_dtype=False,
    show_layer_names=False,
    rankdir="TB",
    expand_nested=False,
    dpi=200,
    show_layer_activations=False,
    show_trainable=False,
    **kwargs
)

将 Keras 模型转换为 dot 格式并保存到文件。

示例

inputs = ...
outputs = ...
model = keras.Model(inputs=inputs, outputs=outputs)

dot_img_file = '/tmp/model_1.png'
keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)

参数

  • model:一个 Keras 模型实例
  • to_file:绘图图像的文件名。
  • show_shapes:是否显示形状信息。
  • show_dtype:是否显示层的数据类型。
  • show_layer_names:是否显示层名称。
  • rankdir:传递给 PyDot 的 rankdir 参数,一个指定绘图格式的字符串:"TB" 创建一个垂直绘图;"LR" 创建一个水平绘图。
  • expand_nested:是否将嵌套的函数式模型扩展为集群。
  • dpi:每英寸点数的图像分辨率。
  • show_layer_activations:显示层激活(仅适用于具有 activation 属性的层)。
  • show_trainable:是否显示层是否可训练。

返回

如果安装了 Jupyter,则返回一个 Jupyter 笔记本 Image 对象。这使得可以在笔记本中内联显示模型图。


[源代码]

model_to_dot 函数

keras.utils.model_to_dot(
    model,
    show_shapes=False,
    show_dtype=False,
    show_layer_names=True,
    rankdir="TB",
    expand_nested=False,
    dpi=200,
    subgraph=False,
    show_layer_activations=False,
    show_trainable=False,
    **kwargs
)

将 Keras 模型转换为 dot 格式。

参数

  • model:一个 Keras 模型实例。
  • show_shapes:是否显示形状信息。
  • show_dtype:是否显示层的数据类型。
  • show_layer_names:是否显示层名称。
  • rankdir:传递给 PyDot 的 rankdir 参数,一个指定绘图格式的字符串:"TB" 创建一个垂直绘图;"LR" 创建一个水平绘图。
  • expand_nested:是否将嵌套的函数式模型扩展为集群。
  • dpi:每英寸点数的图像分辨率。
  • subgraph:是否返回 pydot.Cluster 实例。
  • show_layer_activations:显示层激活(仅适用于具有 activation 属性的层)。
  • show_trainable:是否显示层是否可训练。

返回

一个表示 Keras 模型的 pydot.Dot 实例,或者如果 subgraph=True,则返回一个表示嵌套模型的 pydot.Cluster 实例。