Keras 3 API 文档 / 多设备分布 / 模型并行 API

模型并行 API

[源代码]

ModelParallel

keras.distribution.ModelParallel(layout_map=None, batch_dim_name=None, **kwargs)

将模型变量分片的分布。

与在所有设备上复制变量的 DataParallel 相比,ModelParallel 允许您除了输入数据外,还可以对变量进行分片。

要构建 ModelParallel 分布,您需要提供 DeviceMeshLayoutMap

  1. DeviceMesh 包含物理设备信息。网格中的轴名将用于映射变量和数据布局。
  2. LayoutMap 包含变量路径与其对应 TensorLayout 之间的映射。

示例

devices = list_devices()    # Assume there are 8 devices.

# Create a mesh with 2 devices for data parallelism and 4 devices for
# model parallelism.
device_mesh = DeviceMesh(shape=(2, 4), axis_names=('batch', 'model'),
                         devices=devices)
# Create a layout map that shard the `Dense` layer and `Conv2D`
# layer variables on the last dimension.
# Based on the `device_mesh`, this means the variables
# will be split across 4 devices. Any other variable that doesn't
# match any key in the layout map will be fully replicated.
layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)

distribution = ModelParallel(
    layout_map=layout_map,
    batch_dim_name='batch',
)

# Set the global distribution, or via `with distribution.scope():`
set_distribution(distribution)

model = model_creation()
model.compile()
model.fit(data)

您可以快速更新设备网格形状以更改变量的分片因子。例如:

# With only the shape change for the device mesh, the variables will be
# sharded across 8 devices instead of 4, which further reduces the memory
# footprint of variables on each of the device.
device_mesh = DeviceMesh(
    shape=(1, 8),
    axis_names=('batch', 'model'),
    devices=devices,
)

要找出所有模型变量的适当布局映射规则,您可以先列出所有模型变量路径,这些路径将用作将变量映射到 TensorLayout 的键。

例如:

model = create_model()
for v in model.variables:
    print(v.path)

参数

  • layout_map: LayoutMap 实例,将变量路径映射到相应的张量布局。
  • batch_dim_name: 可选字符串,设备网格(layout_map 对象)中将用于分发数据的轴名。如果未指定,将使用设备网格中的第一个轴。