DistributedEmbedding 类keras_rs.layers.DistributedEmbedding(
feature_configs: Union[
keras_rs.src.layers.embedding.distributed_embedding_config.FeatureConfig,
tensorflow.python.tpu.tpu_embedding_v2_utils.FeatureConfig,
Sequence[
Union[
keras_rs.src.layers.embedding.distributed_embedding_config.FeatureConfig,
tensorflow.python.tpu.tpu_embedding_v2_utils.FeatureConfig,
ForwardRef("Nested[T]"),
]
],
Mapping[
str,
Union[
keras_rs.src.layers.embedding.distributed_embedding_config.FeatureConfig,
tensorflow.python.tpu.tpu_embedding_v2_utils.FeatureConfig,
ForwardRef("Nested[T]"),
],
],
],
table_stacking: Union[str, Sequence[str], Sequence[Sequence[str]]] = "auto",
**kwargs: Any
)
DistributedEmbedding,一个用于加速大型嵌入查找的层。
DistributedEmbedding 处于预览阶段。DistributedEmbedding 是一个针对具有 SparseCore 的 TPU 芯片进行优化的层,可以显著提高嵌入查找和嵌入训练的速度。它通过将多个查找合并为一个调用,并将嵌入表分片到可用芯片上来实现。请注意,只有当嵌入表足够大以至于无法装入单个芯片并需要分片时,才能看到性能提升。更多详细信息请参见下面的“放置”部分。
在其他硬件上,如 GPU、CPU 和没有 SparseCore 的 TPU 上,DistributedEmbedding 提供相同的 API,但没有特定的加速。除了通过 keras.distribution.set_distribution 设置的分片方案外,不应用任何特定的分片方案。
DistributedEmbedding 嵌入输入序列,并通过应用可配置的组合器函数将它们缩减为单个嵌入。
DistributedEmbedding 嵌入层通过一组 keras_rs.layers.FeatureConfig 对象进行配置,这些对象本身引用 keras_rs.layers.TableConfig 对象。
TableConfig 定义了一个嵌入表,其中包含词汇表大小、嵌入维度以及用于组合和训练的优化器等参数。FeatureConfig 定义了 DistributedEmbedding 将处理的输入特征以及要使用的嵌入表。请注意,多个特征可以使用同一个嵌入表。table1 = keras_rs.layers.TableConfig(
name="table1",
vocabulary_size=TABLE1_VOCABULARY_SIZE,
embedding_dim=TABLE1_EMBEDDING_SIZE,
placement="auto",
)
table2 = keras_rs.layers.TableConfig(
name="table2",
vocabulary_size=TABLE2_VOCABULARY_SIZE,
embedding_dim=TABLE2_EMBEDDING_SIZE,
placement="auto",
)
feature1 = keras_rs.layers.FeatureConfig(
name="feature1",
table=table1,
input_shape=(GLOBAL_BATCH_SIZE,),
output_shape=(GLOBAL_BATCH_SIZE, TABLE1_EMBEDDING_SIZE),
)
feature2 = keras_rs.layers.FeatureConfig(
name="feature2",
table=table2,
input_shape=(GLOBAL_BATCH_SIZE,),
output_shape=(GLOBAL_BATCH_SIZE, TABLE2_EMBEDDING_SIZE),
)
feature_configs = {
"feature1": feature1,
"feature2": feature2,
}
embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
DistributedEmbedding 中的每个嵌入表都使用自己的训练优化器,这与在模型中通过 model.compile() 设置的优化器是独立的。
请注意,并非所有优化器都受支持。目前,以下优化器在所有后端和加速器上都受支持:
此外,并非所有优化器参数都受支持(例如 SGD 的 nesterov 选项)。当使用不受支持的优化器或不受支持的优化器参数时,会引发错误。
DistributedEmbedding 中的每个嵌入表都可以放置在 SparseCore 芯片上,或者放置在加速器的默认设备上(例如 TPU 上的 Tensor Cores 的 HBM)。这由 keras_rs.layers.TableConfig 的 placement 属性控制。
"sparsecore" 表示该表应放置在 SparseCore 芯片上。如果选择此选项但没有 SparseCore 芯片,则会引发错误。"default_device" 表示该表不应放置在 SparseCore 上,即使有。相反,该表放置在模型通常放置的设备上,即 TPU 和 GPU 上的 HBM。在这种情况下,如果适用,该表将使用通过 keras.distribution.set_distribution 设置的分片方案进行分片。在 GPU、CPU 和没有 SparseCore 的 TPU 上,这是唯一可用的放置方式,也是 "auto" 选择的方式。"auto" 表示如果可用,则使用 "sparsecore",否则使用 "default_device"。这是未指定时的默认值。优化 TPU 性能
"sparsecore" 放置。"default_device" 放置,并且通常应通过使用 keras.distribution.DataParallel 分片选项在 TPU 上进行复制。除了 tf.Tensor 外,DistributedEmbedding 还接受 tf.RaggedTensor 和 tf.SparseTensor 作为嵌入查找的输入。Ragged 텐서 必须在索引为 1 的维度上是 ragged 的。请注意,如果提供了权重,则每个权重张量必须与该特定特征的输入属于同一类,并且对于 ragged 텐서 使用完全相同的 ragged 行长度,对于 sparse 텐서 使用相同的索引。DistributedEmbedding 的所有输出都是密集张量。
要在 TPU 上使用 DistributedEmbedding 和 TensorFlow,必须使用 tf.distribute.TPUStrategy。DistributedEmbedding 层必须在 TPUStrategy 下创建。
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology, num_replicas=resolver.get_tpu_system_metadata().num_cores
)
strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment
)
with strategy.scope():
embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
要使用 Keras 的 model.fit(),必须在 TPUStrategy 下编译模型。然后,可以直接调用 model.fit()、model.evaluate() 或 model.predict()。Keras 模型负责使用策略运行模型,并自动分片数据集。
with strategy.scope():
embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
model = create_model(embedding)
model.compile(loss=keras.losses.MeanSquaredError(), optimizer="adam")
model.fit(dataset, epochs=10)
DistributedEmbedding 必须通过嵌套在 tf.function 中的 strategy.run 调用来调用。
@tf.function
def embedding_wrapper(tf_fn_inputs, tf_fn_weights=None):
def strategy_fn(st_fn_inputs, st_fn_weights):
return embedding(st_fn_inputs, st_fn_weights)
return strategy.run(strategy_fn, args=(tf_fn_inputs, tf_fn_weights)))
embedding_wrapper(my_inputs, my_weights)
使用数据集时,必须对数据集进行分片。然后可以将迭代器传递给使用 strategy.run 的 tf.function。
dataset = strategy.experimental_distribute_dataset(dataset)
@tf.function
def run_loop(iterator):
def step(data):
(inputs, weights), labels = data
with tf.GradientTape() as tape:
result = embedding(inputs, weights)
loss = keras.losses.mean_squared_error(labels, result)
tape.gradient(loss, embedding.trainable_variables)
return result
for _ in tf.range(4):
result = strategy.run(step, args=(next(iterator),))
run_loop(iter(dataset))
要在 TPU 上使用 DistributedEmbedding 和 JAX,必须创建一个 Keras Distribution 并进行设置。
distribution = keras.distribution.DataParallel(devices=jax.device("tpu"))
keras.distribution.set_distribution(distribution)
对于 JAX,输入可以是密集张量,也可以是 ragged(嵌套)的 NumPy 数组。要启用 jit_compile = True,必须在输入上显式调用 layer.preprocess(...),然后将预处理后的输出馈送给模型。有关详细信息,请参阅下一节关于预处理的内容。
Ragged 输入数组必须在索引为 1 的维度上是 ragged 的。请注意,如果提供了权重,则每个权重张量必须与该特定特征的输入属于同一类,并且对于 ragged 数组使用完全相同的 ragged 行长度。DistributedEmbedding 的所有输出都是密集张量。
在 JAX 中,SparseCore 的使用需要特殊格式的数据,该数据取决于可用硬件的属性。此数据重格式化目前不支持即时编译 (jit-compilation),因此必须在将数据传递给模型之前进行应用。
预处理适用于密集或 ragged 的 NumPy 数组,或者可转换为密集或 ragged NumPy 数组的张量,例如 tf.RaggedTensor。
添加预处理的一种简单方法是通过使用 Python 生成器将其附加到输入管道。
# Create the embedding layer.
embedding_layer = DistributedEmbedding(feature_configs)
# Add preprocessing to a data input pipeline.
def preprocessed_dataset_generator(dataset):
for (inputs, weights), labels in iter(dataset):
yield embedding_layer.preprocess(
inputs, weights, training=True
), labels
preprocessed_train_dataset = preprocessed_dataset_generator(train_dataset)
此显式的预处理阶段会合并输入和可选的权重,因此新数据可以直接传递到层或模型的 inputs 参数。
注意:在多主机环境中进行数据并行处理时,需要将数据正确地分片到各个主机。如果原始数据集是 tf.data.Dataset 类型,则需要在应用预处理生成器之前手动对其进行分片。
# Manually shard the dataset across hosts.
train_dataset = distribution.distribute_dataset(train_dataset)
distribution.auto_shard_dataset = False # Dataset is already sharded.
# Add a preprocessing stage to the distributed data input pipeline.
train_dataset = preprocessed_dataset_generator(train_dataset)
如果原始数据集不是 tf.data.Dataset,则它必须已经跨主机进行了预分片。
一旦设置了全局分片并定义了输入预处理管道,模型训练就可以正常进行。例如:
# Construct, compile, and fit the model using the preprocessed data.
model = keras.Sequential(
[
embedding_layer,
keras.layers.Dense(2),
keras.layers.Dense(3),
keras.layers.Dense(4),
]
)
model.compile(optimizer="adam", loss="mse", jit_compile=True)
model.fit(preprocessed_train_dataset, epochs=10)
DistributedEmbedding 层也可以直接调用。使用即时编译时需要显式预处理。
# Call the layer directly.
activations = embedding_layer(my_inputs, my_weights)
# Call the layer with JIT compilation and explicitly preprocessed inputs.
embedding_layer_jit = jax.jit(embedding_layer)
preprocessed_inputs = embedding_layer.preprocess(my_inputs, my_weights)
activations = embedding_layer_jit(preprocessed_inputs)
同样,对于自定义训练循环,必须在将数据传递到即时编译的训练步骤之前应用预处理。
# Create an optimizer and loss function.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
def loss_and_updates(trainable_variables, non_trainable_variables, x, y):
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x, training=True
)
loss = keras.losses.mean_squared_error(y, y_pred)
return loss, non_trainable_variables
grad_fn = jax.value_and_grad(loss_and_updates, has_aux=True)
# Create a JIT-compiled training step.
@jax.jit
def train_step(state, x, y):
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
) = state
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
)
# Build optimizer variables.
optimizer.build(model.trainable_variables)
# Assemble the training state.
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
state = trainable_variables, non_trainable_variables, optimizer_variables
# Training loop.
for (inputs, weights), labels in train_dataset:
# Explicitly preprocess the data.
preprocessed_inputs = embedding_layer.preprocess(inputs, weights)
loss, state = train_step(state, preprocessed_inputs, labels)
参数
keras_rs.layers.FeatureConfig。None 表示不进行表堆叠。"auto" 表示自动堆叠表。表名列表或表名列表的列表表示将内部列表中的表堆叠在一起。请注意,较旧的 TPU 不支持表堆叠,在这种情况下,默认值 "auto" 将被解释为不进行表堆叠。call 方法DistributedEmbedding.call(
inputs: Union[
Any,
Sequence[Union[Any, ForwardRef("Nested[T]")]],
Mapping[str, Union[Any, ForwardRef("Nested[T]")]],
],
weights: Union[
Any,
Sequence[Union[Any, ForwardRef("Nested[T]")]],
Mapping[str, Union[Any, ForwardRef("Nested[T]")]],
NoneType,
] = None,
training: bool = False,
)
在嵌入表中查找特征并应用缩减。
参数
feature_configs 相同。或者,它可能包含已预处理的输入(请参阅 preprocess)。inputs 相同,并且形状必须匹配。返回
一个嵌套结构的密集 2D 张量,这些张量是来自传入特征的缩减后的嵌入。结构与 inputs 相同。
preprocess 方法DistributedEmbedding.preprocess(
inputs: Union[
Any,
Sequence[Union[Any, ForwardRef("Nested[T]")]],
Mapping[str, Union[Any, ForwardRef("Nested[T]")]],
],
weights: Union[
Any,
Sequence[Union[Any, ForwardRef("Nested[T]")]],
Mapping[str, Union[Any, ForwardRef("Nested[T]")]],
NoneType,
] = None,
training: bool = False,
)
预处理和重构数据以供模型使用。
对于 JAX 后端,将输入数据转换为与硬件相关的格式,这是与 SparseCores 一起使用所必需的。仅当要启用 jit_compile = True 时,才需要显式调用 preprocess。
对于非 JAX 后端,预处理会将输入和权重捆绑在一起,并按设备放置分离输入。此步骤是完全可选的。
参数
返回
一组预处理后的输入,可以直接馈送到层的 inputs 参数。