作者: Fabien Hertschuh, Abheesht Sharma, C. Antonio Sánchez
创建日期 2025/06/03
最后修改日期 2025/09/02
描述: 使用带有 SparseCore 嵌入的双塔模型对电影进行排序。
在 基础排序 教程中,我们展示了如何为 MovieLens 数据集构建一个排序模型,以向用户推荐电影。
本教程在相同的数据集上实现了相同的模型,但使用了 keras_rs.layers.DistributedEmbedding,它利用了 TPU 上的 SparseCore。这是该教程的 JAX 版本。需要将其运行在 TPU v5p 或 v6e 上。
让我们开始选择 JAX 作为后端并导入所有必要的库。
!pip install -q -U jax[tpu]>=0.7.0
!pip install -q jax-tpu-embedding
!pip install -q tensorflow-cpu
!pip install -q keras-rs
import os
os.environ["KERAS_BACKEND"] = "jax"
import jax
import keras
import keras_rs
import tensorflow as tf # Needed for the dataset
import tensorflow_datasets as tfds
模型被复制,并且嵌入表被分片到 SparseCores 上,而数据集是通过将每个批次分片到 TPU 上来分发的。我们需要确保批次大小是 TPU 数量的倍数。
PER_REPLICA_BATCH_SIZE = 256
BATCH_SIZE = PER_REPLICA_BATCH_SIZE * jax.local_device_count("tpu")
distribution = keras.distribution.DataParallel(devices=jax.devices("tpu"))
keras.distribution.set_distribution(distribution)
我们将使用相同的 MovieLens 数据。评分是我们试图预测的目标。
# Ratings data.
ratings = tfds.load("movielens/100k-ratings", split="train")
# Features of all the available movies.
movies = tfds.load("movielens/100k-movies", split="train")
我们需要知道用户的数量,因为我们直接使用用户 ID 作为用户嵌入表中的索引。
users_count = int(
ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
.reduce(tf.constant(0, tf.int32), tf.maximum)
.numpy()
)
我们还需要知道电影的数量,因为我们直接使用电影 ID 作为电影嵌入表中的索引。
movies_count = int(movies.cardinality().numpy())
模型的输入是用户 ID 和电影 ID,标签是评分。
def preprocess_rating(x):
return (
# Inputs are user IDs and movie IDs
{
"user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32),
"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
},
# Labels are ratings between 0 and 1.
(x["user_rating"] - 1.0) / 4.0,
)
我们将通过将 80% 的评分放入训练集,20% 放入测试集来划分数据。
shuffled_ratings = ratings.map(preprocess_rating).shuffle(
100_000, seed=42, reshuffle_each_iteration=False
)
train_ratings = (
shuffled_ratings.take(80_000).batch(BATCH_SIZE, drop_remainder=True).cache()
)
test_ratings = (
shuffled_ratings.skip(80_000)
.take(20_000)
.batch(BATCH_SIZE, drop_remainder=True)
.cache()
)
keras_rs.layers.DistributedEmbedding 处理多个特征和多个嵌入表。这是为了能够共享特征之间的表,并允许将多个嵌入查找合并为一次调用所带来的一些优化。在本节中,我们将描述如何配置这些。
表使用 keras_rs.layers.TableConfig 进行配置,该配置包含:
"sparsecore" 位置。model.compile() 的优化器不用于嵌入表。特征使用 keras_rs.layers.FeatureConfig 进行配置,该配置包含:
我们可以以任何我们想要的结构组织特征,可以是嵌套的。字典通常是一个不错的选择,可以为输入和输出提供名称。
EMBEDDING_DIMENSION = 32
movie_table = keras_rs.layers.TableConfig(
name="movie_table",
vocabulary_size=movies_count + 1, # +1 for movie ID 0, which is not used
embedding_dim=EMBEDDING_DIMENSION,
optimizer="adam",
placement="sparsecore",
)
user_table = keras_rs.layers.TableConfig(
name="user_table",
vocabulary_size=users_count + 1, # +1 for user ID 0, which is not used
embedding_dim=EMBEDDING_DIMENSION,
optimizer="adam",
placement="sparsecore",
)
FEATURE_CONFIGS = {
"movie_id": keras_rs.layers.FeatureConfig(
name="movie",
table=movie_table,
input_shape=(BATCH_SIZE,),
output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
),
"user_id": keras_rs.layers.FeatureConfig(
name="user",
table=user_table,
input_shape=(BATCH_SIZE,),
output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
),
}
现在我们准备在模型中创建一个 DistributedEmbedding。一旦有了配置,我们就只需将其传递给 DistributedEmbedding 的构造函数。然后在模型 call 方法中,DistributedEmbedding 是我们调用的第一个层。
输出具有与输入完全相同的结构。在我们的示例中,我们将获取的嵌入作为输出连接起来,并通过一系列密集层运行它们。
class EmbeddingModel(keras.Model):
"""Create the model with the embedding configuration.
Args:
feature_configs: the configuration for `DistributedEmbedding`.
"""
def __init__(self, feature_configs):
super().__init__()
self.embedding_layer = keras_rs.layers.DistributedEmbedding(
feature_configs=feature_configs
)
self.ratings = keras.Sequential(
[
# Learn multiple dense layers.
keras.layers.Dense(256, activation="relu"),
keras.layers.Dense(64, activation="relu"),
# Make rating predictions in the final layer.
keras.layers.Dense(1),
]
)
def call(self, preprocessed_features):
# Embedding lookup. Outputs have the same structure as the inputs.
embedding = self.embedding_layer(preprocessed_features)
return self.ratings(
keras.ops.concatenate(
[embedding["user_id"], embedding["movie_id"]],
axis=1,
)
)
现在让我们实例化模型。然后我们使用 model.compile() 来配置损失、指标和优化器。同样,这个 Adagrad 优化器将仅应用于密集层,而不应用于嵌入表。
model = EmbeddingModel(FEATURE_CONFIGS)
model.compile(
loss=keras.losses.MeanSquaredError(),
metrics=[keras.metrics.RootMeanSquaredError()],
optimizer="adagrad",
)
使用 JAX 后端,我们需要预处理输入,将它们转换为使用 SparseCores 所需的硬件相关格式。我们将通过将数据集包装到生成器函数中来实现此目的。
def train_dataset_generator():
for inputs, labels in iter(train_ratings):
yield model.embedding_layer.preprocess(inputs, training=True), labels
def test_dataset_generator():
for inputs, labels in iter(test_ratings):
yield model.embedding_layer.preprocess(inputs, training=False), labels
我们可以使用标准的 Keras model.fit() 来训练模型。Keras 将自动使用 TPUStrategy 来分布式模型和数据。
model.fit(train_dataset_generator(), epochs=5)
Epoch 1/5
312/312 ━━━━━━━━━━━━━━━━━━━━ 14s 37ms/step - loss: 0.2746 - root_mean_squared_error: 0.5200
Epoch 2/5
312/312 ━━━━━━━━━━━━━━━━━━━━ 2s 16us/step - loss: 0.0924 - root_mean_squared_error: 0.3040
Epoch 3/5
312/312 ━━━━━━━━━━━━━━━━━━━━ 0s 18us/step - loss: 0.0922 - root_mean_squared_error: 0.3037
Epoch 4/5
312/312 ━━━━━━━━━━━━━━━━━━━━ 0s 17us/step - loss: 0.0921 - root_mean_squared_error: 0.3034
Epoch 5/5
312/312 ━━━━━━━━━━━━━━━━━━━━ 0s 18us/step - loss: 0.0919 - root_mean_squared_error: 0.3031
<keras.src.callbacks.history.History at 0x775c331de5f0>
model.evaluate() 也一样。
model.evaluate(test_dataset_generator(), return_dict=True)
78/78 ━━━━━━━━━━━━━━━━━━━━ 2s 11ms/step - loss: 0.0964 - root_mean_squared_error: 0.3105
{'loss': 0.09723417460918427, 'root_mean_squared_error': 0.31182393431663513}
就是这样。
这个示例表明,在配置 DistributedEmbedding 和设置所需的预处理之后,您可以使用标准的 Keras 工作流。