开发者指南 / 使用 JAX 进行多 GPU 分布式训练

使用 JAX 进行多 GPU 分布式训练

作者: fchollet
创建日期 2023/07/11
上次修改日期 2023/07/11
描述:使用 JAX 对 Keras 模型进行多 GPU/TPU 训练的指南。

在 Colab 中查看 GitHub 源代码


简介

通常有两种方法可以在多个设备之间分配计算

数据并行,其中单个模型在多个设备或多台机器上复制。每个设备处理不同的数据批次,然后合并其结果。此设置存在许多变体,它们在不同模型副本如何合并结果、它们是否在每个批次都保持同步或它们是否更松散耦合等方面有所不同。

模型并行,其中单个模型的不同部分在不同的设备上运行,共同处理单个数据批次。这最适合具有自然并行架构的模型,例如具有多个分支的模型。

本指南重点介绍数据并行,特别是同步数据并行,其中模型的不同副本在处理每个批次后保持同步。同步性使模型收敛行为与您在单设备训练中看到的行为相同。

具体来说,本指南将教您如何使用 jax.sharding API 对 Keras 模型进行训练,只需对代码进行最少的更改,即可在安装在单个机器上的多个 GPU 或 TPU(通常为 2 到 16 个)上进行训练(单主机,多设备训练)。对于研究人员和小型行业工作流程来说,这是最常见的设置。


设置

让我们首先定义创建将要训练的模型的函数,以及创建将要训练的数据集的函数(在本例中为 MNIST)。

import os

os.environ["KERAS_BACKEND"] = "jax"

import jax
import numpy as np
import tensorflow as tf
import keras

from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P


def get_model():
    # Make a simple convnet with batch normalization and dropout.
    inputs = keras.Input(shape=(28, 28, 1))
    x = keras.layers.Rescaling(1.0 / 255.0)(inputs)
    x = keras.layers.Conv2D(filters=12, kernel_size=3, padding="same", use_bias=False)(
        x
    )
    x = keras.layers.BatchNormalization(scale=False, center=True)(x)
    x = keras.layers.ReLU()(x)
    x = keras.layers.Conv2D(
        filters=24,
        kernel_size=6,
        use_bias=False,
        strides=2,
    )(x)
    x = keras.layers.BatchNormalization(scale=False, center=True)(x)
    x = keras.layers.ReLU()(x)
    x = keras.layers.Conv2D(
        filters=32,
        kernel_size=6,
        padding="same",
        strides=2,
        name="large_k",
    )(x)
    x = keras.layers.BatchNormalization(scale=False, center=True)(x)
    x = keras.layers.ReLU()(x)
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dense(256, activation="relu")(x)
    x = keras.layers.Dropout(0.5)(x)
    outputs = keras.layers.Dense(10)(x)
    model = keras.Model(inputs, outputs)
    return model


def get_datasets():
    # Load the data and split it between train and test sets
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

    # Scale images to the [0, 1] range
    x_train = x_train.astype("float32")
    x_test = x_test.astype("float32")
    # Make sure images have shape (28, 28, 1)
    x_train = np.expand_dims(x_train, -1)
    x_test = np.expand_dims(x_test, -1)
    print("x_train shape:", x_train.shape)
    print(x_train.shape[0], "train samples")
    print(x_test.shape[0], "test samples")

    # Create TF Datasets
    train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    eval_data = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    return train_data, eval_data

单主机,多设备同步训练

在此设置中,您有一台机器,上面安装了多个 GPU 或 TPU(通常为 2 到 16 个)。每个设备将运行您的模型的一个副本(称为副本)。为简单起见,在下文中,我们将假设我们正在处理 8 个 GPU,这不会影响一般性。

工作原理

在训练的每个步骤中

  • 当前数据批次(称为全局批次)被分成 8 个不同的子批次(称为本地批次)。例如,如果全局批次有 512 个样本,则 8 个本地批次中的每一个将有 64 个样本。
  • 8 个副本中的每一个都独立地处理一个本地批次:它们运行前向传递,然后运行后向传递,输出模型在本地批次上损失相对于权重的梯度。
  • 来自本地梯度的权重更新将有效地跨 8 个副本合并。因为这在每个步骤结束时完成,所以副本始终保持同步。

在实践中,同步更新模型副本权重的过程在每个单独的权重变量级别上处理。这是通过使用配置为复制变量的 jax.sharding.NamedSharding 来完成的。

使用方法

要使用 Keras 模型进行单主机,多设备同步训练,可以使用 jax.sharding 功能。以下是其工作原理

  • 我们首先使用 mesh_utils.create_device_mesh 创建一个设备网格。
  • 我们使用 jax.sharding.Meshjax.sharding.NamedShardingjax.sharding.PartitionSpec 来定义如何对 JAX 数组进行分区。- 我们指定希望跨所有设备复制模型和优化器变量,方法是使用没有轴的规范。- 我们指定希望跨设备对数据进行分片,方法是使用沿批次维度进行分割的规范。
  • 我们使用 jax.device_put 将模型和优化器变量复制到各个设备上。这在开始时只发生一次。
  • 在训练循环中,对于我们处理的每个批次,我们使用 jax.device_put 将批次拆分到各个设备上,然后再调用训练步骤。

以下是流程,其中每个步骤都拆分为其自身的实用程序函数

# Config
num_epochs = 2
batch_size = 64

train_data, eval_data = get_datasets()
train_data = train_data.batch(batch_size, drop_remainder=True)

model = get_model()
optimizer = keras.optimizers.Adam(1e-3)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Initialize all state with .build()
(one_batch, one_batch_labels) = next(iter(train_data))
model.build(one_batch)
optimizer.build(model.trainable_variables)


# This is the loss function that will be differentiated.
# Keras provides a pure functional forward pass: model.stateless_call
def compute_loss(trainable_variables, non_trainable_variables, x, y):
    y_pred, updated_non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x, training=True
    )
    loss_value = loss(y, y_pred)
    return loss_value, updated_non_trainable_variables


# Function to compute gradients
compute_gradients = jax.value_and_grad(compute_loss, has_aux=True)


# Training step, Keras provides a pure functional optimizer.stateless_apply
@jax.jit
def train_step(train_state, x, y):
    trainable_variables, non_trainable_variables, optimizer_variables = train_state
    (loss_value, non_trainable_variables), grads = compute_gradients(
        trainable_variables, non_trainable_variables, x, y
    )

    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )

    return loss_value, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )


# Replicate the model and optimizer variable on all devices
def get_replicated_train_state(devices):
    # All variables will be replicated on all devices
    var_mesh = Mesh(devices, axis_names=("_"))
    # In NamedSharding, axes not mentioned are replicated (all axes here)
    var_replication = NamedSharding(var_mesh, P())

    # Apply the distribution settings to the model variables
    trainable_variables = jax.device_put(model.trainable_variables, var_replication)
    non_trainable_variables = jax.device_put(
        model.non_trainable_variables, var_replication
    )
    optimizer_variables = jax.device_put(optimizer.variables, var_replication)

    # Combine all state in a tuple
    return (trainable_variables, non_trainable_variables, optimizer_variables)


num_devices = len(jax.local_devices())
print(f"Running on {num_devices} devices: {jax.local_devices()}")
devices = mesh_utils.create_device_mesh((num_devices,))

# Data will be split along the batch axis
data_mesh = Mesh(devices, axis_names=("batch",))  # naming axes of the mesh
data_sharding = NamedSharding(
    data_mesh,
    P(
        "batch",
    ),
)  # naming axes of the sharded partition

# Display data sharding
x, y = next(iter(train_data))
sharded_x = jax.device_put(x.numpy(), data_sharding)
print("Data sharding")
jax.debug.visualize_array_sharding(jax.numpy.reshape(sharded_x, [-1, 28 * 28]))

train_state = get_replicated_train_state(devices)

# Custom training loop
for epoch in range(num_epochs):
    data_iter = iter(train_data)
    for data in data_iter:
        x, y = data
        sharded_x = jax.device_put(x.numpy(), data_sharding)
        loss_value, train_state = train_step(train_state, sharded_x, y.numpy())
    print("Epoch", epoch, "loss:", loss_value)

# Post-processing model state update to write them back into the model
trainable_variables, non_trainable_variables, optimizer_variables = train_state
for variable, value in zip(model.trainable_variables, trainable_variables):
    variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
    variable.assign(value)
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
Running on 1 devices: [CpuDevice(id=0)]
Data sharding
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                     CPU 0                                      
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
Epoch 0 loss: 0.28599858
Epoch 1 loss: 0.23666474

就是这样!