作者: fchollet
创建日期 2023/06/25
最后修改日期 2023/06/25
描述: 在 JAX 中编写低级训练和评估循环。
import os
# This guide can only be run with the jax backend.
os.environ["KERAS_BACKEND"] = "jax"
import jax
# We import TF so we can use tf.data.
import tensorflow as tf
import keras
import numpy as np
Keras 提供了默认的训练和评估循环,fit()
和 evaluate()
。它们的使用在指南 使用内置方法进行训练和评估 中进行了介绍。
如果你想自定义模型的学习算法,同时仍然利用 fit()
的便利性(例如,使用 fit()
训练 GAN),你可以子类化 Model
类并实现你自己的 train_step()
方法,该方法在 fit()
期间被重复调用。
现在,如果你想要对训练和评估进行非常低级的控制,你应该从头开始编写自己的训练和评估循环。这就是本指南的内容。
要编写自定义训练循环,我们需要以下要素
keras.optimizers
的优化器,或者来自 optax
包的优化器。tf.data
加载数据,所以我们将使用它。让我们把它们排好。
首先,让我们获取模型和 MNIST 数据集
def get_model():
inputs = keras.Input(shape=(784,), name="digits")
x1 = keras.layers.Dense(64, activation="relu")(inputs)
x2 = keras.layers.Dense(64, activation="relu")(x1)
outputs = keras.layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
model = get_model()
# Prepare the training dataset.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)
接下来,这是损失函数和优化器。在这种情况下,我们将使用 Keras 优化器。
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
# Instantiate an optimizer.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
让我们使用带有自定义训练循环的 mini-batch 梯度来训练我们的模型。
在 JAX 中,梯度是通过元编程计算的:你调用 jax.grad
(或 jax.value_and_grad
)对一个函数,以便为该函数创建一个梯度计算函数。
因此,我们需要的第一件事是返回损失值的函数。这是我们将用来生成梯度函数的函数。像这样
def compute_loss(x, y):
...
return loss
一旦你有了这样的函数,你就可以通过元编程来计算梯度,如下所示
grad_fn = jax.grad(compute_loss)
grads = grad_fn(x, y)
通常,你不仅仅想获取梯度值,还想获取损失值。你可以通过使用 jax.value_and_grad
而不是 jax.grad
来做到这一点
grad_fn = jax.value_and_grad(compute_loss)
loss, grads = grad_fn(x, y)
在 JAX 中,一切都必须是无状态函数——所以我们的损失计算函数也必须是无状态的。这意味着所有 Keras 变量(例如,权重张量)都必须作为函数输入传递,并且在前向传递期间更新的任何变量都必须作为函数输出返回。该函数没有副作用。
在前向传递期间,Keras 模型不可训练的变量可能会被更新。这些变量可以是,例如,RNG 种子状态变量或 BatchNormalization 统计信息。我们需要返回这些变量。所以我们需要这样的东西
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
...
return loss, non_trainable_variables
一旦你有了这样的函数,你就可以通过在 value_and_grad
中指定 has_aux
来获取梯度函数:它告诉 JAX 损失计算函数返回的输出不仅仅是损失。请注意,损失应该始终是第一个输出。
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
既然我们已经建立了基础,让我们实现这个 compute_loss_and_updates
函数。Keras 模型有一个 stateless_call
方法,在这里会派上用场。它的工作方式类似于 model.__call__
,但它需要你显式传递模型中所有变量的值,并且它不仅返回 __call__
的输出,还返回(可能已更新的)不可训练的变量。
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x, training=True
)
loss = loss_fn(y, y_pred)
return loss, non_trainable_variables
让我们获取梯度函数
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
接下来,让我们实现端到端的训练步骤,该函数将运行前向传递,计算损失,计算梯度,并且还使用优化器来更新可训练变量。此函数也需要是无状态的,因此它将输入一个 state
元组,其中包含我们将要使用的每个状态元素
trainable_variables
和 non_trainable_variables
:模型的变量。optimizer_variables
:优化器的状态变量,例如动量累加器。要更新可训练变量,我们使用优化器的无状态方法 stateless_apply
。它等效于 optimizer.apply()
,但它始终需要传递 trainable_variables
和 optimizer_variables
。它返回更新后的可训练变量和更新后的 optimizer_variables。
def train_step(state, data):
trainable_variables, non_trainable_variables, optimizer_variables = state
x, y = data
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# Return updated state
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
)
jax.jit
使其快速默认情况下,JAX 操作会像 TensorFlow 渴望模式和 PyTorch 渴望模式一样渴望运行。就像 TensorFlow 渴望模式和 PyTorch 渴望模式一样,它非常慢——渴望模式最好用作调试环境,而不是作为任何实际工作的方式。因此,让我们通过编译来使我们的 train_step
快速。
当你有一个无状态的 JAX 函数时,你可以通过 @jax.jit
装饰器将其编译为 XLA。它将在第一次执行期间被跟踪,并且在后续执行中,你将执行跟踪图(这就像 @tf.function(jit_compile=True)
)。让我们尝试一下
@jax.jit
def train_step(state, data):
trainable_variables, non_trainable_variables, optimizer_variables = state
x, y = data
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# Return updated state
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
)
我们现在准备好训练我们的模型了。训练循环本身很简单:我们只需重复调用 loss, state = train_step(state, data)
。
注意
tf.data.Dataset
生成的 TF 张量转换为 NumPy,然后再将它们传递给我们的 JAX 函数。# Build optimizer variables.
optimizer.build(model.trainable_variables)
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
state = trainable_variables, non_trainable_variables, optimizer_variables
# Training loop
for step, data in enumerate(train_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = train_step(state, data)
# Log every 100 batches.
if step % 100 == 0:
print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
print(f"Seen so far: {(step + 1) * batch_size} samples")
Training loss (for 1 batch) at step 0: 96.2726
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.0853
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.6535
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 1.2679
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.7563
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.7154
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 1.0267
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.6860
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.7306
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.4571
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.6023
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.9140
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.4224
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.6696
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.1399
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.5761
Seen so far: 48032 samples
这里需要注意的一个关键点是,循环是完全无状态的——附加到模型 (model.weights
) 的变量在循环期间永远不会被更新。它们的新值仅存储在 state
元组中。这意味着在保存模型之前,你应该在某个时候将新的变量值附加回模型。
只需在你想要更新的每个模型变量上调用 variable.assign(new_value)
即可
trainable_variables, non_trainable_variables, optimizer_variables = state
for variable, value in zip(model.trainable_variables, trainable_variables):
variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
variable.assign(value)
让我们向这个基本训练循环添加指标监控。
你可以在从头开始编写的此类训练循环中轻松重用内置的 Keras 指标(或你编写的自定义指标)。流程如下
train_step
参数和 compute_loss_and_updates
参数中包含 metric_variables
。compute_loss_and_updates
函数中调用 metric.stateless_update_state()
。它等效于 update_state()
——只是无状态的。train_step
之外(在渴望作用域中)显示指标的当前值时,将新的指标变量值附加到指标对象并调用 metric.result()
。metric.reset_state()
让我们利用这些知识来计算训练和验证数据结束时的 CategoricalAccuracy
# Get a fresh model
model = get_model()
# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
# Prepare the metrics.
train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()
def compute_loss_and_updates(
trainable_variables, non_trainable_variables, metric_variables, x, y
):
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = loss_fn(y, y_pred)
metric_variables = train_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, (non_trainable_variables, metric_variables)
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
@jax.jit
def train_step(state, data):
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
) = state
x, y = data
(loss, (non_trainable_variables, metric_variables)), grads = grad_fn(
trainable_variables, non_trainable_variables, metric_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# Return updated state
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
)
我们还将准备一个评估步骤函数
@jax.jit
def eval_step(state, data):
trainable_variables, non_trainable_variables, metric_variables = state
x, y = data
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = loss_fn(y, y_pred)
metric_variables = val_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, (
trainable_variables,
non_trainable_variables,
metric_variables,
)
这是我们的循环
# Build optimizer variables.
optimizer.build(model.trainable_variables)
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
metric_variables = train_acc_metric.variables
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
)
# Training loop
for step, data in enumerate(train_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = train_step(state, data)
# Log every 100 batches.
if step % 100 == 0:
print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
_, _, _, metric_variables = state
for variable, value in zip(train_acc_metric.variables, metric_variables):
variable.assign(value)
print(f"Training accuracy: {train_acc_metric.result()}")
print(f"Seen so far: {(step + 1) * batch_size} samples")
metric_variables = val_acc_metric.variables
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
) = state
state = trainable_variables, non_trainable_variables, metric_variables
# Eval loop
for step, data in enumerate(val_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = eval_step(state, data)
# Log every 100 batches.
if step % 100 == 0:
print(f"Validation loss (for 1 batch) at step {step}: {float(loss):.4f}")
_, _, metric_variables = state
for variable, value in zip(val_acc_metric.variables, metric_variables):
variable.assign(value)
print(f"Validation accuracy: {val_acc_metric.result()}")
print(f"Seen so far: {(step + 1) * batch_size} samples")
Training loss (for 1 batch) at step 0: 70.8851
Training accuracy: 0.09375
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.1930
Training accuracy: 0.6596534848213196
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 3.0249
Training accuracy: 0.7352300882339478
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.6004
Training accuracy: 0.7588247656822205
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 1.4633
Training accuracy: 0.7736907601356506
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 1.3367
Training accuracy: 0.7826846241950989
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.8767
Training accuracy: 0.7930532693862915
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.3479
Training accuracy: 0.8004636168479919
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.3608
Training accuracy: 0.8066869378089905
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.7582
Training accuracy: 0.8117369413375854
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 1.3135
Training accuracy: 0.8142170310020447
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 1.0202
Training accuracy: 0.8186308145523071
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.6766
Training accuracy: 0.822023332118988
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.7606
Training accuracy: 0.8257110118865967
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.7657
Training accuracy: 0.8290283679962158
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.6563
Training accuracy: 0.831653892993927
Seen so far: 48032 samples
Validation loss (for 1 batch) at step 0: 0.1622
Validation accuracy: 0.8329269289970398
Seen so far: 32 samples
Validation loss (for 1 batch) at step 100: 0.7455
Validation accuracy: 0.8338780999183655
Seen so far: 3232 samples
Validation loss (for 1 batch) at step 200: 0.2738
Validation accuracy: 0.836174488067627
Seen so far: 6432 samples
Validation loss (for 1 batch) at step 300: 0.1255
Validation accuracy: 0.8390461206436157
Seen so far: 9632 samples
层和模型递归地跟踪在前向传递期间由调用 self.add_loss(value)
的层创建的任何损失。结果的标量损失值列表可通过属性 model.losses
在前向传递结束时获得。
如果你想使用这些损失分量,你应该将它们求和并将它们添加到训练步骤中的主损失中。
考虑这一层,它创建了一个活动正则化损失
class ActivityRegularizationLayer(keras.layers.Layer):
def call(self, inputs):
self.add_loss(1e-2 * jax.numpy.sum(inputs))
return inputs
让我们构建一个使用它的非常简单的模型
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
这是我们现在的 compute_loss_and_updates
函数应该是什么样子
return_losses=True
传递给 model.stateless_call()
。losses
求和并将它们添加到主损失中。def compute_loss_and_updates(
trainable_variables, non_trainable_variables, metric_variables, x, y
):
y_pred, non_trainable_variables, losses = model.stateless_call(
trainable_variables, non_trainable_variables, x, return_losses=True
)
loss = loss_fn(y, y_pred)
if losses:
loss += jax.numpy.sum(losses)
metric_variables = train_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, non_trainable_variables, metric_variables
就是这样!