KerasRS / 示例 / 使用 TPU SparseCore 和 JAX 实现 DistributedEmbedding

使用 TPU SparseCore 和 JAX 实现 DistributedEmbedding

作者: Fabien Hertschuh, Abheesht Sharma, C. Antonio Sánchez
创建日期 2025/06/03
最后修改日期 2025/09/02
描述: 使用带有 SparseCore 嵌入的双塔模型对电影进行排序。

在 Colab 中查看 GitHub 源代码


简介

基础排序 教程中,我们展示了如何为 MovieLens 数据集构建一个排序模型,以向用户推荐电影。

本教程在相同的数据集上实现了相同的模型,但使用了 keras_rs.layers.DistributedEmbedding,它利用了 TPU 上的 SparseCore。这是该教程的 JAX 版本。需要将其运行在 TPU v5p 或 v6e 上。

让我们开始选择 JAX 作为后端并导入所有必要的库。

!pip install -q -U jax[tpu]>=0.7.0
!pip install -q jax-tpu-embedding
!pip install -q tensorflow-cpu
!pip install -q keras-rs
import os

os.environ["KERAS_BACKEND"] = "jax"

import jax
import keras
import keras_rs
import tensorflow as tf  # Needed for the dataset
import tensorflow_datasets as tfds

数据集分布

模型被复制,并且嵌入表被分片到 SparseCores 上,而数据集是通过将每个批次分片到 TPU 上来分发的。我们需要确保批次大小是 TPU 数量的倍数。

PER_REPLICA_BATCH_SIZE = 256
BATCH_SIZE = PER_REPLICA_BATCH_SIZE * jax.local_device_count("tpu")

distribution = keras.distribution.DataParallel(devices=jax.devices("tpu"))
keras.distribution.set_distribution(distribution)

准备数据集

我们将使用相同的 MovieLens 数据。评分是我们试图预测的目标。

# Ratings data.
ratings = tfds.load("movielens/100k-ratings", split="train")
# Features of all the available movies.
movies = tfds.load("movielens/100k-movies", split="train")

我们需要知道用户的数量,因为我们直接使用用户 ID 作为用户嵌入表中的索引。

users_count = int(
    ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
    .reduce(tf.constant(0, tf.int32), tf.maximum)
    .numpy()
)

我们还需要知道电影的数量,因为我们直接使用电影 ID 作为电影嵌入表中的索引。

movies_count = int(movies.cardinality().numpy())

模型的输入是用户 ID 和电影 ID,标签是评分。

def preprocess_rating(x):
    return (
        # Inputs are user IDs and movie IDs
        {
            "user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32),
            "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
        },
        # Labels are ratings between 0 and 1.
        (x["user_rating"] - 1.0) / 4.0,
    )

我们将通过将 80% 的评分放入训练集,20% 放入测试集来划分数据。

shuffled_ratings = ratings.map(preprocess_rating).shuffle(
    100_000, seed=42, reshuffle_each_iteration=False
)
train_ratings = (
    shuffled_ratings.take(80_000).batch(BATCH_SIZE, drop_remainder=True).cache()
)
test_ratings = (
    shuffled_ratings.skip(80_000)
    .take(20_000)
    .batch(BATCH_SIZE, drop_remainder=True)
    .cache()
)

配置 DistributedEmbedding

keras_rs.layers.DistributedEmbedding 处理多个特征和多个嵌入表。这是为了能够共享特征之间的表,并允许将多个嵌入查找合并为一次调用所带来的一些优化。在本节中,我们将描述如何配置这些。

配置表

表使用 keras_rs.layers.TableConfig 进行配置,该配置包含:

  • 一个名称。
  • 一个词汇量大小(输入大小)。
  • 一个嵌入维度(输出大小)。
  • 一个组合器,用于指定在嵌入序列时如何将多个嵌入减少为单个嵌入。请注意,这不适用于我们的示例,因为我们为每个用户和每个电影获取一个嵌入。
  • 一个放置位置,用于指示是否将表放置在 SparseCore 芯片上。在这种情况下,我们希望放置在 "sparsecore" 位置。
  • 一个优化器,用于指定训练时如何应用梯度。每个表都有自己的优化器,传递给 model.compile() 的优化器不用于嵌入表。

配置特征

特征使用 keras_rs.layers.FeatureConfig 进行配置,该配置包含:

  • 一个名称。
  • 一个表,即要使用的嵌入表。
  • 一个输入形状(批次大小适用于所有 TPU)。
  • 一个输出形状(批次大小适用于所有 TPU)。

我们可以以任何我们想要的结构组织特征,可以是嵌套的。字典通常是一个不错的选择,可以为输入和输出提供名称。

EMBEDDING_DIMENSION = 32

movie_table = keras_rs.layers.TableConfig(
    name="movie_table",
    vocabulary_size=movies_count + 1,  # +1 for movie ID 0, which is not used
    embedding_dim=EMBEDDING_DIMENSION,
    optimizer="adam",
    placement="sparsecore",
)
user_table = keras_rs.layers.TableConfig(
    name="user_table",
    vocabulary_size=users_count + 1,  # +1 for user ID 0, which is not used
    embedding_dim=EMBEDDING_DIMENSION,
    optimizer="adam",
    placement="sparsecore",
)

FEATURE_CONFIGS = {
    "movie_id": keras_rs.layers.FeatureConfig(
        name="movie",
        table=movie_table,
        input_shape=(BATCH_SIZE,),
        output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
    ),
    "user_id": keras_rs.layers.FeatureConfig(
        name="user",
        table=user_table,
        input_shape=(BATCH_SIZE,),
        output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
    ),
}

定义模型

现在我们准备在模型中创建一个 DistributedEmbedding。一旦有了配置,我们就只需将其传递给 DistributedEmbedding 的构造函数。然后在模型 call 方法中,DistributedEmbedding 是我们调用的第一个层。

输出具有与输入完全相同的结构。在我们的示例中,我们将获取的嵌入作为输出连接起来,并通过一系列密集层运行它们。

class EmbeddingModel(keras.Model):
    """Create the model with the embedding configuration.

    Args:
        feature_configs: the configuration for `DistributedEmbedding`.
    """

    def __init__(self, feature_configs):
        super().__init__()

        self.embedding_layer = keras_rs.layers.DistributedEmbedding(
            feature_configs=feature_configs
        )
        self.ratings = keras.Sequential(
            [
                # Learn multiple dense layers.
                keras.layers.Dense(256, activation="relu"),
                keras.layers.Dense(64, activation="relu"),
                # Make rating predictions in the final layer.
                keras.layers.Dense(1),
            ]
        )

    def call(self, preprocessed_features):
        # Embedding lookup. Outputs have the same structure as the inputs.
        embedding = self.embedding_layer(preprocessed_features)
        return self.ratings(
            keras.ops.concatenate(
                [embedding["user_id"], embedding["movie_id"]],
                axis=1,
            )
        )

现在让我们实例化模型。然后我们使用 model.compile() 来配置损失、指标和优化器。同样,这个 Adagrad 优化器将仅应用于密集层,而不应用于嵌入表。

model = EmbeddingModel(FEATURE_CONFIGS)

model.compile(
    loss=keras.losses.MeanSquaredError(),
    metrics=[keras.metrics.RootMeanSquaredError()],
    optimizer="adagrad",
)

使用 JAX 后端,我们需要预处理输入,将它们转换为使用 SparseCores 所需的硬件相关格式。我们将通过将数据集包装到生成器函数中来实现此目的。

def train_dataset_generator():
    for inputs, labels in iter(train_ratings):
        yield model.embedding_layer.preprocess(inputs, training=True), labels


def test_dataset_generator():
    for inputs, labels in iter(test_ratings):
        yield model.embedding_layer.preprocess(inputs, training=False), labels

拟合并评估

我们可以使用标准的 Keras model.fit() 来训练模型。Keras 将自动使用 TPUStrategy 来分布式模型和数据。

model.fit(train_dataset_generator(), epochs=5)
Epoch 1/5

312/312 ━━━━━━━━━━━━━━━━━━━━ 14s 37ms/step - loss: 0.2746 - root_mean_squared_error: 0.5200

Epoch 2/5

312/312 ━━━━━━━━━━━━━━━━━━━━ 2s 16us/step - loss: 0.0924 - root_mean_squared_error: 0.3040

Epoch 3/5

312/312 ━━━━━━━━━━━━━━━━━━━━ 0s 18us/step - loss: 0.0922 - root_mean_squared_error: 0.3037

Epoch 4/5

312/312 ━━━━━━━━━━━━━━━━━━━━ 0s 17us/step - loss: 0.0921 - root_mean_squared_error: 0.3034

Epoch 5/5

312/312 ━━━━━━━━━━━━━━━━━━━━ 0s 18us/step - loss: 0.0919 - root_mean_squared_error: 0.3031

<keras.src.callbacks.history.History at 0x775c331de5f0>

model.evaluate() 也一样。

model.evaluate(test_dataset_generator(), return_dict=True)
78/78 ━━━━━━━━━━━━━━━━━━━━ 2s 11ms/step - loss: 0.0964 - root_mean_squared_error: 0.3105

{'loss': 0.09723417460918427, 'root_mean_squared_error': 0.31182393431663513}

就是这样。

这个示例表明,在配置 DistributedEmbedding 和设置所需的预处理之后,您可以使用标准的 Keras 工作流。