开发者指南 / 从零开始在 JAX 中编写训练循环

从零开始在 JAX 中编写训练循环

作者: fchollet
创建日期 2023/06/25
上次修改日期 2023/06/25
描述:在 JAX 中编写低级训练和评估循环。

在 Colab 中查看 GitHub 源代码


设置

import os

# This guide can only be run with the jax backend.
os.environ["KERAS_BACKEND"] = "jax"

import jax

# We import TF so we can use tf.data.
import tensorflow as tf
import keras
import numpy as np

简介

Keras 提供了默认的训练和评估循环,fit()evaluate()。它们的使用在指南 使用内置方法进行训练和评估 中进行了介绍。

如果你想自定义模型的学习算法,同时仍然利用 fit() 的便利性(例如,使用 fit() 训练 GAN),你可以对 Model 类进行子类化并实现你自己的 train_step() 方法,该方法在 fit() 期间被重复调用。

现在,如果你希望对训练和评估进行非常低级的控制,你应该从头开始编写你自己的训练和评估循环。这就是本指南的主题。


第一个端到端示例

要编写自定义训练循环,我们需要以下要素:

  • 当然,需要一个要训练的模型。
  • 一个优化器。你可以使用 keras.optimizers 中的优化器,或者使用 optax 包中的优化器。
  • 一个损失函数。
  • 一个数据集。JAX 生态系统中的标准是通过 tf.data 加载数据,所以我们也将使用它。

让我们把它们排列起来。

首先,让我们获取模型和 MNIST 数据集

def get_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x1 = keras.layers.Dense(64, activation="relu")(inputs)
    x2 = keras.layers.Dense(64, activation="relu")(x1)
    outputs = keras.layers.Dense(10, name="predictions")(x2)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


model = get_model()

# Prepare the training dataset.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)

# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)

接下来,这是损失函数和优化器。在本例中,我们将使用 Keras 优化器。

# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)

# Instantiate an optimizer.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)

在 JAX 中获取梯度

让我们使用自定义训练循环和迷你批梯度来训练我们的模型。

在 JAX 中,梯度是通过元编程计算的:你调用 jax.grad(或 jax.value_and_grad)在一个函数上,以便为该第一个函数创建一个梯度计算函数。

所以我们首先需要一个返回损失值的函数。这是我们将用于生成梯度函数的函数。类似于以下内容:

def compute_loss(x, y):
    ...
    return loss

一旦你有了这样的函数,你就可以通过元编程这样计算梯度:

grad_fn = jax.grad(compute_loss)
grads = grad_fn(x, y)

通常,你不仅想要获取梯度值,还想要获取损失值。你可以通过使用 jax.value_and_grad 而不是 jax.grad 来做到这一点:

grad_fn = jax.value_and_grad(compute_loss)
loss, grads = grad_fn(x, y)

JAX 计算是纯无状态的

在 JAX 中,一切都必须是无状态函数——因此我们的损失计算函数也必须是无状态的。这意味着所有 Keras 变量(例如权重张量)都必须作为函数输入传递,并且在正向传递期间更新的任何变量都必须作为函数输出返回。函数没有副作用。

在正向传递期间,Keras 模型的不可训练变量可能会被更新。这些变量可能是,例如,RNG 种子状态变量或 BatchNormalization 统计信息。我们将需要返回它们。所以我们需要类似于以下内容:

def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
    ...
    return loss, non_trainable_variables

一旦你有了这样的函数,你就可以通过在 value_and_grad 中指定 has_aux 来获取梯度函数:它告诉 JAX 损失计算函数返回的输出不止损失。请注意,损失应该始终是第一个输出。

grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
(loss, non_trainable_variables), grads = grad_fn(
    trainable_variables, non_trainable_variables, x, y
)

现在我们已经建立了基础,让我们实现这个 compute_loss_and_updates 函数。Keras 模型有一个 stateless_call 方法,它在这里会很有用。它的工作原理与 model.__call__ 完全相同,但它要求你显式地传递模型中所有变量的值,并且它不仅返回 __call__ 输出,还返回(可能已更新的)不可训练变量。

def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x, training=True
    )
    loss = loss_fn(y, y_pred)
    return loss, non_trainable_variables

让我们获取梯度函数

grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)

训练步骤函数

接下来,让我们实现端到端的训练步骤,即运行正向传递、计算损失、计算梯度,以及使用优化器更新可训练变量的函数。此函数也需要是无状态的,因此它将作为输入获取一个 state 元组,其中包含我们将使用的每个状态元素

  • trainable_variablesnon_trainable_variables:模型的变量。
  • optimizer_variables:优化器的状态变量,例如动量累加器。

要更新可训练变量,我们使用优化器的无状态方法 stateless_apply。它等效于 optimizer.apply(),但它始终需要传递 trainable_variablesoptimizer_variables。它返回更新后的可训练变量和更新后的 optimizer_variables

def train_step(state, data):
    trainable_variables, non_trainable_variables, optimizer_variables = state
    x, y = data
    (loss, non_trainable_variables), grads = grad_fn(
        trainable_variables, non_trainable_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )
    # Return updated state
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )

使用 jax.jit 加速

默认情况下,JAX 操作以急切方式运行,就像 TensorFlow 急切模式和 PyTorch 急切模式一样。并且就像 TensorFlow 急切模式和 PyTorch 急切模式一样,它非常慢——急切模式更适合用作调试环境,而不是执行任何实际工作的方式。因此,让我们通过编译 train_step 来使其变快。

当你有无状态 JAX 函数时,你可以通过 @jax.jit 装饰器将其编译到 XLA。它将在第一次执行期间被追踪,并在随后的执行中,你将执行追踪的图(这就像 @tf.function(jit_compile=True))。让我们试试看:

@jax.jit
def train_step(state, data):
    trainable_variables, non_trainable_variables, optimizer_variables = state
    x, y = data
    (loss, non_trainable_variables), grads = grad_fn(
        trainable_variables, non_trainable_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )
    # Return updated state
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )

我们现在可以训练我们的模型了。训练循环本身非常简单:我们只需重复调用 loss, state = train_step(state, data)

注意

  • 我们将 tf.data.Dataset 生成的 TF 张量转换为 NumPy,然后再将其传递给我们的 JAX 函数。
  • 所有变量都必须预先构建:模型必须构建,优化器也必须构建。由于我们使用的是函数式 API 模型,因此它已经构建,但如果它是子类化模型,则需要在一个数据批次上调用它才能构建它。
# Build optimizer variables.
optimizer.build(model.trainable_variables)

trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
state = trainable_variables, non_trainable_variables, optimizer_variables

# Training loop
for step, data in enumerate(train_dataset):
    data = (data[0].numpy(), data[1].numpy())
    loss, state = train_step(state, data)
    # Log every 100 batches.
    if step % 100 == 0:
        print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
        print(f"Seen so far: {(step + 1) * batch_size} samples")
Training loss (for 1 batch) at step 0: 96.2726
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.0853
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.6535
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 1.2679
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.7563
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.7154
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 1.0267
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.6860
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.7306
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.4571
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.6023
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.9140
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.4224
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.6696
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.1399
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.5761
Seen so far: 48032 samples

这里需要注意的一件事是,循环是完全无状态的——附加到模型的变量(model.weights)在循环期间永远不会更新。它们的新值仅存储在 state 元组中。这意味着在某个时刻,在保存模型之前,你应该将新的变量值重新附加到模型上。

只需对要更新的每个模型变量调用 variable.assign(new_value) 即可

trainable_variables, non_trainable_variables, optimizer_variables = state
for variable, value in zip(model.trainable_variables, trainable_variables):
    variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
    variable.assign(value)

指标的低级处理

让我们向这个基本的训练循环添加指标监控。

你可以在这样的从头开始编写的训练循环中轻松地重用内置 Keras 指标(或你编写的自定义指标)。流程如下:

  • 在循环开始时实例化指标
  • train_step 参数和 compute_loss_and_updates 参数中包含 metric_variables
  • compute_loss_and_updates 函数中调用 metric.stateless_update_state()。它等效于 update_state()——只是无状态的。
  • 当你需要显示指标的当前值时,在 train_step 之外(在急切作用域中),将新的指标变量值附加到指标对象上并调用 metric.result()
  • 当你需要清除指标的状态时,调用 metric.reset_state()(通常在 epoch 结束时)

让我们利用这些知识在训练结束时计算训练和验证数据上的 CategoricalAccuracy

# Get a fresh model
model = get_model()

# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)

# Prepare the metrics.
train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()


def compute_loss_and_updates(
    trainable_variables, non_trainable_variables, metric_variables, x, y
):
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x
    )
    loss = loss_fn(y, y_pred)
    metric_variables = train_acc_metric.stateless_update_state(
        metric_variables, y, y_pred
    )
    return loss, (non_trainable_variables, metric_variables)


grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)


@jax.jit
def train_step(state, data):
    (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
        metric_variables,
    ) = state
    x, y = data
    (loss, (non_trainable_variables, metric_variables)), grads = grad_fn(
        trainable_variables, non_trainable_variables, metric_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )
    # Return updated state
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
        metric_variables,
    )

我们还将准备一个评估步骤函数

@jax.jit
def eval_step(state, data):
    trainable_variables, non_trainable_variables, metric_variables = state
    x, y = data
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x
    )
    loss = loss_fn(y, y_pred)
    metric_variables = val_acc_metric.stateless_update_state(
        metric_variables, y, y_pred
    )
    return loss, (
        trainable_variables,
        non_trainable_variables,
        metric_variables,
    )

以下是我们的循环:

# Build optimizer variables.
optimizer.build(model.trainable_variables)

trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
metric_variables = train_acc_metric.variables
state = (
    trainable_variables,
    non_trainable_variables,
    optimizer_variables,
    metric_variables,
)

# Training loop
for step, data in enumerate(train_dataset):
    data = (data[0].numpy(), data[1].numpy())
    loss, state = train_step(state, data)
    # Log every 100 batches.
    if step % 100 == 0:
        print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}")
        _, _, _, metric_variables = state
        for variable, value in zip(train_acc_metric.variables, metric_variables):
            variable.assign(value)
        print(f"Training accuracy: {train_acc_metric.result()}")
        print(f"Seen so far: {(step + 1) * batch_size} samples")

metric_variables = val_acc_metric.variables
(
    trainable_variables,
    non_trainable_variables,
    optimizer_variables,
    metric_variables,
) = state
state = trainable_variables, non_trainable_variables, metric_variables

# Eval loop
for step, data in enumerate(val_dataset):
    data = (data[0].numpy(), data[1].numpy())
    loss, state = eval_step(state, data)
    # Log every 100 batches.
    if step % 100 == 0:
        print(f"Validation loss (for 1 batch) at step {step}: {float(loss):.4f}")
        _, _, metric_variables = state
        for variable, value in zip(val_acc_metric.variables, metric_variables):
            variable.assign(value)
        print(f"Validation accuracy: {val_acc_metric.result()}")
        print(f"Seen so far: {(step + 1) * batch_size} samples")
Training loss (for 1 batch) at step 0: 70.8851
Training accuracy: 0.09375
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.1930
Training accuracy: 0.6596534848213196
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 3.0249
Training accuracy: 0.7352300882339478
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.6004
Training accuracy: 0.7588247656822205
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 1.4633
Training accuracy: 0.7736907601356506
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 1.3367
Training accuracy: 0.7826846241950989
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.8767
Training accuracy: 0.7930532693862915
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.3479
Training accuracy: 0.8004636168479919
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.3608
Training accuracy: 0.8066869378089905
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.7582
Training accuracy: 0.8117369413375854
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 1.3135
Training accuracy: 0.8142170310020447
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 1.0202
Training accuracy: 0.8186308145523071
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.6766
Training accuracy: 0.822023332118988
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.7606
Training accuracy: 0.8257110118865967
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.7657
Training accuracy: 0.8290283679962158
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.6563
Training accuracy: 0.831653892993927
Seen so far: 48032 samples
Validation loss (for 1 batch) at step 0: 0.1622
Validation accuracy: 0.8329269289970398
Seen so far: 32 samples
Validation loss (for 1 batch) at step 100: 0.7455
Validation accuracy: 0.8338780999183655
Seen so far: 3232 samples
Validation loss (for 1 batch) at step 200: 0.2738
Validation accuracy: 0.836174488067627
Seen so far: 6432 samples
Validation loss (for 1 batch) at step 300: 0.1255
Validation accuracy: 0.8390461206436157
Seen so far: 9632 samples

模型跟踪的损失的低级处理

层和模型递归地跟踪在正向传递期间由调用 self.add_loss(value) 的层创建的任何损失。生成的标量损失值列表可以通过属性 model.losses 在正向传递结束时获取。

如果你想使用这些损失组件,你应该将它们求和并将其添加到训练步骤中的主要损失中。

考虑这个层,它创建了一个活动正则化损失:

class ActivityRegularizationLayer(keras.layers.Layer):
    def call(self, inputs):
        self.add_loss(1e-2 * jax.numpy.sum(inputs))
        return inputs

让我们构建一个使用它的非常简单的模型:

inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)

model = keras.Model(inputs=inputs, outputs=outputs)

以下是我们 compute_loss_and_updates 函数现在应该是什么样子

  • return_losses=True 传递给 model.stateless_call()
  • 将得到的 losses 求和,并将其添加到主损失中。
def compute_loss_and_updates(
    trainable_variables, non_trainable_variables, metric_variables, x, y
):
    y_pred, non_trainable_variables, losses = model.stateless_call(
        trainable_variables, non_trainable_variables, x, return_losses=True
    )
    loss = loss_fn(y, y_pred)
    if losses:
        loss += jax.numpy.sum(losses)
    metric_variables = train_acc_metric.stateless_update_state(
        metric_variables, y, y_pred
    )
    return loss, non_trainable_variables, metric_variables

就是这样!