SwapEMAWeights 类keras.callbacks.SwapEMAWeights(swap_on_epoch=False)
在评估前后交换模型权重和EMA权重。
此回调在模型评估之前,用优化器的EMA权重(过去模型权重的指数移动平均值,实现“Polyak平均”)的值替换模型权重值,并在评估后恢复之前的权重。
SwapEMAWeights 回调应与设置了 use_ema=True 的优化器结合使用。
请注意,为了节省内存,权重是就地交换的。如果您在其他回调中修改EMA权重或模型权重,则行为未定义。
示例
# Remember to set `use_ema=True` in the optimizer
optimizer = SGD(use_ema=True)
model.compile(optimizer=optimizer, loss=..., metrics=...)
# Metrics will be computed with EMA weights
model.fit(X_train, Y_train, callbacks=[SwapEMAWeights()])
# If you want to save model checkpoint with EMA weights, you can set
# `swap_on_epoch=True` and place ModelCheckpoint after SwapEMAWeights.
model.fit(
X_train,
Y_train,
callbacks=[SwapEMAWeights(swap_on_epoch=True), ModelCheckpoint(...)]
)
参数
on_epoch_begin() 和 on_epoch_end() 执行交换。如果您希望其他回调(如 ModelCheckpoint)使用EMA权重,这将非常有用。默认为 False。