Model 类tf_keras.Model()
一个模型,它将层组合成一个具有训练/推理功能的对象。
参数
keras.Input 对象或一个 keras.Input 对象的字典、列表或元组的组合。keras.Input 对象的张量,或此类张量的字典、列表或元组的组合。请参阅下面的函数式 API 示例。有两种实例化 Model 的方法:
1 - 使用“函数式 API”,从 Input 开始,通过链接层调用来指定模型的正向传播,最后从输入和输出创建模型。
import tensorflow as tf
inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
注意:仅支持输入张量的字典、列表和元组。不支持嵌套输入(例如,列表的列表或字典的字典)。
还可以通过使用中间张量来创建一个新的函数式 API 模型。这使得您可以快速提取模型的子组件。
示例
inputs = keras.Input(shape=(None, None, 3))
processed = keras.layers.RandomCrop(width=32, height=32)(inputs)
conv = keras.layers.Conv2D(filters=2, kernel_size=3)(processed)
pooling = keras.layers.GlobalAveragePooling2D()(conv)
feature = keras.layers.Dense(10)(pooling)
full_model = keras.Model(inputs, feature)
backbone = keras.Model(processed, conv)
activations = keras.Model(conv, feature)
请注意,backbone 和 activations 模型不是使用 keras.Input 对象创建的,而是使用源自 keras.Input 对象的张量创建的。在底层,这些模型将共享层和权重,以便用户可以训练 full_model,并使用 backbone 或 activations 进行特征提取。模型的输入和输出也可以是张量的嵌套结构,创建的模型是标准的函数式 API 模型,支持所有现有 API。
2 - 通过继承 Model 类:在这种情况下,您应该在 __init__() 中定义您的层,并在 call() 中实现模型的正向传播。
import tensorflow as tf
class MyModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
model = MyModel()
如果您继承 Model,可以在 call() 中选择性地包含一个 training 参数(布尔值),您可以使用它来指定训练和推理时的不同行为。
import tensorflow as tf
class MyModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
self.dropout = tf.keras.layers.Dropout(0.5)
def call(self, inputs, training=False):
x = self.dense1(inputs)
if training:
x = self.dropout(x, training=training)
return self.dense2(x)
model = MyModel()
模型创建后,您可以使用 model.compile() 配置模型的损失函数和评估指标,使用 model.fit() 训练模型,或使用模型进行预测 model.predict()。
summary 方法Model.summary(
line_length=None,
positions=None,
print_fn=None,
expand_nested=False,
show_trainable=False,
layer_range=None,
)
打印网络的字符串摘要。
参数
[0.3, 0.6, 0.70, 1.]。默认为 None。stdout。如果 stdout 在您的环境中不起作用,请更改为 print。它将在摘要的每一行上被调用。您可以将其设置为自定义函数以捕获字符串摘要。False。False。layer_range[0] 的第一个元素,结束谓词将是它匹配 layer_range[1] 的最后一个元素。默认 None,考虑模型的全部层。引发
summary()。get_layer 方法Model.get_layer(name=None, index=None)
根据层名称(唯一)或索引检索层。
如果同时提供了 name 和 index,则 index 具有优先权。索引基于水平图遍历的顺序(自下而上)。
参数
返回
一个层实例。