save_weights
方法Model.save_weights(filepath, overwrite=True, max_shard_size=None)
将所有权重保存到单个文件或分片文件中。
默认情况下,权重将保存到单个 .weights.h5
文件中。如果启用分片(max_shard_size
不是 None
),权重将保存到多个文件中,每个文件的大小不超过 max_shard_size
(以 GB 为单位)。此外,一个配置文件 .weights.json
将包含分片文件的元数据。
保存的分片文件包含
*.weights.json
:包含 'metadata' 和 'weight_map' 的配置文件。*_xxxxxx.weights.h5
:仅包含权重的分片文件。参数
str
或 pathlib.Path
对象。权重将保存到的路径。进行分片时,filepath 必须以 .weights.json
结尾。如果提供了 .weights.h5
,它将被覆盖。int
或 float
。每个分片文件的最大大小(以 GB 为单位)。如果为 None
,则不会进行分片。默认为 None
。示例
# Instantiate a EfficientNetV2L model with about 454MB of weights.
model = keras.applications.EfficientNetV2L(weights=None)
# Save the weights in a single file.
model.save_weights("model.weights.h5")
# Save the weights in sharded files. Use `max_shard_size=0.25` means
# each sharded file will be at most ~250MB.
model.save_weights("model.weights.json", max_shard_size=0.25)
# Load the weights in a new model with the same architecture.
loaded_model = keras.applications.EfficientNetV2L(weights=None)
loaded_model.load_weights("model.weights.h5")
x = keras.random.uniform((1, 480, 480, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))
# Load the sharded weights in a new model with the same architecture.
loaded_model = keras.applications.EfficientNetV2L(weights=None)
loaded_model.load_weights("model.weights.json")
x = keras.random.uniform((1, 480, 480, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))
load_weights
方法Model.load_weights(filepath, skip_mismatch=False, **kwargs)
从单个文件或分片文件加载权重。
权重的加载是基于网络的拓扑结构的。这意味着架构应与保存权重时相同。请注意,没有权重的层在拓扑顺序中不被考虑在内,因此只要添加或移除的层没有权重,这是可以的。
部分权重加载
如果您修改了模型,例如添加了一个新层(带有权重)或更改了某一层权重的形状,您可以通过设置 skip_mismatch=True
来选择忽略错误并继续加载。在这种情况下,任何权重不匹配的层都将被跳过。对于每个跳过的层,都会显示一个警告。
分片
加载分片权重时,指定以 *.weights.json
结尾的 filepath
非常重要,该文件用作配置文件。此外,分片文件 *_xxxxx.weights.h5
必须与配置文件位于同一目录中。
参数
str
或 pathlib.Path
对象。权重将从中加载的路径。进行分片时,filepath 必须以 .weights.json
结尾。示例
# Load the weights in a single file.
model.load_weights("model.weights.h5")
# Load the weights in sharded files.
model.load_weights("model.weights.json")