作者: Owen Vallis
创建日期 2021/09/30
最后修改日期 2022/02/29
描述: 使用相似性度量学习在 CIFAR-10 图像上的示例。
本示例基于“用于图像相似度搜索的度量学习”示例。我们的目标是使用相同的数据集,但使用 TensorFlow Similarity 来实现模型。
度量学习旨在训练可以将输入嵌入到高维空间中的模型,使得“相似”的输入彼此靠近,“不相似”的输入彼此远离。一旦训练完成,这些模型可以为下游系统生成嵌入,其中这种相似性很有用,例如作为搜索的排名信号或作为另一个监督问题的预训练嵌入模型。
有关度量学习的更详细概述,请参阅
本教程将使用 TensorFlow Similarity 库来学习和评估相似性嵌入。TensorFlow Similarity 提供的组件可以
TensorFlow Similarity 可以通过 pip 轻松安装,如下所示
pip -q install tensorflow_similarity
import random
from matplotlib import pyplot as plt
from mpl_toolkits import axes_grid1
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_similarity as tfsim
tfsim.utils.tf_cap_memory()
print("TensorFlow:", tf.__version__)
print("TensorFlow Similarity:", tfsim.__version__)
TensorFlow: 2.7.0
TensorFlow Similarity: 0.15.5
在本教程中,我们将使用 CIFAR-10 数据集。
为了使相似性模型有效地学习,每个批次必须至少包含每个类的 2 个示例。
为了简化此操作,tf_similarity 提供了 Sampler
对象,使您可以设置每个批次的类数和每个类的最少示例数。
训练和验证数据集将使用 TFDatasetMultiShotMemorySampler
对象创建。这将创建一个采样器,该采样器从 TensorFlow 数据集加载数据集,并生成包含目标类数和每个类的目标示例数的批次。此外,我们可以将采样器限制为仅生成 class_list
中定义的类的子集,从而使我们能够在类的子集上进行训练,然后测试嵌入如何推广到看不见的类。这在处理少样本学习问题时很有用。
以下单元格创建 train_ds 样本,该样本
examples_per_class_per_batch
。class_list
中定义的类。我们还以相同的方式创建验证数据集,但我们将每个类的示例总数限制为 100,并将每个批次的每个类的示例数设置为默认值 2。
# This determines the number of classes used during training.
# Here we are using all the classes.
num_known_classes = 10
class_list = random.sample(population=range(10), k=num_known_classes)
classes_per_batch = 10
# Passing multiple examples per class per batch ensures that each example has
# multiple positive pairs. This can be useful when performing triplet mining or
# when using losses like `MultiSimilarityLoss` or `CircleLoss` as these can
# take a weighted mix of all the positive pairs. In general, more examples per
# class will lead to more information for the positive pairs, while more classes
# per batch will provide more varied information in the negative pairs. However,
# the losses compute the pairwise distance between the examples in a batch so
# the upper limit of the batch size is restricted by the memory.
examples_per_class_per_batch = 8
print(
"Batch size is: "
f"{min(classes_per_batch, num_known_classes) * examples_per_class_per_batch}"
)
print(" Create Training Data ".center(34, "#"))
train_ds = tfsim.samplers.TFDatasetMultiShotMemorySampler(
"cifar10",
classes_per_batch=min(classes_per_batch, num_known_classes),
splits="train",
steps_per_epoch=4000,
examples_per_class_per_batch=examples_per_class_per_batch,
class_list=class_list,
)
print("\n" + " Create Validation Data ".center(34, "#"))
val_ds = tfsim.samplers.TFDatasetMultiShotMemorySampler(
"cifar10",
classes_per_batch=classes_per_batch,
splits="test",
total_examples_per_class=100,
)
Batch size is: 80
###### Create Training Data ######
converting train: 0%| | 0/50000 [00:00<?, ?it/s]
The initial batch size is 80 (10 classes * 8 examples per class) with 0 augmenters
filtering examples: 0%| | 0/50000 [00:00<?, ?it/s]
selecting classes: 0%| | 0/10 [00:00<?, ?it/s]
gather examples: 0%| | 0/50000 [00:00<?, ?it/s]
indexing classes: 0%| | 0/50000 [00:00<?, ?it/s]
##### Create Validation Data #####
converting test: 0%| | 0/10000 [00:00<?, ?it/s]
The initial batch size is 20 (10 classes * 2 examples per class) with 0 augmenters
filtering examples: 0%| | 0/10000 [00:00<?, ?it/s]
selecting classes: 0%| | 0/10 [00:00<?, ?it/s]
gather examples: 0%| | 0/1000 [00:00<?, ?it/s]
indexing classes: 0%| | 0/1000 [00:00<?, ?it/s]
采样器将对数据集进行洗牌,因此我们可以通过绘制前 25 张图像来了解数据集。
采样器提供 get_slice(begin, size)
方法,使我们可以轻松选择一个样本块。
或者,我们可以使用 generate_batch()
方法来生成一个批次。这可以让我们检查一个批次是否包含预期数量的类别和每个类别的示例。
num_cols = num_rows = 5
# Get the first 25 examples.
x_slice, y_slice = train_ds.get_slice(begin=0, size=num_cols * num_rows)
fig = plt.figure(figsize=(6.0, 6.0))
grid = axes_grid1.ImageGrid(fig, 111, nrows_ncols=(num_cols, num_rows), axes_pad=0.1)
for ax, im, label in zip(grid, x_slice, y_slice):
ax.imshow(im)
ax.axis("off")
接下来,我们使用 Keras 函数式 API 定义一个 SimilarityModel
。该模型是一个标准的卷积神经网络,添加了一个应用 L2 归一化的 MetricEmbedding
层。当使用 Cosine
距离时,度量嵌入层很有用,因为我们只关心向量之间的角度。
此外,SimilarityModel
提供了一些辅助方法,用于:
有关更多详细信息,请参阅 TensorFlow Similarity 文档。
embedding_size = 256
inputs = keras.layers.Input((32, 32, 3))
x = keras.layers.Rescaling(scale=1.0 / 255)(inputs)
x = keras.layers.Conv2D(64, 3, activation="relu")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Conv2D(128, 3, activation="relu")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.MaxPool2D((4, 4))(x)
x = keras.layers.Conv2D(256, 3, activation="relu")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Conv2D(256, 3, activation="relu")(x)
x = keras.layers.GlobalMaxPool2D()(x)
outputs = tfsim.layers.MetricEmbedding(embedding_size)(x)
# building model
model = tfsim.models.SimilarityModel(inputs, outputs)
model.summary()
Model: "similarity_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 32, 32, 3)] 0
rescaling (Rescaling) (None, 32, 32, 3) 0
conv2d (Conv2D) (None, 30, 30, 64) 1792
batch_normalization (BatchN (None, 30, 30, 64) 256
ormalization)
conv2d_1 (Conv2D) (None, 28, 28, 128) 73856
batch_normalization_1 (Batc (None, 28, 28, 128) 512
hNormalization)
max_pooling2d (MaxPooling2D (None, 7, 7, 128) 0
)
conv2d_2 (Conv2D) (None, 5, 5, 256) 295168
batch_normalization_2 (Batc (None, 5, 5, 256) 1024
hNormalization)
conv2d_3 (Conv2D) (None, 3, 3, 256) 590080
global_max_pooling2d (Globa (None, 256) 0
lMaxPooling2D)
metric_embedding (MetricEmb (None, 256) 65792
edding)
=================================================================
Total params: 1,028,480
Trainable params: 1,027,584
Non-trainable params: 896
_________________________________________________________________
相似性损失期望批次包含每个类别的至少 2 个示例,从中计算成对正距离和负距离的损失。这里我们使用 MultiSimilarityLoss()
(论文),它是 TensorFlow Similarity 中的几种损失之一。此损失尝试使用批次中的所有信息性对,同时考虑自相似性、正相似性和负相似性。
epochs = 3
learning_rate = 0.002
val_steps = 50
# init similarity loss
loss = tfsim.losses.MultiSimilarityLoss()
# compiling and training
model.compile(
optimizer=keras.optimizers.Adam(learning_rate), loss=loss, steps_per_execution=10,
)
history = model.fit(
train_ds, epochs=epochs, validation_data=val_ds, validation_steps=val_steps
)
Distance metric automatically set to cosine use the distance arg to override.
Epoch 1/3
4000/4000 [==============================] - ETA: 0s - loss: 2.2179Warmup complete
4000/4000 [==============================] - 38s 9ms/step - loss: 2.2179 - val_loss: 0.8894
Warmup complete
Epoch 2/3
4000/4000 [==============================] - 34s 9ms/step - loss: 1.9047 - val_loss: 0.8767
Epoch 3/3
4000/4000 [==============================] - 35s 9ms/step - loss: 1.6336 - val_loss: 0.8469
现在我们已经训练了模型,我们可以创建一个示例索引。这里我们通过将 x 和 y 传递给索引以及将图像存储在 data 参数中来批量索引前 200 个验证示例。x_index
值被嵌入,然后添加到索引中以使其可搜索。y_index
和 data 参数是可选的,但允许用户将元数据与嵌入的示例关联起来。
x_index, y_index = val_ds.get_slice(begin=0, size=200)
model.reset_index()
model.index(x_index, y_index, data=x_index)
[Indexing 200 points]
|-Computing embeddings
|-Storing data points in key value store
|-Adding embeddings to index.
|-Building index.
0% 10 20 30 40 50 60 70 80 90 100%
|----|----|----|----|----|----|----|----|----|----|
***************************************************
一旦构建了索引,我们可以使用匹配策略和校准指标来校准距离阈值。
在这里,我们正在使用 K=1 作为我们的分类器来搜索最佳 F1 分数。所有距离小于或等于校准阈值距离的匹配都将被标记为查询示例与匹配结果关联的标签之间的正匹配,而所有距离大于阈值距离的匹配都将被标记为负匹配。
此外,我们还传入额外的指标来计算。输出中的所有值都在校准的阈值处计算。
最后,model.calibrate()
返回一个 CalibrationResults
对象,其中包含:
"cutpoints"
:一个 Python 字典,将分割点名称映射到一个包含与特定距离阈值关联的 ClassificationMetric
值的字典,例如,"optimal" : {"acc": 0.90, "f1": 0.92}
。"thresholds"
:一个 Python 字典,将 ClassificationMetric
名称映射到一个列表,其中包含在每个距离阈值处计算的指标的值,例如,{"f1": [0.99, 0.80], "distance": [0.0, 1.0]}
。x_train, y_train = train_ds.get_slice(begin=0, size=1000)
calibration = model.calibrate(
x_train,
y_train,
calibration_metric="f1",
matcher="match_nearest",
extra_metrics=["precision", "recall", "binary_accuracy"],
verbose=1,
)
Performing NN search
Building NN list: 0%| | 0/1000 [00:00<?, ?it/s]
Evaluating: 0%| | 0/4 [00:00<?, ?it/s]
computing thresholds: 0%| | 0/989 [00:00<?, ?it/s]
name value distance precision recall binary_accuracy f1
------- ------- ---------- ----------- -------- ----------------- --------
optimal 0.93 0.048435 0.869 1 0.869 0.929909
仅从指标中可能难以了解模型的质量。一种补充方法是手动检查一组查询结果,以了解匹配质量。
这里我们选取 10 个验证示例,并将它们与 5 个最近的邻居以及到查询示例的距离一起绘制出来。查看结果,我们看到虽然它们是不完美的,但它们仍然代表着有意义的相似图像,并且该模型能够找到相似的图像,而与它们的姿势或图像照明无关。
我们还可以看到,该模型对某些图像非常有信心,导致查询和邻居之间的距离非常小。相反,随着距离变大,我们看到类标签中出现更多的错误。这是校准对于匹配应用至关重要的原因之一。
num_neighbors = 5
labels = [
"Airplane",
"Automobile",
"Bird",
"Cat",
"Deer",
"Dog",
"Frog",
"Horse",
"Ship",
"Truck",
"Unknown",
]
class_mapping = {c_id: c_lbl for c_id, c_lbl in zip(range(11), labels)}
x_display, y_display = val_ds.get_slice(begin=200, size=10)
# lookup nearest neighbors in the index
nns = model.lookup(x_display, k=num_neighbors)
# display
for idx in np.argsort(y_display):
tfsim.visualization.viz_neigbors_imgs(
x_display[idx],
y_display[idx],
nns[idx],
class_mapping=class_mapping,
fig_size=(16, 2),
)
Performing NN search
Building NN list: 0%| | 0/10 [00:00<?, ?it/s]
我们还可以绘制 CalibrationResults
中包含的额外指标,以了解匹配性能如何随着距离阈值的增加而变化。
以下图表显示了精确率、召回率和 F1 分数。我们可以看到,随着距离的增加,匹配的精确率会降低,但是我们接受为正匹配(召回率)的查询百分比在校准的距离阈值之前增长得更快。
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
x = calibration.thresholds["distance"]
ax1.plot(x, calibration.thresholds["precision"], label="precision")
ax1.plot(x, calibration.thresholds["recall"], label="recall")
ax1.plot(x, calibration.thresholds["f1"], label="f1 score")
ax1.legend()
ax1.set_title("Metric evolution as distance increase")
ax1.set_xlabel("Distance")
ax1.set_ylim((-0.05, 1.05))
ax2.plot(calibration.thresholds["recall"], calibration.thresholds["precision"])
ax2.set_title("Precision recall curve")
ax2.set_xlabel("Recall")
ax2.set_ylabel("Precision")
ax2.set_ylim((-0.05, 1.05))
plt.show()
我们还可以为每个类别选取 100 个示例,并绘制每个示例及其最近匹配的混淆矩阵。我们还添加了一个“额外”的第 10 个类别来表示高于校准距离阈值的匹配。
我们可以看到,大多数错误都发生在动物类别之间,并且飞机和鸟类之间存在相当数量的混淆。此外,我们看到每个类别的 100 个示例中只有少数返回了校准距离阈值之外的匹配。
cutpoint = "optimal"
# This yields 100 examples for each class.
# We defined this when we created the val_ds sampler.
x_confusion, y_confusion = val_ds.get_slice(0, -1)
matches = model.match(x_confusion, cutpoint=cutpoint, no_match_label=10)
cm = tfsim.visualization.confusion_matrix(
matches,
y_confusion,
labels=labels,
title="Confusion matrix for cutpoint:%s" % cutpoint,
normalize=False,
)
我们可以绘制校准阈值之外的示例,以查看哪些图像与任何索引示例不匹配。
这可以提供有关哪些其他示例可能需要索引的见解,或显示类别中的异常示例。
idx_no_match = np.where(np.array(matches) == 10)
no_match_queries = x_confusion[idx_no_match]
if len(no_match_queries):
plt.imshow(no_match_queries[0])
else:
print("All queries have a match below the distance threshold.")
快速了解模型运行情况并了解其缺点的最佳方法之一是将嵌入投影到 2D 空间中。
这使我们可以检查图像聚类并了解哪些类别是纠缠的。
# Each class in val_ds was restricted to 100 examples.
num_examples_to_clusters = 1000
thumb_size = 96
plot_size = 800
vx, vy = val_ds.get_slice(0, num_examples_to_clusters)
# Uncomment to run the interactive projector.
# tfsim.visualization.projector(
# model.predict(vx),
# labels=vy,
# images=vx,
# class_mapping=class_mapping,
# image_size=thumb_size,
# plot_size=plot_size,
# )