作者: Kenneth Borup
创建日期 2020/09/01
上次修改日期 2020/09/01
描述:经典知识蒸馏的实现。
知识蒸馏是一种模型压缩方法,其中一个小模型(学生模型)被训练以匹配一个大型预训练模型(教师模型)。通过最小化损失函数将知识从教师模型转移到学生模型,该损失函数旨在匹配软化的教师 logits 以及真实标签。
通过在 softmax 中应用“温度”缩放函数来软化 logits,有效地平滑概率分布并揭示教师学习的类间关系。
参考文献
import os
import keras
from keras import layers
from keras import ops
import numpy as np
Distiller()
类自定义 Distiller()
类重写了 Model
方法 compile
、compute_loss
和 call
。为了使用蒸馏器,我们需要
temperature
,用于计算软化学生预测与软化教师标签之间的差异alpha
因子,用于加权学生和蒸馏损失在 compute_loss
方法中,我们对教师和学生都进行前向传递,分别使用 alpha
和 1 - alpha
对 student_loss
和 distillation_loss
进行加权计算损失。注意:只有学生权重会被更新。
class Distiller(keras.Model):
def __init__(self, student, teacher):
super().__init__()
self.teacher = teacher
self.student = student
def compile(
self,
optimizer,
metrics,
student_loss_fn,
distillation_loss_fn,
alpha=0.1,
temperature=3,
):
"""Configure the distiller.
Args:
optimizer: Keras optimizer for the student weights
metrics: Keras metrics for evaluation
student_loss_fn: Loss function of difference between student
predictions and ground-truth
distillation_loss_fn: Loss function of difference between soft
student predictions and soft teacher predictions
alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
temperature: Temperature for softening probability distributions.
Larger temperature gives softer distributions.
"""
super().compile(optimizer=optimizer, metrics=metrics)
self.student_loss_fn = student_loss_fn
self.distillation_loss_fn = distillation_loss_fn
self.alpha = alpha
self.temperature = temperature
def compute_loss(
self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
):
teacher_pred = self.teacher(x, training=False)
student_loss = self.student_loss_fn(y, y_pred)
distillation_loss = self.distillation_loss_fn(
ops.softmax(teacher_pred / self.temperature, axis=1),
ops.softmax(y_pred / self.temperature, axis=1),
) * (self.temperature**2)
loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
return loss
def call(self, x):
return self.student(x)
最初,我们创建了一个教师模型和一个更小的学生模型。这两个模型都是卷积神经网络,使用 Sequential()
创建,但可以是任何 Keras 模型。
# Create the teacher
teacher = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
layers.Flatten(),
layers.Dense(10),
],
name="teacher",
)
# Create the student
student = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
layers.Flatten(),
layers.Dense(10),
],
name="student",
)
# Clone student for later comparison
student_scratch = keras.models.clone_model(student)
用于训练教师和蒸馏教师的数据集是 MNIST,对于任何其他数据集(例如 CIFAR-10)的过程也是等效的,前提是选择合适的模型。学生和教师都在训练集上进行训练,并在测试集上进行评估。
# Prepare the train and test dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))
x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))
在知识蒸馏中,我们假设教师已经过训练并且是固定的。因此,我们首先以通常的方式在训练集上训练教师模型。
# Train teacher as usual
teacher.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate teacher on data.
teacher.fit(x_train, y_train, epochs=5)
teacher.evaluate(x_test, y_test)
Epoch 1/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 3ms/step - loss: 0.2408 - sparse_categorical_accuracy: 0.9259
Epoch 2/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0912 - sparse_categorical_accuracy: 0.9726
Epoch 3/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - loss: 0.0758 - sparse_categorical_accuracy: 0.9777
Epoch 4/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0690 - sparse_categorical_accuracy: 0.9797
Epoch 5/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0582 - sparse_categorical_accuracy: 0.9825
313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0931 - sparse_categorical_accuracy: 0.9760
[0.09044107794761658, 0.978100061416626]
我们已经训练了教师模型,我们只需要初始化一个 Distiller(student, teacher)
实例,使用所需的损失、超参数和优化器对其进行 compile()
,并将教师知识蒸馏到学生。
# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
optimizer=keras.optimizers.Adam(),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
distillation_loss_fn=keras.losses.KLDivergence(),
alpha=0.1,
temperature=10,
)
# Distill teacher to student
distiller.fit(x_train, y_train, epochs=3)
# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)
Epoch 1/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 3ms/step - loss: 1.8752 - sparse_categorical_accuracy: 0.7357
Epoch 2/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 6s 3ms/step - loss: 0.0333 - sparse_categorical_accuracy: 0.9475
Epoch 3/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 6s 3ms/step - loss: 0.0223 - sparse_categorical_accuracy: 0.9621
313/313 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.0189 - sparse_categorical_accuracy: 0.9629
[0.017046602442860603, 0.969200074672699]
我们也可以从头开始训练一个相同学生模型,而无需教师,以便评估通过知识蒸馏获得的性能提升。
# Train student as doen usually
student_scratch.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate student trained from scratch.
student_scratch.fit(x_train, y_train, epochs=3)
student_scratch.evaluate(x_test, y_test)
Epoch 1/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 4s 1ms/step - loss: 0.5111 - sparse_categorical_accuracy: 0.8460
Epoch 2/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 3s 1ms/step - loss: 0.1039 - sparse_categorical_accuracy: 0.9687
Epoch 3/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 3s 1ms/step - loss: 0.0748 - sparse_categorical_accuracy: 0.9780
313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0744 - sparse_categorical_accuracy: 0.9737
[0.0629437193274498, 0.9778000712394714]
如果教师训练了 5 个完整的 epoch,并且学生在此教师上蒸馏了 3 个完整的 epoch,那么在本例中,您应该会体验到与从头开始训练相同学生模型相比,甚至与教师本身相比的性能提升。您应该期望教师的准确率约为 97.6%,从头开始训练的学生约为 97.6%,而蒸馏后的学生约为 98.1%。删除或尝试不同的种子以使用不同的权重初始化。