Keras 3 API 文档 / 多设备分布 / 数据并行 API

数据并行 API

[源代码]

DataParallel

keras.distribution.DataParallel(
    device_mesh=None, devices=None, auto_shard_dataset=True
)

数据并行分布。

您可以通过指定 device_meshdevices 参数(但不能同时指定两者)来创建此实例。

device_mesh 参数应为 DeviceMesh 实例,且应仅为 1D。如果网格具有多个轴,则第一个轴将被视为数据并行维度(并将引发警告)。

当提供 devices 列表时,它们将用于构建 1D 网格。

meshdevices 都不存在时,将使用 list_devices() 来检测任何可用设备并从中创建 1D 网格。

参数

  • device_mesh: 可选的 DeviceMesh 实例。
  • devices: 可选的设备列表。
  • auto_shard_dataset: 自动在进程之间分片数据集。默认为 true。