Keras 2 API 文档 / 模型 API / Model 类

Model 类

[源]

Model

tf_keras.Model()

一个将层分组为一个具有训练/推理功能的对象的模型。

参数

  • inputs: 模型的输入:一个 keras.Input 对象,或 dict、list 或 tuple 中 keras.Input 对象的组合。
  • outputs: 模型的输出:一个源自 keras.Input 对象的张量,或 dict、list 或 tuple 中此类张量的组合。参见下面的函数式 API 示例。
  • name: 字符串,模型的名称。

有两种方式可以实例化一个 Model

1 - 使用“函数式 API”,从 Input 开始,将层调用链式连接起来指定模型的前向传播,最后从输入和输出创建你的模型

import tensorflow as tf

inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

注意:只支持输入张量的 dict、list 和 tuple。不支持嵌套输入(例如 list 的 list 或 dict 的 dict)。

还可以使用中间张量创建一个新的函数式 API 模型。这使你能够快速提取模型的子组件。

示例

inputs = keras.Input(shape=(None, None, 3))
processed = keras.layers.RandomCrop(width=32, height=32)(inputs)
conv = keras.layers.Conv2D(filters=2, kernel_size=3)(processed)
pooling = keras.layers.GlobalAveragePooling2D()(conv)
feature = keras.layers.Dense(10)(pooling)

full_model = keras.Model(inputs, feature)
backbone = keras.Model(processed, conv)
activations = keras.Model(conv, feature)

注意,backboneactivations 模型不是使用 keras.Input 对象创建的,而是使用源自 keras.Input 对象的张量创建的。在底层,这些模型的层和权重是共享的,这样用户可以训练 full_model,并使用 backboneactivations 进行特征提取。模型的输入和输出也可以是张量的嵌套结构,并且创建的模型是标准的函数式 API 模型,支持所有现有的 API。

2 - 通过继承 Model 类:在这种情况下,你应该在 __init__() 中定义你的层,并在 call() 中实现模型的前向传播。

import tensorflow as tf

class MyModel(tf.keras.Model):

  def __init__(self):
    super().__init__()
    self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)

  def call(self, inputs):
    x = self.dense1(inputs)
    return self.dense2(x)

model = MyModel()

如果你继承 Model,你可以在 call() 中选择性地添加一个 training 参数(布尔值),你可以用它来指定在训练和推理时的不同行为

import tensorflow as tf

class MyModel(tf.keras.Model):

  def __init__(self):
    super().__init__()
    self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
    self.dropout = tf.keras.layers.Dropout(0.5)

  def call(self, inputs, training=False):
    x = self.dense1(inputs)
    if training:
      x = self.dropout(x, training=training)
    return self.dense2(x)

model = MyModel()

模型创建后,你可以使用 model.compile() 配置模型的损失函数和评估指标,使用 model.fit() 训练模型,或使用 model.predict() 进行预测。


[源]

summary 方法

Model.summary(
    line_length=None,
    positions=None,
    print_fn=None,
    expand_nested=False,
    show_trainable=False,
    layer_range=None,
)

打印网络的字符串摘要。

参数

  • line_length: 打印行的总长度(例如,设置此值以适应不同的终端窗口大小)。
  • positions: 每行日志元素的相对或绝对位置。如果未提供,则为 [0.3, 0.6, 0.70, 1.]。默认为 None
  • print_fn: 要使用的打印函数。默认情况下,打印到 stdout。如果 stdout 在你的环境中不起作用,更改为 print。它将对摘要的每一行调用。你可以将其设置为一个自定义函数,以捕获字符串摘要。
  • expand_nested: 是否展开嵌套模型。默认为 False
  • show_trainable: 是否显示层是否可训练。默认为 False
  • layer_range: 一个包含 2 个字符串的列表或 tuple,分别是起始层名称和结束层名称(均包含),指示要在摘要中打印的层范围。它也接受正则表达式模式而不是精确名称。在这种情况下,起始谓词将是它匹配 layer_range[0] 的第一个元素,结束谓词将是它匹配 layer_range[1] 的最后一个元素。默认为 None,表示考虑模型的所有层。

引发

  • ValueError: 如果在模型构建之前调用 summary()

[源]

get_layer 方法

Model.get_layer(name=None, index=None)

根据层的名称(唯一)或索引检索层。

如果同时提供了 nameindex,则 index 优先。索引基于水平图遍历(自下而上)的顺序。

参数

  • name: 字符串,层的名称。
  • index: 整数,层的索引。

返回

一个层实例。