Keras 3 API 文档 / 回调 API / 学习率调度器 (LearningRateScheduler)

学习率调度器 (LearningRateScheduler)

[源代码]

LearningRateScheduler

keras.callbacks.LearningRateScheduler(schedule, verbose=0)

学习率调度器。

在每个 epoch 的开始,此回调从 __init__ 中提供的 schedule 函数获取更新后的学习率值,使用当前 epoch 和当前学习率作为输入,并将更新后的学习率应用于优化器。

参数

  • schedule:一个函数,它接受一个 epoch 索引(整数,从 0 开始索引)和当前学习率(浮点数)作为输入,并返回一个新的学习率作为输出(浮点数)。
  • verbose:整数。0:静默,1:记录更新消息。

示例

>>> # This function keeps the initial learning rate for the first ten epochs
>>> # and decreases it exponentially after that.
>>> def scheduler(epoch, lr):
...     if epoch < 10:
...         return lr
...     else:
...         return lr * ops.exp(-0.1)
>>>
>>> model = keras.models.Sequential([keras.layers.Dense(10)])
>>> model.compile(keras.optimizers.SGD(), loss='mse')
>>> round(model.optimizer.learning_rate, 5)
0.01
>>> callback = keras.callbacks.LearningRateScheduler(scheduler)
>>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
...                     epochs=15, callbacks=[callback], verbose=0)
>>> round(model.optimizer.learning_rate, 5)
0.00607