ModelParallel API

[源代码]

ModelParallel

keras.distribution.ModelParallel(
    layout_map=None, batch_dim_name=None, auto_shard_dataset=True, **kwargs
)

分片模型变量的分布策略。

与复制所有设备上的变量的 DataParallel 不同,ModelParallel 允许您在输入数据之外对变量进行分片。

要构建 ModelParallel 分布,您需要提供一个 DeviceMesh 和一个 LayoutMap

  1. DeviceMesh 包含物理设备信息。网格中的轴名称将用于映射变量和数据布局。
  2. LayoutMap 包含变量路径到其相应 TensorLayout 的映射。

示例

devices = list_devices()    # Assume there are 8 devices.

# Create a mesh with 2 devices for data parallelism and 4 devices for
# model parallelism.
device_mesh = DeviceMesh(shape=(2, 4), axis_names=('batch', 'model'),
                         devices=devices)
# Create a layout map that shard the `Dense` layer and `Conv2D`
# layer variables on the last dimension.
# Based on the `device_mesh`, this means the variables
# will be split across 4 devices. Any other variable that doesn't
# match any key in the layout map will be fully replicated.
layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)

distribution = ModelParallel(
    layout_map=layout_map,
    batch_dim_name='batch',
)

# Set the global distribution, or via `with distribution.scope():`
set_distribution(distribution)

model = model_creation()
model.compile()
model.fit(data)

您可以快速更新设备网格形状来改变变量的分片因子。例如。

# With only the shape change for the device mesh, the variables will be
# sharded across 8 devices instead of 4, which further reduces the memory
# footprint of variables on each of the device.
device_mesh = DeviceMesh(
    shape=(1, 8),
    axis_names=('batch', 'model'),
    devices=devices,
)

要找出所有模型变量的合适布局映射规则,您可以首先列出所有模型变量路径,这些路径将用作将变量映射到 TensorLayout 的键。

例如

model = create_model()
for v in model.variables:
    print(v.path)

参数

  • layout_map: LayoutMap 实例,它将变量路径映射到相应的张量布局。
  • batch_dim_name: 可选字符串,设备网格中的轴名称(layout_map 对象)将用于分发数据。如果未指定,将使用设备网格的第一个轴。
  • auto_shard_dataset: 在多进程环境中自动分片数据集。如果数据集已跨主机分片,请将其设置为 False。默认为 True