代码示例 / 计算机视觉 / 知识蒸馏

知识蒸馏

作者: Kenneth Borup
创建日期 2020/09/01
最后修改日期 2020/09/01
描述: 经典知识蒸馏的实现。

ⓘ 本示例使用 Keras 3

在 Colab 中查看 GitHub 源代码


知识蒸馏简介

知识蒸馏是一种模型压缩方法,其中训练一个小型(学生)模型来匹配一个大型预训练(教师)模型。通过最小化损失函数,将知识从教师模型转移到学生模型,该损失函数旨在匹配软化的教师 logits 以及真实标签。

通过在 softmax 中应用“温度”缩放函数来软化 logits,这有效地平滑了概率分布,并揭示了教师模型学习到的类间关系。

参考文献


设置

import os

import keras
from keras import layers
from keras import ops
import numpy as np

构建 Distiller()

自定义的 Distiller() 类覆盖了 Modelcompilecompute_losscall 方法。为了使用蒸馏器,我们需要

  • 一个已训练的教师模型
  • 一个待训练的学生模型
  • 一个衡量学生预测与真实值之间差异的学生损失函数
  • 一个蒸馏损失函数,以及一个用于衡量软学生预测与软教师标签之间差异的 temperature 参数
  • 一个用于平衡学生损失和蒸馏损失的 alpha 系数
  • 一个用于学生模型的优化器和(可选的)用于评估性能的指标

compute_loss 方法中,我们对教师和学生模型进行前向传播,然后计算损失,其中 student_lossdistillation_loss 分别通过 alpha1 - alpha 进行加权。注意:只有学生模型的权重会被更新。

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
    ):
        teacher_pred = self.teacher(x, training=False)
        student_loss = self.student_loss_fn(y, y_pred)

        distillation_loss = self.distillation_loss_fn(
            ops.softmax(teacher_pred / self.temperature, axis=1),
            ops.softmax(y_pred / self.temperature, axis=1),
        ) * (self.temperature**2)

        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        return loss

    def call(self, x):
        return self.student(x)

创建学生和教师模型

首先,我们创建一个教师模型和一个较小的学生模型。这两个模型都是卷积神经网络,使用 Sequential() 创建,但可以是任何 Keras 模型。

# Create the teacher
teacher = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="teacher",
)

# Create the student
student = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="student",
)

# Clone student for later comparison
student_scratch = keras.models.clone_model(student)

准备数据集

用于训练教师模型和蒸馏教师模型的数据集是 MNIST,对于任何其他数据集,例如 CIFAR-10,只要选择合适的模型,过程都是类似的。学生和教师模型都在训练集上进行训练,并在测试集上进行评估。

# Prepare the train and test dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

训练教师模型

在知识蒸馏中,我们假设教师模型已经训练完毕并固定。因此,我们首先以通常的方式在训练集上训练教师模型。

# Train teacher as usual
teacher.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train and evaluate teacher on data.
teacher.fit(x_train, y_train, epochs=5)
teacher.evaluate(x_test, y_test)
Epoch 1/5
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 3ms/step - loss: 0.2408 - sparse_categorical_accuracy: 0.9259
Epoch 2/5
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0912 - sparse_categorical_accuracy: 0.9726
Epoch 3/5
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - loss: 0.0758 - sparse_categorical_accuracy: 0.9777
Epoch 4/5
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0690 - sparse_categorical_accuracy: 0.9797
Epoch 5/5
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0582 - sparse_categorical_accuracy: 0.9825
 313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0931 - sparse_categorical_accuracy: 0.9760

[0.09044107794761658, 0.978100061416626]

将知识从教师蒸馏到学生模型

我们已经训练好了教师模型,现在只需初始化一个 Distiller(student, teacher) 实例,使用期望的损失函数、超参数和优化器对其进行 compile(),然后将知识从教师蒸馏到学生模型。

# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)

# Distill teacher to student
distiller.fit(x_train, y_train, epochs=3)

# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)
Epoch 1/3
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 3ms/step - loss: 1.8752 - sparse_categorical_accuracy: 0.7357
Epoch 2/3
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 6s 3ms/step - loss: 0.0333 - sparse_categorical_accuracy: 0.9475
Epoch 3/3
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 6s 3ms/step - loss: 0.0223 - sparse_categorical_accuracy: 0.9621
 313/313 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.0189 - sparse_categorical_accuracy: 0.9629

[0.017046602442860603, 0.969200074672699]

从头开始训练学生模型进行比较

我们也可以从头开始训练一个等效的学生模型,不使用教师模型,以评估知识蒸馏带来的性能提升。

# Train student as doen usually
student_scratch.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train and evaluate student trained from scratch.
student_scratch.fit(x_train, y_train, epochs=3)
student_scratch.evaluate(x_test, y_test)
Epoch 1/3
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 4s 1ms/step - loss: 0.5111 - sparse_categorical_accuracy: 0.8460
Epoch 2/3
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 3s 1ms/step - loss: 0.1039 - sparse_categorical_accuracy: 0.9687
Epoch 3/3
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 3s 1ms/step - loss: 0.0748 - sparse_categorical_accuracy: 0.9780
 313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0744 - sparse_categorical_accuracy: 0.9737

[0.0629437193274498, 0.9778000712394714]

如果教师模型训练了 5 个完整的 epoch,学生模型在此教师模型上蒸馏训练了 3 个完整的 epoch,您应该在此示例中体验到相较于从头训练相同学生模型甚至教师模型本身的性能提升。您可以预期教师模型的准确率在 97.6% 左右,从头训练的学生模型准确率应在 97.6% 左右,而经过蒸馏的学生模型准确率应在 98.1% 左右。移除或尝试不同的随机种子以使用不同的权重初始化。