代码示例 / 计算机视觉 / 近似重复图像搜索

近似重复图像搜索

作者: Sayak Paul
创建日期 2021/09/10
上次修改 2023/08/30

ⓘ 此示例使用 Keras 2

在 Colab 中查看 GitHub 源码

描述:使用深度学习和局部敏感哈希构建近似重复图像搜索工具。


介绍

以(近)实时方式获取类似图像是在信息检索系统中的一个重要用例。一些使用它的流行产品包括 Pinterest、Google 图片搜索等。在本示例中,我们将使用局部敏感哈希 (LSH) 和随机投影在预训练图像分类器计算的图像表示之上构建一个类似的图像搜索工具。这种类型的搜索引擎也称为近似重复(或近似重复)图像检测器。我们还将研究如何使用TensorRT优化搜索工具在 GPU 上的推理性能。

在这方面,keras.io/examples/vision下还有其他一些值得查看的示例。

最后,此示例使用以下资源作为参考,并因此重用了其中一些代码:用于相似项目搜索的局部敏感哈希

请注意,为了优化解析器的性能,您应该拥有可用的 GPU 运行时。


设置

!pip install tensorrt

导入

import matplotlib.pyplot as plt
import tensorflow as tf
import tensorrt
import numpy as np
import time

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

加载数据集并创建一个包含 1000 张图像的训练集

为了使示例的运行时间较短,我们将使用来自tf_flowers数据集(可通过TensorFlow 数据集获得)的 1000 张图像的子集来构建我们的词汇表。

train_ds, validation_ds = tfds.load(
    "tf_flowers", split=["train[:85%]", "train[85%:]"], as_supervised=True
)

IMAGE_SIZE = 224
NUM_IMAGES = 1000

images = []
labels = []

for (image, label) in train_ds.take(NUM_IMAGES):
    image = tf.image.resize(image, (IMAGE_SIZE, IMAGE_SIZE))
    images.append(image.numpy())
    labels.append(label.numpy())

images = np.array(images)
labels = np.array(labels)

加载预训练模型

在本节中,我们加载了一个在tf_flowers数据集上训练的图像分类模型。总图像的 85% 用于构建训练集。有关训练的更多详细信息,请参阅此笔记本

底层模型是 BiT-ResNet(在大型迁移 (BiT):通用视觉表示学习中提出)。BiT-ResNet 模型系列以在各种不同的下游任务中提供出色的迁移性能而闻名。

!wget -q https://github.com/sayakpaul/near-dup-parser/releases/download/v0.1.0/flower_model_bit_0.96875.zip
!unzip -qq flower_model_bit_0.96875.zip
bit_model = tf.keras.models.load_model("flower_model_bit_0.96875")
bit_model.count_params()
23510597

创建嵌入模型

为了在给定查询图像的情况下检索类似图像,我们需要首先生成所有相关图像的向量表示。我们通过嵌入模型来实现这一点,该模型从我们的预训练分类器中提取输出特征,并对生成的特征向量进行归一化。

embedding_model = tf.keras.Sequential(
    [
        tf.keras.layers.Input((IMAGE_SIZE, IMAGE_SIZE, 3)),
        tf.keras.layers.Rescaling(scale=1.0 / 255),
        bit_model.layers[1],
        tf.keras.layers.Normalization(mean=0, variance=1),
    ],
    name="embedding_model",
)

embedding_model.summary()
Model: "embedding_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
rescaling (Rescaling)        (None, 224, 224, 3)       0         
_________________________________________________________________
keras_layer (KerasLayer)     (None, 2048)              23500352  
_________________________________________________________________
normalization (Normalization (None, 2048)              0         
=================================================================
Total params: 23,500,352
Trainable params: 23,500,352
Non-trainable params: 0
_________________________________________________________________

请注意模型内部的归一化层。它用于将表示向量投影到单位球面空间。


哈希实用程序

def hash_func(embedding, random_vectors):
    embedding = np.array(embedding)

    # Random projection.
    bools = np.dot(embedding, random_vectors) > 0
    return [bool2int(bool_vec) for bool_vec in bools]


def bool2int(x):
    y = 0
    for i, j in enumerate(x):
        if j:
            y += 1 << i
    return y

来自embedding_model的向量的形状是(2048,),考虑到实际方面(存储、检索性能等),它相当大。因此,需要在不降低嵌入向量信息内容的情况下降低嵌入向量的维数。这就是随机投影发挥作用的地方。它基于这样一个原理,即如果给定平面上一组点之间的距离近似保持不变,则该平面的维数可以进一步降低。

hash_func()内部,我们首先降低嵌入向量的维数。然后我们计算图像的按位哈希值以确定它们的哈希桶。具有相同哈希值的图像很可能进入相同的哈希桶。从部署的角度来看,按位哈希值存储和操作的成本更低。


查询实用程序

Table类负责构建单个哈希表。哈希表中的每个条目都是我们数据集中的图像的缩减嵌入与唯一标识符之间的映射。因为我们的降维技术涉及随机性,所以每次运行过程时,类似图像可能不会映射到相同的哈希桶。为了减少这种影响,我们将考虑来自多个表的的结果——表的数量和降维是此处的关键超参数。

至关重要的是,在处理现实世界应用程序时,您不会自己重新实现局部敏感哈希。相反,您可能会使用以下流行库之一

class Table:
    def __init__(self, hash_size, dim):
        self.table = {}
        self.hash_size = hash_size
        self.random_vectors = np.random.randn(hash_size, dim).T

    def add(self, id, vectors, label):
        # Create a unique indentifier.
        entry = {"id_label": str(id) + "_" + str(label)}

        # Compute the hash values.
        hashes = hash_func(vectors, self.random_vectors)

        # Add the hash values to the current table.
        for h in hashes:
            if h in self.table:
                self.table[h].append(entry)
            else:
                self.table[h] = [entry]

    def query(self, vectors):
        # Compute hash value for the query vector.
        hashes = hash_func(vectors, self.random_vectors)
        results = []

        # Loop over the query hashes and determine if they exist in
        # the current table.
        for h in hashes:
            if h in self.table:
                results.extend(self.table[h])
        return results

在下面的LSH类中,我们将打包实用程序以拥有多个哈希表。

class LSH:
    def __init__(self, hash_size, dim, num_tables):
        self.num_tables = num_tables
        self.tables = []
        for i in range(self.num_tables):
            self.tables.append(Table(hash_size, dim))

    def add(self, id, vectors, label):
        for table in self.tables:
            table.add(id, vectors, label)

    def query(self, vectors):
        results = []
        for table in self.tables:
            results.extend(table.query(vectors))
        return results

现在,我们可以将构建和操作主 LSH 表(许多表的集合)的逻辑封装在一个类中。它有两个方法

  • train():负责构建最终的 LSH 表。
  • query():计算给定查询图像的匹配数,并量化相似度得分。
class BuildLSHTable:
    def __init__(
        self,
        prediction_model,
        concrete_function=False,
        hash_size=8,
        dim=2048,
        num_tables=10,
    ):
        self.hash_size = hash_size
        self.dim = dim
        self.num_tables = num_tables
        self.lsh = LSH(self.hash_size, self.dim, self.num_tables)

        self.prediction_model = prediction_model
        self.concrete_function = concrete_function

    def train(self, training_files):
        for id, training_file in enumerate(training_files):
            # Unpack the data.
            image, label = training_file
            if len(image.shape) < 4:
                image = image[None, ...]

            # Compute embeddings and update the LSH tables.
            # More on `self.concrete_function()` later.
            if self.concrete_function:
                features = self.prediction_model(tf.constant(image))[
                    "normalization"
                ].numpy()
            else:
                features = self.prediction_model.predict(image)
            self.lsh.add(id, features, label)

    def query(self, image, verbose=True):
        # Compute the embeddings of the query image and fetch the results.
        if len(image.shape) < 4:
            image = image[None, ...]

        if self.concrete_function:
            features = self.prediction_model(tf.constant(image))[
                "normalization"
            ].numpy()
        else:
            features = self.prediction_model.predict(image)

        results = self.lsh.query(features)
        if verbose:
            print("Matches:", len(results))

        # Calculate Jaccard index to quantify the similarity.
        counts = {}
        for r in results:
            if r["id_label"] in counts:
                counts[r["id_label"]] += 1
            else:
                counts[r["id_label"]] = 1
        for k in counts:
            counts[k] = float(counts[k]) / self.dim
        return counts

创建 LSH 表

在我们实现的辅助工具和类之后,我们现在可以构建我们的 LSH 表。由于我们将在优化和未优化嵌入模型之间进行性能基准测试,因此我们还将预热 GPU 以避免任何不公平的比较。

# Utility to warm up the GPU.
def warmup():
    dummy_sample = tf.ones((1, IMAGE_SIZE, IMAGE_SIZE, 3))
    for _ in range(100):
        _ = embedding_model.predict(dummy_sample)

现在,我们可以先进行 GPU 预热,然后使用 embedding_model 构建主 LSH 表。

warmup()

training_files = zip(images, labels)
lsh_builder = BuildLSHTable(embedding_model)
lsh_builder.train(training_files)

在撰写本文时,在 Tesla T4 GPU 上的墙上时间为 54.1 秒。此时间可能会根据您使用的 GPU 而有所不同。


使用 TensorRT 优化模型

对于基于 NVIDIA 的 GPU,TensorRT 框架可用于通过使用各种模型优化技术(如剪枝、常量折叠、层融合等)来大幅提高推理延迟。在这里,我们将使用 tf.experimental.tensorrt 模块来优化我们的嵌入模型。

# First serialize the embedding model as a SavedModel.
embedding_model.save("embedding_model")

# Initialize the conversion parameters.
params = tf.experimental.tensorrt.ConversionParams(
    precision_mode="FP16", maximum_cached_engines=16
)

# Run the conversion.
converter = tf.experimental.tensorrt.Converter(
    input_saved_model_dir="embedding_model", conversion_params=params
)
converter.convert()
converter.save("tensorrt_embedding_model")
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.

WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.

INFO:tensorflow:Assets written to: embedding_model/assets

INFO:tensorflow:Assets written to: embedding_model/assets

INFO:tensorflow:Linked TensorRT version: (0, 0, 0)

INFO:tensorflow:Linked TensorRT version: (0, 0, 0)

INFO:tensorflow:Loaded TensorRT version: (0, 0, 0)

INFO:tensorflow:Loaded TensorRT version: (0, 0, 0)

INFO:tensorflow:Assets written to: tensorrt_embedding_model/assets

INFO:tensorflow:Assets written to: tensorrt_embedding_model/assets

关于 tf.experimental.tensorrt.ConversionParams() 内部的参数的说明:

  • precision_mode 定义待转换模型中操作的数值精度。
  • maximum_cached_engines 指定将缓存以处理动态操作(具有未知形状的操作)的 TRT 引擎的最大数量。

要了解有关其他选项的更多信息,请参阅 官方文档。您还可以探索 tf.experimental.tensorrt 模块提供的不同量化选项。

# Load the converted model.
root = tf.saved_model.load("tensorrt_embedding_model")
trt_model_function = root.signatures["serving_default"]

使用优化模型构建 LSH 表

warmup()

training_files = zip(images, labels)
lsh_builder_trt = BuildLSHTable(trt_model_function, concrete_function=True)
lsh_builder_trt.train(training_files)

请注意墙上时间的差异,为 13.1 秒。之前,使用未优化的模型时为 54.1 秒

我们可以更仔细地查看其中一个哈希表,并了解它们是如何表示的。

idx = 0
for hash, entry in lsh_builder_trt.lsh.tables[0].table.items():
    if idx == 5:
        break
    if len(entry) < 5:
        print(hash, entry)
        idx += 1
145 [{'id_label': '3_4'}, {'id_label': '727_3'}]
5 [{'id_label': '12_4'}]
128 [{'id_label': '30_2'}, {'id_label': '480_2'}]
208 [{'id_label': '34_2'}, {'id_label': '132_2'}, {'id_label': '984_2'}]
188 [{'id_label': '42_0'}, {'id_label': '135_3'}, {'id_label': '436_3'}, {'id_label': '670_3'}]

可视化验证图像上的结果

在本节中,我们将首先编写几个实用程序函数来可视化类似图像解析过程。然后,我们将对使用和不使用优化的模型的查询性能进行基准测试。

首先,我们从验证集中获取 100 张图像用于测试目的。

validation_images = []
validation_labels = []

for image, label in validation_ds.take(100):
    image = tf.image.resize(image, (224, 224))
    validation_images.append(image.numpy())
    validation_labels.append(label.numpy())

validation_images = np.array(validation_images)
validation_labels = np.array(validation_labels)
validation_images.shape, validation_labels.shape
((100, 224, 224, 3), (100,))

现在我们编写可视化实用程序。

def plot_images(images, labels):
    plt.figure(figsize=(20, 10))
    columns = 5
    for (i, image) in enumerate(images):
        ax = plt.subplot(len(images) // columns + 1, columns, i + 1)
        if i == 0:
            ax.set_title("Query Image\n" + "Label: {}".format(labels[i]))
        else:
            ax.set_title("Similar Image # " + str(i) + "\nLabel: {}".format(labels[i]))
        plt.imshow(image.astype("int"))
        plt.axis("off")


def visualize_lsh(lsh_class):
    idx = np.random.choice(len(validation_images))
    image = validation_images[idx]
    label = validation_labels[idx]
    results = lsh_class.query(image)

    candidates = []
    labels = []
    overlaps = []

    for idx, r in enumerate(sorted(results, key=results.get, reverse=True)):
        if idx == 4:
            break
        image_id, label = r.split("_")[0], r.split("_")[1]
        candidates.append(images[int(image_id)])
        labels.append(label)
        overlaps.append(results[r])

    candidates.insert(0, image)
    labels.insert(0, label)

    plot_images(candidates, labels)

非 TRT 模型

for _ in range(5):
    visualize_lsh(lsh_builder)

visualize_lsh(lsh_builder)
Matches: 507
Matches: 554
Matches: 438
Matches: 370
Matches: 407
Matches: 306

png

png

png

png

png

png

TRT 模型

for _ in range(5):
    visualize_lsh(lsh_builder_trt)
Matches: 458
Matches: 181
Matches: 280
Matches: 280
Matches: 503

png

png

png

png

png

您可能已经注意到,有一些错误的结果。这可以通过几种方法来缓解

  • 更好的模型来生成初始嵌入,尤其是在噪声样本的情况下。我们可以使用诸如 ArcFace监督对比学习 等技术,这些技术隐式地鼓励更好地学习用于检索目的的表示。
  • 表数和降维之间的权衡至关重要,有助于为您的应用程序设置所需的正确召回率。

查询性能基准测试

def benchmark(lsh_class):
    warmup()

    start_time = time.time()
    for _ in range(1000):
        image = np.ones((1, 224, 224, 3)).astype("float32")
        _ = lsh_class.query(image, verbose=False)
    end_time = time.time() - start_time
    print(f"Time taken: {end_time:.3f}")


benchmark(lsh_builder)

benchmark(lsh_builder_trt)
Time taken: 54.359
Time taken: 13.963

我们可以立即注意到这两个模型的查询性能之间存在明显差异。


最后备注

在这个例子中,我们探索了 NVIDIA 的 TensorRT 框架来优化我们的模型。它最适合基于 GPU 的推理服务器。还有其他此类框架的选择,可以满足不同的硬件平台

  • TensorFlow Lite 用于移动和边缘设备。
  • ONNX 用于商品 CPU 基于服务器。
  • Apache TVM,涵盖各种平台的机器学习模型编译器。

以下是一些您可能想要查看的资源,以了解有关基于向量相似性搜索的应用程序的更多信息