Keras 3 API 文档 / 模型 API / 保存与序列化 / 仅权重保存与加载

仅保存和加载权重

[源代码]

save_weights 方法

Model.save_weights(filepath, overwrite=True, max_shard_size=None)

将所有权重保存到单个文件或分片文件中。

默认情况下,权重将保存到单个 .weights.h5 文件中。如果启用了分片(max_shard_size 不为 None),权重将保存到多个文件中,每个文件的最大大小为 max_shard_size (GB)。此外,一个配置文件 .weights.json 将包含分片文件的元数据。

保存的分片文件包含:

  • *.weights.json: 包含 'metadata' 和 'weight_map' 的配置文件。
  • *_xxxxxx.weights.h5: 仅包含权重的分片文件。

参数

  • filepath: strpathlib.Path 对象。保存权重的路径。进行分片时,filepath 必须以 .weights.json 结尾。如果提供了 .weights.h5,它将被覆盖。
  • overwrite: 是否覆盖目标位置已存在的权重,或者通过交互式提示询问用户。
  • max_shard_size: intfloat。每个分片文件的最大大小 (GB)。如果为 None,则不进行分片。默认为 None

示例

# Instantiate a EfficientNetV2L model with about 454MB of weights.
model = keras.applications.EfficientNetV2L(weights=None)

# Save the weights in a single file.
model.save_weights("model.weights.h5")

# Save the weights in sharded files. Use `max_shard_size=0.25` means
# each sharded file will be at most ~250MB.
model.save_weights("model.weights.json", max_shard_size=0.25)

# Load the weights in a new model with the same architecture.
loaded_model = keras.applications.EfficientNetV2L(weights=None)
loaded_model.load_weights("model.weights.h5")
x = keras.random.uniform((1, 480, 480, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))

# Load the sharded weights in a new model with the same architecture.
loaded_model = keras.applications.EfficientNetV2L(weights=None)
loaded_model.load_weights("model.weights.json")
x = keras.random.uniform((1, 480, 480, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))

[源代码]

load_weights 方法

Model.load_weights(filepath, skip_mismatch=False, **kwargs)

从单个文件或分片文件中加载权重。

权重是根据网络的拓扑结构加载的。这意味着架构应该与保存权重时相同。请注意,没有权重的层在拓扑排序中不被考虑,因此添加或删除没有权重的层是可以的。

部分权重加载

如果您修改了模型,例如通过添加一个新层(带有权重)或改变层的权重的形状,您可以选择通过设置 skip_mismatch=True 来忽略错误并继续加载。在这种情况下,任何权重不匹配的层都将被跳过。每个被跳过的层都会显示一个警告。

分片

加载分片权重时,重要的是指定以 *.weights.json 结尾的 filepath,它用作配置文件。此外,分片文件 *_xxxxx.weights.h5 必须与配置文件在同一个目录中。

参数

  • filepath: strpathlib.Path 对象。保存权重的路径。进行分片时,filepath 必须以 .weights.json 结尾。
  • skip_mismatch: 布尔值,是否跳过加载权重数量不匹配或权重形状不匹配的层。

示例

# Load the weights in a single file.
model.load_weights("model.weights.h5")

# Load the weights in sharded files.
model.load_weights("model.weights.json")