作者: Sayak Paul
创建日期 2022/04/05
上次修改日期 2022/04/08
描述:通过注意力蒸馏视觉Transformer。
在最初的视觉Transformer (ViT) 论文 (Dosovitskiy 等人) 中,作者得出结论,为了与卷积神经网络 (CNN) 的性能相当,ViT 需要在更大的数据集上进行预训练。越大越好。这主要是由于 ViT 架构缺乏归纳偏差——与 CNN 不同,它们没有利用局部性的层。在后续论文 (Steiner 等人) 中,作者表明,通过更强的正则化和更长的训练可以大幅提高 ViT 的性能。
许多团队提出了不同的方法来解决 ViT 训练的数据密集型问题。其中一种方法在数据高效图像Transformer (DeiT) 论文 (Touvron 等人) 中有所体现。作者介绍了一种特定于基于Transformer的视觉模型的蒸馏技术。DeiT 是首批表明无需使用更大的数据集即可很好地训练 ViT 的工作之一。
在本例中,我们实现了 DeiT 中提出的蒸馏方法。这要求我们稍微调整原始 ViT 架构并编写自定义训练循环来实现蒸馏方法。
要运行此示例,您需要 TensorFlow Addons,您可以使用以下命令进行安装
pip install tensorflow-addons
为了方便浏览此示例,您需要了解 ViT 和知识蒸馏的工作原理。如果您需要复习,以下是一些好的资源
from typing import List
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
import tensorflow_hub as hub
from tensorflow import keras
from tensorflow.keras import layers
tfds.disable_progress_bar()
tf.keras.utils.set_random_seed(42)
# Model
MODEL_TYPE = "deit_distilled_tiny_patch16_224"
RESOLUTION = 224
PATCH_SIZE = 16
NUM_PATCHES = (RESOLUTION // PATCH_SIZE) ** 2
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 192
NUM_HEADS = 3
NUM_LAYERS = 12
MLP_UNITS = [
PROJECTION_DIM * 4,
PROJECTION_DIM,
]
DROPOUT_RATE = 0.0
DROP_PATH_RATE = 0.1
# Training
NUM_EPOCHS = 20
BASE_LR = 0.0005
WEIGHT_DECAY = 0.0001
# Data
BATCH_SIZE = 256
AUTO = tf.data.AUTOTUNE
NUM_CLASSES = 5
您可能注意到 DROPOUT_RATE
已设置为 0.0。在实现中使用了 Dropout 以使其完整。对于较小的模型(如本示例中使用的模型),您不需要它,但对于较大的模型,使用 Dropout 会有所帮助。
tf_flowers
数据集并准备预处理工具作者使用一系列不同的增强技术,包括 MixUp (Zhang 等人)、RandAugment (Cubuk 等人) 等。但是,为了使示例易于理解,我们将丢弃它们。
def preprocess_dataset(is_training=True):
def fn(image, label):
if is_training:
# Resize to a bigger spatial resolution and take the random
# crops.
image = tf.image.resize(image, (RESOLUTION + 20, RESOLUTION + 20))
image = tf.image.random_crop(image, (RESOLUTION, RESOLUTION, 3))
image = tf.image.random_flip_left_right(image)
else:
image = tf.image.resize(image, (RESOLUTION, RESOLUTION))
label = tf.one_hot(label, depth=NUM_CLASSES)
return image, label
return fn
def prepare_dataset(dataset, is_training=True):
if is_training:
dataset = dataset.shuffle(BATCH_SIZE * 10)
dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=AUTO)
return dataset.batch(BATCH_SIZE).prefetch(AUTO)
train_dataset, val_dataset = tfds.load(
"tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
)
num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
print(f"Number of training examples: {num_train}")
print(f"Number of validation examples: {num_val}")
train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)
Number of training examples: 3303
Number of validation examples: 367
由于 DeiT 是 ViT 的扩展,因此首先实现 ViT 然后将其扩展以支持 DeiT 的组件是有意义的。
首先,我们将实现随机深度 (Huang 等人) 的一个层,它在 DeiT 中用于正则化。
# Referred from: github.com:rwightman/pytorch-image-models.
class StochasticDepth(layers.Layer):
def __init__(self, drop_prop, **kwargs):
super().__init__(**kwargs)
self.drop_prob = drop_prop
def call(self, x, training=True):
if training:
keep_prob = 1 - self.drop_prob
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
random_tensor = tf.floor(random_tensor)
return (x / keep_prob) * random_tensor
return x
现在,我们将实现 MLP 和 Transformer 块。
def mlp(x, dropout_rate: float, hidden_units: List):
"""FFN for a Transformer block."""
# Iterate over the hidden units and
# add Dense => Dropout.
for (idx, units) in enumerate(hidden_units):
x = layers.Dense(
units,
activation=tf.nn.gelu if idx == 0 else None,
)(x)
x = layers.Dropout(dropout_rate)(x)
return x
def transformer(drop_prob: float, name: str) -> keras.Model:
"""Transformer block with pre-norm."""
num_patches = NUM_PATCHES + 2 if "distilled" in MODEL_TYPE else NUM_PATCHES + 1
encoded_patches = layers.Input((num_patches, PROJECTION_DIM))
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
# Multi Head Self Attention layer 1.
attention_output = layers.MultiHeadAttention(
num_heads=NUM_HEADS,
key_dim=PROJECTION_DIM,
dropout=DROPOUT_RATE,
)(x1, x1)
attention_output = (
StochasticDepth(drop_prob)(attention_output) if drop_prob else attention_output
)
# Skip connection 1.
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)
# MLP layer 1.
x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=DROPOUT_RATE)
x4 = StochasticDepth(drop_prob)(x4) if drop_prob else x4
# Skip connection 2.
outputs = layers.Add()([x2, x4])
return keras.Model(encoded_patches, outputs, name=name)
我们现在将实现一个 ViTClassifier
类,该类建立在我们刚刚开发的组件之上。在这里,我们将遵循 ViT 论文中使用的原始池化策略——使用类标记并使用与其对应的特征表示进行分类。
class ViTClassifier(keras.Model):
"""Vision Transformer base class."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Patchify + linear projection + reshaping.
self.projection = keras.Sequential(
[
layers.Conv2D(
filters=PROJECTION_DIM,
kernel_size=(PATCH_SIZE, PATCH_SIZE),
strides=(PATCH_SIZE, PATCH_SIZE),
padding="VALID",
name="conv_projection",
),
layers.Reshape(
target_shape=(NUM_PATCHES, PROJECTION_DIM),
name="flatten_projection",
),
],
name="projection",
)
# Positional embedding.
init_shape = (
1,
NUM_PATCHES + 1,
PROJECTION_DIM,
)
self.positional_embedding = tf.Variable(
tf.zeros(init_shape), name="pos_embedding"
)
# Transformer blocks.
dpr = [x for x in tf.linspace(0.0, DROP_PATH_RATE, NUM_LAYERS)]
self.transformer_blocks = [
transformer(drop_prob=dpr[i], name=f"transformer_block_{i}")
for i in range(NUM_LAYERS)
]
# CLS token.
initial_value = tf.zeros((1, 1, PROJECTION_DIM))
self.cls_token = tf.Variable(
initial_value=initial_value, trainable=True, name="cls"
)
# Other layers.
self.dropout = layers.Dropout(DROPOUT_RATE)
self.layer_norm = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)
self.head = layers.Dense(
NUM_CLASSES,
name="classification_head",
)
def call(self, inputs, training=True):
n = tf.shape(inputs)[0]
# Create patches and project the patches.
projected_patches = self.projection(inputs)
# Append class token if needed.
cls_token = tf.tile(self.cls_token, (n, 1, 1))
cls_token = tf.cast(cls_token, projected_patches.dtype)
projected_patches = tf.concat([cls_token, projected_patches], axis=1)
# Add positional embeddings to the projected patches.
encoded_patches = (
self.positional_embedding + projected_patches
) # (B, number_patches, projection_dim)
encoded_patches = self.dropout(encoded_patches)
# Iterate over the number of layers and stack up blocks of
# Transformer.
for transformer_module in self.transformer_blocks:
# Add a Transformer block.
encoded_patches = transformer_module(encoded_patches)
# Final layer normalization.
representation = self.layer_norm(encoded_patches)
# Pool representation.
encoded_patches = representation[:, 0]
# Classification head.
output = self.head(encoded_patches)
return output
此类可以作为独立的 ViT 使用,并且是端到端可训练的。只需删除 MODEL_TYPE
中的 distilled
短语,它就可以与 vit_tiny = ViTClassifier()
一起使用。现在让我们将其扩展到 DeiT。下图展示了 DeiT 的示意图(取自 DeiT 论文)
除了类标记之外,DeiT 还有一个用于蒸馏的标记。在蒸馏过程中,与类标记相对应的 logits 与真实标签进行比较,与蒸馏标记相对应的 logits 与教师的预测进行比较。
class ViTDistilled(ViTClassifier):
def __init__(self, regular_training=False, **kwargs):
super().__init__(**kwargs)
self.num_tokens = 2
self.regular_training = regular_training
# CLS and distillation tokens, positional embedding.
init_value = tf.zeros((1, 1, PROJECTION_DIM))
self.dist_token = tf.Variable(init_value, name="dist_token")
self.positional_embedding = tf.Variable(
tf.zeros(
(
1,
NUM_PATCHES + self.num_tokens,
PROJECTION_DIM,
)
),
name="pos_embedding",
)
# Head layers.
self.head = layers.Dense(
NUM_CLASSES,
name="classification_head",
)
self.head_dist = layers.Dense(
NUM_CLASSES,
name="distillation_head",
)
def call(self, inputs, training=True):
n = tf.shape(inputs)[0]
# Create patches and project the patches.
projected_patches = self.projection(inputs)
# Append the tokens.
cls_token = tf.tile(self.cls_token, (n, 1, 1))
dist_token = tf.tile(self.dist_token, (n, 1, 1))
cls_token = tf.cast(cls_token, projected_patches.dtype)
dist_token = tf.cast(dist_token, projected_patches.dtype)
projected_patches = tf.concat(
[cls_token, dist_token, projected_patches], axis=1
)
# Add positional embeddings to the projected patches.
encoded_patches = (
self.positional_embedding + projected_patches
) # (B, number_patches, projection_dim)
encoded_patches = self.dropout(encoded_patches)
# Iterate over the number of layers and stack up blocks of
# Transformer.
for transformer_module in self.transformer_blocks:
# Add a Transformer block.
encoded_patches = transformer_module(encoded_patches)
# Final layer normalization.
representation = self.layer_norm(encoded_patches)
# Classification heads.
x, x_dist = (
self.head(representation[:, 0]),
self.head_dist(representation[:, 1]),
)
if not training or self.regular_training:
# During standard train / finetune, inference average the classifier
# predictions.
return (x + x_dist) / 2
elif training:
# Only return separate classification predictions when training in distilled
# mode.
return x, x_dist
让我们验证 ViTDistilled
类是否可以按预期初始化和调用。
deit_tiny_distilled = ViTDistilled()
dummy_inputs = tf.ones((2, 224, 224, 3))
outputs = deit_tiny_distilled(dummy_inputs, training=False)
print(outputs.shape)
(2, 5)
与标准知识蒸馏 (Hinton 等人) 中发生的情况不同,在标准知识蒸馏中使用温度缩放的 softmax 以及 KL 散度,DeiT 作者使用以下损失函数
这里,
psi
是 softmax 函数class DeiT(keras.Model):
# Reference:
# https://keras.org.cn/examples/vision/knowledge_distillation/
def __init__(self, student, teacher, **kwargs):
super().__init__(**kwargs)
self.student = student
self.teacher = teacher
self.student_loss_tracker = keras.metrics.Mean(name="student_loss")
self.dist_loss_tracker = keras.metrics.Mean(name="distillation_loss")
@property
def metrics(self):
metrics = super().metrics
metrics.append(self.student_loss_tracker)
metrics.append(self.dist_loss_tracker)
return metrics
def compile(
self,
optimizer,
metrics,
student_loss_fn,
distillation_loss_fn,
):
super().compile(optimizer=optimizer, metrics=metrics)
self.student_loss_fn = student_loss_fn
self.distillation_loss_fn = distillation_loss_fn
def train_step(self, data):
# Unpack data.
x, y = data
# Forward pass of teacher
teacher_predictions = tf.nn.softmax(self.teacher(x, training=False), -1)
teacher_predictions = tf.argmax(teacher_predictions, -1)
with tf.GradientTape() as tape:
# Forward pass of student.
cls_predictions, dist_predictions = self.student(x / 255.0, training=True)
# Compute losses.
student_loss = self.student_loss_fn(y, cls_predictions)
distillation_loss = self.distillation_loss_fn(
teacher_predictions, dist_predictions
)
loss = (student_loss + distillation_loss) / 2
# Compute gradients.
trainable_vars = self.student.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights.
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update the metrics configured in `compile()`.
student_predictions = (cls_predictions + dist_predictions) / 2
self.compiled_metrics.update_state(y, student_predictions)
self.dist_loss_tracker.update_state(distillation_loss)
self.student_loss_tracker.update_state(student_loss)
# Return a dict of performance.
results = {m.name: m.result() for m in self.metrics}
return results
def test_step(self, data):
# Unpack the data.
x, y = data
# Compute predictions.
y_prediction = self.student(x / 255.0, training=False)
# Calculate the loss.
student_loss = self.student_loss_fn(y, y_prediction)
# Update the metrics.
self.compiled_metrics.update_state(y, y_prediction)
self.student_loss_tracker.update_state(student_loss)
# Return a dict of performance.
results = {m.name: m.result() for m in self.metrics}
return results
def call(self, inputs):
return self.student(inputs / 255.0, training=False)
该模型基于 ResNet 的 BiT 系列(Kolesnikov 等人),并在 tf_flowers
数据集上进行了微调。您可以参考此笔记本了解训练是如何进行的。教师模型大约有 2.12 亿个参数,大约是学生的40 倍。
!wget -q https://github.com/sayakpaul/deit-tf/releases/download/v0.1.0/bit_teacher_flowers.zip
!unzip -q bit_teacher_flowers.zip
bit_teacher_flowers = keras.models.load_model("bit_teacher_flowers")
deit_tiny = ViTDistilled()
deit_distiller = DeiT(student=deit_tiny, teacher=bit_teacher_flowers)
lr_scaled = (BASE_LR / 512) * BATCH_SIZE
deit_distiller.compile(
optimizer=tfa.optimizers.AdamW(weight_decay=WEIGHT_DECAY, learning_rate=lr_scaled),
metrics=["accuracy"],
student_loss_fn=keras.losses.CategoricalCrossentropy(
from_logits=True, label_smoothing=0.1
),
distillation_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
_ = deit_distiller.fit(train_dataset, validation_data=val_dataset, epochs=NUM_EPOCHS)
Epoch 1/20
13/13 [==============================] - 44s 2s/step - accuracy: 0.2343 - student_loss: 2.2630 - distillation_loss: 1.7818 - val_accuracy: 0.2234 - val_student_loss: 1.6622 - val_distillation_loss: 0.0000e+00
Epoch 2/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.2150 - student_loss: 1.6377 - distillation_loss: 1.6138 - val_accuracy: 0.1907 - val_student_loss: 1.6150 - val_distillation_loss: 0.0000e+00
Epoch 3/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.2552 - student_loss: 1.6073 - distillation_loss: 1.5970 - val_accuracy: 0.1907 - val_student_loss: 1.6093 - val_distillation_loss: 0.0000e+00
Epoch 4/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.2564 - student_loss: 1.5954 - distillation_loss: 1.5902 - val_accuracy: 0.2997 - val_student_loss: 1.5958 - val_distillation_loss: 0.0000e+00
Epoch 5/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.2922 - student_loss: 1.5839 - distillation_loss: 1.5704 - val_accuracy: 0.3488 - val_student_loss: 1.5635 - val_distillation_loss: 0.0000e+00
Epoch 6/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.3815 - student_loss: 1.4865 - distillation_loss: 1.4551 - val_accuracy: 0.3815 - val_student_loss: 1.4975 - val_distillation_loss: 0.0000e+00
Epoch 7/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.4151 - student_loss: 1.4027 - distillation_loss: 1.3441 - val_accuracy: 0.3733 - val_student_loss: 1.4083 - val_distillation_loss: 0.0000e+00
Epoch 8/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.4423 - student_loss: 1.3616 - distillation_loss: 1.2877 - val_accuracy: 0.4005 - val_student_loss: 1.4014 - val_distillation_loss: 0.0000e+00
Epoch 9/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.4475 - student_loss: 1.3095 - distillation_loss: 1.2200 - val_accuracy: 0.4496 - val_student_loss: 1.3211 - val_distillation_loss: 0.0000e+00
Epoch 10/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.4959 - student_loss: 1.2638 - distillation_loss: 1.1508 - val_accuracy: 0.4932 - val_student_loss: 1.2839 - val_distillation_loss: 0.0000e+00
Epoch 11/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.5431 - student_loss: 1.2063 - distillation_loss: 1.0948 - val_accuracy: 0.5559 - val_student_loss: 1.1938 - val_distillation_loss: 0.0000e+00
Epoch 12/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.5771 - student_loss: 1.1742 - distillation_loss: 1.0461 - val_accuracy: 0.5695 - val_student_loss: 1.1362 - val_distillation_loss: 0.0000e+00
Epoch 13/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.5601 - student_loss: 1.1724 - distillation_loss: 1.0457 - val_accuracy: 0.5477 - val_student_loss: 1.1929 - val_distillation_loss: 0.0000e+00
Epoch 14/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.5777 - student_loss: 1.1717 - distillation_loss: 1.0378 - val_accuracy: 0.5777 - val_student_loss: 1.1171 - val_distillation_loss: 0.0000e+00
Epoch 15/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6173 - student_loss: 1.1232 - distillation_loss: 0.9782 - val_accuracy: 0.5640 - val_student_loss: 1.1229 - val_distillation_loss: 0.0000e+00
Epoch 16/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6237 - student_loss: 1.1091 - distillation_loss: 0.9627 - val_accuracy: 0.5886 - val_student_loss: 1.1371 - val_distillation_loss: 0.0000e+00
Epoch 17/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6261 - student_loss: 1.0880 - distillation_loss: 0.9341 - val_accuracy: 0.6322 - val_student_loss: 1.0972 - val_distillation_loss: 0.0000e+00
Epoch 18/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6427 - student_loss: 1.0688 - distillation_loss: 0.9117 - val_accuracy: 0.6431 - val_student_loss: 1.0548 - val_distillation_loss: 0.0000e+00
Epoch 19/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6458 - student_loss: 1.0529 - distillation_loss: 0.8903 - val_accuracy: 0.6076 - val_student_loss: 1.0761 - val_distillation_loss: 0.0000e+00
Epoch 20/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6382 - student_loss: 1.0641 - distillation_loss: 0.9049 - val_accuracy: 0.6240 - val_student_loss: 1.0521 - val_distillation_loss: 0.0000e+00
如果我们从头开始使用完全相同的超参数训练相同的模型(ViTClassifier
),则该模型的准确率大约为 59%。您可以调整以下代码以重现此结果
vit_tiny = ViTClassifier()
inputs = keras.Input((RESOLUTION, RESOLUTION, 3))
x = keras.layers.Rescaling(scale=1./255)(inputs)
outputs = deit_tiny(x)
model = keras.Model(inputs, outputs)
model.compile(...)
model.fit(...)
timm
,并提供易于阅读的实现。在 TensorFlow 中实现 ViT 和 DeiT 时,我参考了很多这些实现。ViTClassifier
的部分代码。HuggingFace 上提供的示例
训练后的模型 | 演示 |
---|---|