Keras 3 API 文档 / 回调 API / SwapEMAWeights

SwapEMAWeights

[源代码]

SwapEMAWeights

keras.callbacks.SwapEMAWeights(swap_on_epoch=False)

在评估前后交换模型权重和 EMA 权重。

此回调在模型评估之前将模型权重值替换为优化器的 EMA 权重值(即过去模型权重值的指数移动平均,实现了“Polyak 平均”),并在评估之后恢复之前的权重。

SwapEMAWeights 回调应与设置了 use_ema=True 的优化器一起使用。

请注意,为了节省内存,权重是就地交换的。如果您在其他回调中修改 EMA 权重或模型权重,则行为是未定义的。

示例

# Remember to set `use_ema=True` in the optimizer
optimizer = SGD(use_ema=True)
model.compile(optimizer=optimizer, loss=..., metrics=...)

# Metrics will be computed with EMA weights
model.fit(X_train, Y_train, callbacks=[SwapEMAWeights()])

# If you want to save model checkpoint with EMA weights, you can set
# `swap_on_epoch=True` and place ModelCheckpoint after SwapEMAWeights.
model.fit(
    X_train,
    Y_train,
    callbacks=[SwapEMAWeights(swap_on_epoch=True), ModelCheckpoint(...)]
)

参数

  • swap_on_epoch:是否在 on_epoch_begin()on_epoch_end() 时执行交换。如果您想将 EMA 权重用于其他回调(例如 ModelCheckpoint),这会很有用。默认为 False