作者: fchollet
创建日期 2020/04/15
最后修改日期 2023/06/25
描述: Keras 中迁移学习和微调的完整指南。
import numpy as np
import keras
from keras import layers
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
迁移学习是指利用在一个问题上学习到的特征,将其应用于新的、相似的问题。例如,一个已经学会识别浣熊的模型中的特征,可能有助于启动一个旨在识别貉的模型。
迁移学习通常用于数据集数据量太少,无法从头开始训练完整规模模型的情况。
深度学习中迁移学习最常见的形式是以下工作流程
最后一个可选步骤是微调,它包括解冻您上面获得的整个模型(或部分模型),并使用非常低的学习率在新数据上重新训练它。通过逐步调整预训练的特征以适应新数据,这可能会实现有意义的改进。
首先,我们将详细介绍 Keras 的 trainable
API,它是大多数迁移学习和微调工作流程的基础。
然后,我们将通过采用在 ImageNet 数据集上预训练的模型,并在 Kaggle 的“猫狗分类”数据集上重新训练它来演示典型的工作流程。
这是改编自 Deep Learning with Python 和 2016 年的博文 “使用非常少的数据构建强大的图像分类模型”。
trainable
属性层和模型具有三个权重属性
weights
是该层所有权重变量的列表。trainable_weights
是那些旨在通过(梯度下降)更新以在训练期间最小化损失的权重列表。non_trainable_weights
是那些不打算进行训练的权重列表。通常,它们在正向传播期间由模型更新。示例:Dense
层有 2 个可训练的权重(内核和偏置)
layer = keras.layers.Dense(3)
layer.build((None, 4)) # Create the weights
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0
通常,所有权重都是可训练的权重。唯一具有不可训练权重的内置层是 BatchNormalization
层。它使用不可训练的权重来跟踪其输入在训练期间的均值和方差。要了解如何在您自己的自定义层中使用不可训练的权重,请参阅从头开始编写新层的指南。
示例:BatchNormalization
层有 2 个可训练的权重和 2 个不可训练的权重
layer = keras.layers.BatchNormalization()
layer.build((None, 4)) # Create the weights
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2
层和模型还具有布尔属性 trainable
。其值可以更改。将 layer.trainable
设置为 False
会将该层的所有权重从可训练权重移动到不可训练权重。这称为“冻结”该层:冻结层的状态在训练期间不会更新(无论是在使用 fit()
进行训练还是在使用任何依赖 trainable_weights
应用梯度更新的自定义循环进行训练时)。
示例:将 trainable
设置为 False
layer = keras.layers.Dense(3)
layer.build((None, 4)) # Create the weights
layer.trainable = False # Freeze the layer
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2
当一个可训练的权重变为不可训练时,它的值在训练期间不再更新。
# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])
# Freeze the first layer
layer1.trainable = False
# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()
# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))
# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 766ms/step - loss: 0.0615
不要将 layer.trainable
属性与 layer.__call__()
中的参数 training
混淆(它控制该层是否应在推理模式还是训练模式下运行其正向传播)。有关更多信息,请参阅Keras FAQ。
trainable
属性的递归设置如果您在模型或任何具有子层的层上设置 trainable = False
,则所有子层也会变为不可训练。
示例
inner_model = keras.Sequential(
[
keras.Input(shape=(3,)),
keras.layers.Dense(3, activation="relu"),
keras.layers.Dense(3, activation="relu"),
]
)
model = keras.Sequential(
[
keras.Input(shape=(3,)),
inner_model,
keras.layers.Dense(3, activation="sigmoid"),
]
)
model.trainable = False # Freeze the outer model
assert inner_model.trainable == False # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False # `trainable` is propagated recursively
这使我们了解了如何在 Keras 中实现典型的迁移学习工作流程
trainable = False
来冻结基础模型中的所有层。请注意,另一种更轻量级的工作流程也可以是
第二个工作流程的一个主要优点是您只需对您的数据运行一次基础模型,而不是每个训练时期运行一次。因此,它更快更便宜。
不过,第二个工作流程的一个问题是,它不允许您在训练期间动态修改新模型的输入数据,例如,在进行数据增强时这是必需的。当您的新数据集的数据太少而无法从头开始训练完整规模的模型时,通常会使用迁移学习,在这种情况下,数据增强非常重要。因此,在下文中,我们将重点关注第一个工作流程。
这是第一个工作流程在 Keras 中的样子
首先,实例化一个具有预训练权重的基础模型。
base_model = keras.applications.Xception(
weights='imagenet', # Load weights pre-trained on ImageNet.
input_shape=(150, 150, 3),
include_top=False) # Do not include the ImageNet classifier at the top.
然后,冻结基础模型。
base_model.trainable = False
在顶部创建一个新模型。
inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
在新数据上训练模型。
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)
一旦您的模型在新数据上收敛,您可以尝试解冻基础模型的全部或部分,并以非常低的学习率端到端地重新训练整个模型。
这是一个可选的最后一步,可能会给您带来增量改进。它也可能导致快速过度拟合 – 请记住这一点。
至关重要的是,只有在冻结层的模型训练收敛后才执行此步骤。如果您将随机初始化的可训练层与保存预训练特征的可训练层混合,则随机初始化的层会在训练期间导致非常大的梯度更新,从而破坏您的预训练特征。
在此阶段使用非常低的学习率也至关重要,因为您正在训练比第一轮训练更大的模型,而数据集通常非常小。因此,如果您应用较大的权重更新,则您有快速过度拟合的风险。在这里,您只想以增量方式重新调整预训练的权重。
这是如何实现整个基础模型的微调
# Unfreeze the base model
base_model.trainable = True
# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5), # Very low learning rate
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()])
# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)
关于 compile()
和 trainable
的重要说明
在模型上调用 compile()
意味着“冻结”该模型的行为。这意味着在编译模型时 trainable
属性值应在模型整个生命周期中保留,直到再次调用 compile
为止。因此,如果您更改任何 trainable
值,请确保在您的模型上再次调用 compile()
以使您的更改生效。
关于 BatchNormalization
层的重要说明
许多图像模型包含 BatchNormalization
层。该层在各种方面都是一个特例。以下是一些需要注意的事项。
BatchNormalization
包含 2 个不可训练的权重,这些权重在训练期间会更新。这些是跟踪输入均值和方差的变量。bn_layer.trainable = False
时,BatchNormalization
层将以推理模式运行,并且不会更新其均值和方差统计信息。通常情况下,其他层并非如此,因为 权重可训练性和推理/训练模式是两个正交的概念。但对于 BatchNormalization
层,这两者是关联的。BatchNormalization
层的模型时,您应该通过在调用基础模型时传递 training=False
来保持 BatchNormalization
层处于推理模式。否则,应用于不可训练权重的更新会突然破坏模型已学习的内容。您将在本指南末尾的端到端示例中看到此模式的实际应用。
为了巩固这些概念,让我们带您了解一个具体的端到端迁移学习和微调示例。我们将加载在 ImageNet 上预训练的 Xception 模型,并将其用于 Kaggle “猫 vs. 狗” 分类数据集。
首先,让我们使用 TFDS 获取猫狗数据集。如果您有自己的数据集,您可能需要使用实用程序 keras.utils.image_dataset_from_directory
从磁盘上存储在特定类别文件夹中的一组图像生成类似的标记数据集对象。
当处理非常小的数据集时,迁移学习最有用。为了保持我们的数据集较小,我们将使用原始训练数据(25,000 张图像)的 40% 用于训练,10% 用于验证,以及 10% 用于测试。
tfds.disable_progress_bar()
train_ds, validation_ds, test_ds = tfds.load(
"cats_vs_dogs",
# Reserve 10% for validation and 10% for test
split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
as_supervised=True, # Include labels
)
print(f"Number of training samples: {train_ds.cardinality()}")
print(f"Number of validation samples: {validation_ds.cardinality()}")
print(f"Number of test samples: {test_ds.cardinality()}")
Downloading and preparing dataset 786.68 MiB (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/mattdangerw/tensorflow_datasets/cats_vs_dogs/4.0.0...
WARNING:absl:1738 images were corrupted and were skipped
Dataset cats_vs_dogs downloaded and prepared to /home/mattdangerw/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326
这些是训练数据集中的前 9 张图像 – 如您所见,它们的大小都不同。
plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image)
plt.title(int(label))
plt.axis("off")
我们还可以看到标签 1 是“狗”,标签 0 是“猫”。
我们的原始图像具有多种尺寸。此外,每个像素由 3 个介于 0 和 255 之间的整数值(RGB 级别值)组成。这不太适合输入神经网络。我们需要做两件事:
Normalization
层作为模型本身的一部分来执行此操作。一般来说,开发以原始数据作为输入的模型是一个好习惯,而不是开发以预处理数据作为输入的模型。原因在于,如果您的模型期望预处理数据,那么每当您导出模型以在其他地方(在 Web 浏览器、移动应用中)使用时,您都需要重新实现完全相同的预处理流程。这会很快变得非常棘手。因此,在输入模型之前,我们应该尽可能少地进行预处理。
在这里,我们将在数据管道中进行图像大小调整(因为深度神经网络只能处理连续的数据批次),并且我们将在创建模型时将其作为模型的一部分进行输入值缩放。
让我们将图像大小调整为 150x150
resize_fn = keras.layers.Resizing(150, 150)
train_ds = train_ds.map(lambda x, y: (resize_fn(x), y))
validation_ds = validation_ds.map(lambda x, y: (resize_fn(x), y))
test_ds = test_ds.map(lambda x, y: (resize_fn(x), y))
当您没有大型图像数据集时,通过对训练图像应用随机但真实的转换(例如随机水平翻转或小的随机旋转),人为地引入样本多样性是一个好习惯。这有助于模型暴露于训练数据的不同方面,同时减缓过拟合。
augmentation_layers = [
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
]
def data_augmentation(x):
for layer in augmentation_layers:
x = layer(x)
return x
train_ds = train_ds.map(lambda x, y: (data_augmentation(x), y))
让我们对数据进行批处理,并使用预取来优化加载速度。
from tensorflow import data as tf_data
batch_size = 64
train_ds = train_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()
validation_ds = validation_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()
test_ds = test_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()
让我们可视化第一批的第一个图像在各种随机转换后的样子
for images, labels in train_ds.take(1):
plt.figure(figsize=(10, 10))
first_image = images[0]
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
augmented_image = data_augmentation(np.expand_dims(first_image, 0))
plt.imshow(np.array(augmented_image[0]).astype("int32"))
plt.title(int(labels[0]))
plt.axis("off")
现在让我们构建一个遵循我们之前解释的蓝图的模型。
请注意:
Rescaling
层,将输入值(最初在 [0, 255]
范围内)缩放到 [-1, 1]
范围。Dropout
层,用于正则化。training=False
,以便它以推理模式运行,这样即使我们在微调时解冻基础模型后,批归一化统计信息也不会更新。base_model = keras.applications.Xception(
weights="imagenet", # Load weights pre-trained on ImageNet.
input_shape=(150, 150, 3),
include_top=False,
) # Do not include the ImageNet classifier at the top.
# Freeze the base_model
base_model.trainable = False
# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(inputs)
# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x) # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
model.summary(show_trainable=True)
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83683744/83683744 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
Model: "functional_4"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Trai… ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━┩ │ input_layer_4 (InputLayer) │ (None, 150, 150, 3) │ 0 │ - │ ├─────────────────────────────┼──────────────────────────┼─────────┼───────┤ │ rescaling (Rescaling) │ (None, 150, 150, 3) │ 0 │ - │ ├─────────────────────────────┼──────────────────────────┼─────────┼───────┤ │ xception (Functional) │ (None, 5, 5, 2048) │ 20,861… │ N │ ├─────────────────────────────┼──────────────────────────┼─────────┼───────┤ │ global_average_pooling2d │ (None, 2048) │ 0 │ - │ │ (GlobalAveragePooling2D) │ │ │ │ ├─────────────────────────────┼──────────────────────────┼─────────┼───────┤ │ dropout (Dropout) │ (None, 2048) │ 0 │ - │ ├─────────────────────────────┼──────────────────────────┼─────────┼───────┤ │ dense_7 (Dense) │ (None, 1) │ 2,049 │ Y │ └─────────────────────────────┴──────────────────────────┴─────────┴───────┘
Total params: 20,863,529 (79.59 MB)
Trainable params: 2,049 (8.00 KB)
Non-trainable params: 20,861,480 (79.58 MB)
model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()],
)
epochs = 2
print("Fitting the top layer of the model")
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Fitting the top layer of the model
Epoch 1/2
78/146 ━━━━━━━━━━[37m━━━━━━━━━━ 15s 226ms/step - binary_accuracy: 0.7995 - loss: 0.4088
Corrupt JPEG data: 65 extraneous bytes before marker 0xd9
136/146 ━━━━━━━━━━━━━━━━━━[37m━━ 2s 231ms/step - binary_accuracy: 0.8430 - loss: 0.3298
Corrupt JPEG data: 239 extraneous bytes before marker 0xd9
143/146 ━━━━━━━━━━━━━━━━━━━[37m━ 0s 231ms/step - binary_accuracy: 0.8464 - loss: 0.3235
Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9
144/146 ━━━━━━━━━━━━━━━━━━━[37m━ 0s 231ms/step - binary_accuracy: 0.8468 - loss: 0.3226
Corrupt JPEG data: 228 extraneous bytes before marker 0xd9
146/146 ━━━━━━━━━━━━━━━━━━━━ 0s 260ms/step - binary_accuracy: 0.8478 - loss: 0.3209
Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9
146/146 ━━━━━━━━━━━━━━━━━━━━ 54s 317ms/step - binary_accuracy: 0.8482 - loss: 0.3200 - val_binary_accuracy: 0.9667 - val_loss: 0.0877
Epoch 2/2
146/146 ━━━━━━━━━━━━━━━━━━━━ 7s 51ms/step - binary_accuracy: 0.9483 - loss: 0.1232 - val_binary_accuracy: 0.9705 - val_loss: 0.0786
<keras.src.callbacks.history.History at 0x7fc8b7f1db70>
最后,让我们解冻基础模型,并以较低的学习率端到端地训练整个模型。
重要的是,尽管基础模型变得可训练,但它仍然以推理模式运行,因为我们在构建模型时调用它时传递了 training=False
。这意味着内部的批归一化层不会更新其批次统计信息。如果它们这样做,它们会破坏模型目前所学习的表示。
# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary(show_trainable=True)
model.compile(
optimizer=keras.optimizers.Adam(1e-5), # Low learning rate
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()],
)
epochs = 1
print("Fitting the end-to-end model")
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "functional_4"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Trai… ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━┩ │ input_layer_4 (InputLayer) │ (None, 150, 150, 3) │ 0 │ - │ ├─────────────────────────────┼──────────────────────────┼─────────┼───────┤ │ rescaling (Rescaling) │ (None, 150, 150, 3) │ 0 │ - │ ├─────────────────────────────┼──────────────────────────┼─────────┼───────┤ │ xception (Functional) │ (None, 5, 5, 2048) │ 20,861… │ Y │ ├─────────────────────────────┼──────────────────────────┼─────────┼───────┤ │ global_average_pooling2d │ (None, 2048) │ 0 │ - │ │ (GlobalAveragePooling2D) │ │ │ │ ├─────────────────────────────┼──────────────────────────┼─────────┼───────┤ │ dropout (Dropout) │ (None, 2048) │ 0 │ - │ ├─────────────────────────────┼──────────────────────────┼─────────┼───────┤ │ dense_7 (Dense) │ (None, 1) │ 2,049 │ Y │ └─────────────────────────────┴──────────────────────────┴─────────┴───────┘
Total params: 20,867,629 (79.60 MB)
Trainable params: 20,809,001 (79.38 MB)
Non-trainable params: 54,528 (213.00 KB)
Optimizer params: 4,100 (16.02 KB)
Fitting the end-to-end model
146/146 ━━━━━━━━━━━━━━━━━━━━ 75s 327ms/step - binary_accuracy: 0.8487 - loss: 0.3760 - val_binary_accuracy: 0.9494 - val_loss: 0.1160
<keras.src.callbacks.history.History at 0x7fcd1c755090>
经过 10 个 epoch 后,微调在这里为我们带来了不错的改进。让我们在测试数据集上评估模型
print("Test dataset evaluation")
model.evaluate(test_ds)
Test dataset evaluation
11/37 ━━━━━[37m━━━━━━━━━━━━━━━ 1s 52ms/step - binary_accuracy: 0.9407 - loss: 0.1155
Corrupt JPEG data: 99 extraneous bytes before marker 0xd9
37/37 ━━━━━━━━━━━━━━━━━━━━ 2s 47ms/step - binary_accuracy: 0.9427 - loss: 0.1259
[0.13755160570144653, 0.941300630569458]