代码示例 / 计算机视觉 / 使用TensorFlow Similarity进行图像相似性搜索的度量学习

使用TensorFlow Similarity进行图像相似性搜索的度量学习

作者: Owen Vallis
创建日期 2021/09/30
最后修改日期 2022/02/29
描述: 在CIFAR-10图像上使用相似性度量学习的示例。

ⓘ 此示例使用 Keras 2

在Colab中查看 GitHub源代码


概述

本示例基于“用于图像相似性搜索的度量学习”示例。我们的目标是使用相同的数据集,但使用TensorFlow Similarity来实现模型。

度量学习旨在训练模型,将输入嵌入到高维空间中,使“相似”的输入彼此靠近,“不相似”的输入彼此远离。一旦训练完成,这些模型可以为下游系统生成嵌入,其中这种相似性非常有用,例如作为搜索的排序信号或作为另一个监督问题的预训练嵌入模型。

有关度量学习的更详细概述,请参阅


设置

本教程将使用TensorFlow Similarity库来学习和评估相似性嵌入。TensorFlow Similarity提供了以下组件:

  • 使对比模型训练简单快捷。
  • 更容易确保批次包含示例对。
  • 能够评估嵌入的质量。

TensorFlow Similarity可以通过pip轻松安装,如下所示

pip -q install tensorflow_similarity
import random

from matplotlib import pyplot as plt
from mpl_toolkits import axes_grid1
import numpy as np

import tensorflow as tf
from tensorflow import keras

import tensorflow_similarity as tfsim


tfsim.utils.tf_cap_memory()

print("TensorFlow:", tf.__version__)
print("TensorFlow Similarity:", tfsim.__version__)
TensorFlow: 2.7.0
TensorFlow Similarity: 0.15.5

数据集采样器

本教程将使用CIFAR-10数据集。

为了使相似性模型有效学习,每个批次必须包含至少2个每个类别的示例。

为了简化此过程,tf_similarity提供了Sampler对象,允许您设置批次中类别的数量和每个类别最小示例数。

训练和验证数据集将使用TFDatasetMultiShotMemorySampler对象创建。这会创建一个采样器,该采样器从TensorFlow Datasets加载数据集,并生成包含目标类别数量和每个类别目标示例数量的批次。此外,我们可以限制采样器只生成class_list中定义的类别子集,从而使我们能够在类别子集上进行训练,然后测试嵌入对未见类别的泛化能力。这在处理少样本学习问题时非常有用。

以下单元格创建了一个train_ds样本,该样本

  • 从TFDS加载CIFAR-10数据集,然后获取examples_per_class_per_batch
  • 确保采样器将类别限制在class_list中定义的类别。
  • 确保每个批次包含10个不同的类别,每个类别有8个示例。

我们还以相同的方式创建了一个验证数据集,但我们将每个类别的总示例数限制为100,并且每个批次每个类别的示例数设置为默认的2。

# This determines the number of classes used during training.
# Here we are using all the classes.
num_known_classes = 10
class_list = random.sample(population=range(10), k=num_known_classes)

classes_per_batch = 10
# Passing multiple examples per class per batch ensures that each example has
# multiple positive pairs. This can be useful when performing triplet mining or
# when using losses like `MultiSimilarityLoss` or `CircleLoss` as these can
# take a weighted mix of all the positive pairs. In general, more examples per
# class will lead to more information for the positive pairs, while more classes
# per batch will provide more varied information in the negative pairs. However,
# the losses compute the pairwise distance between the examples in a batch so
# the upper limit of the batch size is restricted by the memory.
examples_per_class_per_batch = 8

print(
    "Batch size is: "
    f"{min(classes_per_batch, num_known_classes) * examples_per_class_per_batch}"
)

print(" Create Training Data ".center(34, "#"))
train_ds = tfsim.samplers.TFDatasetMultiShotMemorySampler(
    "cifar10",
    classes_per_batch=min(classes_per_batch, num_known_classes),
    splits="train",
    steps_per_epoch=4000,
    examples_per_class_per_batch=examples_per_class_per_batch,
    class_list=class_list,
)

print("\n" + " Create Validation Data ".center(34, "#"))
val_ds = tfsim.samplers.TFDatasetMultiShotMemorySampler(
    "cifar10",
    classes_per_batch=classes_per_batch,
    splits="test",
    total_examples_per_class=100,
)
Batch size is: 80
###### Create Training Data ######

converting train:   0%|          | 0/50000 [00:00<?, ?it/s]
The initial batch size is 80 (10 classes * 8 examples per class) with 0 augmenters

filtering examples:   0%|          | 0/50000 [00:00<?, ?it/s]

selecting classes:   0%|          | 0/10 [00:00<?, ?it/s]

gather examples:   0%|          | 0/50000 [00:00<?, ?it/s]

indexing classes:   0%|          | 0/50000 [00:00<?, ?it/s]
##### Create Validation Data #####

converting test:   0%|          | 0/10000 [00:00<?, ?it/s]
The initial batch size is 20 (10 classes * 2 examples per class) with 0 augmenters

filtering examples:   0%|          | 0/10000 [00:00<?, ?it/s]

selecting classes:   0%|          | 0/10 [00:00<?, ?it/s]

gather examples:   0%|          | 0/1000 [00:00<?, ?it/s]

indexing classes:   0%|          | 0/1000 [00:00<?, ?it/s]

可视化数据集

采样器将打乱数据集,因此我们可以通过绘制前25张图像来了解数据集。

采样器提供了get_slice(begin, size)方法,允许我们轻松选择一个样本块。

或者,我们可以使用generate_batch()方法生成一个批次。这可以让我们检查批次是否包含预期数量的类别和每个类别的示例。

num_cols = num_rows = 5
# Get the first 25 examples.
x_slice, y_slice = train_ds.get_slice(begin=0, size=num_cols * num_rows)

fig = plt.figure(figsize=(6.0, 6.0))
grid = axes_grid1.ImageGrid(fig, 111, nrows_ncols=(num_cols, num_rows), axes_pad=0.1)

for ax, im, label in zip(grid, x_slice, y_slice):
    ax.imshow(im)
    ax.axis("off")

png


嵌入模型

接下来,我们使用Keras函数式API定义一个SimilarityModel。该模型是一个标准卷积网络,并添加了一个应用L2归一化的MetricEmbedding层。当使用Cosine距离时,度量嵌入层很有用,因为我们只关心向量之间的角度。

此外,SimilarityModel还提供了许多辅助方法,用于

  • 索引嵌入示例
  • 执行示例查找
  • 评估分类
  • 评估嵌入空间的质量

有关更多详细信息,请参阅TensorFlow Similarity文档

embedding_size = 256

inputs = keras.layers.Input((32, 32, 3))
x = keras.layers.Rescaling(scale=1.0 / 255)(inputs)
x = keras.layers.Conv2D(64, 3, activation="relu")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Conv2D(128, 3, activation="relu")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.MaxPool2D((4, 4))(x)
x = keras.layers.Conv2D(256, 3, activation="relu")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Conv2D(256, 3, activation="relu")(x)
x = keras.layers.GlobalMaxPool2D()(x)
outputs = tfsim.layers.MetricEmbedding(embedding_size)(x)

# building model
model = tfsim.models.SimilarityModel(inputs, outputs)
model.summary()
Model: "similarity_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0         

 rescaling (Rescaling)       (None, 32, 32, 3)         0         

 conv2d (Conv2D)             (None, 30, 30, 64)        1792      

 batch_normalization (BatchN  (None, 30, 30, 64)       256       
 ormalization)                                                   

 conv2d_1 (Conv2D)           (None, 28, 28, 128)       73856     

 batch_normalization_1 (Batc  (None, 28, 28, 128)      512       
 hNormalization)                                                 

 max_pooling2d (MaxPooling2D  (None, 7, 7, 128)        0         
 )                                                               

 conv2d_2 (Conv2D)           (None, 5, 5, 256)         295168    

 batch_normalization_2 (Batc  (None, 5, 5, 256)        1024      
 hNormalization)                                                 

 conv2d_3 (Conv2D)           (None, 3, 3, 256)         590080    

 global_max_pooling2d (Globa  (None, 256)              0         
 lMaxPooling2D)                                                  

 metric_embedding (MetricEmb  (None, 256)              65792     
 edding)                                                         

=================================================================
Total params: 1,028,480
Trainable params: 1,027,584
Non-trainable params: 896
_________________________________________________________________

相似性损失

相似性损失需要批次中包含每个类别至少2个示例,从中计算成对正负距离的损失。这里我们使用MultiSimilarityLoss() (论文),它是TensorFlow Similarity中的几种损失之一。这种损失试图使用批次中所有有信息的对,考虑自相似性、正相似性和负相似性。

epochs = 3
learning_rate = 0.002
val_steps = 50

# init similarity loss
loss = tfsim.losses.MultiSimilarityLoss()

# compiling and training
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate), loss=loss, steps_per_execution=10,
)
history = model.fit(
    train_ds, epochs=epochs, validation_data=val_ds, validation_steps=val_steps
)
Distance metric automatically set to cosine use the distance arg to override.
Epoch 1/3

4000/4000 [==============================] - ETA: 0s - loss: 2.2179Warmup complete
4000/4000 [==============================] - 38s 9ms/step - loss: 2.2179 - val_loss: 0.8894
Warmup complete
Epoch 2/3
4000/4000 [==============================] - 34s 9ms/step - loss: 1.9047 - val_loss: 0.8767
Epoch 3/3
4000/4000 [==============================] - 35s 9ms/step - loss: 1.6336 - val_loss: 0.8469

索引

训练模型后,我们可以创建示例索引。这里我们通过将x和y传递给索引,并将图像存储在数据参数中,批量索引前200个验证示例。x_index值被嵌入,然后添加到索引中以使其可搜索。y_index和数据参数是可选的,但允许用户将元数据与嵌入示例关联起来。

x_index, y_index = val_ds.get_slice(begin=0, size=200)
model.reset_index()
model.index(x_index, y_index, data=x_index)
[Indexing 200 points]
|-Computing embeddings
|-Storing data points in key value store
|-Adding embeddings to index.
|-Building index.
0%   10   20   30   40   50   60   70   80   90   100%
|----|----|----|----|----|----|----|----|----|----|
***************************************************

校准

一旦索引构建完成,我们就可以使用匹配策略和校准度量来校准距离阈值。

这里我们寻找最佳F1分数,同时使用K=1作为分类器。所有等于或低于校准阈值距离的匹配都将被标记为查询示例与匹配结果关联标签之间的正匹配,而所有高于阈值距离的匹配都将被标记为负匹配。

此外,我们还传入了要计算的额外度量。输出中的所有值都在校准阈值处计算。

最后,model.calibrate()返回一个CalibrationResults对象,其中包含

  • "cutpoints":一个Python字典,将切点名称映射到一个字典,其中包含与特定距离阈值相关的ClassificationMetric值,例如,"optimal" : {"acc": 0.90, "f1": 0.92}
  • "thresholds":一个Python字典,将ClassificationMetric名称映射到包含在每个距离阈值处计算的度量值的列表,例如,{"f1": [0.99, 0.80], "distance": [0.0, 1.0]}
x_train, y_train = train_ds.get_slice(begin=0, size=1000)
calibration = model.calibrate(
    x_train,
    y_train,
    calibration_metric="f1",
    matcher="match_nearest",
    extra_metrics=["precision", "recall", "binary_accuracy"],
    verbose=1,
)
Performing NN search
Building NN list:   0%|          | 0/1000 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/4 [00:00<?, ?it/s]

computing thresholds:   0%|          | 0/989 [00:00<?, ?it/s]
 name       value    distance    precision    recall    binary_accuracy        f1
-------  -------  ----------  -----------  --------  -----------------  --------
optimal     0.93    0.048435        0.869         1              0.869  0.929909

可视化

仅从度量来看可能难以了解模型质量。一种补充方法是手动检查一组查询结果,以了解匹配质量。

这里我们选取10个验证示例,并绘制它们与5个最近邻居以及到查询示例的距离。查看结果,我们发现虽然它们不完美,但它们仍然代表有意义的相似图像,并且模型能够找到相似图像,无论它们的姿势或图像光照如何。

我们还可以看到,模型对某些图像非常有信心,导致查询和邻居之间的距离非常小。相反,我们看到随着距离变大,类别标签中的错误更多。这是校准对于匹配应用程序至关重要的原因之一。

num_neighbors = 5
labels = [
    "Airplane",
    "Automobile",
    "Bird",
    "Cat",
    "Deer",
    "Dog",
    "Frog",
    "Horse",
    "Ship",
    "Truck",
    "Unknown",
]
class_mapping = {c_id: c_lbl for c_id, c_lbl in zip(range(11), labels)}

x_display, y_display = val_ds.get_slice(begin=200, size=10)
# lookup nearest neighbors in the index
nns = model.lookup(x_display, k=num_neighbors)

# display
for idx in np.argsort(y_display):
    tfsim.visualization.viz_neigbors_imgs(
        x_display[idx],
        y_display[idx],
        nns[idx],
        class_mapping=class_mapping,
        fig_size=(16, 2),
    )
Performing NN search
Building NN list:   0%|          | 0/10 [00:00<?, ?it/s]

png

png

png

png

png

png

png

png

png

png


指标

我们还可以绘制CalibrationResults中包含的额外度量,以了解随着距离阈值的增加匹配性能如何。

以下图表显示了精确度、召回率和F1分数。我们可以看到,随着距离的增加,匹配精确度会下降,但我们接受为正匹配的查询百分比(召回率)在校准距离阈值之前增长得更快。

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
x = calibration.thresholds["distance"]

ax1.plot(x, calibration.thresholds["precision"], label="precision")
ax1.plot(x, calibration.thresholds["recall"], label="recall")
ax1.plot(x, calibration.thresholds["f1"], label="f1 score")
ax1.legend()
ax1.set_title("Metric evolution as distance increase")
ax1.set_xlabel("Distance")
ax1.set_ylim((-0.05, 1.05))

ax2.plot(calibration.thresholds["recall"], calibration.thresholds["precision"])
ax2.set_title("Precision recall curve")
ax2.set_xlabel("Recall")
ax2.set_ylabel("Precision")
ax2.set_ylim((-0.05, 1.05))
plt.show()

png

我们还可以为每个类别选取100个示例,并绘制每个示例及其最近匹配的混淆矩阵。我们还添加了一个“额外”的第10个类别,以表示超出校准距离阈值的匹配。

我们可以看到,大多数错误发生在动物类别之间,飞机和鸟类之间有一些有趣的混淆。此外,我们看到每个类别的100个示例中只有少数返回的匹配超出了校准距离阈值。

cutpoint = "optimal"

# This yields 100 examples for each class.
# We defined this when we created the val_ds sampler.
x_confusion, y_confusion = val_ds.get_slice(0, -1)

matches = model.match(x_confusion, cutpoint=cutpoint, no_match_label=10)
cm = tfsim.visualization.confusion_matrix(
    matches,
    y_confusion,
    labels=labels,
    title="Confusion matrix for cutpoint:%s" % cutpoint,
    normalize=False,
)

png


不匹配

我们可以绘制超出校准阈值的示例,以查看哪些图像不匹配任何索引示例。

这可能有助于深入了解可能需要索引的其他示例或揭示类别内的异常示例。

idx_no_match = np.where(np.array(matches) == 10)
no_match_queries = x_confusion[idx_no_match]
if len(no_match_queries):
    plt.imshow(no_match_queries[0])
else:
    print("All queries have a match below the distance threshold.")

png


可视化簇

快速了解模型质量并理解其不足的最佳方法之一是将嵌入投影到2D空间中。

这使我们能够检查图像簇并了解哪些类别是纠缠的。

# Each class in val_ds was restricted to 100 examples.
num_examples_to_clusters = 1000
thumb_size = 96
plot_size = 800
vx, vy = val_ds.get_slice(0, num_examples_to_clusters)

# Uncomment to run the interactive projector.
# tfsim.visualization.projector(
#     model.predict(vx),
#     labels=vy,
#     images=vx,
#     class_mapping=class_mapping,
#     image_size=thumb_size,
#     plot_size=plot_size,
# )