作者: fchollet
创建日期 2023/06/25
上次修改日期 2023/06/25
描述:在 JAX 中编写低级训练和评估循环。
import os
# This guide can only be run with the jax backend.
os.environ["KERAS_BACKEND"] = "jax"
import jax
# We import TF so we can use tf.data.
import tensorflow as tf
import keras
import numpy as np
Keras 提供了默认的训练和评估循环,fit()
和 evaluate()
。它们的使用在指南 使用内置方法进行训练和评估 中进行了介绍。
如果你想自定义模型的学习算法,同时仍然利用 fit()
的便利性(例如,使用 fit()
训练 GAN),你可以对 Model
类进行子类化并实现你自己的 train_step()
方法,该方法在 fit()
期间被重复调用。
现在,如果你希望对训练和评估进行非常低级的控制,你应该从头开始编写你自己的训练和评估循环。这就是本指南的主题。
要编写自定义训练循环,我们需要以下要素:
keras.optimizers
中的优化器,或者使用 optax
包中的优化器。tf.data
加载数据,所以我们也将使用它。让我们把它们排列起来。
首先,让我们获取模型和 MNIST 数据集
def get_model():
inputs = keras.Input(shape=(784,), name="digits")
x1 = keras.layers.Dense(64, activation="relu")(inputs)
x2 = keras.layers.Dense(64, activation="relu")(x1)
outputs = keras.layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
model = get_model()
# Prepare the training dataset.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)
接下来,这是损失函数和优化器。在本例中,我们将使用 Keras 优化器。
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
# Instantiate an optimizer.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
让我们使用自定义训练循环和迷你批梯度来训练我们的模型。
在 JAX 中,梯度是通过元编程计算的:你调用 jax.grad
(或 jax.value_and_grad
)在一个函数上,以便为该第一个函数创建一个梯度计算函数。
所以我们首先需要一个返回损失值的函数。这是我们将用于生成梯度函数的函数。类似于以下内容:
def compute_loss(x, y):
...
return loss
一旦你有了这样的函数,你就可以通过元编程这样计算梯度:
grad_fn = jax.grad(compute_loss)
grads = grad_fn(x, y)
通常,你不仅想要获取梯度值,还想要获取损失值。你可以通过使用 jax.value_and_grad
而不是 jax.grad
来做到这一点:
grad_fn = jax.value_and_grad(compute_loss)
loss, grads = grad_fn(x, y)
在 JAX 中,一切都必须是无状态函数——因此我们的损失计算函数也必须是无状态的。这意味着所有 Keras 变量(例如权重张量)都必须作为函数输入传递,并且在正向传递期间更新的任何变量都必须作为函数输出返回。函数没有副作用。
在正向传递期间,Keras 模型的不可训练变量可能会被更新。这些变量可能是,例如,RNG 种子状态变量或 BatchNormalization 统计信息。我们将需要返回它们。所以我们需要类似于以下内容:
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
...
return loss, non_trainable_variables
一旦你有了这样的函数,你就可以通过在 value_and_grad
中指定 has_aux
来获取梯度函数:它告诉 JAX 损失计算函数返回的输出不止损失。请注意,损失应该始终是第一个输出。
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
现在我们已经建立了基础,让我们实现这个 compute_loss_and_updates
函数。Keras 模型有一个 stateless_call
方法,它在这里会很有用。它的工作原理与 model.__call__
完全相同,但它要求你显式地传递模型中所有变量的值,并且它不仅返回 __call__
输出,还返回(可能已更新的)不可训练变量。
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x, training=True
)
loss = loss_fn(y, y_pred)
return loss, non_trainable_variables
让我们获取梯度函数
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
接下来,让我们实现端到端的训练步骤,即运行正向传递、计算损失、计算梯度,以及使用优化器更新可训练变量的函数。此函数也需要是无状态的,因此它将作为输入获取一个 state
元组,其中包含我们将使用的每个状态元素
trainable_variables
和 non_trainable_variables
:模型的变量。optimizer_variables
:优化器的状态变量,例如动量累加器。要更新可训练变量,我们使用优化器的无状态方法 stateless_apply
。它等效于 optimizer.apply()
,但它始终需要传递 trainable_variables
和 optimizer_variables
。它返回更新后的可训练变量和更新后的 optimizer_variables
。
def train_step(state, data):
trainable_variables, non_trainable_variables, optimizer_variables = state
x, y = data
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# Return updated state
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
)
jax.jit
加速默认情况下,JAX 操作以急切方式运行,就像 TensorFlow 急切模式和 PyTorch 急切模式一样。并且就像 TensorFlow 急切模式和 PyTorch 急切模式一样,它非常慢——急切模式更适合用作调试环境,而不是执行任何实际工作的方式。因此,让我们通过编译 train_step
来使其变快。
当你有无状态 JAX 函数时,你可以通过 @jax.jit
装饰器将其编译到 XLA。它将在第一次执行期间被追踪,并在随后的执行中,你将执行追踪的图(这就像 @tf.function(jit_compile=True)
)。让我们试试看:
@jax.jit
def train_step(state, data):
trainable_variables, non_trainable_variables, optimizer_variables = state
x, y = data
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# Return updated state
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
)
我们现在可以训练我们的模型了。训练循环本身非常简单:我们只需重复调用 loss, state = train_step(state, data)
。
注意
tf.data.Dataset
生成的 TF 张量转换为 NumPy,然后再将其传递给我们的 JAX 函数。# Build optimizer variables.
optimizer.build(model.trainable_variables)
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
state = trainable_variables, non_trainable_variables, optimizer_variables
# Training loop
for step, data in enumerate(train_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = train_step(state, data)
# Log every 100 batches.
if step % 100 == 0:
print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
print(f"Seen so far: {(step + 1) * batch_size} samples")
Training loss (for 1 batch) at step 0: 96.2726
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.0853
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.6535
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 1.2679
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.7563
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.7154
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 1.0267
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.6860
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.7306
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.4571
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.6023
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.9140
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.4224
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.6696
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.1399
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.5761
Seen so far: 48032 samples
这里需要注意的一件事是,循环是完全无状态的——附加到模型的变量(model.weights
)在循环期间永远不会更新。它们的新值仅存储在 state
元组中。这意味着在某个时刻,在保存模型之前,你应该将新的变量值重新附加到模型上。
只需对要更新的每个模型变量调用 variable.assign(new_value)
即可
trainable_variables, non_trainable_variables, optimizer_variables = state
for variable, value in zip(model.trainable_variables, trainable_variables):
variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
variable.assign(value)
让我们向这个基本的训练循环添加指标监控。
你可以在这样的从头开始编写的训练循环中轻松地重用内置 Keras 指标(或你编写的自定义指标)。流程如下:
train_step
参数和 compute_loss_and_updates
参数中包含 metric_variables
。compute_loss_and_updates
函数中调用 metric.stateless_update_state()
。它等效于 update_state()
——只是无状态的。train_step
之外(在急切作用域中),将新的指标变量值附加到指标对象上并调用 metric.result()
。metric.reset_state()
(通常在 epoch 结束时)让我们利用这些知识在训练结束时计算训练和验证数据上的 CategoricalAccuracy
# Get a fresh model
model = get_model()
# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
# Prepare the metrics.
train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()
def compute_loss_and_updates(
trainable_variables, non_trainable_variables, metric_variables, x, y
):
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = loss_fn(y, y_pred)
metric_variables = train_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, (non_trainable_variables, metric_variables)
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
@jax.jit
def train_step(state, data):
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
) = state
x, y = data
(loss, (non_trainable_variables, metric_variables)), grads = grad_fn(
trainable_variables, non_trainable_variables, metric_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# Return updated state
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
)
我们还将准备一个评估步骤函数
@jax.jit
def eval_step(state, data):
trainable_variables, non_trainable_variables, metric_variables = state
x, y = data
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = loss_fn(y, y_pred)
metric_variables = val_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, (
trainable_variables,
non_trainable_variables,
metric_variables,
)
以下是我们的循环:
# Build optimizer variables.
optimizer.build(model.trainable_variables)
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
metric_variables = train_acc_metric.variables
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
)
# Training loop
for step, data in enumerate(train_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = train_step(state, data)
# Log every 100 batches.
if step % 100 == 0:
print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
_, _, _, metric_variables = state
for variable, value in zip(train_acc_metric.variables, metric_variables):
variable.assign(value)
print(f"Training accuracy: {train_acc_metric.result()}")
print(f"Seen so far: {(step + 1) * batch_size} samples")
metric_variables = val_acc_metric.variables
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
) = state
state = trainable_variables, non_trainable_variables, metric_variables
# Eval loop
for step, data in enumerate(val_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = eval_step(state, data)
# Log every 100 batches.
if step % 100 == 0:
print(f"Validation loss (for 1 batch) at step {step}: {float(loss):.4f}")
_, _, metric_variables = state
for variable, value in zip(val_acc_metric.variables, metric_variables):
variable.assign(value)
print(f"Validation accuracy: {val_acc_metric.result()}")
print(f"Seen so far: {(step + 1) * batch_size} samples")
Training loss (for 1 batch) at step 0: 70.8851
Training accuracy: 0.09375
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.1930
Training accuracy: 0.6596534848213196
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 3.0249
Training accuracy: 0.7352300882339478
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.6004
Training accuracy: 0.7588247656822205
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 1.4633
Training accuracy: 0.7736907601356506
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 1.3367
Training accuracy: 0.7826846241950989
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.8767
Training accuracy: 0.7930532693862915
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.3479
Training accuracy: 0.8004636168479919
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.3608
Training accuracy: 0.8066869378089905
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.7582
Training accuracy: 0.8117369413375854
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 1.3135
Training accuracy: 0.8142170310020447
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 1.0202
Training accuracy: 0.8186308145523071
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.6766
Training accuracy: 0.822023332118988
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.7606
Training accuracy: 0.8257110118865967
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.7657
Training accuracy: 0.8290283679962158
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.6563
Training accuracy: 0.831653892993927
Seen so far: 48032 samples
Validation loss (for 1 batch) at step 0: 0.1622
Validation accuracy: 0.8329269289970398
Seen so far: 32 samples
Validation loss (for 1 batch) at step 100: 0.7455
Validation accuracy: 0.8338780999183655
Seen so far: 3232 samples
Validation loss (for 1 batch) at step 200: 0.2738
Validation accuracy: 0.836174488067627
Seen so far: 6432 samples
Validation loss (for 1 batch) at step 300: 0.1255
Validation accuracy: 0.8390461206436157
Seen so far: 9632 samples
层和模型递归地跟踪在正向传递期间由调用 self.add_loss(value)
的层创建的任何损失。生成的标量损失值列表可以通过属性 model.losses
在正向传递结束时获取。
如果你想使用这些损失组件,你应该将它们求和并将其添加到训练步骤中的主要损失中。
考虑这个层,它创建了一个活动正则化损失:
class ActivityRegularizationLayer(keras.layers.Layer):
def call(self, inputs):
self.add_loss(1e-2 * jax.numpy.sum(inputs))
return inputs
让我们构建一个使用它的非常简单的模型:
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
以下是我们 compute_loss_and_updates
函数现在应该是什么样子
return_losses=True
传递给 model.stateless_call()
。losses
求和,并将其添加到主损失中。def compute_loss_and_updates(
trainable_variables, non_trainable_variables, metric_variables, x, y
):
y_pred, non_trainable_variables, losses = model.stateless_call(
trainable_variables, non_trainable_variables, x, return_losses=True
)
loss = loss_fn(y, y_pred)
if losses:
loss += jax.numpy.sum(losses)
metric_variables = train_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, non_trainable_variables, metric_variables
就是这样!