作者: Abheesht Sharma、 Fabien Hertschuh
创建日期 2025/04/28
最后修改日期 2025/04/28
描述:使用双塔模型进行电影召回(数据并行训练)。
在本教程中,我们将训练与 基础召回 教程中完全相同的召回模型,但采用分布式方式。
分布式训练用于在多个设备或机器上同时训练模型,从而减少训练时间。在这里,我们重点介绍同步数据并行训练。每个加速器(GPU/TPU)都持有模型的完整副本,并看到不同的输入数据 mini-batch。每个设备上计算局部梯度,然后进行聚合并用于计算全局梯度更新。
在开始之前,让我们注意几点
keras.distribution
API 仅适用于 JAX。因此,请确保选择 JAX 作为后端!!pip install -q keras-rs
import os
os.environ["KERAS_BACKEND"] = "jax"
import random
import jax
import keras
import tensorflow as tf # Needed only for the dataset
import tensorflow_datasets as tfds
import keras_rs
对于分布式训练中的同步数据并行策略,我们将使用 keras.distribution
API 中的 DataParallel
类。
devices = jax.devices() # Assume it has >1 local devices.
data_parallel = keras.distribution.DataParallel(devices=devices)
或者,您也可以选择使用 1D DeviceMesh
对象创建 DataParallel
对象,如下所示
mesh_1d = keras.distribution.DeviceMesh(
shape=(len(devices),), axis_names=["data"], devices=devices
)
data_parallel = keras.distribution.DataParallel(device_mesh=mesh_1d)
# Set the global distribution strategy.
keras.distribution.set_distribution(data_parallel)
定义完全局分布式策略后,本指南的其余部分与之前的基础召回指南完全相同。
让我们加载并准备数据集。这里也使用 MovieLens 数据集。
# Ratings data with user and movie data.
ratings = tfds.load("movielens/100k-ratings", split="train")
# Features of all the available movies.
movies = tfds.load("movielens/100k-movies", split="train")
# User, movie counts for defining vocabularies.
users_count = (
ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
.reduce(tf.constant(0, tf.int32), tf.maximum)
.numpy()
)
movies_count = movies.cardinality().numpy()
# Preprocess dataset, and split it into train-test datasets.
def preprocess_rating(x):
return (
# Input is the user IDs
tf.strings.to_number(x["user_id"], out_type=tf.int32),
# Labels are movie IDs + ratings between 0 and 1.
{
"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
"rating": (x["user_rating"] - 1.0) / 4.0,
},
)
shuffled_ratings = ratings.map(preprocess_rating).shuffle(
100_000, seed=42, reshuffle_each_iteration=False
)
train_ratings = shuffled_ratings.take(80_000).batch(1000).cache()
test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache()
WARNING:absl:Variant folder /root/tensorflow_datasets/movielens/100k-ratings/0.1.1 has no dataset_info.json
Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/100k-ratings/0.1.1...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Extraction completed...: 0 file [00:00, ? file/s]
Generating splits...: 0%| | 0/1 [00:00<?, ? splits/s]
Generating train examples...: 0 examples [00:00, ? examples/s]
Shuffling /root/tensorflow_datasets/movielens/100k-ratings/incomplete.2O98FR_0.1.1/movielens-train.tfrecord*..…
WARNING:absl:Variant folder /root/tensorflow_datasets/movielens/100k-movies/0.1.1 has no dataset_info.json
Dataset movielens downloaded and prepared to /root/tensorflow_datasets/movielens/100k-ratings/0.1.1. Subsequent calls will reuse this data.
Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/100k-movies/0.1.1...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Extraction completed...: 0 file [00:00, ? file/s]
Generating splits...: 0%| | 0/1 [00:00<?, ? splits/s]
Generating train examples...: 0 examples [00:00, ? examples/s]
Shuffling /root/tensorflow_datasets/movielens/100k-movies/incomplete.4QKWMO_0.1.1/movielens-train.tfrecord*...…
Dataset movielens downloaded and prepared to /root/tensorflow_datasets/movielens/100k-movies/0.1.1. Subsequent calls will reuse this data.
我们构建了一个双塔召回模型。因此,我们需要结合用于用户的查询塔和用于电影的候选塔。请注意,与之前的基本召回教程相比,这里无需更改任何内容。
class RetrievalModel(keras.Model):
"""Create the retrieval model with the provided parameters.
Args:
num_users: Number of entries in the user embedding table.
num_candidates: Number of entries in the candidate embedding table.
embedding_dimension: Output dimension for user and movie embedding tables.
"""
def __init__(
self,
num_users,
num_candidates,
embedding_dimension=32,
**kwargs,
):
super().__init__(**kwargs)
# Our query tower, simply an embedding table.
self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension)
# Our candidate tower, simply an embedding table.
self.candidate_embedding = keras.layers.Embedding(
num_candidates, embedding_dimension
)
# The layer that performs the retrieval.
self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False)
self.loss_fn = keras.losses.MeanSquaredError()
def build(self, input_shape):
self.user_embedding.build(input_shape)
self.candidate_embedding.build(input_shape)
# In this case, the candidates are directly the movie embeddings.
# We take a shortcut and directly reuse the variable.
self.retrieval.candidate_embeddings = self.candidate_embedding.embeddings
self.retrieval.build(input_shape)
super().build(input_shape)
def call(self, inputs, training=False):
user_embeddings = self.user_embedding(inputs)
result = {
"user_embeddings": user_embeddings,
}
if not training:
# Skip the retrieval of top movies during training as the
# predictions are not used.
result["predictions"] = self.retrieval(user_embeddings)
return result
def compute_loss(self, x, y, y_pred, sample_weight, training=True):
candidate_id, rating = y["movie_id"], y["rating"]
user_embeddings = y_pred["user_embeddings"]
candidate_embeddings = self.candidate_embedding(candidate_id)
labels = keras.ops.expand_dims(rating, -1)
# Compute the affinity score by multiplying the two embeddings.
scores = keras.ops.sum(
keras.ops.multiply(user_embeddings, candidate_embeddings),
axis=1,
keepdims=True,
)
return self.loss_fn(labels, scores, sample_weight)
定义模型后,我们可以使用标准的 Keras model.fit()
来训练和评估模型。
model = RetrievalModel(users_count + 1, movies_count + 1)
model.compile(optimizer=keras.optimizers.Adagrad(learning_rate=0.2))
让我们训练模型。评估需要一些时间,因此我们每隔 5 个 epoch 评估一次模型。
history = model.fit(
train_ratings, validation_data=test_ratings, validation_freq=5, epochs=50
)
Epoch 1/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - loss: 0.4772
Epoch 2/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4771
Epoch 3/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4770
Epoch 4/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4769
Epoch 5/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 3s 37ms/step - loss: 0.4769 - val_loss: 0.4836
Epoch 6/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 0.4768
Epoch 7/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4767
Epoch 8/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4766
Epoch 9/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 0.4764
Epoch 10/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4763 - val_loss: 0.4833
Epoch 11/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4761
Epoch 12/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4759
Epoch 13/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4757
Epoch 14/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4754
Epoch 15/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4750 - val_loss: 0.4821
Epoch 16/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 0.4746
Epoch 17/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 0.4740
Epoch 18/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4734
Epoch 19/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4725
Epoch 20/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4715 - val_loss: 0.4784
Epoch 21/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4702
Epoch 22/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4686
Epoch 23/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4666
Epoch 24/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4641
Epoch 25/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4609 - val_loss: 0.4664
Epoch 26/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4571
Epoch 27/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4524
Epoch 28/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4466
Epoch 29/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4395
Epoch 30/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4311 - val_loss: 0.4326
Epoch 31/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4210
Epoch 32/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4093
Epoch 33/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.3957
Epoch 34/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.3805
Epoch 35/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.3636 - val_loss: 0.3597
Epoch 36/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.3455
Epoch 37/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.3265
Epoch 38/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.3072
Epoch 39/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2880
Epoch 40/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.2696 - val_loss: 0.2664
Epoch 41/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.2523
Epoch 42/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2363
Epoch 43/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.2218
Epoch 44/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.2087
Epoch 45/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.1970 - val_loss: 0.1986
Epoch 46/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.1866
Epoch 47/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.1773
Epoch 48/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1689
Epoch 49/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.1613
Epoch 50/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.1544 - val_loss: 0.1586
现在我们有了模型,让我们运行推理并进行预测。
movie_id_to_movie_title = {
int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator()
}
movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset.
然后,我们只需使用 Keras model.predict()
方法即可。在底层,它会调用 BruteForceRetrieval
层来执行实际的召回操作。
user_ids = random.sample(range(1, 1001), len(devices))
predictions = model.predict(keras.ops.convert_to_tensor(user_ids))
predictions = keras.ops.convert_to_numpy(predictions["predictions"])
for i, user_id in enumerate(user_ids):
print(f"\n==Recommended movies for user {user_id}==")
for movie_id in predictions[i]:
print(movie_id_to_movie_title[movie_id])
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 211ms/step
==Recommended movies for user 793==
b'Star Wars (1977)'
b'Godfather, The (1972)'
b'Raiders of the Lost Ark (1981)'
b'Fargo (1996)'
b'Silence of the Lambs, The (1991)'
b"Schindler's List (1993)"
b'Shawshank Redemption, The (1994)'
b'Titanic (1997)'
b'Braveheart (1995)'
b'Pulp Fiction (1994)'
==Recommended movies for user 188==
b'Star Wars (1977)'
b'Fargo (1996)'
b'Godfather, The (1972)'
b'Silence of the Lambs, The (1991)'
b"Schindler's List (1993)"
b'Return of the Jedi (1983)'
b'Raiders of the Lost Ark (1981)'
b'Pulp Fiction (1994)'
b'Toy Story (1995)'
b'Empire Strikes Back, The (1980)'
==Recommended movies for user 865==
b'Star Wars (1977)'
b'Fargo (1996)'
b'Godfather, The (1972)'
b'Silence of the Lambs, The (1991)'
b'Raiders of the Lost Ark (1981)'
b"Schindler's List (1993)"
b'Return of the Jedi (1983)'
b'Shawshank Redemption, The (1994)'
b'Pulp Fiction (1994)'
b'Empire Strikes Back, The (1980)'
==Recommended movies for user 710==
b'Star Wars (1977)'
b'Fargo (1996)'
b'Godfather, The (1972)'
b'Silence of the Lambs, The (1991)'
b'Raiders of the Lost Ark (1981)'
b"Schindler's List (1993)"
b'Pulp Fiction (1994)'
b'Return of the Jedi (1983)'
b'Empire Strikes Back, The (1980)'
b'Toy Story (1995)'
==Recommended movies for user 721==
b'Star Wars (1977)'
b'Fargo (1996)'
b'Godfather, The (1972)'
b'Raiders of the Lost Ark (1981)'
b'Silence of the Lambs, The (1991)'
b"Schindler's List (1993)"
b'Return of the Jedi (1983)'
b'Empire Strikes Back, The (1980)'
b'Pulp Fiction (1994)'
b'Casablanca (1942)'
==Recommended movies for user 451==
b'Star Wars (1977)'
b'Raiders of the Lost Ark (1981)'
b'Godfather, The (1972)'
b'Fargo (1996)'
b'Silence of the Lambs, The (1991)'
b'Return of the Jedi (1983)'
b'Contact (1997)'
b'Casablanca (1942)'
b'Empire Strikes Back, The (1980)'
b'Pulp Fiction (1994)'
==Recommended movies for user 228==
b'Star Wars (1977)'
b'Fargo (1996)'
b'Godfather, The (1972)'
b'Raiders of the Lost Ark (1981)'
b'Silence of the Lambs, The (1991)'
b"Schindler's List (1993)"
b'Return of the Jedi (1983)'
b'Pulp Fiction (1994)'
b'Empire Strikes Back, The (1980)'
b'Shawshank Redemption, The (1994)'
==Recommended movies for user 175==
b'Star Wars (1977)'
b'Fargo (1996)'
b'Silence of the Lambs, The (1991)'
b'Raiders of the Lost Ark (1981)'
b'Return of the Jedi (1983)'
b'Casablanca (1942)'
b"Schindler's List (1993)"
b'Empire Strikes Back, The (1980)'
b'Godfather, The (1972)'
b"One Flew Over the Cuckoo's Nest (1975)"
大功告成!对于数据并行训练,我们所做的只是增加了大约 3-5 行代码。其余部分则完全相同。