Keras 3 API 文档 / 多设备分布 / LayoutMap API

LayoutMap API

[源代码]

LayoutMap

keras.distribution.LayoutMap(device_mesh)

一个类似字典的对象,用于将字符串映射到 TensorLayout 实例。

LayoutMap 使用字符串作为键,TensorLayout 作为值。它与普通的 Python 字典之间存在行为差异。在检索值时,字符串键将被视为正则表达式。有关更多详细信息,请参阅 get 的文档字符串。

以下是一个使用示例。您可以定义 TensorLayout 的命名方案,然后检索相应的 TensorLayout 实例。

在正常情况下,查询的键通常是 variable.path,它是变量的标识符。

作为快捷方式,在插入值时也允许使用轴名称的元组或列表,并将转换为 TensorLayout

layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)

layout_1 = layout_map['dense_1.kernel']             # layout_1 == layout_2d
layout_2 = layout_map['dense_1.bias']               # layout_2 == layout_1d
layout_3 = layout_map['dense_2.kernel']             # layout_3 == layout_2d
layout_4 = layout_map['dense_2.bias']               # layout_4 == layout_1d
layout_5 = layout_map['my_model/conv2d_123/kernel'] # layout_5 == layout_4d
layout_6 = layout_map['my_model/conv2d_123/bias']   # layout_6 == layout_1d
layout_7 = layout_map['my_model/conv3d_1/kernel']   # layout_7 == None
layout_8 = layout_map['my_model/conv3d_1/bias']     # layout_8 == None

参数


[源代码]

DeviceMesh

keras.distribution.DeviceMesh(shape, axis_names, devices=None)

用于分布式计算的计算设备集群。

此 API 与 jax.sharding.Meshtf.dtensor.Mesh 保持一致,它们表示全局上下文中的计算设备。

有关更多详细信息,请参阅 jax.sharding.Meshtf.dtensor.Mesh

参数

  • shape: 整数元组或列表。整个 DeviceMesh 的形状,例如,对于仅数据并行的分布,(8,),或者对于模型+数据并行的分布,(4, 2)
  • axis_names: 字符串列表。DeviceMesh 中每个轴的逻辑名称。axis_names 的长度应与 shape 的秩匹配。在分布数据和变量时,axis_names 将用于匹配/创建 TensorLayout
  • devices: 可选的设备列表。默认为 keras.distribution.list_devices() 从本地获取的所有可用设备。

[源代码]

TensorLayout

keras.distribution.TensorLayout(axes, device_mesh=None)

应用于张量的布局。

此 API 与 jax.sharding.NamedShardingtf.dtensor.Layout 保持一致。

有关更多详细信息,请参阅 jax.sharding.NamedShardingtf.dtensor.Layout

参数

  • axes: 字符串元组,应映射到 DeviceMesh 中的 axis_names。对于不需要任何分区的任何维度,可以使用 None 作为占位符。
  • device_mesh: 可选的 DeviceMesh,用于创建布局。在指定网格之前,张量到物理设备的实际映射是未知的。

[源代码]

distribute_tensor 函数

keras.distribution.distribute_tensor(tensor, layout)

在 jit 函数执行中更改张量值的布局。

参数

  • tensor: 要更改布局的张量。
  • layout: 要应用于值的 TensorLayout

返回值

具有指定张量布局的新值。