作者: fchollet
创建日期 2023/07/11
最后修改日期 2023/07/11
描述: 关于使用 JAX 对 Keras 模型进行多 GPU/TPU 训练的指南。
通常有两种方式将计算分布到多个设备上
数据并行,即将单个模型复制到多个设备或多台机器上。每个设备或机器处理不同的数据批次,然后合并结果。这种设置存在许多变体,它们在不同模型副本合并结果的方式、是否在每个批次保持同步或耦合更松散等方面有所不同。
模型并行,即单个模型的不同部分在不同设备上运行,共同处理单个数据批次。这种方式最适合具有自然并行架构的模型,例如具有多个分支的模型。
本指南重点介绍数据并行,特别是同步数据并行,其中模型的不同副本在处理完每个批次后保持同步。同步性使模型收敛行为与单设备训练时相同。
具体而言,本指南教您如何使用 jax.sharding
API 在安装在单机(单主机、多设备训练)上的多个 GPU 或 TPU(通常为 2 到 16 个)上训练 Keras 模型,且只需对代码进行极少量更改。这是研究人员和小型工业工作流程中最常见的设置。
首先定义创建我们将要训练的模型的函数,以及创建我们将要训练的数据集(此处为 MNIST)的函数。
import os
os.environ["KERAS_BACKEND"] = "jax"
import jax
import numpy as np
import tensorflow as tf
import keras
from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
def get_model():
# Make a simple convnet with batch normalization and dropout.
inputs = keras.Input(shape=(28, 28, 1))
x = keras.layers.Rescaling(1.0 / 255.0)(inputs)
x = keras.layers.Conv2D(filters=12, kernel_size=3, padding="same", use_bias=False)(
x
)
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(
filters=24,
kernel_size=6,
use_bias=False,
strides=2,
)(x)
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(
filters=32,
kernel_size=6,
padding="same",
strides=2,
name="large_k",
)(x)
x = keras.layers.BatchNormalization(scale=False, center=True)(x)
x = keras.layers.ReLU()(x)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dense(256, activation="relu")(x)
x = keras.layers.Dropout(0.5)(x)
outputs = keras.layers.Dense(10)(x)
model = keras.Model(inputs, outputs)
return model
def get_datasets():
# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Scale images to the [0, 1] range
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")
# Create TF Datasets
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
eval_data = tf.data.Dataset.from_tensor_slices((x_test, y_test))
return train_data, eval_data
在这种设置中,您有一台机器,上面装有多个 GPU 或 TPU(通常为 2 到 16 个)。每个设备将运行您的模型的一个副本(称为副本)。为简单起见,在下文中,我们假设处理的是 8 个 GPU,这不失一般性。
工作原理
在训练的每个步骤中
实际上,同步更新模型副本权重的过程是在每个单独权重变量级别处理的。这是通过使用配置为复制变量的 jax.sharding.NamedSharding
来完成的。
如何使用
要使用 Keras 模型进行单主机、多设备同步训练,您需要使用 jax.sharding
功能。其工作原理如下:
mesh_utils.create_device_mesh
创建一个设备网格。jax.sharding.Mesh
、jax.sharding.NamedSharding
和 jax.sharding.PartitionSpec
来定义如何划分 JAX 数组。 - 我们指定希望通过使用不带轴的 spec 将模型和优化器变量复制到所有设备上。 - 我们指定希望通过使用沿批次维度进行分割的 spec 将数据分片到设备上。jax.device_put
将模型和优化器变量复制到所有设备上。这在开始时发生一次。jax.device_put
将批次分割到设备上。以下是流程,其中每个步骤都拆分为自己的实用函数:
# Config
num_epochs = 2
batch_size = 64
train_data, eval_data = get_datasets()
train_data = train_data.batch(batch_size, drop_remainder=True)
model = get_model()
optimizer = keras.optimizers.Adam(1e-3)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Initialize all state with .build()
(one_batch, one_batch_labels) = next(iter(train_data))
model.build(one_batch)
optimizer.build(model.trainable_variables)
# This is the loss function that will be differentiated.
# Keras provides a pure functional forward pass: model.stateless_call
def compute_loss(trainable_variables, non_trainable_variables, x, y):
y_pred, updated_non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x, training=True
)
loss_value = loss(y, y_pred)
return loss_value, updated_non_trainable_variables
# Function to compute gradients
compute_gradients = jax.value_and_grad(compute_loss, has_aux=True)
# Training step, Keras provides a pure functional optimizer.stateless_apply
@jax.jit
def train_step(train_state, x, y):
trainable_variables, non_trainable_variables, optimizer_variables = train_state
(loss_value, non_trainable_variables), grads = compute_gradients(
trainable_variables, non_trainable_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
return loss_value, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
)
# Replicate the model and optimizer variable on all devices
def get_replicated_train_state(devices):
# All variables will be replicated on all devices
var_mesh = Mesh(devices, axis_names=("_"))
# In NamedSharding, axes not mentioned are replicated (all axes here)
var_replication = NamedSharding(var_mesh, P())
# Apply the distribution settings to the model variables
trainable_variables = jax.device_put(model.trainable_variables, var_replication)
non_trainable_variables = jax.device_put(
model.non_trainable_variables, var_replication
)
optimizer_variables = jax.device_put(optimizer.variables, var_replication)
# Combine all state in a tuple
return (trainable_variables, non_trainable_variables, optimizer_variables)
num_devices = len(jax.local_devices())
print(f"Running on {num_devices} devices: {jax.local_devices()}")
devices = mesh_utils.create_device_mesh((num_devices,))
# Data will be split along the batch axis
data_mesh = Mesh(devices, axis_names=("batch",)) # naming axes of the mesh
data_sharding = NamedSharding(
data_mesh,
P(
"batch",
),
) # naming axes of the sharded partition
# Display data sharding
x, y = next(iter(train_data))
sharded_x = jax.device_put(x.numpy(), data_sharding)
print("Data sharding")
jax.debug.visualize_array_sharding(jax.numpy.reshape(sharded_x, [-1, 28 * 28]))
train_state = get_replicated_train_state(devices)
# Custom training loop
for epoch in range(num_epochs):
data_iter = iter(train_data)
for data in data_iter:
x, y = data
sharded_x = jax.device_put(x.numpy(), data_sharding)
loss_value, train_state = train_step(train_state, sharded_x, y.numpy())
print("Epoch", epoch, "loss:", loss_value)
# Post-processing model state update to write them back into the model
trainable_variables, non_trainable_variables, optimizer_variables = train_state
for variable, value in zip(model.trainable_variables, trainable_variables):
variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
variable.assign(value)
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
Running on 1 devices: [CpuDevice(id=0)]
Data sharding
CPU 0
Epoch 0 loss: 0.28599858
Epoch 1 loss: 0.23666474
就是这样!