代码示例 / 计算机视觉 / 使用 NNCLR 进行自监督对比学习

使用 NNCLR 进行自监督对比学习

作者: Rishit Dagli
创建日期 2021/09/13
最后修改日期 2024/01/22
描述: NNCLR 的实现,一种用于计算机视觉的自监督学习方法。

ⓘ 本示例使用 Keras 3

在 Colab 中查看 GitHub 源代码

简介

自监督学习

自监督表示学习旨在从原始数据中获取鲁棒的样本表示,而无需昂贵的标签或注释。该领域的早期方法侧重于定义预训练任务,这些任务涉及在具有足够弱监督标签的领域上进行替代任务。经过训练以解决此类任务的编码器有望学习可能对需要昂贵注释(如图像分类)的其他下游任务有用的通用特征。

对比学习

自监督学习技术的一个广泛类别是使用对比损失的技术,这些技术已广泛应用于各种计算机视觉应用,如图像相似度降维 (DrLIM)人脸验证/识别。这些方法学习一个潜在空间,该空间将正样本聚集在一起,同时将负样本推开。

NNCLR

在本例中,我们实现 NNCLR,如 Google Research 和 DeepMind 的论文在朋友的帮助下:最近邻对比学习视觉表示中所述。

NNCLR 学习超越单实例正样本的自监督表示,这使得学习更好的特征能够对不同的视角、变形甚至类内变异保持不变。基于聚类的方法提供了一种超越单实例正样本的好方法,但假设整个聚类都是正样本可能会因早期过度泛化而损害性能。相反,NNCLR 使用学习到的表示空间中的最近邻作为正样本。此外,NNCLR 提高了现有对比学习方法(如SimCLRKeras 示例))的性能,并降低了自监督方法对数据增强策略的依赖。

以下是论文作者展示 NNCLR 如何基于 SimCLR 思想的精彩可视化

我们可以看到 SimCLR 使用同一图像的两个视图作为正样本对。通过随机数据增强生成的这两个视图通过编码器获得正嵌入对,我们最终使用两种增强。NNCLR 则保留一个表示完整数据分布的支持集嵌入,并使用最近邻形成正样本对。支持集在训练期间用作内存,类似于MoCo中的队列(即先进先出)。

本例需要安装 tensorflow_datasets,可以使用以下命令安装

!pip install tensorflow-datasets

设置

import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import os

os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import keras_cv
from keras import ops
from keras import layers

超参数

如原始论文所示,更大的 queue_size 很可能意味着更好的性能,但会引入显著的计算开销。作者表明,NNCLR 的最佳结果是在 queue_size 为 98,304(他们实验过的最大 queue_size)时实现的。我们这里使用 10,000 来展示一个可行的示例。

AUTOTUNE = tf.data.AUTOTUNE
shuffle_buffer = 5000
# The below two values are taken from https://tensorflowcn.cn/datasets/catalog/stl10
labelled_train_images = 5000
unlabelled_images = 100000

temperature = 0.1
queue_size = 10000
contrastive_augmenter = {
    "brightness": 0.5,
    "name": "contrastive_augmenter",
    "scale": (0.2, 1.0),
}
classification_augmenter = {
    "brightness": 0.2,
    "name": "classification_augmenter",
    "scale": (0.5, 1.0),
}
input_shape = (96, 96, 3)
width = 128
num_epochs = 5  # Use 25 for better results
steps_per_epoch = 50  # Use 200 for better results

加载数据集

我们从 TensorFlow Datasets 加载STL-10数据集,这是一个用于开发无监督特征学习、深度学习、自学算法的图像识别数据集。它受到 CIFAR-10 数据集的启发,并进行了一些修改。

dataset_name = "stl10"


def prepare_dataset():
    unlabeled_batch_size = unlabelled_images // steps_per_epoch
    labeled_batch_size = labelled_train_images // steps_per_epoch
    batch_size = unlabeled_batch_size + labeled_batch_size

    unlabeled_train_dataset = (
        tfds.load(
            dataset_name, split="unlabelled", as_supervised=True, shuffle_files=True
        )
        .shuffle(buffer_size=shuffle_buffer)
        .batch(unlabeled_batch_size, drop_remainder=True)
    )
    labeled_train_dataset = (
        tfds.load(dataset_name, split="train", as_supervised=True, shuffle_files=True)
        .shuffle(buffer_size=shuffle_buffer)
        .batch(labeled_batch_size, drop_remainder=True)
    )
    test_dataset = (
        tfds.load(dataset_name, split="test", as_supervised=True)
        .batch(batch_size)
        .prefetch(buffer_size=AUTOTUNE)
    )
    train_dataset = tf.data.Dataset.zip(
        (unlabeled_train_dataset, labeled_train_dataset)
    ).prefetch(buffer_size=AUTOTUNE)

    return batch_size, train_dataset, labeled_train_dataset, test_dataset


batch_size, train_dataset, labeled_train_dataset, test_dataset = prepare_dataset()

增强

其他自监督技术,如SimCLRBYOLSwAV 等,严重依赖精心设计的数据增强管道以获得最佳性能。然而,NNCLR 对复杂增强的依赖性较小,因为最近邻已经提供了样本变化的丰富性。增强管道中通常包含的一些常见技术是

  • 随机调整大小裁剪
  • 多种颜色失真
  • 高斯模糊

由于 NNCLR 对复杂增强的依赖性较小,我们只使用随机裁剪和随机亮度来增强输入图像。

准备增强模块

def augmenter(brightness, name, scale):
    return keras.Sequential(
        [
            layers.Input(shape=input_shape),
            layers.Rescaling(1 / 255),
            layers.RandomFlip("horizontal"),
            keras_cv.layers.RandomCropAndResize(
                target_size=(input_shape[0], input_shape[1]),
                crop_area_factor=scale,
                aspect_ratio_factor=(3 / 4, 4 / 3),
            ),
            keras_cv.layers.RandomBrightness(factor=brightness, value_range=(0.0, 1.0)),
        ],
        name=name,
    )

编码器架构

使用 ResNet-50 作为编码器架构是文献中的标准做法。在原始论文中,作者使用 ResNet-50 作为编码器架构,并对 ResNet-50 的输出进行空间平均。但是,请记住,更强大的模型不仅会增加训练时间,还会需要更多内存,并会限制您可使用的最大批量大小。为了本示例的目的,我们只使用四个卷积层。

def encoder():
    return keras.Sequential(
        [
            layers.Input(shape=input_shape),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
            layers.Flatten(),
            layers.Dense(width, activation="relu"),
        ],
        name="encoder",
    )

用于对比预训练的 NNCLR 模型

我们使用对比损失训练一个无标签图像上的编码器。编码器顶部连接一个非线性投影头,因为它可以提高编码器表示的质量。

class NNCLR(keras.Model):
    def __init__(
        self, temperature, queue_size,
    ):
        super().__init__()
        self.probe_accuracy = keras.metrics.SparseCategoricalAccuracy()
        self.correlation_accuracy = keras.metrics.SparseCategoricalAccuracy()
        self.contrastive_accuracy = keras.metrics.SparseCategoricalAccuracy()
        self.probe_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

        self.contrastive_augmenter = augmenter(**contrastive_augmenter)
        self.classification_augmenter = augmenter(**classification_augmenter)
        self.encoder = encoder()
        self.projection_head = keras.Sequential(
            [
                layers.Input(shape=(width,)),
                layers.Dense(width, activation="relu"),
                layers.Dense(width),
            ],
            name="projection_head",
        )
        self.linear_probe = keras.Sequential(
            [layers.Input(shape=(width,)), layers.Dense(10)], name="linear_probe"
        )
        self.temperature = temperature

        feature_dimensions = self.encoder.output_shape[1]
        self.feature_queue = keras.Variable(
            keras.utils.normalize(
                keras.random.normal(shape=(queue_size, feature_dimensions)),
                axis=1,
                order=2,
            ),
            trainable=False,
        )

    def compile(self, contrastive_optimizer, probe_optimizer, **kwargs):
        super().compile(**kwargs)
        self.contrastive_optimizer = contrastive_optimizer
        self.probe_optimizer = probe_optimizer

    def nearest_neighbour(self, projections):
        support_similarities = ops.matmul(projections, ops.transpose(self.feature_queue))
        nn_projections = ops.take(
            self.feature_queue, ops.argmax(support_similarities, axis=1), axis=0
        )
        return projections + ops.stop_gradient(nn_projections - projections)

    def update_contrastive_accuracy(self, features_1, features_2):
        features_1 = keras.utils.normalize(features_1, axis=1, order=2)
        features_2 = keras.utils.normalize(features_2, axis=1, order=2)
        similarities = ops.matmul(features_1, ops.transpose(features_2))
        batch_size = ops.shape(features_1)[0]
        contrastive_labels = ops.arange(batch_size)
        self.contrastive_accuracy.update_state(
            ops.concatenate([contrastive_labels, contrastive_labels], axis=0),
            ops.concatenate([similarities, ops.transpose(similarities)], axis=0),
        )

    def update_correlation_accuracy(self, features_1, features_2):
        features_1 = (features_1 - ops.mean(features_1, axis=0)) / ops.std(
            features_1, axis=0
        )
        features_2 = (features_2 - ops.mean(features_2, axis=0)) / ops.std(
            features_2, axis=0
        )

        batch_size = ops.shape(features_1)[0]
        cross_correlation = (
            ops.matmul(ops.transpose(features_1), features_2) / batch_size
        )

        feature_dim = ops.shape(features_1)[1]
        correlation_labels = ops.arange(feature_dim)
        self.correlation_accuracy.update_state(
            ops.concatenate([correlation_labels, correlation_labels], axis=0),
            ops.concatenate(
                [cross_correlation, ops.transpose(cross_correlation)], axis=0
            ),
        )

    def contrastive_loss(self, projections_1, projections_2):
        projections_1 = keras.utils.normalize(projections_1, axis=1, order=2)
        projections_2 = keras.utils.normalize(projections_2, axis=1, order=2)

        similarities_1_2_1 = (
            ops.matmul(
                self.nearest_neighbour(projections_1), ops.transpose(projections_2)
            )
            / self.temperature
        )
        similarities_1_2_2 = (
             ops.matmul(
                projections_2, ops.transpose(self.nearest_neighbour(projections_1))
            )
            / self.temperature
        )

        similarities_2_1_1 = (
            ops.matmul(
                self.nearest_neighbour(projections_2), ops.transpose(projections_1)
            )
            / self.temperature
        )
        similarities_2_1_2 = (
            ops.matmul(
                projections_1, ops.transpose(self.nearest_neighbour(projections_2))
            )
            / self.temperature
        )

        batch_size = ops.shape(projections_1)[0]
        contrastive_labels = ops.arange(batch_size)
        loss = keras.losses.sparse_categorical_crossentropy(
            ops.concatenate(
                [
                    contrastive_labels,
                    contrastive_labels,
                    contrastive_labels,
                    contrastive_labels,
                ],
                axis=0,
            ),
            ops.concatenate(
                [
                    similarities_1_2_1,
                    similarities_1_2_2,
                    similarities_2_1_1,
                    similarities_2_1_2,
                ],
                axis=0,
            ),
            from_logits=True,
        )

        self.feature_queue.assign(
            ops.concatenate([projections_1, self.feature_queue[:-batch_size]], axis=0)
        )
        return loss

    def train_step(self, data):
        (unlabeled_images, _), (labeled_images, labels) = data
        images = ops.concatenate((unlabeled_images, labeled_images), axis=0)
        augmented_images_1 = self.contrastive_augmenter(images)
        augmented_images_2 = self.contrastive_augmenter(images)

        with tf.GradientTape() as tape:
            features_1 = self.encoder(augmented_images_1)
            features_2 = self.encoder(augmented_images_2)
            projections_1 = self.projection_head(features_1)
            projections_2 = self.projection_head(features_2)
            contrastive_loss = self.contrastive_loss(projections_1, projections_2)
        gradients = tape.gradient(
            contrastive_loss,
            self.encoder.trainable_weights + self.projection_head.trainable_weights,
        )
        self.contrastive_optimizer.apply_gradients(
            zip(
                gradients,
                self.encoder.trainable_weights + self.projection_head.trainable_weights,
            )
        )
        self.update_contrastive_accuracy(features_1, features_2)
        self.update_correlation_accuracy(features_1, features_2)
        preprocessed_images = self.classification_augmenter(labeled_images)

        with tf.GradientTape() as tape:
            features = self.encoder(preprocessed_images)
            class_logits = self.linear_probe(features)
            probe_loss = self.probe_loss(labels, class_logits)
        gradients = tape.gradient(probe_loss, self.linear_probe.trainable_weights)
        self.probe_optimizer.apply_gradients(
            zip(gradients, self.linear_probe.trainable_weights)
        )
        self.probe_accuracy.update_state(labels, class_logits)

        return {
            "c_loss": contrastive_loss,
            "c_acc": self.contrastive_accuracy.result(),
            "r_acc": self.correlation_accuracy.result(),
            "p_loss": probe_loss,
            "p_acc": self.probe_accuracy.result(),
        }

    def test_step(self, data):
        labeled_images, labels = data

        preprocessed_images = self.classification_augmenter(
            labeled_images, training=False
        )
        features = self.encoder(preprocessed_images, training=False)
        class_logits = self.linear_probe(features, training=False)
        probe_loss = self.probe_loss(labels, class_logits)

        self.probe_accuracy.update_state(labels, class_logits)
        return {"p_loss": probe_loss, "p_acc": self.probe_accuracy.result()}

预训练 NNCLR

我们使用论文中建议的 0.1 temperature 和前面解释过的 10,000 queue_size 来训练网络。我们使用 Adam 作为对比和探测优化器。在本示例中,我们只训练模型 30 个 epoch,但为了获得更好的性能,应该训练更多 epoch。

以下两个指标可用于监控预训练性能,我们也进行了记录(取自此 Keras 示例

  • 对比准确度:自监督指标,表示图像的表示与不同增强版本表示的相似度高于当前批次中任何其他图像表示的案例比例。即使没有标记示例,自监督指标也可用于超参数调整。
  • 线性探测准确度:线性探测是评估自监督分类器的一种流行指标。它计算为在编码器特征之上训练的逻辑回归分类器的准确度。在我们的例子中,这是通过在冻结编码器之上训练一个单密集层来完成的。请注意,与传统方法中分类器在预训练阶段之后进行训练不同,本例中我们将其在预训练期间进行训练。这可能会略微降低其准确度,但这样我们可以在训练期间监控其值,这有助于实验和调试。
model = NNCLR(temperature=temperature, queue_size=queue_size)
model.compile(
    contrastive_optimizer=keras.optimizers.Adam(),
    probe_optimizer=keras.optimizers.Adam(),
    jit_compile=False,
)
pretrain_history = model.fit(
    train_dataset, epochs=num_epochs, validation_data=test_dataset
)

自监督学习在您只能访问非常有限的标记训练数据但可以设法构建大量未标记数据语料库时特别有用,如以前的方法所示,例如SEERSimCLRSwAV 等。

您还应该查看这些论文的博客文章,这些文章清楚地表明,通过首先在大规模未标记数据集上进行预训练,然后对较小的标记数据集进行微调,可以使用少量类别标签获得良好的结果

也建议您查看原始论文

非常感谢 NNCLR 论文的主要作者 Debidatta Dwibedi (Google Research) 对本示例的极具见地的评论。本示例也借鉴了SimCLR Keras 示例