Model
类keras.Model()
一个将层分组到具有训练/推理功能的对象的模型。
有三种实例化 Model
的方法
从 Input
开始,通过链接层调用来指定模型的正向传播,最后,从输入和输出创建模型。
inputs = keras.Input(shape=(37,))
x = keras.layers.Dense(32, activation="relu")(inputs)
outputs = keras.layers.Dense(5, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
注意:仅支持字典、列表和元组形式的输入张量。不支持嵌套输入(例如列表的列表或字典的字典)。
还可以通过使用中间张量来创建新的函数式 API 模型。这使您可以快速提取模型的子组件。
示例
inputs = keras.Input(shape=(None, None, 3))
processed = keras.layers.RandomCrop(width=128, height=128)(inputs)
conv = keras.layers.Conv2D(filters=32, kernel_size=3)(processed)
pooling = keras.layers.GlobalAveragePooling2D()(conv)
feature = keras.layers.Dense(10)(pooling)
full_model = keras.Model(inputs, feature)
backbone = keras.Model(processed, conv)
activations = keras.Model(conv, feature)
请注意,`backbone` 和 `activations` 模型并非使用 `keras.Input` 对象创建,而是使用源自 `keras.Input` 对象的张量创建。在底层,这些模型将共享层和权重,因此用户可以训练 `full_model`,并使用 `backbone` 或 `activations` 进行特征提取。模型的输入和输出也可以是张量的嵌套结构,并且创建的模型是标准的函数式 API 模型,支持所有现有 API。
Model
类在这种情况下,您应该在 __init__()
中定义您的层,并且应该在 call()
中实现模型的正向传播。
class MyModel(keras.Model):
def __init__(self):
super().__init__()
self.dense1 = keras.layers.Dense(32, activation="relu")
self.dense2 = keras.layers.Dense(5, activation="softmax")
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
model = MyModel()
如果您子类化 Model
,您可以选择在 call()
中包含一个 training
参数(布尔值),您可以使用它来指定训练和推理中的不同行为
class MyModel(keras.Model):
def __init__(self):
super().__init__()
self.dense1 = keras.layers.Dense(32, activation="relu")
self.dense2 = keras.layers.Dense(5, activation="softmax")
self.dropout = keras.layers.Dropout(0.5)
def call(self, inputs, training=False):
x = self.dense1(inputs)
x = self.dropout(x, training=training)
return self.dense2(x)
model = MyModel()
模型创建后,您可以使用 model.compile()
配置模型的损失和指标,使用 model.fit()
训练模型,或使用 model.predict()
进行预测。
Sequential
类此外,keras.Sequential
是模型的一种特殊情况,其中模型纯粹是单输入、单输出层的堆栈。
model = keras.Sequential([
keras.Input(shape=(None, None, 3)),
keras.layers.Conv2D(filters=32, kernel_size=3),
])
summary
方法Model.summary(
line_length=None,
positions=None,
print_fn=None,
expand_nested=False,
show_trainable=False,
layer_range=None,
)
打印网络的字符串摘要。
参数
[0.3, 0.6, 0.70, 1.]
。默认为 None
。stdout
。如果 stdout
在您的环境中不起作用,请更改为 print
。它将针对摘要的每一行调用。您可以将其设置为自定义函数以捕获字符串摘要。False
。False
。引发
summary()
。get_layer
方法Model.get_layer(name=None, index=None)
根据其名称(唯一)或索引检索层。
如果同时提供了 name
和 index
,则 index
将优先。索引基于水平图遍历(自下而上)的顺序。
参数
返回
一个层实例。