作者: Sayak Paul
创建日期 2022/04/05
最后修改日期 2022/04/08
描述: 通过注意力蒸馏视觉 Transformer。
在最初的视觉 Transformer (ViT) 论文 (Dosovitskiy 等人) 中,作者得出结论,为了达到与卷积神经网络 (CNN) 相当的性能,ViT 需要在更大的数据集上进行预训练。数据集越大越好。这主要是由于 ViT 架构中缺乏归纳偏置——与 CNN 不同,它们没有利用局部性的层。在后续论文 (Steiner 等人) 中,作者表明,通过更强的正则化和更长的训练,可以显著提高 ViT 的性能。
许多研究团队提出了不同的方法来解决 ViT 训练的数据密集型问题。一种这样的方法在 数据高效的图像 Transformer (DeiT) 论文 (Touvron 等人) 中展示。作者引入了一种特定于基于 Transformer 的视觉模型的蒸馏技术。DeiT 是首批表明无需使用更大的数据集即可良好训练 ViT 的工作之一。
在此示例中,我们实现了 DeiT 中提出的蒸馏方法。这需要我们稍微调整原始的 ViT 架构,并编写自定义的训练循环来实现蒸馏方法。
要运行此示例,您需要 TensorFlow Addons,您可以使用以下命令安装它
pip install tensorflow-addons
为了方便地浏览此示例,您应该了解 ViT 和知识蒸馏的工作原理。以下是一些有用的资源,以防您需要复习
from typing import List
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
import tensorflow_hub as hub
from tensorflow import keras
from tensorflow.keras import layers
tfds.disable_progress_bar()
tf.keras.utils.set_random_seed(42)
# Model
MODEL_TYPE = "deit_distilled_tiny_patch16_224"
RESOLUTION = 224
PATCH_SIZE = 16
NUM_PATCHES = (RESOLUTION // PATCH_SIZE) ** 2
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 192
NUM_HEADS = 3
NUM_LAYERS = 12
MLP_UNITS = [
PROJECTION_DIM * 4,
PROJECTION_DIM,
]
DROPOUT_RATE = 0.0
DROP_PATH_RATE = 0.1
# Training
NUM_EPOCHS = 20
BASE_LR = 0.0005
WEIGHT_DECAY = 0.0001
# Data
BATCH_SIZE = 256
AUTO = tf.data.AUTOTUNE
NUM_CLASSES = 5
您可能注意到 DROPOUT_RATE
已设置为 0.0。Dropout 已在实现中使用以保持完整性。对于较小的模型(如本示例中使用的模型),您不需要它,但对于较大的模型,使用 dropout 会有所帮助。
tf_flowers
数据集并准备预处理实用程序作者使用了一系列不同的增强技术,包括 MixUp (Zhang 等人)、RandAugment (Cubuk 等人) 等。然而,为了使示例易于理解,我们将放弃它们。
def preprocess_dataset(is_training=True):
def fn(image, label):
if is_training:
# Resize to a bigger spatial resolution and take the random
# crops.
image = tf.image.resize(image, (RESOLUTION + 20, RESOLUTION + 20))
image = tf.image.random_crop(image, (RESOLUTION, RESOLUTION, 3))
image = tf.image.random_flip_left_right(image)
else:
image = tf.image.resize(image, (RESOLUTION, RESOLUTION))
label = tf.one_hot(label, depth=NUM_CLASSES)
return image, label
return fn
def prepare_dataset(dataset, is_training=True):
if is_training:
dataset = dataset.shuffle(BATCH_SIZE * 10)
dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=AUTO)
return dataset.batch(BATCH_SIZE).prefetch(AUTO)
train_dataset, val_dataset = tfds.load(
"tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
)
num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
print(f"Number of training examples: {num_train}")
print(f"Number of validation examples: {num_val}")
train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)
Number of training examples: 3303
Number of validation examples: 367
由于 DeiT 是 ViT 的扩展,因此首先实现 ViT,然后扩展它以支持 DeiT 的组件是有意义的。
首先,我们将实现一个用于随机深度 (Huang 等人) 的层,它在 DeiT 中用于正则化。
# Referred from: github.com:rwightman/pytorch-image-models.
class StochasticDepth(layers.Layer):
def __init__(self, drop_prop, **kwargs):
super().__init__(**kwargs)
self.drop_prob = drop_prop
def call(self, x, training=True):
if training:
keep_prob = 1 - self.drop_prob
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
random_tensor = tf.floor(random_tensor)
return (x / keep_prob) * random_tensor
return x
现在,我们将实现 MLP 和 Transformer 块。
def mlp(x, dropout_rate: float, hidden_units: List):
"""FFN for a Transformer block."""
# Iterate over the hidden units and
# add Dense => Dropout.
for (idx, units) in enumerate(hidden_units):
x = layers.Dense(
units,
activation=tf.nn.gelu if idx == 0 else None,
)(x)
x = layers.Dropout(dropout_rate)(x)
return x
def transformer(drop_prob: float, name: str) -> keras.Model:
"""Transformer block with pre-norm."""
num_patches = NUM_PATCHES + 2 if "distilled" in MODEL_TYPE else NUM_PATCHES + 1
encoded_patches = layers.Input((num_patches, PROJECTION_DIM))
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
# Multi Head Self Attention layer 1.
attention_output = layers.MultiHeadAttention(
num_heads=NUM_HEADS,
key_dim=PROJECTION_DIM,
dropout=DROPOUT_RATE,
)(x1, x1)
attention_output = (
StochasticDepth(drop_prob)(attention_output) if drop_prob else attention_output
)
# Skip connection 1.
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)
# MLP layer 1.
x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=DROPOUT_RATE)
x4 = StochasticDepth(drop_prob)(x4) if drop_prob else x4
# Skip connection 2.
outputs = layers.Add()([x2, x4])
return keras.Model(encoded_patches, outputs, name=name)
现在,我们将基于我们刚刚开发的组件来实现一个 ViTClassifier
类。在这里,我们将遵循 ViT 论文中使用的原始池化策略——使用一个类别标记,并使用与其对应的特征表示进行分类。
class ViTClassifier(keras.Model):
"""Vision Transformer base class."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Patchify + linear projection + reshaping.
self.projection = keras.Sequential(
[
layers.Conv2D(
filters=PROJECTION_DIM,
kernel_size=(PATCH_SIZE, PATCH_SIZE),
strides=(PATCH_SIZE, PATCH_SIZE),
padding="VALID",
name="conv_projection",
),
layers.Reshape(
target_shape=(NUM_PATCHES, PROJECTION_DIM),
name="flatten_projection",
),
],
name="projection",
)
# Positional embedding.
init_shape = (
1,
NUM_PATCHES + 1,
PROJECTION_DIM,
)
self.positional_embedding = tf.Variable(
tf.zeros(init_shape), name="pos_embedding"
)
# Transformer blocks.
dpr = [x for x in tf.linspace(0.0, DROP_PATH_RATE, NUM_LAYERS)]
self.transformer_blocks = [
transformer(drop_prob=dpr[i], name=f"transformer_block_{i}")
for i in range(NUM_LAYERS)
]
# CLS token.
initial_value = tf.zeros((1, 1, PROJECTION_DIM))
self.cls_token = tf.Variable(
initial_value=initial_value, trainable=True, name="cls"
)
# Other layers.
self.dropout = layers.Dropout(DROPOUT_RATE)
self.layer_norm = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)
self.head = layers.Dense(
NUM_CLASSES,
name="classification_head",
)
def call(self, inputs, training=True):
n = tf.shape(inputs)[0]
# Create patches and project the patches.
projected_patches = self.projection(inputs)
# Append class token if needed.
cls_token = tf.tile(self.cls_token, (n, 1, 1))
cls_token = tf.cast(cls_token, projected_patches.dtype)
projected_patches = tf.concat([cls_token, projected_patches], axis=1)
# Add positional embeddings to the projected patches.
encoded_patches = (
self.positional_embedding + projected_patches
) # (B, number_patches, projection_dim)
encoded_patches = self.dropout(encoded_patches)
# Iterate over the number of layers and stack up blocks of
# Transformer.
for transformer_module in self.transformer_blocks:
# Add a Transformer block.
encoded_patches = transformer_module(encoded_patches)
# Final layer normalization.
representation = self.layer_norm(encoded_patches)
# Pool representation.
encoded_patches = representation[:, 0]
# Classification head.
output = self.head(encoded_patches)
return output
这个类可以作为独立的 ViT 使用,并且是端到端可训练的。只需删除 MODEL_TYPE
中的 distilled
短语,它就可以与 vit_tiny = ViTClassifier()
一起工作。现在让我们将其扩展到 DeiT。下图展示了 DeiT 的示意图(取自 DeiT 论文)。
除了类别标记之外,DeiT 还有另一个用于蒸馏的标记。在蒸馏过程中,与类别标记对应的 logits 会与真实标签进行比较,而与蒸馏标记对应的 logits 会与教师的预测进行比较。
class ViTDistilled(ViTClassifier):
def __init__(self, regular_training=False, **kwargs):
super().__init__(**kwargs)
self.num_tokens = 2
self.regular_training = regular_training
# CLS and distillation tokens, positional embedding.
init_value = tf.zeros((1, 1, PROJECTION_DIM))
self.dist_token = tf.Variable(init_value, name="dist_token")
self.positional_embedding = tf.Variable(
tf.zeros(
(
1,
NUM_PATCHES + self.num_tokens,
PROJECTION_DIM,
)
),
name="pos_embedding",
)
# Head layers.
self.head = layers.Dense(
NUM_CLASSES,
name="classification_head",
)
self.head_dist = layers.Dense(
NUM_CLASSES,
name="distillation_head",
)
def call(self, inputs, training=True):
n = tf.shape(inputs)[0]
# Create patches and project the patches.
projected_patches = self.projection(inputs)
# Append the tokens.
cls_token = tf.tile(self.cls_token, (n, 1, 1))
dist_token = tf.tile(self.dist_token, (n, 1, 1))
cls_token = tf.cast(cls_token, projected_patches.dtype)
dist_token = tf.cast(dist_token, projected_patches.dtype)
projected_patches = tf.concat(
[cls_token, dist_token, projected_patches], axis=1
)
# Add positional embeddings to the projected patches.
encoded_patches = (
self.positional_embedding + projected_patches
) # (B, number_patches, projection_dim)
encoded_patches = self.dropout(encoded_patches)
# Iterate over the number of layers and stack up blocks of
# Transformer.
for transformer_module in self.transformer_blocks:
# Add a Transformer block.
encoded_patches = transformer_module(encoded_patches)
# Final layer normalization.
representation = self.layer_norm(encoded_patches)
# Classification heads.
x, x_dist = (
self.head(representation[:, 0]),
self.head_dist(representation[:, 1]),
)
if not training or self.regular_training:
# During standard train / finetune, inference average the classifier
# predictions.
return (x + x_dist) / 2
elif training:
# Only return separate classification predictions when training in distilled
# mode.
return x, x_dist
让我们验证一下 ViTDistilled
类是否可以按预期进行初始化和调用。
deit_tiny_distilled = ViTDistilled()
dummy_inputs = tf.ones((2, 224, 224, 3))
outputs = deit_tiny_distilled(dummy_inputs, training=False)
print(outputs.shape)
(2, 5)
与标准知识蒸馏(Hinton 等人)中使用温度缩放的 softmax 以及 KL 散度不同,DeiT 作者使用了以下损失函数
这里,
psi
是 softmax 函数class DeiT(keras.Model):
# Reference:
# https://keras.org.cn/examples/vision/knowledge_distillation/
def __init__(self, student, teacher, **kwargs):
super().__init__(**kwargs)
self.student = student
self.teacher = teacher
self.student_loss_tracker = keras.metrics.Mean(name="student_loss")
self.dist_loss_tracker = keras.metrics.Mean(name="distillation_loss")
@property
def metrics(self):
metrics = super().metrics
metrics.append(self.student_loss_tracker)
metrics.append(self.dist_loss_tracker)
return metrics
def compile(
self,
optimizer,
metrics,
student_loss_fn,
distillation_loss_fn,
):
super().compile(optimizer=optimizer, metrics=metrics)
self.student_loss_fn = student_loss_fn
self.distillation_loss_fn = distillation_loss_fn
def train_step(self, data):
# Unpack data.
x, y = data
# Forward pass of teacher
teacher_predictions = tf.nn.softmax(self.teacher(x, training=False), -1)
teacher_predictions = tf.argmax(teacher_predictions, -1)
with tf.GradientTape() as tape:
# Forward pass of student.
cls_predictions, dist_predictions = self.student(x / 255.0, training=True)
# Compute losses.
student_loss = self.student_loss_fn(y, cls_predictions)
distillation_loss = self.distillation_loss_fn(
teacher_predictions, dist_predictions
)
loss = (student_loss + distillation_loss) / 2
# Compute gradients.
trainable_vars = self.student.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights.
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update the metrics configured in `compile()`.
student_predictions = (cls_predictions + dist_predictions) / 2
self.compiled_metrics.update_state(y, student_predictions)
self.dist_loss_tracker.update_state(distillation_loss)
self.student_loss_tracker.update_state(student_loss)
# Return a dict of performance.
results = {m.name: m.result() for m in self.metrics}
return results
def test_step(self, data):
# Unpack the data.
x, y = data
# Compute predictions.
y_prediction = self.student(x / 255.0, training=False)
# Calculate the loss.
student_loss = self.student_loss_fn(y, y_prediction)
# Update the metrics.
self.compiled_metrics.update_state(y, y_prediction)
self.student_loss_tracker.update_state(student_loss)
# Return a dict of performance.
results = {m.name: m.result() for m in self.metrics}
return results
def call(self, inputs):
return self.student(inputs / 255.0, training=False)
该模型基于 ResNet 的 BiT 系列(Kolesnikov 等人),并在 tf_flowers
数据集上进行了微调。你可以参考这个笔记本,了解如何进行训练。教师模型约有 2.12 亿个参数,比学生模型多 40 倍。
!wget -q https://github.com/sayakpaul/deit-tf/releases/download/v0.1.0/bit_teacher_flowers.zip
!unzip -q bit_teacher_flowers.zip
bit_teacher_flowers = keras.models.load_model("bit_teacher_flowers")
deit_tiny = ViTDistilled()
deit_distiller = DeiT(student=deit_tiny, teacher=bit_teacher_flowers)
lr_scaled = (BASE_LR / 512) * BATCH_SIZE
deit_distiller.compile(
optimizer=tfa.optimizers.AdamW(weight_decay=WEIGHT_DECAY, learning_rate=lr_scaled),
metrics=["accuracy"],
student_loss_fn=keras.losses.CategoricalCrossentropy(
from_logits=True, label_smoothing=0.1
),
distillation_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
_ = deit_distiller.fit(train_dataset, validation_data=val_dataset, epochs=NUM_EPOCHS)
Epoch 1/20
13/13 [==============================] - 44s 2s/step - accuracy: 0.2343 - student_loss: 2.2630 - distillation_loss: 1.7818 - val_accuracy: 0.2234 - val_student_loss: 1.6622 - val_distillation_loss: 0.0000e+00
Epoch 2/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.2150 - student_loss: 1.6377 - distillation_loss: 1.6138 - val_accuracy: 0.1907 - val_student_loss: 1.6150 - val_distillation_loss: 0.0000e+00
Epoch 3/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.2552 - student_loss: 1.6073 - distillation_loss: 1.5970 - val_accuracy: 0.1907 - val_student_loss: 1.6093 - val_distillation_loss: 0.0000e+00
Epoch 4/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.2564 - student_loss: 1.5954 - distillation_loss: 1.5902 - val_accuracy: 0.2997 - val_student_loss: 1.5958 - val_distillation_loss: 0.0000e+00
Epoch 5/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.2922 - student_loss: 1.5839 - distillation_loss: 1.5704 - val_accuracy: 0.3488 - val_student_loss: 1.5635 - val_distillation_loss: 0.0000e+00
Epoch 6/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.3815 - student_loss: 1.4865 - distillation_loss: 1.4551 - val_accuracy: 0.3815 - val_student_loss: 1.4975 - val_distillation_loss: 0.0000e+00
Epoch 7/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.4151 - student_loss: 1.4027 - distillation_loss: 1.3441 - val_accuracy: 0.3733 - val_student_loss: 1.4083 - val_distillation_loss: 0.0000e+00
Epoch 8/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.4423 - student_loss: 1.3616 - distillation_loss: 1.2877 - val_accuracy: 0.4005 - val_student_loss: 1.4014 - val_distillation_loss: 0.0000e+00
Epoch 9/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.4475 - student_loss: 1.3095 - distillation_loss: 1.2200 - val_accuracy: 0.4496 - val_student_loss: 1.3211 - val_distillation_loss: 0.0000e+00
Epoch 10/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.4959 - student_loss: 1.2638 - distillation_loss: 1.1508 - val_accuracy: 0.4932 - val_student_loss: 1.2839 - val_distillation_loss: 0.0000e+00
Epoch 11/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.5431 - student_loss: 1.2063 - distillation_loss: 1.0948 - val_accuracy: 0.5559 - val_student_loss: 1.1938 - val_distillation_loss: 0.0000e+00
Epoch 12/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.5771 - student_loss: 1.1742 - distillation_loss: 1.0461 - val_accuracy: 0.5695 - val_student_loss: 1.1362 - val_distillation_loss: 0.0000e+00
Epoch 13/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.5601 - student_loss: 1.1724 - distillation_loss: 1.0457 - val_accuracy: 0.5477 - val_student_loss: 1.1929 - val_distillation_loss: 0.0000e+00
Epoch 14/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.5777 - student_loss: 1.1717 - distillation_loss: 1.0378 - val_accuracy: 0.5777 - val_student_loss: 1.1171 - val_distillation_loss: 0.0000e+00
Epoch 15/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6173 - student_loss: 1.1232 - distillation_loss: 0.9782 - val_accuracy: 0.5640 - val_student_loss: 1.1229 - val_distillation_loss: 0.0000e+00
Epoch 16/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6237 - student_loss: 1.1091 - distillation_loss: 0.9627 - val_accuracy: 0.5886 - val_student_loss: 1.1371 - val_distillation_loss: 0.0000e+00
Epoch 17/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6261 - student_loss: 1.0880 - distillation_loss: 0.9341 - val_accuracy: 0.6322 - val_student_loss: 1.0972 - val_distillation_loss: 0.0000e+00
Epoch 18/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6427 - student_loss: 1.0688 - distillation_loss: 0.9117 - val_accuracy: 0.6431 - val_student_loss: 1.0548 - val_distillation_loss: 0.0000e+00
Epoch 19/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6458 - student_loss: 1.0529 - distillation_loss: 0.8903 - val_accuracy: 0.6076 - val_student_loss: 1.0761 - val_distillation_loss: 0.0000e+00
Epoch 20/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6382 - student_loss: 1.0641 - distillation_loss: 0.9049 - val_accuracy: 0.6240 - val_student_loss: 1.0521 - val_distillation_loss: 0.0000e+00
如果我们使用完全相同的超参数从头开始训练相同的模型 (ViTClassifier
),该模型将获得大约 59% 的准确率。你可以调整以下代码来重现这个结果。
vit_tiny = ViTClassifier()
inputs = keras.Input((RESOLUTION, RESOLUTION, 3))
x = keras.layers.Rescaling(scale=1./255)(inputs)
outputs = deit_tiny(x)
model = keras.Model(inputs, outputs)
model.compile(...)
model.fit(...)
timm
更新并提供可读的实现。我在 TensorFlow 中实现 ViT 和 DeiT 时参考了很多他们的实现。ViTClassifier
的一些部分。HuggingFace 上的示例
训练好的模型 | 演示 |
---|---|