作者: Kenneth Borup
创建日期 2020/09/01
最后修改日期 2020/09/01
描述: 经典知识蒸馏的实现。
知识蒸馏是一种模型压缩方法,其中训练一个小型(学生)模型来匹配一个大型预训练(教师)模型。通过最小化损失函数,将知识从教师模型转移到学生模型,该损失函数旨在匹配软化的教师 logits 以及真实标签。
通过在 softmax 中应用“温度”缩放函数来软化 logits,这有效地平滑了概率分布,并揭示了教师模型学习到的类间关系。
参考文献
import os
import keras
from keras import layers
from keras import ops
import numpy as np
Distiller()
类自定义的 Distiller()
类覆盖了 Model
的 compile
、compute_loss
和 call
方法。为了使用蒸馏器,我们需要
temperature
参数alpha
系数在 compute_loss
方法中,我们对教师和学生模型进行前向传播,然后计算损失,其中 student_loss
和 distillation_loss
分别通过 alpha
和 1 - alpha
进行加权。注意:只有学生模型的权重会被更新。
class Distiller(keras.Model):
def __init__(self, student, teacher):
super().__init__()
self.teacher = teacher
self.student = student
def compile(
self,
optimizer,
metrics,
student_loss_fn,
distillation_loss_fn,
alpha=0.1,
temperature=3,
):
"""Configure the distiller.
Args:
optimizer: Keras optimizer for the student weights
metrics: Keras metrics for evaluation
student_loss_fn: Loss function of difference between student
predictions and ground-truth
distillation_loss_fn: Loss function of difference between soft
student predictions and soft teacher predictions
alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
temperature: Temperature for softening probability distributions.
Larger temperature gives softer distributions.
"""
super().compile(optimizer=optimizer, metrics=metrics)
self.student_loss_fn = student_loss_fn
self.distillation_loss_fn = distillation_loss_fn
self.alpha = alpha
self.temperature = temperature
def compute_loss(
self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
):
teacher_pred = self.teacher(x, training=False)
student_loss = self.student_loss_fn(y, y_pred)
distillation_loss = self.distillation_loss_fn(
ops.softmax(teacher_pred / self.temperature, axis=1),
ops.softmax(y_pred / self.temperature, axis=1),
) * (self.temperature**2)
loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
return loss
def call(self, x):
return self.student(x)
首先,我们创建一个教师模型和一个较小的学生模型。这两个模型都是卷积神经网络,使用 Sequential()
创建,但可以是任何 Keras 模型。
# Create the teacher
teacher = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
layers.Flatten(),
layers.Dense(10),
],
name="teacher",
)
# Create the student
student = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
layers.Flatten(),
layers.Dense(10),
],
name="student",
)
# Clone student for later comparison
student_scratch = keras.models.clone_model(student)
用于训练教师模型和蒸馏教师模型的数据集是 MNIST,对于任何其他数据集,例如 CIFAR-10,只要选择合适的模型,过程都是类似的。学生和教师模型都在训练集上进行训练,并在测试集上进行评估。
# Prepare the train and test dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))
x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))
在知识蒸馏中,我们假设教师模型已经训练完毕并固定。因此,我们首先以通常的方式在训练集上训练教师模型。
# Train teacher as usual
teacher.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate teacher on data.
teacher.fit(x_train, y_train, epochs=5)
teacher.evaluate(x_test, y_test)
Epoch 1/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 3ms/step - loss: 0.2408 - sparse_categorical_accuracy: 0.9259
Epoch 2/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0912 - sparse_categorical_accuracy: 0.9726
Epoch 3/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - loss: 0.0758 - sparse_categorical_accuracy: 0.9777
Epoch 4/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0690 - sparse_categorical_accuracy: 0.9797
Epoch 5/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0582 - sparse_categorical_accuracy: 0.9825
313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0931 - sparse_categorical_accuracy: 0.9760
[0.09044107794761658, 0.978100061416626]
我们已经训练好了教师模型,现在只需初始化一个 Distiller(student, teacher)
实例,使用期望的损失函数、超参数和优化器对其进行 compile()
,然后将知识从教师蒸馏到学生模型。
# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
optimizer=keras.optimizers.Adam(),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
distillation_loss_fn=keras.losses.KLDivergence(),
alpha=0.1,
temperature=10,
)
# Distill teacher to student
distiller.fit(x_train, y_train, epochs=3)
# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)
Epoch 1/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 3ms/step - loss: 1.8752 - sparse_categorical_accuracy: 0.7357
Epoch 2/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 6s 3ms/step - loss: 0.0333 - sparse_categorical_accuracy: 0.9475
Epoch 3/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 6s 3ms/step - loss: 0.0223 - sparse_categorical_accuracy: 0.9621
313/313 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.0189 - sparse_categorical_accuracy: 0.9629
[0.017046602442860603, 0.969200074672699]
我们也可以从头开始训练一个等效的学生模型,不使用教师模型,以评估知识蒸馏带来的性能提升。
# Train student as doen usually
student_scratch.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate student trained from scratch.
student_scratch.fit(x_train, y_train, epochs=3)
student_scratch.evaluate(x_test, y_test)
Epoch 1/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 4s 1ms/step - loss: 0.5111 - sparse_categorical_accuracy: 0.8460
Epoch 2/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 3s 1ms/step - loss: 0.1039 - sparse_categorical_accuracy: 0.9687
Epoch 3/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 3s 1ms/step - loss: 0.0748 - sparse_categorical_accuracy: 0.9780
313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0744 - sparse_categorical_accuracy: 0.9737
[0.0629437193274498, 0.9778000712394714]
如果教师模型训练了 5 个完整的 epoch,学生模型在此教师模型上蒸馏训练了 3 个完整的 epoch,您应该在此示例中体验到相较于从头训练相同学生模型甚至教师模型本身的性能提升。您可以预期教师模型的准确率在 97.6% 左右,从头训练的学生模型准确率应在 97.6% 左右,而经过蒸馏的学生模型准确率应在 98.1% 左右。移除或尝试不同的随机种子以使用不同的权重初始化。