代码示例 / 计算机视觉 / 使用可组合全卷积网络进行图像分割

使用可组合全卷积网络进行图像分割

作者: Suvaditya Mukherjee
创建日期 2023/06/16
上次修改日期 2023/12/25
描述:使用全卷积网络进行图像分割。

ⓘ 此示例使用 Keras 3

在 Colab 中查看 GitHub 源代码


简介

以下示例逐步介绍了在 Oxford-IIIT Pets 数据集上使用全卷积网络进行图像分割的步骤。该模型在论文 Long 等人 (2014) 的《用于语义分割的全卷积网络》 中提出。图像分割是计算机视觉中最常见和入门级的任务之一,我们将图像分类问题从每个图像一个标签扩展到像素级分类问题。在本例中,我们将组装上述全卷积分割架构,使其能够执行图像分割。该网络扩展了 VGG 中的池化层输出,以便执行上采样并获得最终结果。从 VGG19 的第 3、第 4 和第 5 个最大池化层输出的中间输出被提取出来,并在不同的级别和因子下进行上采样,以获得与输出形状相同的最终输出,但在每个位置都有每个像素的类别,而不是像素强度值。对不同的中间池层进行提取和处理以获得网络的不同版本。FCN 架构有 3 个不同质量的版本。

  • FCN-32S
  • FCN-16S
  • FCN-8S

所有版本的模型都通过对所用主干的连续中间池化层的迭代处理来导出其输出。从下图可以更好地了解这一点。

FCN Architecture
图 1:组合架构版本(来源:论文)

要更好地了解图像分割或查找更多预训练模型,请随时浏览 Hugging Face 图像分割模型 页面或 PyImageSearch 博客上的语义分割


设置导入

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
from keras import ops
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import numpy as np

AUTOTUNE = tf.data.AUTOTUNE

为笔记本变量设置配置

我们为实验设置了所需的参数。所选数据集每个图像共有 4 个类别,涉及分割掩码。我们还在此单元格中设置了超参数。

混合精度作为一种选项,在支持它的系统中也可用,以减少负载。这将使大多数张量使用16位浮点数值,而不是32位浮点数值,在不会对计算产生负面影响的地方。这意味着,在计算过程中,TensorFlow将使用16位浮点数张量来提高速度,但会以精度为代价,同时以其原始默认的32位浮点数形式存储值。

NUM_CLASSES = 4
INPUT_HEIGHT = 224
INPUT_WIDTH = 224
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
EPOCHS = 20
BATCH_SIZE = 32
MIXED_PRECISION = True
SHUFFLE = True

# Mixed-precision setting
if MIXED_PRECISION:
    policy = keras.mixed_precision.Policy("mixed_float16")
    keras.mixed_precision.set_global_policy(policy)
INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: Quadro RTX 5000, compute capability 7.5

加载数据集

我们使用Oxford-IIIT Pets数据集,其中包含总共 7,349 个样本及其分割掩码。我们有 37 个类别,每个类别大约有 200 个样本。我们的训练和验证数据集分别有 3,128 和 552 个样本。除此之外,我们的测试集总共有 3,669 个样本。

我们设置一个batch_size参数,将我们的样本批量在一起,使用shuffle参数将我们的样本混合在一起。

(train_ds, valid_ds, test_ds) = tfds.load(
    "oxford_iiit_pet",
    split=["train[:85%]", "train[85%:]", "test"],
    batch_size=BATCH_SIZE,
    shuffle_files=SHUFFLE,
)

解压和预处理数据集

我们定义一个简单的函数,其中包括对我们的训练、验证和测试数据集执行调整大小操作。我们对掩码也执行相同的过程,以确保两者在形状和大小方面保持一致。

# Image and Mask Pre-processing
def unpack_resize_data(section):
    image = section["image"]
    segmentation_mask = section["segmentation_mask"]

    resize_layer = keras.layers.Resizing(INPUT_HEIGHT, INPUT_WIDTH)

    image = resize_layer(image)
    segmentation_mask = resize_layer(segmentation_mask)

    return image, segmentation_mask


train_ds = train_ds.map(unpack_resize_data, num_parallel_calls=AUTOTUNE)
valid_ds = valid_ds.map(unpack_resize_data, num_parallel_calls=AUTOTUNE)
test_ds = test_ds.map(unpack_resize_data, num_parallel_calls=AUTOTUNE)

可视化预处理数据集中的一个随机样本

我们可视化数据集测试集中随机样本的样子,并在其上绘制分割掩码以查看有效的掩码区域。请注意,我们也对该数据集进行了预处理,这使得图像和掩码大小相同。

# Select random image and mask. Cast to NumPy array
# for Matplotlib visualization.

images, masks = next(iter(test_ds))
random_idx = keras.random.uniform([], minval=0, maxval=BATCH_SIZE, seed=10)

test_image = images[int(random_idx)].numpy().astype("float")
test_mask = masks[int(random_idx)].numpy().astype("float")

# Overlay segmentation mask on top of image.
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))

ax[0].set_title("Image")
ax[0].imshow(test_image / 255.0)

ax[1].set_title("Image with segmentation mask overlay")
ax[1].imshow(test_image / 255.0)
ax[1].imshow(
    test_mask,
    cmap="inferno",
    alpha=0.6,
)
plt.show()

png


执行 VGG 特定的预处理

keras.applications.VGG19 需要使用一个preprocess_input函数,该函数将主动执行 Image-net 风格的标准偏差归一化方案。

def preprocess_data(image, segmentation_mask):
    image = keras.applications.vgg19.preprocess_input(image)

    return image, segmentation_mask


train_ds = (
    train_ds.map(preprocess_data, num_parallel_calls=AUTOTUNE)
    .shuffle(buffer_size=1024)
    .prefetch(buffer_size=1024)
)
valid_ds = (
    valid_ds.map(preprocess_data, num_parallel_calls=AUTOTUNE)
    .shuffle(buffer_size=1024)
    .prefetch(buffer_size=1024)
)
test_ds = (
    test_ds.map(preprocess_data, num_parallel_calls=AUTOTUNE)
    .shuffle(buffer_size=1024)
    .prefetch(buffer_size=1024)
)

模型定义

全卷积网络拥有一个简单的架构,仅由keras.layers.Conv2D层、keras.layers.Dense层和keras.layers.Dropout层组成。

FCN Architecture
图 2:通用 FCN 前向传播(来源:论文)

像素级预测是通过具有与图像相同大小的 Softmax 卷积层来执行的,这样我们就可以执行直接比较。我们可以找到网络上的几个重要指标,例如准确率和平均交并比。

骨干网络 (VGG-19)

我们使用VGG-19 网络作为骨干网络,因为论文建议它是该网络最有效的骨干网络之一。我们通过使用keras.models.Model从网络中提取不同的输出。在此之后,我们在顶部添加层以创建一个完全模拟图 1 的网络。骨干网络的keras.layers.Dense层将根据此处提供的原始 Caffe 代码转换为keras.layers.Conv2D层。所有 3 个网络将共享相同的骨干网络权重,但其结果将根据其扩展而有所不同。我们使骨干网络不可训练以提高训练时间要求。论文中也指出,使网络可训练不会带来主要好处。

input_layer = keras.Input(shape=(INPUT_HEIGHT, INPUT_WIDTH, 3))

# VGG Model backbone with pre-trained ImageNet weights.
vgg_model = keras.applications.vgg19.VGG19(include_top=True, weights="imagenet")

# Extracting different outputs from same model
fcn_backbone = keras.models.Model(
    inputs=vgg_model.layers[1].input,
    outputs=[
        vgg_model.get_layer(block_name).output
        for block_name in ["block3_pool", "block4_pool", "block5_pool"]
    ],
)

# Setting backbone to be non-trainable
fcn_backbone.trainable = False

x = fcn_backbone(input_layer)

# Converting Dense layers to Conv2D layers
units = [4096, 4096]
dense_convs = []

for filter_idx in range(len(units)):
    dense_conv = keras.layers.Conv2D(
        filters=units[filter_idx],
        kernel_size=(7, 7) if filter_idx == 0 else (1, 1),
        strides=(1, 1),
        activation="relu",
        padding="same",
        use_bias=False,
        kernel_initializer=keras.initializers.Constant(1.0),,
    )
    dense_convs.append(dense_conv)
    dropout_layer = keras.layers.Dropout(0.5)
    dense_convs.append(dropout_layer)

dense_convs = keras.Sequential(dense_convs)
dense_convs.trainable = False

x[-1] = dense_convs(x[-1])

pool3_output, pool4_output, pool5_output = x

FCN-32S

我们扩展最后一个输出,执行1x1 卷积,并通过 32 倍的 2D 双线性上采样来获得与输入图像大小相同的图像。我们对keras.layers.Conv2DTranspose使用简单的keras.layers.UpSampling2D层,因为它作为确定性数学运算而不是卷积运算而带来了性能优势。论文中也指出,使上采样参数可训练不会带来好处。论文的原始实验也使用了上采样。

# 1x1 convolution to set channels = number of classes
pool5 = keras.layers.Conv2D(
    filters=NUM_CLASSES,
    kernel_size=(1, 1),
    padding="same",
    strides=(1, 1),
    activation="relu",
)

# Get Softmax outputs for all classes
fcn32s_conv_layer = keras.layers.Conv2D(
    filters=NUM_CLASSES,
    kernel_size=(1, 1),
    activation="softmax",
    padding="same",
    strides=(1, 1),
)

# Up-sample to original image size
fcn32s_upsampling = keras.layers.UpSampling2D(
    size=(32, 32),
    data_format=keras.backend.image_data_format(),
    interpolation="bilinear",
)

final_fcn32s_pool = pool5(pool5_output)
final_fcn32s_output = fcn32s_conv_layer(final_fcn32s_pool)
final_fcn32s_output = fcn32s_upsampling(final_fcn32s_output)

fcn32s_model = keras.Model(inputs=input_layer, outputs=final_fcn32s_output)

FCN-16S

FCN-32S 的池化输出被扩展并添加到我们骨干网络的第 4 层池化输出中。在此之后,我们向上采样 16 倍以获得与输入图像大小相同的图像。

# 1x1 convolution to set channels = number of classes
# Followed from the original Caffe implementation
pool4 = keras.layers.Conv2D(
    filters=NUM_CLASSES,
    kernel_size=(1, 1),
    padding="same",
    strides=(1, 1),
    activation="linear",
    kernel_initializer=keras.initializers.Zeros(),
)(pool4_output)

# Intermediate up-sample
pool5 = keras.layers.UpSampling2D(
    size=(2, 2),
    data_format=keras.backend.image_data_format(),
    interpolation="bilinear",
)(final_fcn32s_pool)

# Get Softmax outputs for all classes
fcn16s_conv_layer = keras.layers.Conv2D(
    filters=NUM_CLASSES,
    kernel_size=(1, 1),
    activation="softmax",
    padding="same",
    strides=(1, 1),
)

# Up-sample to original image size
fcn16s_upsample_layer = keras.layers.UpSampling2D(
    size=(16, 16),
    data_format=keras.backend.image_data_format(),
    interpolation="bilinear",
)

# Add intermediate outputs
final_fcn16s_pool = keras.layers.Add()([pool4, pool5])
final_fcn16s_output = fcn16s_conv_layer(final_fcn16s_pool)
final_fcn16s_output = fcn16s_upsample_layer(final_fcn16s_output)

fcn16s_model = keras.models.Model(inputs=input_layer, outputs=final_fcn16s_output)

FCN-8S

FCN-16S 的池化输出再次被扩展,并从我们骨干网络的第 3 层池化输出中添加。此结果向上采样 8 倍以获得与输入图像大小相同的图像。

# 1x1 convolution to set channels = number of classes
# Followed from the original Caffe implementation
pool3 = keras.layers.Conv2D(
    filters=NUM_CLASSES,
    kernel_size=(1, 1),
    padding="same",
    strides=(1, 1),
    activation="linear",
    kernel_initializer=keras.initializers.Zeros(),
)(pool3_output)

# Intermediate up-sample
intermediate_pool_output = keras.layers.UpSampling2D(
    size=(2, 2),
    data_format=keras.backend.image_data_format(),
    interpolation="bilinear",
)(final_fcn16s_pool)

# Get Softmax outputs for all classes
fcn8s_conv_layer = keras.layers.Conv2D(
    filters=NUM_CLASSES,
    kernel_size=(1, 1),
    activation="softmax",
    padding="same",
    strides=(1, 1),
)

# Up-sample to original image size
fcn8s_upsample_layer = keras.layers.UpSampling2D(
    size=(8, 8),
    data_format=keras.backend.image_data_format(),
    interpolation="bilinear",
)

# Add intermediate outputs
final_fcn8s_pool = keras.layers.Add()([pool3, intermediate_pool_output])
final_fcn8s_output = fcn8s_conv_layer(final_fcn8s_pool)
final_fcn8s_output = fcn8s_upsample_layer(final_fcn8s_output)

fcn8s_model = keras.models.Model(inputs=input_layer, outputs=final_fcn8s_output)

将权重加载到骨干网络中

论文以及实验都指出,从骨干网络中提取最后 2 个全连接 Dense 层的权重,重新整形权重以适应我们之前转换为keras.layers.Conv2Dkeras.layers.Dense层的权重,并将它们设置为它,可以产生更好的结果并显着提高 mIOU 性能。

# VGG's last 2 layers
weights1 = vgg_model.get_layer("fc1").get_weights()[0]
weights2 = vgg_model.get_layer("fc2").get_weights()[0]

weights1 = weights1.reshape(7, 7, 512, 4096)
weights2 = weights2.reshape(1, 1, 4096, 4096)

dense_convs.layers[0].set_weights([weights1])
dense_convs.layers[2].set_weights([weights2])

训练

原始论文讨论了使用带动量的 SGD作为首选优化器。但在实验中发现,AdamW在 mIOU 和像素级准确率方面产生了更好的结果。

FCN-32S

fcn32s_optimizer = keras.optimizers.AdamW(
    learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

fcn32s_loss = keras.losses.SparseCategoricalCrossentropy()

# Maintain mIOU and Pixel-wise Accuracy as metrics
fcn32s_model.compile(
    optimizer=fcn32s_optimizer,
    loss=fcn32s_loss,
    metrics=[
        keras.metrics.MeanIoU(num_classes=NUM_CLASSES, sparse_y_pred=False),
        keras.metrics.SparseCategoricalAccuracy(),
    ],
)

fcn32s_history = fcn32s_model.fit(train_ds, epochs=EPOCHS, validation_data=valid_ds)
Epoch 1/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 31s 171ms/step - loss: 0.9853 - mean_io_u: 0.3056 - sparse_categorical_accuracy: 0.6242 - val_loss: 0.7911 - val_mean_io_u: 0.4022 - val_sparse_categorical_accuracy: 0.7011
Epoch 2/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 22s 131ms/step - loss: 0.7463 - mean_io_u: 0.3978 - sparse_categorical_accuracy: 0.7100 - val_loss: 0.7162 - val_mean_io_u: 0.3968 - val_sparse_categorical_accuracy: 0.7157
Epoch 3/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 21s 120ms/step - loss: 0.6939 - mean_io_u: 0.4139 - sparse_categorical_accuracy: 0.7255 - val_loss: 0.6714 - val_mean_io_u: 0.4383 - val_sparse_categorical_accuracy: 0.7379
Epoch 4/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 21s 117ms/step - loss: 0.6694 - mean_io_u: 0.4239 - sparse_categorical_accuracy: 0.7339 - val_loss: 0.6715 - val_mean_io_u: 0.4258 - val_sparse_categorical_accuracy: 0.7332
Epoch 5/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 21s 115ms/step - loss: 0.6556 - mean_io_u: 0.4279 - sparse_categorical_accuracy: 0.7382 - val_loss: 0.6271 - val_mean_io_u: 0.4483 - val_sparse_categorical_accuracy: 0.7514
Epoch 6/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 21s 120ms/step - loss: 0.6501 - mean_io_u: 0.4295 - sparse_categorical_accuracy: 0.7394 - val_loss: 0.6390 - val_mean_io_u: 0.4375 - val_sparse_categorical_accuracy: 0.7442
Epoch 7/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 109ms/step - loss: 0.6464 - mean_io_u: 0.4309 - sparse_categorical_accuracy: 0.7402 - val_loss: 0.6143 - val_mean_io_u: 0.4508 - val_sparse_categorical_accuracy: 0.7553
Epoch 8/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 108ms/step - loss: 0.6363 - mean_io_u: 0.4343 - sparse_categorical_accuracy: 0.7444 - val_loss: 0.6143 - val_mean_io_u: 0.4481 - val_sparse_categorical_accuracy: 0.7541
Epoch 9/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 108ms/step - loss: 0.6367 - mean_io_u: 0.4346 - sparse_categorical_accuracy: 0.7445 - val_loss: 0.6222 - val_mean_io_u: 0.4534 - val_sparse_categorical_accuracy: 0.7510
Epoch 10/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 19s 108ms/step - loss: 0.6398 - mean_io_u: 0.4346 - sparse_categorical_accuracy: 0.7426 - val_loss: 0.6123 - val_mean_io_u: 0.4494 - val_sparse_categorical_accuracy: 0.7541
Epoch 11/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 110ms/step - loss: 0.6361 - mean_io_u: 0.4365 - sparse_categorical_accuracy: 0.7439 - val_loss: 0.6310 - val_mean_io_u: 0.4405 - val_sparse_categorical_accuracy: 0.7461
Epoch 12/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 21s 110ms/step - loss: 0.6325 - mean_io_u: 0.4362 - sparse_categorical_accuracy: 0.7454 - val_loss: 0.6155 - val_mean_io_u: 0.4441 - val_sparse_categorical_accuracy: 0.7509
Epoch 13/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 112ms/step - loss: 0.6335 - mean_io_u: 0.4368 - sparse_categorical_accuracy: 0.7452 - val_loss: 0.6153 - val_mean_io_u: 0.4430 - val_sparse_categorical_accuracy: 0.7504
Epoch 14/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 113ms/step - loss: 0.6289 - mean_io_u: 0.4380 - sparse_categorical_accuracy: 0.7466 - val_loss: 0.6357 - val_mean_io_u: 0.4309 - val_sparse_categorical_accuracy: 0.7382
Epoch 15/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 113ms/step - loss: 0.6267 - mean_io_u: 0.4369 - sparse_categorical_accuracy: 0.7474 - val_loss: 0.5974 - val_mean_io_u: 0.4619 - val_sparse_categorical_accuracy: 0.7617
Epoch 16/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 109ms/step - loss: 0.6309 - mean_io_u: 0.4368 - sparse_categorical_accuracy: 0.7458 - val_loss: 0.6071 - val_mean_io_u: 0.4463 - val_sparse_categorical_accuracy: 0.7533
Epoch 17/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 112ms/step - loss: 0.6285 - mean_io_u: 0.4382 - sparse_categorical_accuracy: 0.7465 - val_loss: 0.5979 - val_mean_io_u: 0.4576 - val_sparse_categorical_accuracy: 0.7602
Epoch 18/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 111ms/step - loss: 0.6250 - mean_io_u: 0.4403 - sparse_categorical_accuracy: 0.7479 - val_loss: 0.6121 - val_mean_io_u: 0.4451 - val_sparse_categorical_accuracy: 0.7507
Epoch 19/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 111ms/step - loss: 0.6307 - mean_io_u: 0.4386 - sparse_categorical_accuracy: 0.7454 - val_loss: 0.6010 - val_mean_io_u: 0.4532 - val_sparse_categorical_accuracy: 0.7577
Epoch 20/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 114ms/step - loss: 0.6199 - mean_io_u: 0.4403 - sparse_categorical_accuracy: 0.7505 - val_loss: 0.6180 - val_mean_io_u: 0.4339 - val_sparse_categorical_accuracy: 0.7465

FCN-16S

fcn16s_optimizer = keras.optimizers.AdamW(
    learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

fcn16s_loss = keras.losses.SparseCategoricalCrossentropy()

# Maintain mIOU and Pixel-wise Accuracy as metrics
fcn16s_model.compile(
    optimizer=fcn16s_optimizer,
    loss=fcn16s_loss,
    metrics=[
        keras.metrics.MeanIoU(num_classes=NUM_CLASSES, sparse_y_pred=False),
        keras.metrics.SparseCategoricalAccuracy(),
    ],
)

fcn16s_history = fcn16s_model.fit(train_ds, epochs=EPOCHS, validation_data=valid_ds)
Epoch 1/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 23s 127ms/step - loss: 6.4519 - mean_io_u_1: 0.3101 - sparse_categorical_accuracy: 0.5649 - val_loss: 5.7052 - val_mean_io_u_1: 0.3842 - val_sparse_categorical_accuracy: 0.6057
Epoch 2/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 19s 110ms/step - loss: 5.2670 - mean_io_u_1: 0.3936 - sparse_categorical_accuracy: 0.6339 - val_loss: 5.8929 - val_mean_io_u_1: 0.3864 - val_sparse_categorical_accuracy: 0.5940
Epoch 3/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 111ms/step - loss: 5.2376 - mean_io_u_1: 0.3945 - sparse_categorical_accuracy: 0.6366 - val_loss: 5.6404 - val_mean_io_u_1: 0.3889 - val_sparse_categorical_accuracy: 0.6079
Epoch 4/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 21s 113ms/step - loss: 5.3014 - mean_io_u_1: 0.3924 - sparse_categorical_accuracy: 0.6323 - val_loss: 5.6516 - val_mean_io_u_1: 0.3874 - val_sparse_categorical_accuracy: 0.6094
Epoch 5/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 112ms/step - loss: 5.3135 - mean_io_u_1: 0.3918 - sparse_categorical_accuracy: 0.6323 - val_loss: 5.6588 - val_mean_io_u_1: 0.3903 - val_sparse_categorical_accuracy: 0.6084
Epoch 6/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 108ms/step - loss: 5.2401 - mean_io_u_1: 0.3938 - sparse_categorical_accuracy: 0.6357 - val_loss: 5.6463 - val_mean_io_u_1: 0.3868 - val_sparse_categorical_accuracy: 0.6097
Epoch 7/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 109ms/step - loss: 5.2277 - mean_io_u_1: 0.3921 - sparse_categorical_accuracy: 0.6371 - val_loss: 5.6272 - val_mean_io_u_1: 0.3796 - val_sparse_categorical_accuracy: 0.6136
Epoch 8/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 112ms/step - loss: 5.2479 - mean_io_u_1: 0.3910 - sparse_categorical_accuracy: 0.6360 - val_loss: 5.6303 - val_mean_io_u_1: 0.3823 - val_sparse_categorical_accuracy: 0.6108
Epoch 9/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 21s 112ms/step - loss: 5.1940 - mean_io_u_1: 0.3913 - sparse_categorical_accuracy: 0.6388 - val_loss: 5.8818 - val_mean_io_u_1: 0.3848 - val_sparse_categorical_accuracy: 0.5912
Epoch 10/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 111ms/step - loss: 5.2457 - mean_io_u_1: 0.3898 - sparse_categorical_accuracy: 0.6358 - val_loss: 5.6423 - val_mean_io_u_1: 0.3880 - val_sparse_categorical_accuracy: 0.6087
Epoch 11/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 110ms/step - loss: 5.1808 - mean_io_u_1: 0.3905 - sparse_categorical_accuracy: 0.6400 - val_loss: 5.6175 - val_mean_io_u_1: 0.3834 - val_sparse_categorical_accuracy: 0.6090
Epoch 12/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 112ms/step - loss: 5.2730 - mean_io_u_1: 0.3907 - sparse_categorical_accuracy: 0.6341 - val_loss: 5.6322 - val_mean_io_u_1: 0.3878 - val_sparse_categorical_accuracy: 0.6109
Epoch 13/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 109ms/step - loss: 5.2501 - mean_io_u_1: 0.3904 - sparse_categorical_accuracy: 0.6359 - val_loss: 5.8711 - val_mean_io_u_1: 0.3859 - val_sparse_categorical_accuracy: 0.5950
Epoch 14/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 107ms/step - loss: 5.2407 - mean_io_u_1: 0.3926 - sparse_categorical_accuracy: 0.6362 - val_loss: 5.6387 - val_mean_io_u_1: 0.3805 - val_sparse_categorical_accuracy: 0.6122
Epoch 15/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 108ms/step - loss: 5.2280 - mean_io_u_1: 0.3909 - sparse_categorical_accuracy: 0.6370 - val_loss: 5.6382 - val_mean_io_u_1: 0.3837 - val_sparse_categorical_accuracy: 0.6112
Epoch 16/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 108ms/step - loss: 5.2232 - mean_io_u_1: 0.3899 - sparse_categorical_accuracy: 0.6369 - val_loss: 5.6285 - val_mean_io_u_1: 0.3818 - val_sparse_categorical_accuracy: 0.6101
Epoch 17/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 107ms/step - loss: 1.4671 - mean_io_u_1: 0.5928 - sparse_categorical_accuracy: 0.8210 - val_loss: 0.7661 - val_mean_io_u_1: 0.6455 - val_sparse_categorical_accuracy: 0.8504
Epoch 18/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 110ms/step - loss: 0.6795 - mean_io_u_1: 0.6508 - sparse_categorical_accuracy: 0.8664 - val_loss: 0.6913 - val_mean_io_u_1: 0.6490 - val_sparse_categorical_accuracy: 0.8562
Epoch 19/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 110ms/step - loss: 0.6498 - mean_io_u_1: 0.6530 - sparse_categorical_accuracy: 0.8663 - val_loss: 0.6834 - val_mean_io_u_1: 0.6559 - val_sparse_categorical_accuracy: 0.8577
Epoch 20/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 110ms/step - loss: 0.6305 - mean_io_u_1: 0.6563 - sparse_categorical_accuracy: 0.8681 - val_loss: 0.6529 - val_mean_io_u_1: 0.6575 - val_sparse_categorical_accuracy: 0.8657

FCN-8S

fcn8s_optimizer = keras.optimizers.AdamW(
    learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

fcn8s_loss = keras.losses.SparseCategoricalCrossentropy()

# Maintain mIOU and Pixel-wise Accuracy as metrics
fcn8s_model.compile(
    optimizer=fcn8s_optimizer,
    loss=fcn8s_loss,
    metrics=[
        keras.metrics.MeanIoU(num_classes=NUM_CLASSES, sparse_y_pred=False),
        keras.metrics.SparseCategoricalAccuracy(),
    ],
)

fcn8s_history = fcn8s_model.fit(train_ds, epochs=EPOCHS, validation_data=valid_ds)
Epoch 1/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 24s 125ms/step - loss: 8.4168 - mean_io_u_2: 0.3116 - sparse_categorical_accuracy: 0.4237 - val_loss: 7.6113 - val_mean_io_u_2: 0.3540 - val_sparse_categorical_accuracy: 0.4682
Epoch 2/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 110ms/step - loss: 8.1030 - mean_io_u_2: 0.3423 - sparse_categorical_accuracy: 0.4401 - val_loss: 7.7038 - val_mean_io_u_2: 0.3335 - val_sparse_categorical_accuracy: 0.4481
Epoch 3/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 110ms/step - loss: 8.0868 - mean_io_u_2: 0.3433 - sparse_categorical_accuracy: 0.4408 - val_loss: 7.5839 - val_mean_io_u_2: 0.3518 - val_sparse_categorical_accuracy: 0.4722
Epoch 4/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 21s 111ms/step - loss: 8.1508 - mean_io_u_2: 0.3414 - sparse_categorical_accuracy: 0.4365 - val_loss: 7.2391 - val_mean_io_u_2: 0.3519 - val_sparse_categorical_accuracy: 0.4805
Epoch 5/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 112ms/step - loss: 8.1621 - mean_io_u_2: 0.3440 - sparse_categorical_accuracy: 0.4361 - val_loss: 7.2805 - val_mean_io_u_2: 0.3474 - val_sparse_categorical_accuracy: 0.4816
Epoch 6/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 110ms/step - loss: 8.1470 - mean_io_u_2: 0.3412 - sparse_categorical_accuracy: 0.4360 - val_loss: 7.5605 - val_mean_io_u_2: 0.3543 - val_sparse_categorical_accuracy: 0.4736
Epoch 7/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 110ms/step - loss: 8.1464 - mean_io_u_2: 0.3430 - sparse_categorical_accuracy: 0.4368 - val_loss: 7.5442 - val_mean_io_u_2: 0.3542 - val_sparse_categorical_accuracy: 0.4702
Epoch 8/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 108ms/step - loss: 8.0812 - mean_io_u_2: 0.3463 - sparse_categorical_accuracy: 0.4403 - val_loss: 7.5565 - val_mean_io_u_2: 0.3471 - val_sparse_categorical_accuracy: 0.4614
Epoch 9/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 109ms/step - loss: 8.0441 - mean_io_u_2: 0.3463 - sparse_categorical_accuracy: 0.4420 - val_loss: 7.5563 - val_mean_io_u_2: 0.3522 - val_sparse_categorical_accuracy: 0.4734
Epoch 10/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 110ms/step - loss: 8.1385 - mean_io_u_2: 0.3432 - sparse_categorical_accuracy: 0.4363 - val_loss: 7.5236 - val_mean_io_u_2: 0.3506 - val_sparse_categorical_accuracy: 0.4660
Epoch 11/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 111ms/step - loss: 8.1114 - mean_io_u_2: 0.3447 - sparse_categorical_accuracy: 0.4381 - val_loss: 7.2068 - val_mean_io_u_2: 0.3518 - val_sparse_categorical_accuracy: 0.4808
Epoch 12/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 107ms/step - loss: 8.0777 - mean_io_u_2: 0.3451 - sparse_categorical_accuracy: 0.4392 - val_loss: 7.2252 - val_mean_io_u_2: 0.3497 - val_sparse_categorical_accuracy: 0.4815
Epoch 13/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 21s 110ms/step - loss: 8.1355 - mean_io_u_2: 0.3446 - sparse_categorical_accuracy: 0.4366 - val_loss: 7.5587 - val_mean_io_u_2: 0.3500 - val_sparse_categorical_accuracy: 0.4671
Epoch 14/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 107ms/step - loss: 8.1828 - mean_io_u_2: 0.3410 - sparse_categorical_accuracy: 0.4330 - val_loss: 7.2464 - val_mean_io_u_2: 0.3557 - val_sparse_categorical_accuracy: 0.4927
Epoch 15/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 108ms/step - loss: 8.1845 - mean_io_u_2: 0.3432 - sparse_categorical_accuracy: 0.4330 - val_loss: 7.2032 - val_mean_io_u_2: 0.3506 - val_sparse_categorical_accuracy: 0.4805
Epoch 16/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 21s 109ms/step - loss: 8.1183 - mean_io_u_2: 0.3449 - sparse_categorical_accuracy: 0.4374 - val_loss: 7.6210 - val_mean_io_u_2: 0.3460 - val_sparse_categorical_accuracy: 0.4751
Epoch 17/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 21s 111ms/step - loss: 8.1766 - mean_io_u_2: 0.3429 - sparse_categorical_accuracy: 0.4329 - val_loss: 7.5361 - val_mean_io_u_2: 0.3489 - val_sparse_categorical_accuracy: 0.4639
Epoch 18/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 109ms/step - loss: 8.0453 - mean_io_u_2: 0.3442 - sparse_categorical_accuracy: 0.4404 - val_loss: 7.1767 - val_mean_io_u_2: 0.3549 - val_sparse_categorical_accuracy: 0.4839
Epoch 19/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 109ms/step - loss: 8.0856 - mean_io_u_2: 0.3449 - sparse_categorical_accuracy: 0.4390 - val_loss: 7.1724 - val_mean_io_u_2: 0.3574 - val_sparse_categorical_accuracy: 0.4878
Epoch 20/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 21s 109ms/step - loss: 8.1378 - mean_io_u_2: 0.3445 - sparse_categorical_accuracy: 0.4358 - val_loss: 7.5449 - val_mean_io_u_2: 0.3521 - val_sparse_categorical_accuracy: 0.4681

可视化

绘制训练运行的指标

我们通过跟踪准确率、损失和平均 IoU 的训练和验证指标,对模型的所有 3 个版本进行了比较研究。

total_plots = len(fcn32s_history.history)
cols = total_plots // 2

rows = total_plots // cols

if total_plots % cols != 0:
    rows += 1

# Set all history dictionary objects
fcn32s_dict = fcn32s_history.history
fcn16s_dict = fcn16s_history.history
fcn8s_dict = fcn8s_history.history

pos = range(1, total_plots + 1)
plt.figure(figsize=(15, 10))

for i, ((key_32s, value_32s), (key_16s, value_16s), (key_8s, value_8s)) in enumerate(
    zip(fcn32s_dict.items(), fcn16s_dict.items(), fcn8s_dict.items())
):
    plt.subplot(rows, cols, pos[i])
    plt.plot(range(len(value_32s)), value_32s)
    plt.plot(range(len(value_16s)), value_16s)
    plt.plot(range(len(value_8s)), value_8s)
    plt.title(str(key_32s) + " (combined)")
    plt.legend(["FCN-32S", "FCN-16S", "FCN-8S"])

plt.show()

png

可视化预测的分割掩码

为了更好地理解结果并查看它们,我们从测试集中选择一个随机图像,并对其执行推理以查看每个模型生成的掩码。注意:为了获得更好的结果,模型必须训练更多的轮次。

images, masks = next(iter(test_ds))
random_idx = keras.random.uniform([], minval=0, maxval=BATCH_SIZE,seed=10)

# Get random test image and mask
test_image = images[int(random_idx)].numpy().astype("float")
test_mask = masks[int(random_idx)].numpy().astype("float")

pred_image = ops.expand_dims(test_image, axis=0)
pred_image = keras.applications.vgg19.preprocess_input(pred_image)

# Perform inference on FCN-32S
pred_mask_32s = fcn32s_model.predict(pred_image, verbose=0).astype("float")
pred_mask_32s = np.argmax(pred_mask_32s, axis=-1)
pred_mask_32s = pred_mask_32s[0, ...]

# Perform inference on FCN-16S
pred_mask_16s = fcn16s_model.predict(pred_image, verbose=0).astype("float")
pred_mask_16s = np.argmax(pred_mask_16s, axis=-1)
pred_mask_16s = pred_mask_16s[0, ...]

# Perform inference on FCN-8S
pred_mask_8s = fcn8s_model.predict(pred_image, verbose=0).astype("float")
pred_mask_8s = np.argmax(pred_mask_8s, axis=-1)
pred_mask_8s = pred_mask_8s[0, ...]

# Plot all results
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(15, 8))

fig.delaxes(ax[0, 2])

ax[0, 0].set_title("Image")
ax[0, 0].imshow(test_image / 255.0)

ax[0, 1].set_title("Image with ground truth overlay")
ax[0, 1].imshow(test_image / 255.0)
ax[0, 1].imshow(
    test_mask,
    cmap="inferno",
    alpha=0.6,
)

ax[1, 0].set_title("Image with FCN-32S mask overlay")
ax[1, 0].imshow(test_image / 255.0)
ax[1, 0].imshow(pred_mask_32s, cmap="inferno", alpha=0.6)

ax[1, 1].set_title("Image with FCN-16S mask overlay")
ax[1, 1].imshow(test_image / 255.0)
ax[1, 1].imshow(pred_mask_16s, cmap="inferno", alpha=0.6)

ax[1, 2].set_title("Image with FCN-8S mask overlay")
ax[1, 2].imshow(test_image / 255.0)
ax[1, 2].imshow(pred_mask_8s, cmap="inferno", alpha=0.6)

plt.show()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png


结论

全卷积网络是一个非常简单的网络,在不同基准的图像分割任务中取得了良好的结果。随着注意力等更好机制(如SegFormerDeTR中使用的那样)的出现,该模型成为一种快速迭代并在未知数据上为该任务找到基准的方法。


致谢

感谢Aritra Roy GosthipatyAyush ThakurRitwik Raha对该示例进行初步审查。我还感谢Google 开发者专家计划。