开发者指南 / 迁移学习与微调

迁移学习与微调

作者: fchollet
创建日期 2020/04/15
最后修改日期 2023/06/25
描述: Keras 中迁移学习与微调的完整指南。

在 Colab 中查看 GitHub 源代码


设置

import numpy as np
import keras
from keras import layers
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

简介

迁移学习 是指采用在一个问题上学到的特征,并将其应用于一个新的、相似的问题。例如,从一个识别浣熊的模型中学到的特征,可能有助于启动一个用于识别狸猫的模型。

迁移学习通常用于您数据集不足以从头开始训练完整模型的情况。

在深度学习的背景下,最常见的迁移学习形式是以下工作流程:

  1. 从先前训练过的模型中提取层。
  2. 冻结它们,以避免在将来的训练轮次中破坏它们包含的任何信息。
  3. 在冻结的层之上添加一些新的、可训练的层。这些层将学习将旧特征转化为新数据集上的预测。
  4. 在新数据集上训练新层。

最后一步是可选的 微调,它包括解冻您上面得到的整个模型(或其中一部分),并以非常低的学习率在新数据上重新训练它。这可能会通过逐步适应预训练特征以适应新数据来取得有意义的改进。

首先,我们将详细介绍 Keras 的 `trainable` API,它构成了大多数迁移学习和微调工作流程的基础。

然后,我们将通过使用在 ImageNet 数据集上预训练的模型,并在 Kaggle 的“猫狗”分类数据集上重新训练它来演示典型的工作流程。

这改编自 《Python 深度学习》 和 2016 年的博客文章 “使用极少数据构建强大的图像分类模型”


冻结层:理解 `trainable` 属性

层和模型有三个权重属性:

  • weights 是层所有权重变量的列表。
  • trainable_weights 是那些旨在通过(梯度下降)更新以在训练期间最小化损失的权重列表。
  • non_trainable_weights 是那些不打算训练的权重列表。通常它们在前向传播过程中由模型更新。

示例:`Dense` 层有两个可训练权重(核和偏置)。

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

通常,所有权重都是可训练的。唯一内置的具有不可训练权重的层是 `BatchNormalization` 层。它使用不可训练的权重来跟踪训练期间输入的均值和方差。要了解如何在自己的自定义层中使用不可训练权重,请参阅 从头开始编写新层指南

示例:`BatchNormalization` 层有 2 个可训练权重和 2 个不可训练权重。

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

层和模型还具有一个布尔属性 `trainable`。可以更改其值。将 `layer.trainable` 设置为 `False` 会将该层的所有权重从可训练移至不可训练。这被称为“冻结”该层:冻结的层的状态在训练期间不会被更新(无论是使用 `fit()` 训练还是使用任何依赖于 `trainable_weights` 来应用梯度更新的自定义循环)。

示例:将 `trainable` 设置为 `False`

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

当一个可训练权重变为不可训练时,其值在训练期间将不再被更新。

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 766ms/step - loss: 0.0615

请勿将 `layer.trainable` 属性与 `layer.__call__()` 中的 `training` 参数混淆(该参数控制层是应以推理模式还是训练模式运行其前向传播)。有关更多信息,请参阅 Keras FAQ


递归设置 `trainable` 属性

如果您在模型或任何具有子层的层上设置 `trainable = False`,则所有子层也将变为不可训练。

示例

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        inner_model,
        keras.layers.Dense(3, activation="sigmoid"),
    ]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

典型的迁移学习工作流程

这引出了如何在 Keras 中实现典型的迁移学习工作流程。

  1. 实例化一个基础模型并加载预训练的权重到其中。
  2. 通过设置 `trainable = False` 来冻结基础模型中的所有层。
  3. 在基础模型的一个(或多个)层的输出之上创建一个新模型。
  4. 在新数据集上训练您的新模型。

请注意,一个更轻量级的替代工作流程也可以是:

  1. 实例化一个基础模型并加载预训练的权重到其中。
  2. 使用您的新数据集运行它,并记录基础模型的一个(或多个)层的输出。这被称为特征提取
  3. 将该输出用作一个新、更小的模型的输入数据。

第二个工作流程的一个关键优势是,您只需要在数据上运行基础模型一次,而不是在每个训练 epoch 中运行一次。因此,它速度更快,成本更低。

然而,第二个工作流程的一个问题是,它不允许您在训练期间动态修改新模型的输入数据,而这在进行数据增强时是必需的。迁移学习通常用于新数据集数据量不足以从头开始训练完整模型的情况,在这些情况下,数据增强非常重要。因此,在接下来的内容中,我们将重点关注第一个工作流程。

以下是第一个工作流程在 Keras 中的实现方式:

首先,实例化一个带有预训练权重的基础模型。

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

然后,冻结基础模型。

base_model.trainable = False

在顶部创建一个新模型。

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

在新数据上训练模型。

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

微调

一旦您的模型在新数据上收敛,您可以尝试解冻基础模型的所有层或部分层,并以非常低的学习率对整个模型进行端到端重新训练。

这是可选的最后一步,可能会带来渐进式的改进。但也要注意,它也可能导致快速过拟合。

只有在冻结层的模型训练到收敛之后,才执行此步骤至关重要。如果您将随机初始化的可训练层与持有预训练特征的可训练层混合在一起,随机初始化的层将在训练期间导致非常大的梯度更新,这将破坏您的预训练特征。

在此阶段使用非常低的学习率也至关重要,因为您正在训练一个比第一轮训练更大的模型,而且数据集通常非常小。因此,如果您应用大的权重更新,您很快就有过拟合的风险。在这里,您只想以渐进的方式重新调整预训练权重的适应性。

这是如何实现整个基础模型微调的:

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

关于 `compile()` 和 `trainable` 的重要说明

在模型上调用 `compile()` 旨在“冻结”该模型的行为。这意味着模型编译时 `trainable` 属性的值应该在该模型整个生命周期内保持不变,直到再次调用 `compile`。因此,如果您更改任何 `trainable` 值,请务必再次调用模型的 `compile()`,以使您的更改生效。

关于 `BatchNormalization` 层的注意事项

许多图像模型都包含 `BatchNormalization` 层。该层在各种情况下都是一个特殊情况。以下是一些需要牢记的事情:

  • `BatchNormalization` 包含 2 个不可训练的权重,它们在训练期间会被更新。这些是跟踪输入均值和方差的变量。
  • 当您设置 `bn_layer.trainable = False` 时,`BatchNormalization` 层将以推理模式运行,并且不会更新其均值和方差统计信息。通常,其他层并非如此,因为 权重可训练性与推理/训练模式是两个正交的概念。但对于 `BatchNormalization` 层来说,这两者是绑定的。
  • 当您解冻包含 `BatchNormalization` 层的模型以进行微调时,您应该通过在调用基础模型时传递 `training=False` 来将 `BatchNormalization` 层保留在推理模式下。否则,对不可训练权重的更新将突然破坏模型迄今为止学到的内容。

您将在本指南结尾的端到端示例中看到此模式的应用。


一个端到端的例子:在猫狗数据集上微调图像分类模型

为了巩固这些概念,我们将通过一个具体的端到端迁移学习和微调示例来介绍。我们将加载在 ImageNet 上预训练的 Xception 模型,并在 Kaggle 的“猫狗”分类数据集上使用它。

获取数据

首先,让我们使用 TFDS 获取猫狗数据集。如果您有自己的数据集,您可能希望使用实用工具 `keras.utils.image_dataset_from_directory` 从磁盘上按类别文件夹组织的图像集合生成类似的有标签数据集对象。

迁移学习在处理非常小的数据集时最有用。为了保持数据集较小,我们将使用原始训练数据的 40%(25,000 张图像)进行训练,10% 用于验证,10% 用于测试。

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print(f"Number of training samples: {train_ds.cardinality()}")
print(f"Number of validation samples: {validation_ds.cardinality()}")
print(f"Number of test samples: {test_ds.cardinality()}")
 Downloading and preparing dataset 786.68 MiB (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/mattdangerw/tensorflow_datasets/cats_vs_dogs/4.0.0...

WARNING:absl:1738 images were corrupted and were skipped

 Dataset cats_vs_dogs downloaded and prepared to /home/mattdangerw/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

这是训练数据集中的前 9 张图像——正如您所见,它们的大小各不相同。

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

我们还可以看到标签 1 是“狗”,标签 0 是“猫”。

标准化数据

我们的原始图像尺寸各不相同。此外,每个像素由 3 个介于 0 和 255 之间的整数值(RGB 色彩值)组成。这不太适合输入神经网络。我们需要做两件事:

  • 标准化到固定图像大小。我们选择 150x150。
  • 将像素值归一化到 -1 和 1 之间。我们将使用 `Normalization` 层作为模型本身的一部分来完成。

通常,开发接受原始数据作为输入的模型,而不是接受已预处理数据的模型,是一个好习惯。原因是,如果您的模型需要预处理数据,那么任何时候您将模型导出到其他地方使用(在浏览器、移动应用中),您都需要重新实现完全相同的预处理管道。这很快就会变得非常棘手。所以我们在模型接收输入之前,应该只做最少量的预处理。

在这里,我们将在数据管道中进行图像大小调整(因为深度神经网络只能处理连续的数据批次),并且在创建模型时,我们将把输入值缩放作为模型的一部分。

让我们将图像大小调整为 150x150。

resize_fn = keras.layers.Resizing(150, 150)

train_ds = train_ds.map(lambda x, y: (resize_fn(x), y))
validation_ds = validation_ds.map(lambda x, y: (resize_fn(x), y))
test_ds = test_ds.map(lambda x, y: (resize_fn(x), y))

使用随机数据增强

当您没有大型图像数据集时,通过对训练图像应用随机但真实的变换,例如随机水平翻转或小的随机旋转,来人为地引入样本多样性,这是一个好习惯。这有助于模型接触训练数据的不同方面,同时减缓过拟合。

augmentation_layers = [
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
]


def data_augmentation(x):
    for layer in augmentation_layers:
        x = layer(x)
    return x


train_ds = train_ds.map(lambda x, y: (data_augmentation(x), y))

让我们对数据进行批处理并使用预取以优化加载速度。

from tensorflow import data as tf_data

batch_size = 64

train_ds = train_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()
validation_ds = validation_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()
test_ds = test_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()

让我们可视化第一批图像的第一个图像经过各种随机变换后的样子。

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(np.expand_dims(first_image, 0))
        plt.imshow(np.array(augmented_image[0]).astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")

png


构建一个模型

现在让我们构建一个遵循我们之前解释过的蓝图的模型。

请注意:

  • 我们添加了一个 `Rescaling` 层来缩放输入值(最初在 `[0, 255]` 范围内)到 `[-1, 1]` 范围。
  • 我们在分类层之前添加了一个 `Dropout` 层,用于正则化。
  • 我们在调用基础模型时,确保传递 `training=False`,以便它以推理模式运行,这样即使在解冻基础模型进行微调后,批处理归一化统计信息也不会被更新。
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(inputs)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary(show_trainable=True)
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
 83683744/83683744 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
Model: "functional_4"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━┓
┃ Layer (type)                 Output Shape              Param #  Trai… ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━┩
│ input_layer_4 (InputLayer)  │ (None, 150, 150, 3)      │       0-   │
├─────────────────────────────┼──────────────────────────┼─────────┼───────┤
│ rescaling (Rescaling)       │ (None, 150, 150, 3)      │       0-   │
├─────────────────────────────┼──────────────────────────┼─────────┼───────┤
│ xception (Functional)       │ (None, 5, 5, 2048)       │ 20,861…N   │
├─────────────────────────────┼──────────────────────────┼─────────┼───────┤
│ global_average_pooling2d    │ (None, 2048)             │       0-   │
│ (GlobalAveragePooling2D)    │                          │         │       │
├─────────────────────────────┼──────────────────────────┼─────────┼───────┤
│ dropout (Dropout)           │ (None, 2048)             │       0-   │
├─────────────────────────────┼──────────────────────────┼─────────┼───────┤
│ dense_7 (Dense)             │ (None, 1)                │   2,049Y   │
└─────────────────────────────┴──────────────────────────┴─────────┴───────┘
 Total params: 20,863,529 (79.59 MB)
 Trainable params: 2,049 (8.00 KB)
 Non-trainable params: 20,861,480 (79.58 MB)

训练顶层

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 2
print("Fitting the top layer of the model")
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Fitting the top layer of the model
Epoch 1/2
  78/146 ━━━━━━━━━━━━━━━━━━━━  15s 226ms/step - binary_accuracy: 0.7995 - loss: 0.4088

Corrupt JPEG data: 65 extraneous bytes before marker 0xd9

 136/146 ━━━━━━━━━━━━━━━━━━━━  2s 231ms/step - binary_accuracy: 0.8430 - loss: 0.3298

Corrupt JPEG data: 239 extraneous bytes before marker 0xd9

 143/146 ━━━━━━━━━━━━━━━━━━━━  0s 231ms/step - binary_accuracy: 0.8464 - loss: 0.3235

Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9

 144/146 ━━━━━━━━━━━━━━━━━━━━  0s 231ms/step - binary_accuracy: 0.8468 - loss: 0.3226

Corrupt JPEG data: 228 extraneous bytes before marker 0xd9

 146/146 ━━━━━━━━━━━━━━━━━━━━ 0s 260ms/step - binary_accuracy: 0.8478 - loss: 0.3209

Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9

 146/146 ━━━━━━━━━━━━━━━━━━━━ 54s 317ms/step - binary_accuracy: 0.8482 - loss: 0.3200 - val_binary_accuracy: 0.9667 - val_loss: 0.0877
Epoch 2/2
 146/146 ━━━━━━━━━━━━━━━━━━━━ 7s 51ms/step - binary_accuracy: 0.9483 - loss: 0.1232 - val_binary_accuracy: 0.9705 - val_loss: 0.0786

<keras.src.callbacks.history.History at 0x7fc8b7f1db70>

进行一轮整个模型的微调

最后,让我们解冻基础模型,并以低学习率对整个模型进行端到端训练。

重要的是,虽然基础模型变得可训练,但由于我们在构建模型时调用它时传递了 `training=False`,它仍然以推理模式运行。这意味着其中的批处理归一化层不会更新其批处理统计信息。如果它们更新了,它们将严重破坏模型迄今为止学到的表示。

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary(show_trainable=True)

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 1
print("Fitting the end-to-end model")
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "functional_4"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━┓
┃ Layer (type)                 Output Shape              Param #  Trai… ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━┩
│ input_layer_4 (InputLayer)  │ (None, 150, 150, 3)      │       0-   │
├─────────────────────────────┼──────────────────────────┼─────────┼───────┤
│ rescaling (Rescaling)       │ (None, 150, 150, 3)      │       0-   │
├─────────────────────────────┼──────────────────────────┼─────────┼───────┤
│ xception (Functional)       │ (None, 5, 5, 2048)       │ 20,861…Y   │
├─────────────────────────────┼──────────────────────────┼─────────┼───────┤
│ global_average_pooling2d    │ (None, 2048)             │       0-   │
│ (GlobalAveragePooling2D)    │                          │         │       │
├─────────────────────────────┼──────────────────────────┼─────────┼───────┤
│ dropout (Dropout)           │ (None, 2048)             │       0-   │
├─────────────────────────────┼──────────────────────────┼─────────┼───────┤
│ dense_7 (Dense)             │ (None, 1)                │   2,049Y   │
└─────────────────────────────┴──────────────────────────┴─────────┴───────┘
 Total params: 20,867,629 (79.60 MB)
 Trainable params: 20,809,001 (79.38 MB)
 Non-trainable params: 54,528 (213.00 KB)
 Optimizer params: 4,100 (16.02 KB)
Fitting the end-to-end model
 146/146 ━━━━━━━━━━━━━━━━━━━━ 75s 327ms/step - binary_accuracy: 0.8487 - loss: 0.3760 - val_binary_accuracy: 0.9494 - val_loss: 0.1160

<keras.src.callbacks.history.History at 0x7fcd1c755090>

经过 10 个 epoch 的训练,微调在这里带来了显著的改进。让我们在测试数据集上评估模型。

print("Test dataset evaluation")
model.evaluate(test_ds)
Test dataset evaluation
 11/37 ━━━━━━━━━━━━━━━━━━━━  1s 52ms/step - binary_accuracy: 0.9407 - loss: 0.1155

Corrupt JPEG data: 99 extraneous bytes before marker 0xd9

 37/37 ━━━━━━━━━━━━━━━━━━━━ 2s 47ms/step - binary_accuracy: 0.9427 - loss: 0.1259

[0.13755160570144653, 0.941300630569458]