作者: fchollet
创建日期 2023/06/25
最后修改日期 2023/06/25
描述: 用 JAX 从头开始编写底层训练和评估循环。
import os
# This guide can only be run with the jax backend.
os.environ["KERAS_BACKEND"] = "jax"
import jax
# We import TF so we can use tf.data.
import tensorflow as tf
import keras
import numpy as np
Keras 提供了默认的训练和评估循环,即 fit() 和 evaluate()。它们的使用在指南 使用内置方法进行训练和评估 中进行了介绍。
如果您想在仍然利用 fit() 的便利性的同时自定义模型的学习算法(例如,使用 fit() 训练 GAN),您可以继承 Model 类并实现自己的 train_step() 方法,该方法会在 fit() 过程中被反复调用。
现在,如果您想要对训练和评估进行非常底层的控制,您应该从头开始编写自己的训练和评估循环。本指南就是关于这个的。
要编写自定义训练循环,我们需要以下要素:
keras.optimizers 中选择一个优化器,也可以从 optax 包中选择一个。tf.data 加载数据,所以我们将使用它。让我们把它们都准备好。
首先,我们获取模型和 MNIST 数据集。
def get_model():
inputs = keras.Input(shape=(784,), name="digits")
x1 = keras.layers.Dense(64, activation="relu")(inputs)
x2 = keras.layers.Dense(64, activation="relu")(x1)
outputs = keras.layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
model = get_model()
# Prepare the training dataset.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)
接下来,这是损失函数和优化器。在本例中,我们将使用 Keras 优化器。
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
# Instantiate an optimizer.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
让我们使用自定义训练循环通过小批量梯度训练我们的模型。
在 JAX 中,梯度是通过元编程计算的:您调用 jax.grad(或 jax.value_and_grad)对一个函数,以创建一个该函数的梯度计算函数。
所以我们需要做的第一件事是编写一个返回损失值的函数。这就是我们将用于生成梯度函数的函数。大致如下:
def compute_loss(x, y):
...
return loss
一旦有了这样一个函数,您就可以通过元编程像这样计算梯度:
grad_fn = jax.grad(compute_loss)
grads = grad_fn(x, y)
通常,您不仅仅想获取梯度值,还想获取损失值。您可以通过使用 jax.value_and_grad 而不是 jax.grad 来实现这一点。
grad_fn = jax.value_and_grad(compute_loss)
loss, grads = grad_fn(x, y)
在 JAX 中,一切都必须是无状态函数——所以我们的损失计算函数也必须是无状态的。这意味着所有 Keras 变量(例如权重张量)都必须作为函数输入传递,并且在前向传播过程中更新的任何变量都必须作为函数输出返回。函数不能有副作用。
在前向传播过程中,Keras 模型的非训练变量可能会被更新。这些变量可以是,例如,RNG 种子状态变量或 BatchNormalization 统计量。我们需要返回它们。所以我们需要这样的东西:
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
...
return loss, non_trainable_variables
一旦有了这样一个函数,您就可以通过在 value_and_grad 中指定 has_aux 来获取梯度函数:它告诉 JAX 损失计算函数返回的输出多于损失。请注意,损失应始终是第一个输出。
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
现在我们已经建立了基础,让我们来实现这个 compute_loss_and_updates 函数。Keras 模型有一个 stateless_call 方法,这在这里会很有用。它的工作方式与 model.__call__ 相同,但它要求您显式地传递模型中所有变量的值,并且它不仅返回 __call__ 的输出,还返回(可能更新的)非训练变量。
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x, training=True
)
loss = loss_fn(y, y_pred)
return loss, non_trainable_variables
我们来获取梯度函数。
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
接下来,让我们实现端到端的训练步,这个函数将同时进行前向传播、计算损失、计算梯度,并使用优化器来更新训练变量。这个函数也必须是无状态的,因此它将接收一个 state 元组作为输入,该元组包含我们将使用的所有状态元素。
trainable_variables 和 non_trainable_variables:模型的变量。optimizer_variables:优化器的状态变量,例如动量累加器。要更新训练变量,我们使用优化器的无状态方法 stateless_apply。它等同于 optimizer.apply(),但它总是要求传递 trainable_variables 和 optimizer_variables。它返回更新后的训练变量和更新后的优化器变量。
def train_step(state, data):
trainable_variables, non_trainable_variables, optimizer_variables = state
x, y = data
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# Return updated state
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
)
jax.jit 来加速默认情况下,JAX 操作是立即执行的,就像 TensorFlow 立即执行模式和 PyTorch 立即执行模式一样。同样,就像 TensorFlow 立即执行模式和 PyTorch 立即执行模式一样,它的速度相当慢——立即执行模式最好用作调试环境,而不是用来做任何实际工作。所以让我们通过编译来加快 train_step 的速度。
当您有一个无状态的 JAX 函数时,您可以通过 @jax.jit 装饰器将其编译为 XLA。它将在第一次执行时进行跟踪,并在后续执行中执行跟踪的图(这就像 @tf.function(jit_compile=True)。让我们试试。
@jax.jit
def train_step(state, data):
trainable_variables, non_trainable_variables, optimizer_variables = state
x, y = data
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# Return updated state
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
)
我们现在准备好训练我们的模型了。训练循环本身很简单:我们只需反复调用 loss, state = train_step(state, data)。
注意
tf.data.Dataset 生成的 TF 张量转换为 NumPy,然后再将它们传递给我们的 JAX 函数。# Build optimizer variables.
optimizer.build(model.trainable_variables)
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
state = trainable_variables, non_trainable_variables, optimizer_variables
# Training loop
for step, data in enumerate(train_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = train_step(state, data)
# Log every 100 batches.
if step % 100 == 0:
print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
print(f"Seen so far: {(step + 1) * batch_size} samples")
Training loss (for 1 batch) at step 0: 96.2726
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.0853
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.6535
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 1.2679
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.7563
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.7154
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 1.0267
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.6860
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.7306
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.4571
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.6023
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.9140
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.4224
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.6696
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.1399
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.5761
Seen so far: 48032 samples
这里一个关键的注意事项是,循环完全是无状态的——附加到模型上的变量(model.weights)在循环过程中永远不会被更新。它们的新值只存储在 state 元组中。这意味着在某个时候,在保存模型之前,您应该将新的变量值重新附加到模型上。
只需在您想更新的每个模型变量上调用 variable.assign(new_value)。
trainable_variables, non_trainable_variables, optimizer_variables = state
for variable, value in zip(model.trainable_variables, trainable_variables):
variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
variable.assign(value)
让我们在这个基本的训练循环中添加指标监控。
您可以轻松地在这些从头开始编写的训练循环中重用内置的 Keras 指标(或您自己编写的自定义指标)。流程如下:
train_step 参数和 compute_loss_and_updates 参数中包含 metric_variables。compute_loss_and_updates 函数中调用 metric.stateless_update_state()。它等同于 update_state() —— 只是无状态的。train_step 之外(在立即执行的范围内)显示指标的当前值时,将新的指标变量值附加到指标对象并调用 metric.result()。metric.reset_state()。让我们利用这些知识在训练结束后计算训练和验证数据的 CategoricalAccuracy。
# Get a fresh model
model = get_model()
# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
# Prepare the metrics.
train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()
def compute_loss_and_updates(
trainable_variables, non_trainable_variables, metric_variables, x, y
):
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = loss_fn(y, y_pred)
metric_variables = train_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, (non_trainable_variables, metric_variables)
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
@jax.jit
def train_step(state, data):
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
) = state
x, y = data
(loss, (non_trainable_variables, metric_variables)), grads = grad_fn(
trainable_variables, non_trainable_variables, metric_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# Return updated state
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
)
我们还将准备一个评估步函数。
@jax.jit
def eval_step(state, data):
trainable_variables, non_trainable_variables, metric_variables = state
x, y = data
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = loss_fn(y, y_pred)
metric_variables = val_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, (
trainable_variables,
non_trainable_variables,
metric_variables,
)
这是我们的循环。
# Build optimizer variables.
optimizer.build(model.trainable_variables)
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
metric_variables = train_acc_metric.variables
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
)
# Training loop
for step, data in enumerate(train_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = train_step(state, data)
# Log every 100 batches.
if step % 100 == 0:
print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
_, _, _, metric_variables = state
for variable, value in zip(train_acc_metric.variables, metric_variables):
variable.assign(value)
print(f"Training accuracy: {train_acc_metric.result()}")
print(f"Seen so far: {(step + 1) * batch_size} samples")
metric_variables = val_acc_metric.variables
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
) = state
state = trainable_variables, non_trainable_variables, metric_variables
# Eval loop
for step, data in enumerate(val_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = eval_step(state, data)
# Log every 100 batches.
if step % 100 == 0:
print(f"Validation loss (for 1 batch) at step {step}: {float(loss):.4f}")
_, _, metric_variables = state
for variable, value in zip(val_acc_metric.variables, metric_variables):
variable.assign(value)
print(f"Validation accuracy: {val_acc_metric.result()}")
print(f"Seen so far: {(step + 1) * batch_size} samples")
Training loss (for 1 batch) at step 0: 70.8851
Training accuracy: 0.09375
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.1930
Training accuracy: 0.6596534848213196
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 3.0249
Training accuracy: 0.7352300882339478
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.6004
Training accuracy: 0.7588247656822205
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 1.4633
Training accuracy: 0.7736907601356506
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 1.3367
Training accuracy: 0.7826846241950989
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.8767
Training accuracy: 0.7930532693862915
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.3479
Training accuracy: 0.8004636168479919
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.3608
Training accuracy: 0.8066869378089905
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.7582
Training accuracy: 0.8117369413375854
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 1.3135
Training accuracy: 0.8142170310020447
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 1.0202
Training accuracy: 0.8186308145523071
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.6766
Training accuracy: 0.822023332118988
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.7606
Training accuracy: 0.8257110118865967
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.7657
Training accuracy: 0.8290283679962158
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.6563
Training accuracy: 0.831653892993927
Seen so far: 48032 samples
Validation loss (for 1 batch) at step 0: 0.1622
Validation accuracy: 0.8329269289970398
Seen so far: 32 samples
Validation loss (for 1 batch) at step 100: 0.7455
Validation accuracy: 0.8338780999183655
Seen so far: 3232 samples
Validation loss (for 1 batch) at step 200: 0.2738
Validation accuracy: 0.836174488067627
Seen so far: 6432 samples
Validation loss (for 1 batch) at step 300: 0.1255
Validation accuracy: 0.8390461206436157
Seen so far: 9632 samples
层和模型会递归地跟踪在调用 self.add_loss(value) 的层在前向传播过程中产生的任何损失。在获得前向传播结果后,可以通过 model.losses 属性访问由此产生的标量损失值列表。
如果您想使用这些损失分量,您应该将它们相加并添加到训练步中的主损失值。
考虑这个创建活动正则化损失的层。
class ActivityRegularizationLayer(keras.layers.Layer):
def call(self, inputs):
self.add_loss(1e-2 * jax.numpy.sum(inputs))
return inputs
让我们构建一个非常简单的模型来使用它。
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
现在,我们的 compute_loss_and_updates 函数应该看起来像这样。
return_losses=True 传递给 model.stateless_call()。losses 相加,并将其添加到主损失中。def compute_loss_and_updates(
trainable_variables, non_trainable_variables, metric_variables, x, y
):
y_pred, non_trainable_variables, losses = model.stateless_call(
trainable_variables, non_trainable_variables, x, return_losses=True
)
loss = loss_fn(y, y_pred)
if losses:
loss += jax.numpy.sum(losses)
metric_variables = train_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, non_trainable_variables, metric_variables
就是这样!