代码示例 / 计算机视觉 / 探究 Vision Transformer 表示

探究 Vision Transformer 表示

作者: Aritra Roy Gosthipaty, Sayak Paul (同等贡献)
创建日期 2022/04/12
最后修改 2023/11/20
描述: 探究不同 Vision Transformer 变体学习到的表示。

ⓘ 本示例使用 Keras 3

在 Colab 中查看 GitHub 源代码


引言

在本示例中,我们将探究不同 Vision Transformer (ViT) 模型学习到的表示。本示例的主要目标是深入了解 Vision Transformer 如何从图像数据中学习。具体来说,本示例讨论了几种不同的 Vision Transformer 分析工具的实现。

注意: 当我们提及 "Vision Transformer" 时,我们指的是包含 Transformer 块的计算机视觉架构(Vaswani 等人),而不一定特指原始的 Vision Transformer 模型(Dosovitskiy 等人)。


考虑的模型

自原始 Vision Transformer 问世以来,计算机视觉社区已经看到了许多不同的 ViT 变体,它们在训练改进、架构改进等方面对原始模型进行了提升。在本示例中,我们将考虑以下 ViT 模型系列:

  • 使用 ImageNet-1k 和 ImageNet-21k 数据集进行监督预训练的 Vision Transformer(Dosovitskiy 等人
  • 仅使用 ImageNet-1k 数据集进行监督预训练,并采用更多正则化和蒸馏的 Vision Transformer(Touvron 等人)(DeiT)。
  • 使用自监督预训练的 Vision Transformer(Caron 等人)(DINO)。

由于这些预训练模型并未在 Keras 中实现,我们首先尽可能忠实地实现了它们。然后,我们用官方的预训练参数填充了这些模型。最后,我们在 ImageNet-1k 验证集上评估了我们的实现,以确保评估结果与原始实现相匹配。我们实现的详细信息可在此仓库中找到。

为了使示例简洁,我们不会将每种模型与所有分析方法一一配对。我们将在相应的章节中提供注释,以便您可以自行尝试。

要在 Google Colab 上运行此示例,我们需要更新 gdown 库,如下所示:

pip install -U gdown -q

导入

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import zipfile
from io import BytesIO

import cv2
import matplotlib.pyplot as plt
import numpy as np
import requests

from PIL import Image
from sklearn.preprocessing import MinMaxScaler
import keras
from keras import ops

常量

RESOLUTION = 224
PATCH_SIZE = 16
GITHUB_RELEASE = "https://github.com/sayakpaul/probing-vits/releases/download/v1.0.0/probing_vits.zip"
FNAME = "probing_vits.zip"
MODELS_ZIP = {
    "vit_dino_base16": "Probing_ViTs/vit_dino_base16.zip",
    "vit_b16_patch16_224": "Probing_ViTs/vit_b16_patch16_224.zip",
    "vit_b16_patch16_224-i1k_pretrained": "Probing_ViTs/vit_b16_patch16_224-i1k_pretrained.zip",
}

数据工具函数

对于原始的 ViT 模型,输入图像需要缩放到 [-1, 1] 的范围。对于开头提到的其他模型系列,我们需要使用 ImageNet-1k 训练集的通道均值和标准差对图像进行归一化。

crop_layer = keras.layers.CenterCrop(RESOLUTION, RESOLUTION)
norm_layer = keras.layers.Normalization(
    mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
    variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
)
rescale_layer = keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1)


def preprocess_image(image, model_type, size=RESOLUTION):
    # Turn the image into a numpy array and add batch dim.
    image = np.array(image)
    image = ops.expand_dims(image, 0)

    # If model type is vit rescale the image to [-1, 1].
    if model_type == "original_vit":
        image = rescale_layer(image)

    # Resize the image using bicubic interpolation.
    resize_size = int((256 / 224) * size)
    image = ops.image.resize(image, (resize_size, resize_size), interpolation="bicubic")

    # Crop the image.
    image = crop_layer(image)

    # If model type is DeiT or DINO normalize the image.
    if model_type != "original_vit":
        image = norm_layer(image)

    return ops.convert_to_numpy(image)


def load_image_from_url(url, model_type):
    # Credit: Willi Gierke
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    preprocessed_image = preprocess_image(image, model_type)
    return image, preprocessed_image

加载并显示测试图像

# ImageNet-1k label mapping file and load it.

mapping_file = keras.utils.get_file(
    origin="https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
)

with open(mapping_file, "r") as f:
    lines = f.readlines()
imagenet_int_to_str = [line.rstrip() for line in lines]

img_url = "https://dl.fbaipublicfiles.com/dino/img.png"
image, preprocessed_image = load_image_from_url(img_url, model_type="original_vit")

plt.imshow(image)
plt.axis("off")
plt.show()

png


加载模型

zip_path = keras.utils.get_file(
    fname=FNAME,
    origin=GITHUB_RELEASE,
)

with zipfile.ZipFile(zip_path, "r") as zip_ref:
    zip_ref.extractall("./")

os.rename("Probing ViTs", "Probing_ViTs")


def load_model(model_path: str) -> keras.Model:
    with zipfile.ZipFile(model_path, "r") as zip_ref:
        zip_ref.extractall("Probing_ViTs/")
    model_name = model_path.split(".")[0]

    inputs = keras.Input((RESOLUTION, RESOLUTION, 3))
    model = keras.layers.TFSMLayer(model_name, call_endpoint="serving_default")
    outputs = model(inputs, training=False)

    return keras.Model(inputs, outputs=outputs)


vit_base_i21k_patch16_224 = load_model(MODELS_ZIP["vit_b16_patch16_224-i1k_pretrained"])
print("Model loaded.")
Model loaded.

关于模型的更多信息:

该模型在 ImageNet-21k 数据集上进行了预训练,然后在 ImageNet-1k 数据集上进行了微调。要了解更多关于我们如何在 TensorFlow 中开发此模型(使用来自此来源的预训练权重)的信息,请参阅此 Notebook


使用模型运行常规推理

现在我们使用加载的模型对测试图像运行推理。

def split_prediction_and_attention_scores(outputs):
    predictions = outputs["output_1"]
    attention_score_dict = {}
    for key, value in outputs.items():
        if key.startswith("output_2_"):
            attention_score_dict[key[len("output_2_") :]] = value
    return predictions, attention_score_dict


predictions, attention_score_dict = split_prediction_and_attention_scores(
    vit_base_i21k_patch16_224.predict(preprocessed_image)
)
predicted_label = imagenet_int_to_str[int(np.argmax(predictions))]
print(predicted_label)
 1/1 ━━━━━━━━━━━━━━━━━━━━ 5s 5s/step
toucan

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1700526824.965785   75784 device_compiler.h:187] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.

attention_score_dict 包含来自每个 Transformer 块的每个注意力头的注意力得分(softmax 输出)。


方法一:平均注意力距离

Dosovitskiy 等人Raghu 等人使用一种称为“平均注意力距离”的度量,从不同 Transformer 块的每个注意力头中提取,以了解局部和全局信息如何流入 Vision Transformers。

平均注意力距离定义为查询 tokens 与其他 tokens 之间的距离乘以注意力权重。因此,对于单张图像:

  • 我们提取图像中的单个 patch(即 token),
  • 计算它们的几何距离,然后
  • 将其与注意力得分相乘。

在推理模式下将图像前向通过网络后,此处计算注意力得分。下图可能有助于您更好地理解该过程。

此动画由 Ritwik Raha 创建。

def compute_distance_matrix(patch_size, num_patches, length):
    distance_matrix = np.zeros((num_patches, num_patches))
    for i in range(num_patches):
        for j in range(num_patches):
            if i == j:  # zero distance
                continue

            xi, yi = (int(i / length)), (i % length)
            xj, yj = (int(j / length)), (j % length)
            distance_matrix[i, j] = patch_size * np.linalg.norm([xi - xj, yi - yj])

    return distance_matrix


def compute_mean_attention_dist(patch_size, attention_weights, model_type):
    num_cls_tokens = 2 if "distilled" in model_type else 1

    # The attention_weights shape = (batch, num_heads, num_patches, num_patches)
    attention_weights = attention_weights[
        ..., num_cls_tokens:, num_cls_tokens:
    ]  # Removing the CLS token
    num_patches = attention_weights.shape[-1]
    length = int(np.sqrt(num_patches))
    assert length**2 == num_patches, "Num patches is not perfect square"

    distance_matrix = compute_distance_matrix(patch_size, num_patches, length)
    h, w = distance_matrix.shape

    distance_matrix = distance_matrix.reshape((1, 1, h, w))
    # The attention_weights along the last axis adds to 1
    # this is due to the fact that they are softmax of the raw logits
    # summation of the (attention_weights * distance_matrix)
    # should result in an average distance per token.
    mean_distances = attention_weights * distance_matrix
    mean_distances = np.sum(
        mean_distances, axis=-1
    )  # Sum along last axis to get average distance per token
    mean_distances = np.mean(
        mean_distances, axis=-1
    )  # Now average across all the tokens

    return mean_distances

感谢 Google 的 Simon Kornblith 帮助我们提供了此代码片段。它可以在这里找到。现在,让我们使用这些工具函数,结合我们加载的模型和测试图像,生成注意力距离的图。

# Build the mean distances for every Transformer block.
mean_distances = {
    f"{name}_mean_dist": compute_mean_attention_dist(
        patch_size=PATCH_SIZE,
        attention_weights=attention_weight,
        model_type="original_vit",
    )
    for name, attention_weight in attention_score_dict.items()
}

# Get the number of heads from the mean distance output.
num_heads = mean_distances["transformer_block_0_att_mean_dist"].shape[-1]

# Print the shapes
print(f"Num Heads: {num_heads}.")

plt.figure(figsize=(9, 9))

for idx in range(len(mean_distances)):
    mean_distance = mean_distances[f"transformer_block_{idx}_att_mean_dist"]
    x = [idx] * num_heads
    y = mean_distance[0, :]
    plt.scatter(x=x, y=y, label=f"transformer_block_{idx}")

plt.legend(loc="lower right")
plt.xlabel("Attention Head", fontsize=14)
plt.ylabel("Attention Distance", fontsize=14)
plt.title("vit_base_i21k_patch16_224", fontsize=14)
plt.grid()
plt.show()
Num Heads: 12.

png

检查图表

自注意力机制如何跨越输入空间?它们是关注局部区域还是全局区域?

自注意力机制的优势在于能够学习上下文依赖关系,从而使模型能够关注输入中最显著的区域。从上面的图表中我们可以注意到,不同的注意力头产生不同的注意力距离,这表明它们使用了图像中的局部和全局信息。但随着我们在 Transformer 块中深入,注意力头倾向于更多地关注全局聚合信息。

Raghu 等人的启发,我们计算了从 ImageNet-1k 验证集中随机抽取的 1000 张图像的平均注意力距离,并对开头提到的所有模型重复了该过程。有趣的是,我们注意到以下几点:

  • 使用更大的数据集进行预训练有助于获得更全局的注意力跨度
在 ImageNet-21k 上预训练
在 ImageNet-1k 上微调
在 ImageNet-1k 上预训练
  • 当从 CNN 进行蒸馏时,Vision Transformer 倾向于具有较小的全局注意力跨度
未蒸馏 (来自 DeiT 的 ViT B-16) 来自 DeiT 的蒸馏 ViT B-16

要重现这些图表,请参阅此 notebook


方法二:注意力展开 (Attention Rollout)

Abnar 等人引入了“注意力展开”(Attention rollout)来量化信息如何在 Transformer 块的自注意力层中流动。原始 Vision Transformer 作者使用此方法来研究学习到的表示,他们指出:

简而言之,我们对 ViTL/16 中所有注意力头的权重进行了平均,然后递归地将所有层的权重矩阵相乘。这解释了注意力在所有层中 token 之间的混合。

我们使用了此 notebook,并修改了其中的注意力展开代码,使其与我们的模型兼容。

def attention_rollout_map(image, attention_score_dict, model_type):
    num_cls_tokens = 2 if "distilled" in model_type else 1

    # Stack the individual attention matrices from individual Transformer blocks.
    attn_mat = ops.stack([attention_score_dict[k] for k in attention_score_dict.keys()])
    attn_mat = ops.squeeze(attn_mat, axis=1)

    # Average the attention weights across all heads.
    attn_mat = ops.mean(attn_mat, axis=1)

    # To account for residual connections, we add an identity matrix to the
    # attention matrix and re-normalize the weights.
    residual_attn = ops.eye(attn_mat.shape[1])
    aug_attn_mat = attn_mat + residual_attn
    aug_attn_mat = aug_attn_mat / ops.sum(aug_attn_mat, axis=-1)[..., None]
    aug_attn_mat = ops.convert_to_numpy(aug_attn_mat)

    # Recursively multiply the weight matrices.
    joint_attentions = np.zeros(aug_attn_mat.shape)
    joint_attentions[0] = aug_attn_mat[0]

    for n in range(1, aug_attn_mat.shape[0]):
        joint_attentions[n] = np.matmul(aug_attn_mat[n], joint_attentions[n - 1])

    # Attention from the output token to the input space.
    v = joint_attentions[-1]
    grid_size = int(np.sqrt(aug_attn_mat.shape[-1]))
    mask = v[0, num_cls_tokens:].reshape(grid_size, grid_size)
    mask = cv2.resize(mask / mask.max(), image.size)[..., np.newaxis]
    result = (mask * image).astype("uint8")
    return result

现在让我们使用这些工具函数,基于前面“使用模型运行常规推理”部分的结果,生成注意力图。以下是下载每个独立模型的链接:

attn_rollout_result = attention_rollout_map(
    image, attention_score_dict, model_type="original_vit"
)

fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(8, 10))
fig.suptitle(f"Predicted label: {predicted_label}.", fontsize=20)

_ = ax1.imshow(image)
_ = ax2.imshow(attn_rollout_result)
ax1.set_title("Input Image", fontsize=16)
ax2.set_title("Attention Map", fontsize=16)
ax1.axis("off")
ax2.axis("off")

fig.tight_layout()
fig.subplots_adjust(top=1.35)
fig.show()

png

检查图表

我们如何量化通过注意力层传播的信息流?

我们注意到模型能够将其注意力集中在输入图像的显著部分。我们鼓励您将此方法应用于我们提到的其他模型并比较结果。注意力展开图会根据模型训练的任务和数据增强方式而有所不同。我们观察到 DeiT 具有最佳的展开图,这可能归因于其数据增强策略。


方法三:注意力热力图

一种简单而有用的方法来探究 Vision Transformer 的表示是可视化叠加在输入图像上的注意力图。这有助于直观了解模型关注的内容。我们为此使用了 DINO 模型,因为它能生成更好的注意力热力图。

# Load the model.
vit_dino_base16 = load_model(MODELS_ZIP["vit_dino_base16"])
print("Model loaded.")

# Preprocess the same image but with normlization.
img_url = "https://dl.fbaipublicfiles.com/dino/img.png"
image, preprocessed_image = load_image_from_url(img_url, model_type="dino")

# Grab the predictions.
predictions, attention_score_dict = split_prediction_and_attention_scores(
    vit_dino_base16.predict(preprocessed_image)
)
Model loaded.
 1/1 ━━━━━━━━━━━━━━━━━━━━ 4s 4s/step

Transformer 块由多个注意力头组成。Transformer 块中的每个注意力头将输入数据投影到不同的子空间。这有助于每个独立的头关注图像的不同部分。因此,分别可视化每个注意力头的图以理解每个头所关注的内容是有意义的。

注意:

  • 以下代码是从原始 DINO 代码库复制修改而来。
  • 这里我们获取最后一个 Transformer 块的注意力图。
  • DINO 使用自监督目标进行预训练。
def attention_heatmap(attention_score_dict, image, model_type="dino"):
    num_tokens = 2 if "distilled" in model_type else 1

    # Sort the Transformer blocks in order of their depth.
    attention_score_list = list(attention_score_dict.keys())
    attention_score_list.sort(key=lambda x: int(x.split("_")[-2]), reverse=True)

    # Process the attention maps for overlay.
    w_featmap = image.shape[2] // PATCH_SIZE
    h_featmap = image.shape[1] // PATCH_SIZE
    attention_scores = attention_score_dict[attention_score_list[0]]

    # Taking the representations from CLS token.
    attentions = attention_scores[0, :, 0, num_tokens:].reshape(num_heads, -1)

    # Reshape the attention scores to resemble mini patches.
    attentions = attentions.reshape(num_heads, w_featmap, h_featmap)
    attentions = attentions.transpose((1, 2, 0))

    # Resize the attention patches to 224x224 (224: 14x16).
    attentions = ops.image.resize(
        attentions, size=(h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE)
    )
    return attentions

我们可以使用与 DINO 进行推理时使用的同一张图像,以及从结果中提取的 attention_score_dict

# De-normalize the image for visual clarity.
in1k_mean = np.array([0.485 * 255, 0.456 * 255, 0.406 * 255])
in1k_std = np.array([0.229 * 255, 0.224 * 255, 0.225 * 255])
preprocessed_img_orig = (preprocessed_image * in1k_std) + in1k_mean
preprocessed_img_orig = preprocessed_img_orig / 255.0
preprocessed_img_orig = ops.convert_to_numpy(ops.clip(preprocessed_img_orig, 0.0, 1.0))

# Generate the attention heatmaps.
attentions = attention_heatmap(attention_score_dict, preprocessed_img_orig)

# Plot the maps.
fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(13, 13))
img_count = 0

for i in range(3):
    for j in range(4):
        if img_count < len(attentions):
            axes[i, j].imshow(preprocessed_img_orig[0])
            axes[i, j].imshow(attentions[..., img_count], cmap="inferno", alpha=0.6)
            axes[i, j].title.set_text(f"Attention head: {img_count}")
            axes[i, j].axis("off")
            img_count += 1

png

检查图表

我们如何定性评估注意力权重?

Transformer 块的注意力权重在 key 和 query 之间计算。权重量化了 key 对 query 的重要程度。在 Vision Transformer 中,key 和 query 来自同一张图像,因此权重决定了图像的哪个部分是重要的。

将注意力权重叠加在图像上绘制出来,可以很好地直观了解对 Transformer 来说重要的图像部分。此图定性地评估了注意力权重的作用。


方法四:可视化学习到的投影滤波器

提取非重叠的 patch 后,Vision Transformer 会将这些 patch 沿空间维度展平,然后对其进行线性投影。人们可能会想,这些投影是什么样的?下面,我们以 ViT B-16 模型为例,可视化其学习到的投影。

def extract_weights(model, name):
    for variable in model.weights:
        if variable.name.startswith(name):
            return variable.numpy()


# Extract the projections.
projections = extract_weights(vit_base_i21k_patch16_224, "conv_projection/kernel")
projection_dim = projections.shape[-1]
patch_h, patch_w, patch_channels = projections.shape[:-1]

# Scale the projections.
scaled_projections = MinMaxScaler().fit_transform(
    projections.reshape(-1, projection_dim)
)

# Reshape the scaled projections so that the leading
# three dimensions resemble an image.
scaled_projections = scaled_projections.reshape(patch_h, patch_w, patch_channels, -1)

# Visualize the first 128 filters of the learned
# projections.
fig, axes = plt.subplots(nrows=8, ncols=16, figsize=(13, 8))
img_count = 0
limit = 128

for i in range(8):
    for j in range(16):
        if img_count < limit:
            axes[i, j].imshow(scaled_projections[..., img_count])
            axes[i, j].axis("off")
            img_count += 1

fig.tight_layout()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

检查图表

投影滤波器学习到了什么?

可视化后,卷积神经网络的核会显示它们在图像中寻找的模式。这可能是圆形,有时是线条——当它们组合在一起(在 ConvNet 的后期阶段)时,滤波器会转化为更复杂的形状。我们发现这种 ConvNet 核与 ViT 的投影滤波器之间存在显著的相似性。


方法五:可视化位置嵌入

Transformer 是置换不变的。这意味着它们不考虑输入 tokens 的空间位置。为了克服这一限制,我们在输入 tokens 中添加位置信息。

位置信息可以是学习到的位置嵌入,也可以是手工制作的固定嵌入。在本例中,Vision Transformer 的所有三个变体都使用了学习到的位置嵌入。

在本节中,我们可视化学习到的位置嵌入与其自身的相似性。下面,我们以 ViT B-16 模型为例,通过计算位置嵌入的内积来可视化它们的相似性。

position_embeddings = extract_weights(vit_base_i21k_patch16_224, "pos_embedding")

# Discard the batch dimension and the position embeddings of the
# cls token.
position_embeddings = position_embeddings.squeeze()[1:, ...]

similarity = position_embeddings @ position_embeddings.T
plt.imshow(similarity, cmap="inferno")
plt.show()

png

检查图表

位置嵌入告诉我们什么?

该图具有独特的对角线模式。主对角线最亮,表示一个位置与自身最为相似。一个有趣的模式是重复的对角线。这种重复模式描绘了一个正弦函数,其本质上接近于Vaswani 等人作为手工特征提出的方法。


注意

  • DINO 将注意力热力图生成过程扩展到了视频。我们还我们的 DINO 实现在一系列视频上应用,并获得了类似的结果。这是一个注意力热力图视频示例:

    dino

  • Raghu 等人使用了一系列技术来探究 Vision Transformer 学习到的表示,并与 ResNet 的表示进行了比较。我们强烈建议阅读他们的工作。
  • 为了编写本示例,我们开发了此仓库,以便引导读者轻松重现和扩展这些实验。
  • 在这方面您可能感兴趣的另一个仓库是vit-explain
  • 使用我们的 Hugging Face Spaces,您还可以使用自定义图像绘制注意力展开图和注意力热力图。
注意力热力图 注意力展开
Generic badge Generic badge

致谢