fit() 中的行为作者: fchollet
创建日期 2023/06/27
最后修改日期 2023/06/27
描述: 用 JAX 重写 Model 类的训练步骤。
当你进行监督学习时,你可以使用 `fit()`,一切都会顺利进行。
当你需要控制每一个细节时,你可以从头开始编写自己的训练循环。
但是,如果你需要自定义训练算法,但又想利用 `fit()` 的便捷功能,如回调函数、内置分布式支持或步进融合,该怎么办?
Keras 的核心原则是循序渐进地暴露复杂性。你总应该能够以渐进的方式进入更低级别的流程。如果高级功能不能完全满足你的用例,你不应该就此放弃。你应该能够在保留相应程度的高级便利性的同时,获得对细节的更多控制。
当你需要自定义 `fit()` 的行为时,你应该重写 `Model` 类的训练步骤函数。这个函数由 `fit()` 为每个数据批次调用。然后,你就可以像往常一样调用 `fit()`——它将运行你自己的学习算法。
请注意,这种模式不会阻止你使用函数式 API 构建模型。无论你构建的是 `Sequential` 模型、函数式 API 模型还是子类化模型,都可以这样做。
让我们来看看它是如何工作的。
import os
# This guide can only be run with the JAX backend.
os.environ["KERAS_BACKEND"] = "jax"
import jax
import keras
import numpy as np
让我们从一个简单的例子开始
class CustomModel(keras.Model):
def compute_loss_and_updates(
self,
trainable_variables,
non_trainable_variables,
metrics_variables,
x,
y,
sample_weight,
training=False,
):
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables,
non_trainable_variables,
x,
training=training,
)
loss, (
trainable_variables,
non_trainable_variables,
metrics_variables,
) = self.stateless_compute_loss(
trainable_variables,
non_trainable_variables,
metrics_variables,
x=x,
y=y,
y_pred=y_pred,
sample_weight=sample_weight,
training=training,
)
return loss, (y_pred, non_trainable_variables, metrics_variables)
def train_step(self, state, data):
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
) = state
x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
# Get the gradient function.
grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)
# Compute the gradients.
(loss, (y_pred, non_trainable_variables, metrics_variables)), grads = grad_fn(
trainable_variables,
non_trainable_variables,
metrics_variables,
x,
y,
sample_weight,
training=True,
)
# Update trainable variables and optimizer variables.
trainable_variables, optimizer_variables = self.optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# Update metrics.
new_metrics_vars = []
logs = {}
for metric in self.metrics:
this_metric_vars = metrics_variables[
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
]
if metric.name == "loss":
this_metric_vars = metric.stateless_update_state(
this_metric_vars, loss, sample_weight=sample_weight
)
else:
this_metric_vars = metric.stateless_update_state(
this_metric_vars, y, y_pred, sample_weight=sample_weight
)
logs[metric.name] = metric.stateless_result(this_metric_vars)
new_metrics_vars += this_metric_vars
# Return metric logs and updated state variables.
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
new_metrics_vars,
)
return logs, state
让我们试试这个
# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
# Just use `fit` as usual
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)
Epoch 1/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - mae: 0.3765 - loss: 0.2093
Epoch 2/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 232us/step - mae: 0.3634 - loss: 0.1968
Epoch 3/3
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 228us/step - mae: 0.3543 - loss: 0.1877
<keras.src.callbacks.history.History at 0x15d8472e0>
自然,你可以跳过在 `compile()` 中传递损失函数,而是在 `train_step` 中手动完成所有操作。指标也是如此。
这是一个更底层的示例,它只使用 `compile()` 来配置优化器
class CustomModel(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss_tracker = keras.metrics.Mean(name="loss")
self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
self.loss_fn = keras.losses.MeanSquaredError()
def compute_loss_and_updates(
self,
trainable_variables,
non_trainable_variables,
x,
y,
sample_weight,
training=False,
):
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables,
non_trainable_variables,
x,
training=training,
)
loss = self.loss_fn(y, y_pred, sample_weight=sample_weight)
return loss, (y_pred, non_trainable_variables)
def train_step(self, state, data):
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
) = state
x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
# Get the gradient function.
grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)
# Compute the gradients.
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
trainable_variables,
non_trainable_variables,
x,
y,
sample_weight,
training=True,
)
# Update trainable variables and optimizer variables.
trainable_variables, optimizer_variables = self.optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# Update metrics.
loss_tracker_vars = metrics_variables[: len(self.loss_tracker.variables)]
mae_metric_vars = metrics_variables[len(self.loss_tracker.variables) :]
loss_tracker_vars = self.loss_tracker.stateless_update_state(
loss_tracker_vars, loss, sample_weight=sample_weight
)
mae_metric_vars = self.mae_metric.stateless_update_state(
mae_metric_vars, y, y_pred, sample_weight=sample_weight
)
logs = {}
logs[self.loss_tracker.name] = self.loss_tracker.stateless_result(
loss_tracker_vars
)
logs[self.mae_metric.name] = self.mae_metric.stateless_result(mae_metric_vars)
new_metrics_vars = loss_tracker_vars + mae_metric_vars
# Return metric logs and updated state variables.
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
new_metrics_vars,
)
return logs, state
@property
def metrics(self):
# We list our `Metric` objects here so that `reset_states()` can be
# called automatically at the start of each epoch
# or at the start of `evaluate()`.
return [self.loss_tracker, self.mae_metric]
# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
# We don't pass a loss or metrics here.
model.compile(optimizer="adam")
# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)
Epoch 1/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.9146 - mae: 0.8248
Epoch 2/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 225us/step - loss: 0.4087 - mae: 0.5116
Epoch 3/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 230us/step - loss: 0.2766 - mae: 0.4233
Epoch 4/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 202us/step - loss: 0.2631 - mae: 0.4106
Epoch 5/5
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 198us/step - loss: 0.2604 - mae: 0.4070
<keras.src.callbacks.history.History at 0x15dccb0a0>
如果你想对 `model.evaluate()` 的调用做同样的事情怎么办?那么你可以以完全相同的方式重写 `test_step`。下面是它的样子:
class CustomModel(keras.Model):
def test_step(self, state, data):
# Unpack the data.
x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
(
trainable_variables,
non_trainable_variables,
metrics_variables,
) = state
# Compute predictions and loss.
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables,
non_trainable_variables,
x,
training=False,
)
loss, (
trainable_variables,
non_trainable_variables,
metrics_variables,
) = self.stateless_compute_loss(
trainable_variables,
non_trainable_variables,
metrics_variables,
x=x,
y=y,
y_pred=y_pred,
sample_weight=sample_weight,
training=False,
)
# Update metrics.
new_metrics_vars = []
logs = {}
for metric in self.metrics:
this_metric_vars = metrics_variables[
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
]
if metric.name == "loss":
this_metric_vars = metric.stateless_update_state(
this_metric_vars, loss, sample_weight=sample_weight
)
else:
this_metric_vars = metric.stateless_update_state(
this_metric_vars, y, y_pred, sample_weight=sample_weight
)
logs[metric.name] = metric.stateless_result(this_metric_vars)
new_metrics_vars += this_metric_vars
# Return metric logs and updated state variables.
state = (
trainable_variables,
non_trainable_variables,
new_metrics_vars,
)
return logs, state
# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"])
# Evaluate with our custom test_step
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(x, y, return_dict=True)
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - mae: 0.5369 - loss: 0.4170
{'compile_metrics': {'mae': Array(0.5368782, dtype=float32)},
'loss': Array(0.41702443, dtype=float32)}
就是这样!