作者: Aritra Roy Gosthipaty, Sayak Paul(同等贡献)
创建日期 2022/04/12
最后修改 2023/11/20
描述: 研究不同 Vision Transformer 变体学习到的表征。
在本示例中,我们将研究不同 Vision Transformer (ViT) 模型学习到的表征。本示例的主要目标是深入了解是什么使 ViT 能够从图像数据中学习。特别是,该示例讨论了几个不同 ViT 分析工具的实现。
注意:当我们说“Vision Transformer”时,我们指的是包含 Transformer 块的计算机视觉架构(Vaswani 等人),而不一定是原始的 Vision Transformer 模型(Dosovitskiy 等人)。
自从最初的 Vision Transformer 出现以来,计算机视觉社区已经看到了许多不同的 ViT 变体,它们在各种方面对原始版本进行了改进:训练改进、架构改进等等。在本示例中,我们考虑以下 ViT 模型系列
由于预训练模型未在 Keras 中实现,我们首先尽可能忠实地实现了它们。然后,我们用官方预训练参数填充它们。最后,我们在 ImageNet-1k 验证集上评估了我们的实现,以确保评估数字与原始实现相匹配。我们的实现的详细信息可在此存储库中找到。
为了使示例简洁,我们不会将每个模型与分析方法详尽地配对。我们将在各个部分提供注释,以便您可以理解。
要在 Google Colab 上运行此示例,我们需要更新 gdown
库,如下所示
pip install -U gdown -q
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import zipfile
from io import BytesIO
import cv2
import matplotlib.pyplot as plt
import numpy as np
import requests
from PIL import Image
from sklearn.preprocessing import MinMaxScaler
import keras
from keras import ops
RESOLUTION = 224
PATCH_SIZE = 16
GITHUB_RELEASE = "https://github.com/sayakpaul/probing-vits/releases/download/v1.0.0/probing_vits.zip"
FNAME = "probing_vits.zip"
MODELS_ZIP = {
"vit_dino_base16": "Probing_ViTs/vit_dino_base16.zip",
"vit_b16_patch16_224": "Probing_ViTs/vit_b16_patch16_224.zip",
"vit_b16_patch16_224-i1k_pretrained": "Probing_ViTs/vit_b16_patch16_224-i1k_pretrained.zip",
}
对于原始 ViT 模型,输入图像需要缩放到 [-1, 1]
范围。对于开头提到的其他模型系列,我们需要使用 ImageNet-1k 训练集的通道均值和标准差来规范化图像。
crop_layer = keras.layers.CenterCrop(RESOLUTION, RESOLUTION)
norm_layer = keras.layers.Normalization(
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
)
rescale_layer = keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1)
def preprocess_image(image, model_type, size=RESOLUTION):
# Turn the image into a numpy array and add batch dim.
image = np.array(image)
image = ops.expand_dims(image, 0)
# If model type is vit rescale the image to [-1, 1].
if model_type == "original_vit":
image = rescale_layer(image)
# Resize the image using bicubic interpolation.
resize_size = int((256 / 224) * size)
image = ops.image.resize(image, (resize_size, resize_size), interpolation="bicubic")
# Crop the image.
image = crop_layer(image)
# If model type is DeiT or DINO normalize the image.
if model_type != "original_vit":
image = norm_layer(image)
return ops.convert_to_numpy(image)
def load_image_from_url(url, model_type):
# Credit: Willi Gierke
response = requests.get(url)
image = Image.open(BytesIO(response.content))
preprocessed_image = preprocess_image(image, model_type)
return image, preprocessed_image
# ImageNet-1k label mapping file and load it.
mapping_file = keras.utils.get_file(
origin="https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
)
with open(mapping_file, "r") as f:
lines = f.readlines()
imagenet_int_to_str = [line.rstrip() for line in lines]
img_url = "https://dl.fbaipublicfiles.com/dino/img.png"
image, preprocessed_image = load_image_from_url(img_url, model_type="original_vit")
plt.imshow(image)
plt.axis("off")
plt.show()
zip_path = keras.utils.get_file(
fname=FNAME,
origin=GITHUB_RELEASE,
)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall("./")
os.rename("Probing ViTs", "Probing_ViTs")
def load_model(model_path: str) -> keras.Model:
with zipfile.ZipFile(model_path, "r") as zip_ref:
zip_ref.extractall("Probing_ViTs/")
model_name = model_path.split(".")[0]
inputs = keras.Input((RESOLUTION, RESOLUTION, 3))
model = keras.layers.TFSMLayer(model_name, call_endpoint="serving_default")
outputs = model(inputs, training=False)
return keras.Model(inputs, outputs=outputs)
vit_base_i21k_patch16_224 = load_model(MODELS_ZIP["vit_b16_patch16_224-i1k_pretrained"])
print("Model loaded.")
Model loaded.
关于模型的更多信息:
此模型在 ImageNet-21k 数据集上进行了预训练,然后在 ImageNet-1k 数据集上进行了微调。要了解有关我们如何在 TensorFlow 中开发此模型(使用来自此来源的预训练权重)的更多信息,请参阅此笔记本。
我们现在使用加载的模型对我们的测试图像运行推理。
def split_prediction_and_attention_scores(outputs):
predictions = outputs["output_1"]
attention_score_dict = {}
for key, value in outputs.items():
if key.startswith("output_2_"):
attention_score_dict[key[len("output_2_") :]] = value
return predictions, attention_score_dict
predictions, attention_score_dict = split_prediction_and_attention_scores(
vit_base_i21k_patch16_224.predict(preprocessed_image)
)
predicted_label = imagenet_int_to_str[int(np.argmax(predictions))]
print(predicted_label)
1/1 ━━━━━━━━━━━━━━━━━━━━ 5s 5s/step
toucan
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1700526824.965785 75784 device_compiler.h:187] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
attention_score_dict
包含来自每个 Transformer 块的每个注意力头的注意力分数(softmaxed 输出)。
Dosovitskiy 等人和 Raghu 等人使用一种称为“平均注意力距离”的度量,该度量来自不同 Transformer 块的每个注意力头,以了解局部和全局信息如何流入 Vision Transformer。
平均注意力距离定义为查询标记与其他标记之间的距离乘以注意力权重。因此,对于单个图像
注意力分数是在图像以推理模式正向传递通过网络后在此处计算的。下图可能会帮助您更好地理解该过程。
此动画由 Ritwik Raha 创建。
def compute_distance_matrix(patch_size, num_patches, length):
distance_matrix = np.zeros((num_patches, num_patches))
for i in range(num_patches):
for j in range(num_patches):
if i == j: # zero distance
continue
xi, yi = (int(i / length)), (i % length)
xj, yj = (int(j / length)), (j % length)
distance_matrix[i, j] = patch_size * np.linalg.norm([xi - xj, yi - yj])
return distance_matrix
def compute_mean_attention_dist(patch_size, attention_weights, model_type):
num_cls_tokens = 2 if "distilled" in model_type else 1
# The attention_weights shape = (batch, num_heads, num_patches, num_patches)
attention_weights = attention_weights[
..., num_cls_tokens:, num_cls_tokens:
] # Removing the CLS token
num_patches = attention_weights.shape[-1]
length = int(np.sqrt(num_patches))
assert length**2 == num_patches, "Num patches is not perfect square"
distance_matrix = compute_distance_matrix(patch_size, num_patches, length)
h, w = distance_matrix.shape
distance_matrix = distance_matrix.reshape((1, 1, h, w))
# The attention_weights along the last axis adds to 1
# this is due to the fact that they are softmax of the raw logits
# summation of the (attention_weights * distance_matrix)
# should result in an average distance per token.
mean_distances = attention_weights * distance_matrix
mean_distances = np.sum(
mean_distances, axis=-1
) # Sum along last axis to get average distance per token
mean_distances = np.mean(
mean_distances, axis=-1
) # Now average across all the tokens
return mean_distances
感谢来自 Google 的 Simon Kornblith 帮助我们完成了这段代码片段。它可以在这里找到。现在,让我们使用这些工具来生成一个包含我们加载的模型和测试图像的注意力距离图。
# Build the mean distances for every Transformer block.
mean_distances = {
f"{name}_mean_dist": compute_mean_attention_dist(
patch_size=PATCH_SIZE,
attention_weights=attention_weight,
model_type="original_vit",
)
for name, attention_weight in attention_score_dict.items()
}
# Get the number of heads from the mean distance output.
num_heads = mean_distances["transformer_block_0_att_mean_dist"].shape[-1]
# Print the shapes
print(f"Num Heads: {num_heads}.")
plt.figure(figsize=(9, 9))
for idx in range(len(mean_distances)):
mean_distance = mean_distances[f"transformer_block_{idx}_att_mean_dist"]
x = [idx] * num_heads
y = mean_distance[0, :]
plt.scatter(x=x, y=y, label=f"transformer_block_{idx}")
plt.legend(loc="lower right")
plt.xlabel("Attention Head", fontsize=14)
plt.ylabel("Attention Distance", fontsize=14)
plt.title("vit_base_i21k_patch16_224", fontsize=14)
plt.grid()
plt.show()
Num Heads: 12.
自注意力在输入空间上的跨度如何?它们是局部还是全局地关注输入区域?
自注意力的优势在于能够学习上下文依赖关系,从而使模型可以关注与目标相关的最显著的输入区域。从上面的图中我们可以注意到,不同的注意力头会产生不同的注意力距离,这表明它们既使用了图像的局部信息,也使用了全局信息。但是,随着我们深入到 Transformer 模块中,注意力头往往更关注全局聚合信息。
受 Raghu 等人 的启发,我们计算了从 ImageNet-1k 验证集中随机抽取的 1000 张图像的平均注意力距离,并对开头提到的所有模型重复了这个过程。有趣的是,我们注意到以下几点:
在 ImageNet-21k 上预训练 在 ImageNet-1k 上微调 |
在 ImageNet-1k 上预训练 |
---|---|
无提炼 (来自 DeiT 的 ViT B-16) | 来自 DeiT 的提炼 ViT B-16 |
---|---|
要重现这些图表,请参阅这个笔记本。
Abnar 等人引入了“注意力展开”,用于量化信息如何通过 Transformer 模块的自注意力层流动。原始 ViT 作者使用此方法来研究学习到的表示,声明:
简而言之,我们对 ViTL/16 的所有注意力头求平均注意力权重,然后递归地乘以所有层的权重矩阵。这解释了注意力在所有层中跨标记的混合。
我们使用了这个笔记本,并修改了其中的注意力展开代码,以使其与我们的模型兼容。
def attention_rollout_map(image, attention_score_dict, model_type):
num_cls_tokens = 2 if "distilled" in model_type else 1
# Stack the individual attention matrices from individual Transformer blocks.
attn_mat = ops.stack([attention_score_dict[k] for k in attention_score_dict.keys()])
attn_mat = ops.squeeze(attn_mat, axis=1)
# Average the attention weights across all heads.
attn_mat = ops.mean(attn_mat, axis=1)
# To account for residual connections, we add an identity matrix to the
# attention matrix and re-normalize the weights.
residual_attn = ops.eye(attn_mat.shape[1])
aug_attn_mat = attn_mat + residual_attn
aug_attn_mat = aug_attn_mat / ops.sum(aug_attn_mat, axis=-1)[..., None]
aug_attn_mat = ops.convert_to_numpy(aug_attn_mat)
# Recursively multiply the weight matrices.
joint_attentions = np.zeros(aug_attn_mat.shape)
joint_attentions[0] = aug_attn_mat[0]
for n in range(1, aug_attn_mat.shape[0]):
joint_attentions[n] = np.matmul(aug_attn_mat[n], joint_attentions[n - 1])
# Attention from the output token to the input space.
v = joint_attentions[-1]
grid_size = int(np.sqrt(aug_attn_mat.shape[-1]))
mask = v[0, num_cls_tokens:].reshape(grid_size, grid_size)
mask = cv2.resize(mask / mask.max(), image.size)[..., np.newaxis]
result = (mask * image).astype("uint8")
return result
现在,让我们使用这些工具,根据我们在“使用模型运行常规推理”部分中的先前结果生成注意力图。以下是下载每个单独模型的链接:
attn_rollout_result = attention_rollout_map(
image, attention_score_dict, model_type="original_vit"
)
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(8, 10))
fig.suptitle(f"Predicted label: {predicted_label}.", fontsize=20)
_ = ax1.imshow(image)
_ = ax2.imshow(attn_rollout_result)
ax1.set_title("Input Image", fontsize=16)
ax2.set_title("Attention Map", fontsize=16)
ax1.axis("off")
ax2.axis("off")
fig.tight_layout()
fig.subplots_adjust(top=1.35)
fig.show()
我们如何量化通过注意力层传播的信息流?
我们注意到该模型能够将注意力集中在输入图像的显著部分。我们鼓励您将此方法应用于我们提到的其他模型,并比较结果。注意力展开图会因模型训练的任务和增强而异。我们观察到 DeiT 具有最佳的展开图,这可能是由于其增强方案。
一种简单但有用的方法来探测视觉 Transformer 的表示是将注意力图叠加在输入图像上可视化。这有助于形成关于模型关注内容的直觉。我们为此目的使用了 DINO 模型,因为它产生了更好的注意力热图。
# Load the model.
vit_dino_base16 = load_model(MODELS_ZIP["vit_dino_base16"])
print("Model loaded.")
# Preprocess the same image but with normlization.
img_url = "https://dl.fbaipublicfiles.com/dino/img.png"
image, preprocessed_image = load_image_from_url(img_url, model_type="dino")
# Grab the predictions.
predictions, attention_score_dict = split_prediction_and_attention_scores(
vit_dino_base16.predict(preprocessed_image)
)
Model loaded.
1/1 ━━━━━━━━━━━━━━━━━━━━ 4s 4s/step
一个 Transformer 模块由多个头组成。Transformer 模块中的每个头将输入数据投影到不同的子空间。这有助于每个单独的头关注图像的不同部分。因此,单独可视化每个注意力头图,以了解每个头所关注的内容是有意义的。
笔记:
def attention_heatmap(attention_score_dict, image, model_type="dino"):
num_tokens = 2 if "distilled" in model_type else 1
# Sort the Transformer blocks in order of their depth.
attention_score_list = list(attention_score_dict.keys())
attention_score_list.sort(key=lambda x: int(x.split("_")[-2]), reverse=True)
# Process the attention maps for overlay.
w_featmap = image.shape[2] // PATCH_SIZE
h_featmap = image.shape[1] // PATCH_SIZE
attention_scores = attention_score_dict[attention_score_list[0]]
# Taking the representations from CLS token.
attentions = attention_scores[0, :, 0, num_tokens:].reshape(num_heads, -1)
# Reshape the attention scores to resemble mini patches.
attentions = attentions.reshape(num_heads, w_featmap, h_featmap)
attentions = attentions.transpose((1, 2, 0))
# Resize the attention patches to 224x224 (224: 14x16).
attentions = ops.image.resize(
attentions, size=(h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE)
)
return attentions
我们可以使用与 DINO 推理相同的图像,以及我们从结果中提取的 attention_score_dict
。
# De-normalize the image for visual clarity.
in1k_mean = np.array([0.485 * 255, 0.456 * 255, 0.406 * 255])
in1k_std = np.array([0.229 * 255, 0.224 * 255, 0.225 * 255])
preprocessed_img_orig = (preprocessed_image * in1k_std) + in1k_mean
preprocessed_img_orig = preprocessed_img_orig / 255.0
preprocessed_img_orig = ops.convert_to_numpy(ops.clip(preprocessed_img_orig, 0.0, 1.0))
# Generate the attention heatmaps.
attentions = attention_heatmap(attention_score_dict, preprocessed_img_orig)
# Plot the maps.
fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(13, 13))
img_count = 0
for i in range(3):
for j in range(4):
if img_count < len(attentions):
axes[i, j].imshow(preprocessed_img_orig[0])
axes[i, j].imshow(attentions[..., img_count], cmap="inferno", alpha=0.6)
axes[i, j].title.set_text(f"Attention head: {img_count}")
axes[i, j].axis("off")
img_count += 1
我们如何定性评估注意力权重?
Transformer 模块的注意力权重是在键和查询之间计算的。权重量化了键对于查询的重要性。在 ViT 中,键和查询来自同一图像,因此权重决定了图像的哪个部分重要。
将叠加在图像上的注意力权重绘制出来,可以让我们对 Transformer 重要的图像部分有一个很好的直觉。该图定性地评估了注意力权重的目的。
提取非重叠的补丁后,ViT 会将这些补丁在其空间维度上展平,然后对其进行线性投影。人们可能会想,这些投影看起来像什么?下面,我们采用 ViT B-16 模型并可视化其学习到的投影。
def extract_weights(model, name):
for variable in model.weights:
if variable.name.startswith(name):
return variable.numpy()
# Extract the projections.
projections = extract_weights(vit_base_i21k_patch16_224, "conv_projection/kernel")
projection_dim = projections.shape[-1]
patch_h, patch_w, patch_channels = projections.shape[:-1]
# Scale the projections.
scaled_projections = MinMaxScaler().fit_transform(
projections.reshape(-1, projection_dim)
)
# Reshape the scaled projections so that the leading
# three dimensions resemble an image.
scaled_projections = scaled_projections.reshape(patch_h, patch_w, patch_channels, -1)
# Visualize the first 128 filters of the learned
# projections.
fig, axes = plt.subplots(nrows=8, ncols=16, figsize=(13, 8))
img_count = 0
limit = 128
for i in range(8):
for j in range(16):
if img_count < limit:
axes[i, j].imshow(scaled_projections[..., img_count])
axes[i, j].axis("off")
img_count += 1
fig.tight_layout()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
投影过滤器学习什么?
当可视化时,卷积神经网络的内核会显示它们在图像中寻找的模式。这可能是圆形,有时是线条——当组合在一起时(在卷积神经网络的后期),过滤器会转换为更复杂的形状。我们发现这种卷积神经网络内核与 ViT 的投影过滤器之间存在惊人的相似之处。
Transformer 是置换不变的。这意味着它们不考虑输入标记的空间位置。为了克服这一限制,我们将位置信息添加到输入标记中。
位置信息可以是学习到的位置嵌入或手工制作的恒定嵌入的形式。在我们的例子中,所有三种 ViT 变体都具有学习到的位置嵌入。
在本节中,我们将可视化学习到的位置嵌入本身之间的相似性。下面,我们采用 ViT B-16 模型,并通过取它们的点积来可视化位置嵌入的相似性。
position_embeddings = extract_weights(vit_base_i21k_patch16_224, "pos_embedding")
# Discard the batch dimension and the position embeddings of the
# cls token.
position_embeddings = position_embeddings.squeeze()[1:, ...]
similarity = position_embeddings @ position_embeddings.T
plt.imshow(similarity, cmap="inferno")
plt.show()
位置嵌入告诉我们什么?
该图具有独特的对角线模式。主对角线最亮,表示一个位置与自身最相似。一个有趣的模式是重复的对角线。重复的模式描绘了一个正弦函数,这在本质上接近于 Vaswani 等人提出的作为手工制作的特征。
DINO 将注意力热图生成过程扩展到了视频。我们还在一系列视频中应用了我们的 DINO 实现,并获得了类似的结果。这是一个关于注意力热图的视频:
vit-explain
。注意力热图 | 注意力展开 |
---|---|