开发者指南 / 用 JAX 从头开始编写训练循环

用 JAX 从头开始编写训练循环

作者: fchollet
创建日期 2023/06/25
最后修改日期 2023/06/25
描述: 用 JAX 从头开始编写底层训练和评估循环。

在 Colab 中查看 GitHub 源代码


设置

import os

# This guide can only be run with the jax backend.
os.environ["KERAS_BACKEND"] = "jax"

import jax

# We import TF so we can use tf.data.
import tensorflow as tf
import keras
import numpy as np

简介

Keras 提供了默认的训练和评估循环,即 fit()evaluate()。它们的使用在指南 使用内置方法进行训练和评估 中进行了介绍。

如果您想在仍然利用 fit() 的便利性的同时自定义模型的学习算法(例如,使用 fit() 训练 GAN),您可以继承 Model 类并实现自己的 train_step() 方法,该方法会在 fit() 过程中被反复调用。

现在,如果您想要对训练和评估进行非常底层的控制,您应该从头开始编写自己的训练和评估循环。本指南就是关于这个的。


第一个端到端示例

要编写自定义训练循环,我们需要以下要素:

  • 当然,需要一个模型来训练。
  • 一个优化器。您可以从 keras.optimizers 中选择一个优化器,也可以从 optax 包中选择一个。
  • 一个损失函数。
  • 一个数据集。JAX 生态系统中的标准是通过 tf.data 加载数据,所以我们将使用它。

让我们把它们都准备好。

首先,我们获取模型和 MNIST 数据集。

def get_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x1 = keras.layers.Dense(64, activation="relu")(inputs)
    x2 = keras.layers.Dense(64, activation="relu")(x1)
    outputs = keras.layers.Dense(10, name="predictions")(x2)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


model = get_model()

# Prepare the training dataset.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)

# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)

接下来,这是损失函数和优化器。在本例中,我们将使用 Keras 优化器。

# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)

# Instantiate an optimizer.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)

在 JAX 中获取梯度

让我们使用自定义训练循环通过小批量梯度训练我们的模型。

在 JAX 中,梯度是通过元编程计算的:您调用 jax.grad(或 jax.value_and_grad)对一个函数,以创建一个该函数的梯度计算函数。

所以我们需要做的第一件事是编写一个返回损失值的函数。这就是我们将用于生成梯度函数的函数。大致如下:

def compute_loss(x, y):
    ...
    return loss

一旦有了这样一个函数,您就可以通过元编程像这样计算梯度:

grad_fn = jax.grad(compute_loss)
grads = grad_fn(x, y)

通常,您不仅仅想获取梯度值,还想获取损失值。您可以通过使用 jax.value_and_grad 而不是 jax.grad 来实现这一点。

grad_fn = jax.value_and_grad(compute_loss)
loss, grads = grad_fn(x, y)

JAX 计算是完全无状态的

在 JAX 中,一切都必须是无状态函数——所以我们的损失计算函数也必须是无状态的。这意味着所有 Keras 变量(例如权重张量)都必须作为函数输入传递,并且在前向传播过程中更新的任何变量都必须作为函数输出返回。函数不能有副作用。

在前向传播过程中,Keras 模型的非训练变量可能会被更新。这些变量可以是,例如,RNG 种子状态变量或 BatchNormalization 统计量。我们需要返回它们。所以我们需要这样的东西:

def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
    ...
    return loss, non_trainable_variables

一旦有了这样一个函数,您就可以通过在 value_and_grad 中指定 has_aux 来获取梯度函数:它告诉 JAX 损失计算函数返回的输出多于损失。请注意,损失应始终是第一个输出。

grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
(loss, non_trainable_variables), grads = grad_fn(
    trainable_variables, non_trainable_variables, x, y
)

现在我们已经建立了基础,让我们来实现这个 compute_loss_and_updates 函数。Keras 模型有一个 stateless_call 方法,这在这里会很有用。它的工作方式与 model.__call__ 相同,但它要求您显式地传递模型中所有变量的值,并且它不仅返回 __call__ 的输出,还返回(可能更新的)非训练变量。

def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x, training=True
    )
    loss = loss_fn(y, y_pred)
    return loss, non_trainable_variables

我们来获取梯度函数。

grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)

训练步函数

接下来,让我们实现端到端的训练步,这个函数将同时进行前向传播、计算损失、计算梯度,并使用优化器来更新训练变量。这个函数也必须是无状态的,因此它将接收一个 state 元组作为输入,该元组包含我们将使用的所有状态元素。

  • trainable_variablesnon_trainable_variables:模型的变量。
  • optimizer_variables:优化器的状态变量,例如动量累加器。

要更新训练变量,我们使用优化器的无状态方法 stateless_apply。它等同于 optimizer.apply(),但它总是要求传递 trainable_variablesoptimizer_variables。它返回更新后的训练变量和更新后的优化器变量。

def train_step(state, data):
    trainable_variables, non_trainable_variables, optimizer_variables = state
    x, y = data
    (loss, non_trainable_variables), grads = grad_fn(
        trainable_variables, non_trainable_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )
    # Return updated state
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )

使用 jax.jit 来加速

默认情况下,JAX 操作是立即执行的,就像 TensorFlow 立即执行模式和 PyTorch 立即执行模式一样。同样,就像 TensorFlow 立即执行模式和 PyTorch 立即执行模式一样,它的速度相当慢——立即执行模式最好用作调试环境,而不是用来做任何实际工作。所以让我们通过编译来加快 train_step 的速度。

当您有一个无状态的 JAX 函数时,您可以通过 @jax.jit 装饰器将其编译为 XLA。它将在第一次执行时进行跟踪,并在后续执行中执行跟踪的图(这就像 @tf.function(jit_compile=True)。让我们试试。

@jax.jit
def train_step(state, data):
    trainable_variables, non_trainable_variables, optimizer_variables = state
    x, y = data
    (loss, non_trainable_variables), grads = grad_fn(
        trainable_variables, non_trainable_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )
    # Return updated state
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )

我们现在准备好训练我们的模型了。训练循环本身很简单:我们只需反复调用 loss, state = train_step(state, data)

注意

  • 我们将 tf.data.Dataset 生成的 TF 张量转换为 NumPy,然后再将它们传递给我们的 JAX 函数。
  • 所有变量必须预先构建:模型必须已构建,优化器也必须已构建。由于我们使用的是函数式 API 模型,它已经构建好了,但如果它是一个子类化模型,您需要在一个数据批次上调用它来构建它。
# Build optimizer variables.
optimizer.build(model.trainable_variables)

trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
state = trainable_variables, non_trainable_variables, optimizer_variables

# Training loop
for step, data in enumerate(train_dataset):
    data = (data[0].numpy(), data[1].numpy())
    loss, state = train_step(state, data)
    # Log every 100 batches.
    if step % 100 == 0:
        print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
        print(f"Seen so far: {(step + 1) * batch_size} samples")
Training loss (for 1 batch) at step 0: 96.2726
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.0853
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.6535
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 1.2679
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.7563
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.7154
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 1.0267
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.6860
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.7306
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.4571
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.6023
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.9140
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.4224
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.6696
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.1399
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.5761
Seen so far: 48032 samples

这里一个关键的注意事项是,循环完全是无状态的——附加到模型上的变量(model.weights)在循环过程中永远不会被更新。它们的新值只存储在 state 元组中。这意味着在某个时候,在保存模型之前,您应该将新的变量值重新附加到模型上。

只需在您想更新的每个模型变量上调用 variable.assign(new_value)

trainable_variables, non_trainable_variables, optimizer_variables = state
for variable, value in zip(model.trainable_variables, trainable_variables):
    variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
    variable.assign(value)

低级别的指标处理

让我们在这个基本的训练循环中添加指标监控。

您可以轻松地在这些从头开始编写的训练循环中重用内置的 Keras 指标(或您自己编写的自定义指标)。流程如下:

  • 在循环开始时实例化指标。
  • train_step 参数和 compute_loss_and_updates 参数中包含 metric_variables
  • compute_loss_and_updates 函数中调用 metric.stateless_update_state()。它等同于 update_state() —— 只是无状态的。
  • 当您需要在 train_step 之外(在立即执行的范围内)显示指标的当前值时,将新的指标变量值附加到指标对象并调用 metric.result()
  • 当您需要清除指标状态时(通常是在一个 epoch 结束时),调用 metric.reset_state()

让我们利用这些知识在训练结束后计算训练和验证数据的 CategoricalAccuracy

# Get a fresh model
model = get_model()

# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)

# Prepare the metrics.
train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()


def compute_loss_and_updates(
    trainable_variables, non_trainable_variables, metric_variables, x, y
):
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x
    )
    loss = loss_fn(y, y_pred)
    metric_variables = train_acc_metric.stateless_update_state(
        metric_variables, y, y_pred
    )
    return loss, (non_trainable_variables, metric_variables)


grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)


@jax.jit
def train_step(state, data):
    (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
        metric_variables,
    ) = state
    x, y = data
    (loss, (non_trainable_variables, metric_variables)), grads = grad_fn(
        trainable_variables, non_trainable_variables, metric_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )
    # Return updated state
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
        metric_variables,
    )

我们还将准备一个评估步函数。

@jax.jit
def eval_step(state, data):
    trainable_variables, non_trainable_variables, metric_variables = state
    x, y = data
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x
    )
    loss = loss_fn(y, y_pred)
    metric_variables = val_acc_metric.stateless_update_state(
        metric_variables, y, y_pred
    )
    return loss, (
        trainable_variables,
        non_trainable_variables,
        metric_variables,
    )

这是我们的循环。

# Build optimizer variables.
optimizer.build(model.trainable_variables)

trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
metric_variables = train_acc_metric.variables
state = (
    trainable_variables,
    non_trainable_variables,
    optimizer_variables,
    metric_variables,
)

# Training loop
for step, data in enumerate(train_dataset):
    data = (data[0].numpy(), data[1].numpy())
    loss, state = train_step(state, data)
    # Log every 100 batches.
    if step % 100 == 0:
        print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
        _, _, _, metric_variables = state
        for variable, value in zip(train_acc_metric.variables, metric_variables):
            variable.assign(value)
        print(f"Training accuracy: {train_acc_metric.result()}")
        print(f"Seen so far: {(step + 1) * batch_size} samples")

metric_variables = val_acc_metric.variables
(
    trainable_variables,
    non_trainable_variables,
    optimizer_variables,
    metric_variables,
) = state
state = trainable_variables, non_trainable_variables, metric_variables

# Eval loop
for step, data in enumerate(val_dataset):
    data = (data[0].numpy(), data[1].numpy())
    loss, state = eval_step(state, data)
    # Log every 100 batches.
    if step % 100 == 0:
        print(f"Validation loss (for 1 batch) at step {step}: {float(loss):.4f}")
        _, _, metric_variables = state
        for variable, value in zip(val_acc_metric.variables, metric_variables):
            variable.assign(value)
        print(f"Validation accuracy: {val_acc_metric.result()}")
        print(f"Seen so far: {(step + 1) * batch_size} samples")
Training loss (for 1 batch) at step 0: 70.8851
Training accuracy: 0.09375
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.1930
Training accuracy: 0.6596534848213196
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 3.0249
Training accuracy: 0.7352300882339478
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.6004
Training accuracy: 0.7588247656822205
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 1.4633
Training accuracy: 0.7736907601356506
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 1.3367
Training accuracy: 0.7826846241950989
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.8767
Training accuracy: 0.7930532693862915
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.3479
Training accuracy: 0.8004636168479919
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.3608
Training accuracy: 0.8066869378089905
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.7582
Training accuracy: 0.8117369413375854
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 1.3135
Training accuracy: 0.8142170310020447
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 1.0202
Training accuracy: 0.8186308145523071
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.6766
Training accuracy: 0.822023332118988
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.7606
Training accuracy: 0.8257110118865967
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.7657
Training accuracy: 0.8290283679962158
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.6563
Training accuracy: 0.831653892993927
Seen so far: 48032 samples
Validation loss (for 1 batch) at step 0: 0.1622
Validation accuracy: 0.8329269289970398
Seen so far: 32 samples
Validation loss (for 1 batch) at step 100: 0.7455
Validation accuracy: 0.8338780999183655
Seen so far: 3232 samples
Validation loss (for 1 batch) at step 200: 0.2738
Validation accuracy: 0.836174488067627
Seen so far: 6432 samples
Validation loss (for 1 batch) at step 300: 0.1255
Validation accuracy: 0.8390461206436157
Seen so far: 9632 samples

模型跟踪的损失的低级别处理

层和模型会递归地跟踪在调用 self.add_loss(value) 的层在前向传播过程中产生的任何损失。在获得前向传播结果后,可以通过 model.losses 属性访问由此产生的标量损失值列表。

如果您想使用这些损失分量,您应该将它们相加并添加到训练步中的主损失值。

考虑这个创建活动正则化损失的层。

class ActivityRegularizationLayer(keras.layers.Layer):
    def call(self, inputs):
        self.add_loss(1e-2 * jax.numpy.sum(inputs))
        return inputs

让我们构建一个非常简单的模型来使用它。

inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)

model = keras.Model(inputs=inputs, outputs=outputs)

现在,我们的 compute_loss_and_updates 函数应该看起来像这样。

  • return_losses=True 传递给 model.stateless_call()
  • 将生成的 losses 相加,并将其添加到主损失中。
def compute_loss_and_updates(
    trainable_variables, non_trainable_variables, metric_variables, x, y
):
    y_pred, non_trainable_variables, losses = model.stateless_call(
        trainable_variables, non_trainable_variables, x, return_losses=True
    )
    loss = loss_fn(y, y_pred)
    if losses:
        loss += jax.numpy.sum(losses)
    metric_variables = train_acc_metric.stateless_update_state(
        metric_variables, y, y_pred
    )
    return loss, non_trainable_variables, metric_variables

就是这样!