Keras 3 API 文档 / 模型 API / 保存与序列化 / 仅保存和加载权重

仅保存和加载权重

[源代码]

save_weights 方法

Model.save_weights(filepath, overwrite=True, max_shard_size=None)

将所有权重保存到单个文件或分片文件中。

默认情况下,权重将保存到单个 .weights.h5 文件中。如果启用分片(max_shard_size 不是 None),权重将保存到多个文件中,每个文件的大小不超过 max_shard_size (以 GB 为单位)。此外,一个配置文件 .weights.json 将包含分片文件的元数据。

保存的分片文件包含

  • *.weights.json:包含 'metadata' 和 'weight_map' 的配置文件。
  • *_xxxxxx.weights.h5:仅包含权重的分片文件。

参数

  • filepathstrpathlib.Path 对象。权重将保存到的路径。进行分片时,filepath 必须以 .weights.json 结尾。如果提供了 .weights.h5,它将被覆盖。
  • overwrite:是否覆盖目标位置的任何现有权重,或者改为通过交互式提示询问用户。
  • max_shard_sizeintfloat。每个分片文件的最大大小(以 GB 为单位)。如果为 None,则不会进行分片。默认为 None

示例

# Instantiate a EfficientNetV2L model with about 454MB of weights.
model = keras.applications.EfficientNetV2L(weights=None)

# Save the weights in a single file.
model.save_weights("model.weights.h5")

# Save the weights in sharded files. Use `max_shard_size=0.25` means
# each sharded file will be at most ~250MB.
model.save_weights("model.weights.json", max_shard_size=0.25)

# Load the weights in a new model with the same architecture.
loaded_model = keras.applications.EfficientNetV2L(weights=None)
loaded_model.load_weights("model.weights.h5")
x = keras.random.uniform((1, 480, 480, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))

# Load the sharded weights in a new model with the same architecture.
loaded_model = keras.applications.EfficientNetV2L(weights=None)
loaded_model.load_weights("model.weights.json")
x = keras.random.uniform((1, 480, 480, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))

[源代码]

load_weights 方法

Model.load_weights(filepath, skip_mismatch=False, **kwargs)

从单个文件或分片文件加载权重。

权重的加载是基于网络的拓扑结构的。这意味着架构应与保存权重时相同。请注意,没有权重的层在拓扑顺序中不被考虑在内,因此只要添加或移除的层没有权重,这是可以的。

部分权重加载

如果您修改了模型,例如添加了一个新层(带有权重)或更改了某一层权重的形状,您可以通过设置 skip_mismatch=True 来选择忽略错误并继续加载。在这种情况下,任何权重不匹配的层都将被跳过。对于每个跳过的层,都会显示一个警告。

分片

加载分片权重时,指定以 *.weights.json 结尾的 filepath 非常重要,该文件用作配置文件。此外,分片文件 *_xxxxx.weights.h5 必须与配置文件位于同一目录中。

参数

  • filepathstrpathlib.Path 对象。权重将从中加载的路径。进行分片时,filepath 必须以 .weights.json 结尾。
  • skip_mismatch:布尔值,是否跳过加载权重数量不匹配或权重形状不匹配的层。

示例

# Load the weights in a single file.
model.load_weights("model.weights.h5")

# Load the weights in sharded files.
model.load_weights("model.weights.json")