作者: Sayak Paul
创建日期 2022/04/05
上次修改日期 2022/04/08
描述:通过注意力进行视觉转换器的蒸馏。
在最初的视觉转换器(ViT)论文(Dosovitskiy 等人)中,作者得出结论,为了与卷积神经网络(CNN)的性能相媲美,ViT 需要在更大的数据集上进行预训练。越大越好。这主要是由于 ViT 架构缺乏归纳偏差——与 CNN 不同,它们没有利用局部性的层。在后续论文(Steiner 等人)中,作者表明,通过更强的正则化和更长的训练可以大幅提高 ViT 的性能。
许多团队提出了不同的方法来解决 ViT 训练的数据密集型问题。数据高效图像转换器(DeiT)论文(Touvron 等人)中展示了一种这样的方法。作者介绍了一种特定于基于转换器的视觉模型的蒸馏技术。DeiT 是最早证明无需使用更大的数据集即可很好地训练 ViT 的工作之一。
在本例中,我们实现了 DeiT 中提出的蒸馏方法。这要求我们稍微调整原始 ViT 架构并编写自定义训练循环来实现蒸馏方法。
要运行此示例,您需要 TensorFlow Addons,您可以使用以下命令安装它
pip install tensorflow-addons
为了轻松浏览此示例,您需要了解 ViT 和知识蒸馏的工作原理。如果您需要复习,以下是一些有用的资源
from typing import List
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
import tensorflow_hub as hub
from tensorflow import keras
from tensorflow.keras import layers
tfds.disable_progress_bar()
tf.keras.utils.set_random_seed(42)
# Model
MODEL_TYPE = "deit_distilled_tiny_patch16_224"
RESOLUTION = 224
PATCH_SIZE = 16
NUM_PATCHES = (RESOLUTION // PATCH_SIZE) ** 2
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 192
NUM_HEADS = 3
NUM_LAYERS = 12
MLP_UNITS = [
PROJECTION_DIM * 4,
PROJECTION_DIM,
]
DROPOUT_RATE = 0.0
DROP_PATH_RATE = 0.1
# Training
NUM_EPOCHS = 20
BASE_LR = 0.0005
WEIGHT_DECAY = 0.0001
# Data
BATCH_SIZE = 256
AUTO = tf.data.AUTOTUNE
NUM_CLASSES = 5
您可能注意到DROPOUT_RATE
已设置为 0.0。在实现中使用了 Dropout 以保持其完整性。对于较小的模型(如本例中使用的模型),您不需要它,但对于较大的模型,使用 Dropout 会有所帮助。
tf_flowers
数据集并准备预处理工具作者使用一系列不同的增强技术,包括 MixUp(Zhang 等人)、RandAugment(Cubuk 等人)等等。但是,为了使示例易于理解,我们将放弃它们。
def preprocess_dataset(is_training=True):
def fn(image, label):
if is_training:
# Resize to a bigger spatial resolution and take the random
# crops.
image = tf.image.resize(image, (RESOLUTION + 20, RESOLUTION + 20))
image = tf.image.random_crop(image, (RESOLUTION, RESOLUTION, 3))
image = tf.image.random_flip_left_right(image)
else:
image = tf.image.resize(image, (RESOLUTION, RESOLUTION))
label = tf.one_hot(label, depth=NUM_CLASSES)
return image, label
return fn
def prepare_dataset(dataset, is_training=True):
if is_training:
dataset = dataset.shuffle(BATCH_SIZE * 10)
dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=AUTO)
return dataset.batch(BATCH_SIZE).prefetch(AUTO)
train_dataset, val_dataset = tfds.load(
"tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
)
num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
print(f"Number of training examples: {num_train}")
print(f"Number of validation examples: {num_val}")
train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)
Number of training examples: 3303
Number of validation examples: 367
由于 DeiT 是 ViT 的扩展,因此首先实现 ViT 然后将其扩展以支持 DeiT 的组件是有意义的。
首先,我们将实现随机深度(Huang 等人)的层,该层在 DeiT 中用于正则化。
# Referred from: github.com:rwightman/pytorch-image-models.
class StochasticDepth(layers.Layer):
def __init__(self, drop_prop, **kwargs):
super().__init__(**kwargs)
self.drop_prob = drop_prop
def call(self, x, training=True):
if training:
keep_prob = 1 - self.drop_prob
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
random_tensor = tf.floor(random_tensor)
return (x / keep_prob) * random_tensor
return x
现在,我们将实现 MLP 和 Transformer 块。
def mlp(x, dropout_rate: float, hidden_units: List):
"""FFN for a Transformer block."""
# Iterate over the hidden units and
# add Dense => Dropout.
for (idx, units) in enumerate(hidden_units):
x = layers.Dense(
units,
activation=tf.nn.gelu if idx == 0 else None,
)(x)
x = layers.Dropout(dropout_rate)(x)
return x
def transformer(drop_prob: float, name: str) -> keras.Model:
"""Transformer block with pre-norm."""
num_patches = NUM_PATCHES + 2 if "distilled" in MODEL_TYPE else NUM_PATCHES + 1
encoded_patches = layers.Input((num_patches, PROJECTION_DIM))
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
# Multi Head Self Attention layer 1.
attention_output = layers.MultiHeadAttention(
num_heads=NUM_HEADS,
key_dim=PROJECTION_DIM,
dropout=DROPOUT_RATE,
)(x1, x1)
attention_output = (
StochasticDepth(drop_prob)(attention_output) if drop_prob else attention_output
)
# Skip connection 1.
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)
# MLP layer 1.
x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=DROPOUT_RATE)
x4 = StochasticDepth(drop_prob)(x4) if drop_prob else x4
# Skip connection 2.
outputs = layers.Add()([x2, x4])
return keras.Model(encoded_patches, outputs, name=name)
我们现在将实现一个基于我们刚刚开发的组件的ViTClassifier
类。在这里,我们将遵循 ViT 论文中使用的原始池化策略——使用类标记并使用与其对应的特征表示进行分类。
class ViTClassifier(keras.Model):
"""Vision Transformer base class."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Patchify + linear projection + reshaping.
self.projection = keras.Sequential(
[
layers.Conv2D(
filters=PROJECTION_DIM,
kernel_size=(PATCH_SIZE, PATCH_SIZE),
strides=(PATCH_SIZE, PATCH_SIZE),
padding="VALID",
name="conv_projection",
),
layers.Reshape(
target_shape=(NUM_PATCHES, PROJECTION_DIM),
name="flatten_projection",
),
],
name="projection",
)
# Positional embedding.
init_shape = (
1,
NUM_PATCHES + 1,
PROJECTION_DIM,
)
self.positional_embedding = tf.Variable(
tf.zeros(init_shape), name="pos_embedding"
)
# Transformer blocks.
dpr = [x for x in tf.linspace(0.0, DROP_PATH_RATE, NUM_LAYERS)]
self.transformer_blocks = [
transformer(drop_prob=dpr[i], name=f"transformer_block_{i}")
for i in range(NUM_LAYERS)
]
# CLS token.
initial_value = tf.zeros((1, 1, PROJECTION_DIM))
self.cls_token = tf.Variable(
initial_value=initial_value, trainable=True, name="cls"
)
# Other layers.
self.dropout = layers.Dropout(DROPOUT_RATE)
self.layer_norm = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)
self.head = layers.Dense(
NUM_CLASSES,
name="classification_head",
)
def call(self, inputs, training=True):
n = tf.shape(inputs)[0]
# Create patches and project the patches.
projected_patches = self.projection(inputs)
# Append class token if needed.
cls_token = tf.tile(self.cls_token, (n, 1, 1))
cls_token = tf.cast(cls_token, projected_patches.dtype)
projected_patches = tf.concat([cls_token, projected_patches], axis=1)
# Add positional embeddings to the projected patches.
encoded_patches = (
self.positional_embedding + projected_patches
) # (B, number_patches, projection_dim)
encoded_patches = self.dropout(encoded_patches)
# Iterate over the number of layers and stack up blocks of
# Transformer.
for transformer_module in self.transformer_blocks:
# Add a Transformer block.
encoded_patches = transformer_module(encoded_patches)
# Final layer normalization.
representation = self.layer_norm(encoded_patches)
# Pool representation.
encoded_patches = representation[:, 0]
# Classification head.
output = self.head(encoded_patches)
return output
此类可以作为 ViT 独立使用,并且是端到端可训练的。只需删除MODEL_TYPE
中的distilled
短语,它就可以与vit_tiny = ViTClassifier()
一起使用。现在让我们将其扩展到 DeiT。下图展示了 DeiT 的示意图(取自 DeiT 论文)
除了类标记之外,DeiT 还有另一个用于蒸馏的标记。在蒸馏过程中,与类标记对应的 logits 与真实标签进行比较,与蒸馏标记对应的 logits 与教师的预测进行比较。
class ViTDistilled(ViTClassifier):
def __init__(self, regular_training=False, **kwargs):
super().__init__(**kwargs)
self.num_tokens = 2
self.regular_training = regular_training
# CLS and distillation tokens, positional embedding.
init_value = tf.zeros((1, 1, PROJECTION_DIM))
self.dist_token = tf.Variable(init_value, name="dist_token")
self.positional_embedding = tf.Variable(
tf.zeros(
(
1,
NUM_PATCHES + self.num_tokens,
PROJECTION_DIM,
)
),
name="pos_embedding",
)
# Head layers.
self.head = layers.Dense(
NUM_CLASSES,
name="classification_head",
)
self.head_dist = layers.Dense(
NUM_CLASSES,
name="distillation_head",
)
def call(self, inputs, training=True):
n = tf.shape(inputs)[0]
# Create patches and project the patches.
projected_patches = self.projection(inputs)
# Append the tokens.
cls_token = tf.tile(self.cls_token, (n, 1, 1))
dist_token = tf.tile(self.dist_token, (n, 1, 1))
cls_token = tf.cast(cls_token, projected_patches.dtype)
dist_token = tf.cast(dist_token, projected_patches.dtype)
projected_patches = tf.concat(
[cls_token, dist_token, projected_patches], axis=1
)
# Add positional embeddings to the projected patches.
encoded_patches = (
self.positional_embedding + projected_patches
) # (B, number_patches, projection_dim)
encoded_patches = self.dropout(encoded_patches)
# Iterate over the number of layers and stack up blocks of
# Transformer.
for transformer_module in self.transformer_blocks:
# Add a Transformer block.
encoded_patches = transformer_module(encoded_patches)
# Final layer normalization.
representation = self.layer_norm(encoded_patches)
# Classification heads.
x, x_dist = (
self.head(representation[:, 0]),
self.head_dist(representation[:, 1]),
)
if not training or self.regular_training:
# During standard train / finetune, inference average the classifier
# predictions.
return (x + x_dist) / 2
elif training:
# Only return separate classification predictions when training in distilled
# mode.
return x, x_dist
让我们验证一下ViTDistilled
类是否可以按预期初始化和调用。
deit_tiny_distilled = ViTDistilled()
dummy_inputs = tf.ones((2, 224, 224, 3))
outputs = deit_tiny_distilled(dummy_inputs, training=False)
print(outputs.shape)
(2, 5)
与标准知识蒸馏(Hinton 等人)中使用温度缩放 softmax 和 KL 散度不同,DeiT 作者使用以下损失函数
这里,
psi
是 softmax 函数class DeiT(keras.Model):
# Reference:
# https://keras.org.cn/examples/vision/knowledge_distillation/
def __init__(self, student, teacher, **kwargs):
super().__init__(**kwargs)
self.student = student
self.teacher = teacher
self.student_loss_tracker = keras.metrics.Mean(name="student_loss")
self.dist_loss_tracker = keras.metrics.Mean(name="distillation_loss")
@property
def metrics(self):
metrics = super().metrics
metrics.append(self.student_loss_tracker)
metrics.append(self.dist_loss_tracker)
return metrics
def compile(
self,
optimizer,
metrics,
student_loss_fn,
distillation_loss_fn,
):
super().compile(optimizer=optimizer, metrics=metrics)
self.student_loss_fn = student_loss_fn
self.distillation_loss_fn = distillation_loss_fn
def train_step(self, data):
# Unpack data.
x, y = data
# Forward pass of teacher
teacher_predictions = tf.nn.softmax(self.teacher(x, training=False), -1)
teacher_predictions = tf.argmax(teacher_predictions, -1)
with tf.GradientTape() as tape:
# Forward pass of student.
cls_predictions, dist_predictions = self.student(x / 255.0, training=True)
# Compute losses.
student_loss = self.student_loss_fn(y, cls_predictions)
distillation_loss = self.distillation_loss_fn(
teacher_predictions, dist_predictions
)
loss = (student_loss + distillation_loss) / 2
# Compute gradients.
trainable_vars = self.student.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights.
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update the metrics configured in `compile()`.
student_predictions = (cls_predictions + dist_predictions) / 2
self.compiled_metrics.update_state(y, student_predictions)
self.dist_loss_tracker.update_state(distillation_loss)
self.student_loss_tracker.update_state(student_loss)
# Return a dict of performance.
results = {m.name: m.result() for m in self.metrics}
return results
def test_step(self, data):
# Unpack the data.
x, y = data
# Compute predictions.
y_prediction = self.student(x / 255.0, training=False)
# Calculate the loss.
student_loss = self.student_loss_fn(y, y_prediction)
# Update the metrics.
self.compiled_metrics.update_state(y, y_prediction)
self.student_loss_tracker.update_state(student_loss)
# Return a dict of performance.
results = {m.name: m.result() for m in self.metrics}
return results
def call(self, inputs):
return self.student(inputs / 255.0, training=False)
此模型基于 BiT 系列 ResNets(Kolesnikov 等人),在tf_flowers
数据集上进行了微调。您可以参考此笔记本以了解如何执行训练。教师模型大约有 2.12 亿个参数,大约是学生的40 倍。
!wget -q https://github.com/sayakpaul/deit-tf/releases/download/v0.1.0/bit_teacher_flowers.zip
!unzip -q bit_teacher_flowers.zip
bit_teacher_flowers = keras.models.load_model("bit_teacher_flowers")
deit_tiny = ViTDistilled()
deit_distiller = DeiT(student=deit_tiny, teacher=bit_teacher_flowers)
lr_scaled = (BASE_LR / 512) * BATCH_SIZE
deit_distiller.compile(
optimizer=tfa.optimizers.AdamW(weight_decay=WEIGHT_DECAY, learning_rate=lr_scaled),
metrics=["accuracy"],
student_loss_fn=keras.losses.CategoricalCrossentropy(
from_logits=True, label_smoothing=0.1
),
distillation_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
_ = deit_distiller.fit(train_dataset, validation_data=val_dataset, epochs=NUM_EPOCHS)
Epoch 1/20
13/13 [==============================] - 44s 2s/step - accuracy: 0.2343 - student_loss: 2.2630 - distillation_loss: 1.7818 - val_accuracy: 0.2234 - val_student_loss: 1.6622 - val_distillation_loss: 0.0000e+00
Epoch 2/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.2150 - student_loss: 1.6377 - distillation_loss: 1.6138 - val_accuracy: 0.1907 - val_student_loss: 1.6150 - val_distillation_loss: 0.0000e+00
Epoch 3/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.2552 - student_loss: 1.6073 - distillation_loss: 1.5970 - val_accuracy: 0.1907 - val_student_loss: 1.6093 - val_distillation_loss: 0.0000e+00
Epoch 4/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.2564 - student_loss: 1.5954 - distillation_loss: 1.5902 - val_accuracy: 0.2997 - val_student_loss: 1.5958 - val_distillation_loss: 0.0000e+00
Epoch 5/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.2922 - student_loss: 1.5839 - distillation_loss: 1.5704 - val_accuracy: 0.3488 - val_student_loss: 1.5635 - val_distillation_loss: 0.0000e+00
Epoch 6/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.3815 - student_loss: 1.4865 - distillation_loss: 1.4551 - val_accuracy: 0.3815 - val_student_loss: 1.4975 - val_distillation_loss: 0.0000e+00
Epoch 7/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.4151 - student_loss: 1.4027 - distillation_loss: 1.3441 - val_accuracy: 0.3733 - val_student_loss: 1.4083 - val_distillation_loss: 0.0000e+00
Epoch 8/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.4423 - student_loss: 1.3616 - distillation_loss: 1.2877 - val_accuracy: 0.4005 - val_student_loss: 1.4014 - val_distillation_loss: 0.0000e+00
Epoch 9/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.4475 - student_loss: 1.3095 - distillation_loss: 1.2200 - val_accuracy: 0.4496 - val_student_loss: 1.3211 - val_distillation_loss: 0.0000e+00
Epoch 10/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.4959 - student_loss: 1.2638 - distillation_loss: 1.1508 - val_accuracy: 0.4932 - val_student_loss: 1.2839 - val_distillation_loss: 0.0000e+00
Epoch 11/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.5431 - student_loss: 1.2063 - distillation_loss: 1.0948 - val_accuracy: 0.5559 - val_student_loss: 1.1938 - val_distillation_loss: 0.0000e+00
Epoch 12/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.5771 - student_loss: 1.1742 - distillation_loss: 1.0461 - val_accuracy: 0.5695 - val_student_loss: 1.1362 - val_distillation_loss: 0.0000e+00
Epoch 13/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.5601 - student_loss: 1.1724 - distillation_loss: 1.0457 - val_accuracy: 0.5477 - val_student_loss: 1.1929 - val_distillation_loss: 0.0000e+00
Epoch 14/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.5777 - student_loss: 1.1717 - distillation_loss: 1.0378 - val_accuracy: 0.5777 - val_student_loss: 1.1171 - val_distillation_loss: 0.0000e+00
Epoch 15/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6173 - student_loss: 1.1232 - distillation_loss: 0.9782 - val_accuracy: 0.5640 - val_student_loss: 1.1229 - val_distillation_loss: 0.0000e+00
Epoch 16/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6237 - student_loss: 1.1091 - distillation_loss: 0.9627 - val_accuracy: 0.5886 - val_student_loss: 1.1371 - val_distillation_loss: 0.0000e+00
Epoch 17/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6261 - student_loss: 1.0880 - distillation_loss: 0.9341 - val_accuracy: 0.6322 - val_student_loss: 1.0972 - val_distillation_loss: 0.0000e+00
Epoch 18/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6427 - student_loss: 1.0688 - distillation_loss: 0.9117 - val_accuracy: 0.6431 - val_student_loss: 1.0548 - val_distillation_loss: 0.0000e+00
Epoch 19/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6458 - student_loss: 1.0529 - distillation_loss: 0.8903 - val_accuracy: 0.6076 - val_student_loss: 1.0761 - val_distillation_loss: 0.0000e+00
Epoch 20/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6382 - student_loss: 1.0641 - distillation_loss: 0.9049 - val_accuracy: 0.6240 - val_student_loss: 1.0521 - val_distillation_loss: 0.0000e+00
如果我们从头开始训练相同的模型(ViTClassifier
)并使用完全相同的超参数,该模型的准确率大约为 59%。您可以调整以下代码来重现此结果。
vit_tiny = ViTClassifier()
inputs = keras.Input((RESOLUTION, RESOLUTION, 3))
x = keras.layers.Rescaling(scale=1./255)(inputs)
outputs = deit_tiny(x)
model = keras.Model(inputs, outputs)
model.compile(...)
model.fit(...)
timm
并提供易于理解的实现。在 TensorFlow 中实现 ViT 和 DeiT 时,我参考了很多这些实现。ViTClassifier
的部分内容。HuggingFace 上的示例
训练后的模型 | 演示 |
---|---|