Keras 2 API 文档 / 模型 API / 保存与序列化 / 仅权重保存与加载

仅保存和加载权重

[源代码]

get_weights 方法

Model.get_weights()

获取模型的权重。

返回

一个扁平的 NumPy 数组列表。


[源代码]

set_weights 方法

Model.set_weights(weights)

从 NumPy 数组设置层的权重。

层的权重代表了层的状态。此函数从 NumPy 数组设置权重值。权重值应按层创建的顺序传递。请注意,必须通过调用层来实例化层的权重,然后才能调用此函数。

例如,一个 Dense 层会返回一个包含两个值的列表:核矩阵和偏置向量。这些可以用于设置另一个 Dense 层的权重。

>>> layer_a = tf.keras.layers.Dense(1,
...   kernel_initializer=tf.constant_initializer(1.))
>>> a_out = layer_a(tf.convert_to_tensor([[1., 2., 3.]]))
>>> layer_a.get_weights()
[array([[1.],
       [1.],
       [1.]], dtype=float32), array([0.], dtype=float32)]
>>> layer_b = tf.keras.layers.Dense(1,
...   kernel_initializer=tf.constant_initializer(2.))
>>> b_out = layer_b(tf.convert_to_tensor([[10., 20., 30.]]))
>>> layer_b.get_weights()
[array([[2.],
       [2.],
       [2.]], dtype=float32), array([0.], dtype=float32)]
>>> layer_b.set_weights(layer_a.get_weights())
>>> layer_b.get_weights()
[array([[1.],
       [1.],
       [1.]], dtype=float32), array([0.], dtype=float32)]

参数

  • weights:一个 NumPy 数组列表。数组的数量和它们的形状必须与层的权重的维度数量相匹配(即,它应该与 get_weights 的输出相匹配)。

引发

  • ValueError:如果提供的权重列表与层的规范不匹配。

[源代码]

save_weights 方法

Model.save_weights(filepath, overwrite=True, save_format=None, options=None)

保存所有层权重。

根据 save_format 参数,可以保存为 HDF5 或 TensorFlow 格式。

当保存为 HDF5 格式时,权重文件包含: - layer_names(属性),一个字符串列表(模型层的有序名称)。 - 对于每个层,一个名为 layer.namegroup。 - 对于每个此类层组,一个 weight_names 的组属性,一个字符串列表(层权重张量的有序名称)。 - 对于层中的每个权重,一个存储权重值的 dataset,以权重张量命名。

当保存为 TensorFlow 格式时,网络引用的所有对象都以与 tf.train.Checkpoint 相同的格式保存,包括分配给对象属性的任何 Layer 实例或 Optimizer 实例。对于使用 tf.keras.Model(inputs, outputs) 从输入和输出构建的网络,网络使用的 Layer 实例会被自动跟踪/保存。对于继承自 tf.keras.Model 的用户定义类,Layer 实例必须分配给对象属性,通常在构造函数中。有关详细信息,请参阅 tf.train.Checkpointtf.keras.Model 的文档。

虽然格式相同,但请不要混合使用 save_weightstf.train.Checkpoint。由 Model.save_weights 保存的检查点应该使用 Model.load_weights 加载。使用 tf.train.Checkpoint.save 保存的检查点应该使用相应的 tf.train.Checkpoint.restore 进行恢复。对于训练检查点,请优先使用 tf.train.Checkpoint 而非 save_weights

TensorFlow 格式通过从根对象(对于 save_weights,是 self)开始,并贪婪地匹配属性名称来匹配对象和变量。对于 Model.save,这是 Model,而对于 Checkpoint.save,这是 Checkpoint,即使 Checkpoint 附加了一个模型。这意味着使用 save_weights 保存的 tf.keras.Model,然后加载到附加了 Modeltf.train.Checkpoint(反之亦然)将无法匹配 Model 的变量。有关 TensorFlow 格式的详细信息,请参阅 训练检查点指南

参数

  • filepath:字符串或 PathLike,要将权重保存到的文件的路径。当保存为 TensorFlow 格式时,这是检查点文件的前缀(会生成多个文件)。请注意,后缀 '.h5' 会导致权重以 HDF5 格式保存。
  • overwrite:是否静默覆盖目标位置的任何现有文件,还是向用户提供手动提示。
  • save_format:可以是 'tf' 或 'h5'。以 '.h5' 或 '.keras' 结尾的 filepath,如果 save_formatNone,则默认为 HDF5。否则,None 将变为 'tf'。默认为 None
  • options:可选的 tf.train.CheckpointOptions 对象,用于指定保存权重的选项。

引发

  • ImportError:当尝试以 HDF5 格式保存时,如果 h5py 不可用。

[源代码]

load_weights 方法

Model.load_weights(filepath, skip_mismatch=False, by_name=False, options=None)

从保存的文件加载所有层权重。

保存的文件可以是 SavedModel 文件、.keras 文件(v3 保存格式),或者通过 model.save_weights() 创建的文件。

默认情况下,权重是根据网络的拓扑结构加载的。这意味着架构应该与保存权重时相同。请注意,没有权重的层在拓扑排序中不被考虑,因此只要它们没有权重,添加或删除层都是可以的。

部分权重加载

如果您修改了模型,例如添加了一个新层(带有权重)或更改了层的权重的形状,您可以通过设置 skip_mismatch=True 来选择忽略错误并继续加载。在这种情况下,任何权重不匹配的层都将被跳过。对于每个跳过的层都会显示一个警告。

按名称加载权重

如果您的权重是作为通过 model.save_weights() 创建的 .h5 文件保存的,您可以使用参数 by_name=True

在这种情况下,只有当层名称相同时,权重才会被加载到层中。这对于需要微调或迁移学习的模型(其中一些层已更改)非常有用。

请注意,当从 .keras v3 格式或 TensorFlow SavedModel 格式加载权重时,仅支持拓扑加载(by_name=False)。

参数

  • filepath:字符串,要加载的权重文件的路径。对于 TensorFlow 格式的权重文件,这是文件前缀(与传递给 save_weights() 的相同)。这也可以是 model.save() 保存的 SavedModel 或 .keras 文件的路径(v3 保存格式)。
  • skip_mismatch:布尔值,是否跳过权重数量不匹配或权重形状不匹配的层的加载。
  • by_name:布尔值,是按名称还是按拓扑顺序加载权重。对于 .keras v3 格式或 TensorFlow SavedModel 格式的权重文件,仅支持拓扑加载。
  • options:可选的 tf.train.CheckpointOptions 对象,用于指定加载权重的选项(仅对 SavedModel 文件有效)。