作者: Sayak PaulSachin Prasad
创建日期 2021/03/13
最后修改日期 2023/12/12
描述: RandAugment 用于训练具有改进鲁棒性的图像分类模型。
数据增强是一种非常有用的技术,可以帮助提高卷积神经网络 (CNN) 的平移不变性。RandAugment 是一种用于视觉数据的随机数据增强例程,并在 RandAugment: Practical automated data augmentation with a reduced search space 中提出。它由强大的增强变换(如颜色抖动、高斯模糊、饱和度等)以及更传统的增强变换(如随机裁剪)组成。
这些参数是针对给定的数据集和网络架构进行调整的。RandAugment 的作者还在原始论文(图 2)中提供了 RandAugment 的伪代码。
最近,它已成为 Noisy Student Training 和 Unsupervised Data Augmentation for Consistency Training 等工作的关键组成部分。它也是 EfficientNets 成功的核心。
pip install keras-cv
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import keras_cv
from keras import ops
from keras import layers
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
keras.utils.set_random_seed(42)
在此示例中,我们将使用 CIFAR10 数据集。
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
print(f"Total training examples: {len(x_train)}")
print(f"Total test examples: {len(x_test)}")
Total training examples: 50000
Total test examples: 10000
AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 128
EPOCHS = 1
IMAGE_SIZE = 72
RandAugment 对象现在,我们将从 imgaug.augmenters 模块初始化一个 RandAugment 对象,并使用 RandAugment 作者建议的参数。
rand_augment = keras_cv.layers.RandAugment(
value_range=(0, 255), augmentations_per_image=3, magnitude=0.8
)
Dataset 对象train_ds_rand = (
tf.data.Dataset.from_tensor_slices((x_train, y_train))
.shuffle(BATCH_SIZE * 100)
.batch(BATCH_SIZE)
.map(
lambda x, y: (tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE)), y),
num_parallel_calls=AUTO,
)
.map(
lambda x, y: (rand_augment(tf.cast(x, tf.uint8)), y),
num_parallel_calls=AUTO,
)
.prefetch(AUTO)
)
test_ds = (
tf.data.Dataset.from_tensor_slices((x_test, y_test))
.batch(BATCH_SIZE)
.map(
lambda x, y: (tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE)), y),
num_parallel_calls=AUTO,
)
.prefetch(AUTO)
)
为了进行比较,我们还定义了一个简单的增强管道,它由随机翻转、随机旋转和随机缩放组成。
simple_aug = keras.Sequential(
[
layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
layers.RandomFlip("horizontal"),
layers.RandomRotation(factor=0.02),
layers.RandomZoom(height_factor=0.2, width_factor=0.2),
]
)
# Now, map the augmentation pipeline to our training dataset
train_ds_simple = (
tf.data.Dataset.from_tensor_slices((x_train, y_train))
.shuffle(BATCH_SIZE * 100)
.batch(BATCH_SIZE)
.map(lambda x, y: (simple_aug(x), y), num_parallel_calls=AUTO)
.prefetch(AUTO)
)
sample_images, _ = next(iter(train_ds_rand))
plt.figure(figsize=(10, 10))
for i, image in enumerate(sample_images[:9]):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image.numpy().astype("int"))
plt.axis("off")

鼓励您运行几次上面的代码块,以查看不同的变化。
simple_aug 增强的数据集sample_images, _ = next(iter(train_ds_simple))
plt.figure(figsize=(10, 10))
for i, image in enumerate(sample_images[:9]):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image.numpy().astype("int"))
plt.axis("off")

现在,我们定义一个基于 ResNet50V2 架构的 CNN 模型。另外,请注意网络内部已经有一个重缩放层。这消除了对我们的数据集进行任何单独预处理的需要,并且对于部署目的特别有用。
def get_training_model():
resnet50_v2 = keras.applications.ResNet50V2(
weights=None,
include_top=True,
input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
classes=10,
)
model = keras.Sequential(
[
layers.Input((IMAGE_SIZE, IMAGE_SIZE, 3)),
layers.Rescaling(scale=1.0 / 127.5, offset=-1),
resnet50_v2,
]
)
return model
get_training_model().summary()
Model: "sequential_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ rescaling (Rescaling) │ (None, 72, 72, 3) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ resnet50v2 (Functional) │ (None, 10) │ 23,585,290 │ └─────────────────────────────────┴───────────────────────────┴────────────┘
Total params: 23,585,290 (89.97 MB)
Trainable params: 23,539,850 (89.80 MB)
Non-trainable params: 45,440 (177.50 KB)
我们将在我们数据集的两个不同版本上训练这个网络
simple_aug 进行增强。由于 RandAugment 以增强模型对常见扰动和腐化的鲁棒性而闻名,因此我们还将评估模型在 CIFAR-10-C 数据集上的性能,该数据集由 Hendrycks 等人在 Benchmarking Neural Network Robustness to Common Corruptions and Perturbations 中提出。CIFAR-10-C 数据集包含 19 种不同的图像腐化和扰动(例如斑点噪声、雾、高斯模糊等),并且具有不同的严重程度。在此示例中,我们将使用以下配置:cifar10_corrupted/saturate_5。此配置下的图像外观如下:

为了确保结果可复现,我们对浅层网络的初始随机权重进行了序列化。
initial_model = get_training_model()
initial_model.save_weights("initial.weights.h5")
rand_aug_model = get_training_model()
rand_aug_model.load_weights("initial.weights.h5")
rand_aug_model.compile(
loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
)
rand_aug_model.fit(train_ds_rand, validation_data=test_ds, epochs=EPOCHS)
_, test_acc = rand_aug_model.evaluate(test_ds)
print("Test accuracy: {:.2f}%".format(test_acc * 100))
391/391 ━━━━━━━━━━━━━━━━━━━━ 1146s 3s/step - accuracy: 0.1677 - loss: 2.3232 - val_accuracy: 0.2818 - val_loss: 1.9966
79/79 ━━━━━━━━━━━━━━━━━━━━ 39s 489ms/step - accuracy: 0.2803 - loss: 2.0073
Test accuracy: 28.18%
simple_aug 训练模型simple_aug_model = get_training_model()
simple_aug_model.load_weights("initial.weights.h5")
simple_aug_model.compile(
loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
)
simple_aug_model.fit(train_ds_simple, validation_data=test_ds, epochs=EPOCHS)
_, test_acc = simple_aug_model.evaluate(test_ds)
print("Test accuracy: {:.2f}%".format(test_acc * 100))
391/391 ━━━━━━━━━━━━━━━━━━━━ 1132s 3s/step - accuracy: 0.3673 - loss: 1.7929 - val_accuracy: 0.4789 - val_loss: 1.4296
79/79 ━━━━━━━━━━━━━━━━━━━━ 39s 494ms/step - accuracy: 0.4762 - loss: 1.4368
Test accuracy: 47.89%
# Load and prepare the CIFAR-10-C dataset
# (If it's not already downloaded, it takes ~10 minutes of time to download)
cifar_10_c = tfds.load("cifar10_corrupted/saturate_5", split="test", as_supervised=True)
cifar_10_c = cifar_10_c.batch(BATCH_SIZE).map(
lambda x, y: (tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE)), y),
num_parallel_calls=AUTO,
)
# Evaluate `rand_aug_model`
_, test_acc = rand_aug_model.evaluate(cifar_10_c, verbose=0)
print(
"Accuracy with RandAugment on CIFAR-10-C (saturate_5): {:.2f}%".format(
test_acc * 100
)
)
# Evaluate `simple_aug_model`
_, test_acc = simple_aug_model.evaluate(cifar_10_c, verbose=0)
print(
"Accuracy with simple_aug on CIFAR-10-C (saturate_5): {:.2f}%".format(
test_acc * 100
)
)
Downloading and preparing dataset 2.72 GiB (download: 2.72 GiB, generated: Unknown size, total: 2.72 GiB) to /home/sachinprasad/tensorflow_datasets/cifar10_corrupted/saturate_5/1.0.0...
Dataset cifar10_corrupted downloaded and prepared to /home/sachinprasad/tensorflow_datasets/cifar10_corrupted/saturate_5/1.0.0. Subsequent calls will reuse this data.
Accuracy with RandAugment on CIFAR-10-C (saturate_5): 30.36%
Accuracy with simple_aug on CIFAR-10-C (saturate_5): 37.18%
为了本示例的目的,我们仅将模型训练了一个 epoch。在 CIFAR-10-C 数据集上,使用 RandAugment 的模型可以比使用 simple_aug 训练的模型具有更高的准确率(例如,在一项实验中为 76.64%),而使用 simple_aug 的模型(例如,为 64.80%)。RandAugment 也有助于稳定训练。
在 notebook 中,您可能会注意到,尽管使用 RandAugment 会增加训练时间,但我们能够在 CIFAR-10-C 数据集上获得更好的性能。您可以尝试使用运行相同的 CIFAR-10-C 数据集附带的其他腐化和扰动设置,看看 RandAugment 是否有所帮助。
您还可以尝试 RandAugment 对象中 n 和 m 的不同值。在 原始论文中,作者们展示了单个增强变换对特定任务的影响以及一系列消融研究。欢迎您进行查阅。
RandAugment 在提高计算机视觉深度模型的鲁棒性方面取得了巨大进展,正如 Noisy Student Training 和 FixMatch 等工作中所示。这使得 RandAugment 成为训练不同视觉模型的一个非常有用的方法。
您可以使用托管在 Hugging Face Hub 上的训练模型,并在 Hugging Face Spaces 上尝试演示。