get_weights
方法Model.get_weights()
检索模型的权重。
返回
一个由 Numpy 数组组成的扁平列表。
set_weights
方法Model.set_weights(weights)
从 NumPy 数组设置层的权重。
层的权重代表了层的状态。此函数从 numpy 数组设置权重值。权重值应按照层创建它们的顺序传入。请注意,必须通过调用层来实例化层的权重,然后才能调用此函数。
例如,一个 Dense
层返回一个包含两个值的列表:核矩阵和偏置向量。这些值可以用来设置另一个 Dense
层的权重
>>> layer_a = tf.keras.layers.Dense(1,
... kernel_initializer=tf.constant_initializer(1.))
>>> a_out = layer_a(tf.convert_to_tensor([[1., 2., 3.]]))
>>> layer_a.get_weights()
[array([[1.],
[1.],
[1.]], dtype=float32), array([0.], dtype=float32)]
>>> layer_b = tf.keras.layers.Dense(1,
... kernel_initializer=tf.constant_initializer(2.))
>>> b_out = layer_b(tf.convert_to_tensor([[10., 20., 30.]]))
>>> layer_b.get_weights()
[array([[2.],
[2.],
[2.]], dtype=float32), array([0.], dtype=float32)]
>>> layer_b.set_weights(layer_a.get_weights())
>>> layer_b.get_weights()
[array([[1.],
[1.],
[1.]], dtype=float32), array([0.], dtype=float32)]
参数
get_weights
的输出匹配)。抛出
save_weights
方法Model.save_weights(filepath, overwrite=True, save_format=None, options=None)
保存所有层的权重。
根据 save_format
参数,保存为 HDF5 格式或 TensorFlow 格式。
以 HDF5 格式保存时,权重文件包含: - layer_names
(属性),一个字符串列表(按顺序排列的模型层名称)。 - 对于每一层,一个名为 layer.name
的 group
。 - 对于每个此类层组,一个组属性 weight_names
,一个字符串列表(按顺序排列的层权重张量名称)。 - 对于层中的每个权重,一个存储权重值的数据集,其名称与权重张量名称相同。
以 TensorFlow 格式保存时,网络引用的所有对象都以与 tf.train.Checkpoint
相同的格式保存,包括分配给对象属性的任何 Layer
实例或 Optimizer
实例。对于使用 tf.keras.Model(inputs, outputs)
从输入和输出构建的网络,网络使用的 Layer
实例会被自动跟踪/保存。对于继承自 tf.keras.Model
的用户自定义类,Layer
实例必须被分配给对象属性,通常在构造函数中进行。有关详细信息,请参阅 tf.train.Checkpoint
和 tf.keras.Model
的文档。
虽然格式相同,但不要混用 save_weights
和 tf.train.Checkpoint
。通过 Model.save_weights
保存的检查点应使用 Model.load_weights
加载。使用 tf.train.Checkpoint.save
保存的检查点应使用相应的 tf.train.Checkpoint.restore
恢复。对于训练检查点,优先使用 tf.train.Checkpoint
而非 save_weights
。
TensorFlow 格式通过从根对象(对于 save_weights
而言是 self
)开始,并贪婪地匹配属性名来匹配对象和变量。对于 Model.save
,根对象是 Model
;对于 Checkpoint.save
,根对象是 Checkpoint
,即使该 Checkpoint
附加了模型。这意味着使用 save_weights
保存一个 tf.keras.Model
并加载到一个附加了 Model
的 tf.train.Checkpoint
中(反之亦然)将无法匹配 Model
的变量。有关 TensorFlow 格式的详细信息,请参阅训练检查点指南。
参数
save_format
为 None
,以 '.h5' 或 '.keras' 结尾的 filepath
将默认采用 HDF5 格式。否则,None
将变为 'tf'。默认为 None
。tf.train.CheckpointOptions
对象,用于指定保存权重的选项。抛出
h5py
不可用。load_weights
方法Model.load_weights(filepath, skip_mismatch=False, by_name=False, options=None)
从已保存的文件加载所有层的权重。
保存的文件可以是 SavedModel 文件、.keras
文件(v3 保存格式),或通过 model.save_weights()
创建的文件。
默认情况下,权重基于网络的拓扑结构加载。这意味着架构应与保存权重时相同。请注意,没有权重的层在拓扑排序中不被考虑,因此只要它们没有权重,添加或移除层是可以的。
部分权重加载
如果您修改了模型,例如添加了一个新层(带有权重)或改变了层的权重形状,您可以通过设置 skip_mismatch=True
来选择忽略错误并继续加载。在这种情况下,任何权重不匹配的层将被跳过。对于每个跳过的层,将显示一条警告。
按名称加载权重
如果您的权重是通过 model.save_weights()
保存为 .h5
文件,您可以使用参数 by_name=True
。
在这种情况下,只有当层具有相同名称时,权重才会被加载到层中。这对于微调或迁移学习模型很有用,因为在这些模型中,某些层可能已更改。
请注意,从 .keras
v3 格式或 TensorFlow SavedModel 格式加载权重时,仅支持拓扑加载(by_name=False
)。
参数
save_weights()
的相同)。这也可以是 SavedModel 或通过 model.save()
保存的 .keras
文件(v3 保存格式)的路径。.keras
v3 格式或 TensorFlow SavedModel 格式的权重文件加载时,仅支持拓扑加载。tf.train.CheckpointOptions
对象,用于指定加载权重的选项(仅对 SavedModel 文件有效)。