代码示例 / 计算机视觉 / 使用可组合全卷积网络进行图像分割

使用可组合全卷积网络进行图像分割

作者: Suvaditya Mukherjee
创建日期 2023/06/16
最后修改日期 2023/12/25
描述: 使用全卷积网络进行图像分割。

ⓘ 此示例使用 Keras 3

在 Colab 中查看 GitHub 源代码


介绍

以下示例将逐步介绍在 Oxford-IIIT Pets 数据集上使用全卷积网络进行图像分割的步骤。该模型在 Long 等人的论文 用于语义分割的全卷积网络(2014) 中提出。图像分割是计算机视觉中最常见和最基本的任务之一,我们将图像分类的问题从每个图像一个标签扩展到像素级分类问题。在此示例中,我们将组装上述全卷积分割架构,该架构能够执行图像分割。该网络扩展了 VGG 中的池化层输出,以执行上采样并获得最终结果。从 VGG19 的第 3、第 4 和第 5 个最大池化层中提取的中间输出在不同级别和因子处上采样,以获得与输出形状相同的最终输出,但每个像素的类别存在于每个位置,而不是像素强度值。不同的中间池化层被提取出来,并根据网络的不同版本进行处理。FCN 架构有 3 个不同质量的版本。

  • FCN-32S
  • FCN-16S
  • FCN-8S

所有版本的模型都通过对所用主干网络的连续中间池化层进行迭代处理而获得输出。从下图可以更好地了解这一点。

FCN Architecture
图 1:组合架构版本(来源:论文)

要更好地了解图像分割或查找更多预训练模型,请随时导航到 Hugging Face 图像分割模型 页面或 PyImageSearch 关于语义分割的博客


设置导入

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
from keras import ops
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import numpy as np

AUTOTUNE = tf.data.AUTOTUNE

设置笔记本变量的配置

我们为实验设置了必要的参数。选择的数据集在每个图像的分割掩码方面总共有 4 个类别。我们还在此单元格中设置了超参数。

混合精度作为一种选择也适用于支持它的系统,以减少负载。这将使大多数张量使用 16 位浮点数 值而不是 32 位浮点数 值,在不会对计算产生负面影响的地方。这意味着,在计算过程中,TensorFlow 将使用 16 位浮点数 张量来提高速度,但代价是精度,同时以原始默认 32 位浮点数 形式存储值。

NUM_CLASSES = 4
INPUT_HEIGHT = 224
INPUT_WIDTH = 224
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
EPOCHS = 20
BATCH_SIZE = 32
MIXED_PRECISION = True
SHUFFLE = True

# Mixed-precision setting
if MIXED_PRECISION:
    policy = keras.mixed_precision.Policy("mixed_float16")
    keras.mixed_precision.set_global_policy(policy)
INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: Quadro RTX 5000, compute capability 7.5

加载数据集

我们使用 Oxford-IIIT Pets 数据集,该数据集包含 7,349 个样本及其分割掩码。我们有 37 个类别,每个类别大约有 200 个样本。我们的训练和验证数据集分别有 3,128 个和 552 个样本。除此之外,我们的测试集总共有 3,669 个样本。

我们设置了一个 batch_size 参数,它将把我们的样本批处理在一起,使用一个 shuffle 参数将我们的样本混合在一起。

(train_ds, valid_ds, test_ds) = tfds.load(
    "oxford_iiit_pet",
    split=["train[:85%]", "train[85%:]", "test"],
    batch_size=BATCH_SIZE,
    shuffle_files=SHUFFLE,
)

解压缩和预处理数据集

我们定义了一个简单函数,该函数包括对我们的训练、验证和测试数据集执行调整大小操作。我们也对掩码进行相同的处理,以确保两者在形状和大小方面保持一致。

# Image and Mask Pre-processing
def unpack_resize_data(section):
    image = section["image"]
    segmentation_mask = section["segmentation_mask"]

    resize_layer = keras.layers.Resizing(INPUT_HEIGHT, INPUT_WIDTH)

    image = resize_layer(image)
    segmentation_mask = resize_layer(segmentation_mask)

    return image, segmentation_mask


train_ds = train_ds.map(unpack_resize_data, num_parallel_calls=AUTOTUNE)
valid_ds = valid_ds.map(unpack_resize_data, num_parallel_calls=AUTOTUNE)
test_ds = test_ds.map(unpack_resize_data, num_parallel_calls=AUTOTUNE)

可视化预处理数据集中的一个随机样本

我们可视化了测试集中一个随机样本的外观,并在上面绘制了分割掩码以查看有效的掩码区域。请注意,我们也对该数据集进行了预处理,这使得图像和掩码大小相同。

# Select random image and mask. Cast to NumPy array
# for Matplotlib visualization.

images, masks = next(iter(test_ds))
random_idx = keras.random.uniform([], minval=0, maxval=BATCH_SIZE, seed=10)

test_image = images[int(random_idx)].numpy().astype("float")
test_mask = masks[int(random_idx)].numpy().astype("float")

# Overlay segmentation mask on top of image.
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))

ax[0].set_title("Image")
ax[0].imshow(test_image / 255.0)

ax[1].set_title("Image with segmentation mask overlay")
ax[1].imshow(test_image / 255.0)
ax[1].imshow(
    test_mask,
    cmap="inferno",
    alpha=0.6,
)
plt.show()

png


执行 VGG 特定的预处理

keras.applications.VGG19 需要使用一个 preprocess_input 函数,该函数将积极执行 Image-net 风格的标准差归一化方案。

def preprocess_data(image, segmentation_mask):
    image = keras.applications.vgg19.preprocess_input(image)

    return image, segmentation_mask


train_ds = (
    train_ds.map(preprocess_data, num_parallel_calls=AUTOTUNE)
    .shuffle(buffer_size=1024)
    .prefetch(buffer_size=1024)
)
valid_ds = (
    valid_ds.map(preprocess_data, num_parallel_calls=AUTOTUNE)
    .shuffle(buffer_size=1024)
    .prefetch(buffer_size=1024)
)
test_ds = (
    test_ds.map(preprocess_data, num_parallel_calls=AUTOTUNE)
    .shuffle(buffer_size=1024)
    .prefetch(buffer_size=1024)
)

模型定义

全卷积网络拥有一个简单的架构,仅由 keras.layers.Conv2D 层、keras.layers.Dense 层和 keras.layers.Dropout 层组成。

FCN Architecture
图 2:通用 FCN 正向传递(来源:论文)

通过使用与图像尺寸相同的 Softmax 卷积层进行像素级预测,以便进行直接比较。我们可以在网络上找到几个重要的指标,例如准确率和平均交并比。

主干网络 (VGG-19)

我们使用 VGG-19 网络 作为主干网络,因为论文建议它是该网络最有效的网络之一。我们通过使用 keras.models.Model 从网络中提取不同的输出。之后,我们在其之上添加层,以完美模拟图 1 的网络。主干网络的 keras.layers.Dense 层将根据 这里提供的原始 Caffe 代码 转换为 keras.layers.Conv2D 层。所有 3 个网络将共享相同的主干网络权重,但它们的结果将根据其扩展的不同而有所不同。为了提高训练时间需求,我们使主干网络不可训练。论文中还指出,使网络可训练不会带来重大好处。

input_layer = keras.Input(shape=(INPUT_HEIGHT, INPUT_WIDTH, 3))

# VGG Model backbone with pre-trained ImageNet weights.
vgg_model = keras.applications.vgg19.VGG19(include_top=True, weights="imagenet")

# Extracting different outputs from same model
fcn_backbone = keras.models.Model(
    inputs=vgg_model.layers[1].input,
    outputs=[
        vgg_model.get_layer(block_name).output
        for block_name in ["block3_pool", "block4_pool", "block5_pool"]
    ],
)

# Setting backbone to be non-trainable
fcn_backbone.trainable = False

x = fcn_backbone(input_layer)

# Converting Dense layers to Conv2D layers
units = [4096, 4096]
dense_convs = []

for filter_idx in range(len(units)):
    dense_conv = keras.layers.Conv2D(
        filters=units[filter_idx],
        kernel_size=(7, 7) if filter_idx == 0 else (1, 1),
        strides=(1, 1),
        activation="relu",
        padding="same",
        use_bias=False,
        kernel_initializer=keras.initializers.Constant(1.0),,
    )
    dense_convs.append(dense_conv)
    dropout_layer = keras.layers.Dropout(0.5)
    dense_convs.append(dropout_layer)

dense_convs = keras.Sequential(dense_convs)
dense_convs.trainable = False

x[-1] = dense_convs(x[-1])

pool3_output, pool4_output, pool5_output = x

FCN-32S

我们扩展最后一个输出,执行 1x1 卷积 并通过因子 32 进行二维双线性上采样,以获得与输入图像尺寸相同的图像。我们使用简单的 keras.layers.UpSampling2D 层而不是 keras.layers.Conv2DTranspose,因为它作为确定性数学运算而不是卷积运算,可以带来性能优势。论文中还指出,使上采样参数可训练不会带来好处。论文的原始实验也使用了上采样。

# 1x1 convolution to set channels = number of classes
pool5 = keras.layers.Conv2D(
    filters=NUM_CLASSES,
    kernel_size=(1, 1),
    padding="same",
    strides=(1, 1),
    activation="relu",
)

# Get Softmax outputs for all classes
fcn32s_conv_layer = keras.layers.Conv2D(
    filters=NUM_CLASSES,
    kernel_size=(1, 1),
    activation="softmax",
    padding="same",
    strides=(1, 1),
)

# Up-sample to original image size
fcn32s_upsampling = keras.layers.UpSampling2D(
    size=(32, 32),
    data_format=keras.backend.image_data_format(),
    interpolation="bilinear",
)

final_fcn32s_pool = pool5(pool5_output)
final_fcn32s_output = fcn32s_conv_layer(final_fcn32s_pool)
final_fcn32s_output = fcn32s_upsampling(final_fcn32s_output)

fcn32s_model = keras.Model(inputs=input_layer, outputs=final_fcn32s_output)

FCN-16S

FCN-32S 的池化输出被扩展并添加到我们主干网络的第 4 级池化输出中。之后,我们以因子 16 进行上采样,以获得与输入图像尺寸相同的图像。

# 1x1 convolution to set channels = number of classes
# Followed from the original Caffe implementation
pool4 = keras.layers.Conv2D(
    filters=NUM_CLASSES,
    kernel_size=(1, 1),
    padding="same",
    strides=(1, 1),
    activation="linear",
    kernel_initializer=keras.initializers.Zeros(),
)(pool4_output)

# Intermediate up-sample
pool5 = keras.layers.UpSampling2D(
    size=(2, 2),
    data_format=keras.backend.image_data_format(),
    interpolation="bilinear",
)(final_fcn32s_pool)

# Get Softmax outputs for all classes
fcn16s_conv_layer = keras.layers.Conv2D(
    filters=NUM_CLASSES,
    kernel_size=(1, 1),
    activation="softmax",
    padding="same",
    strides=(1, 1),
)

# Up-sample to original image size
fcn16s_upsample_layer = keras.layers.UpSampling2D(
    size=(16, 16),
    data_format=keras.backend.image_data_format(),
    interpolation="bilinear",
)

# Add intermediate outputs
final_fcn16s_pool = keras.layers.Add()([pool4, pool5])
final_fcn16s_output = fcn16s_conv_layer(final_fcn16s_pool)
final_fcn16s_output = fcn16s_upsample_layer(final_fcn16s_output)

fcn16s_model = keras.models.Model(inputs=input_layer, outputs=final_fcn16s_output)

FCN-8S

FCN-16S 的池化输出再次被扩展,并添加到我们主干网络的第 3 级池化输出中。该结果以因子 8 进行上采样,以获得与输入图像尺寸相同的图像。

# 1x1 convolution to set channels = number of classes
# Followed from the original Caffe implementation
pool3 = keras.layers.Conv2D(
    filters=NUM_CLASSES,
    kernel_size=(1, 1),
    padding="same",
    strides=(1, 1),
    activation="linear",
    kernel_initializer=keras.initializers.Zeros(),
)(pool3_output)

# Intermediate up-sample
intermediate_pool_output = keras.layers.UpSampling2D(
    size=(2, 2),
    data_format=keras.backend.image_data_format(),
    interpolation="bilinear",
)(final_fcn16s_pool)

# Get Softmax outputs for all classes
fcn8s_conv_layer = keras.layers.Conv2D(
    filters=NUM_CLASSES,
    kernel_size=(1, 1),
    activation="softmax",
    padding="same",
    strides=(1, 1),
)

# Up-sample to original image size
fcn8s_upsample_layer = keras.layers.UpSampling2D(
    size=(8, 8),
    data_format=keras.backend.image_data_format(),
    interpolation="bilinear",
)

# Add intermediate outputs
final_fcn8s_pool = keras.layers.Add()([pool3, intermediate_pool_output])
final_fcn8s_output = fcn8s_conv_layer(final_fcn8s_pool)
final_fcn8s_output = fcn8s_upsample_layer(final_fcn8s_output)

fcn8s_model = keras.models.Model(inputs=input_layer, outputs=final_fcn8s_output)

将权重加载到主干网络中

论文以及实验表明,从主干网络中提取最后两个完全连接的 Dense 层的权重,将权重重新整形以适合我们之前转换为 keras.layers.Conv2Dkeras.layers.Dense 层,并将它们设置到该层中,可以产生更好的结果,并显着提高 mIOU 性能。

# VGG's last 2 layers
weights1 = vgg_model.get_layer("fc1").get_weights()[0]
weights2 = vgg_model.get_layer("fc2").get_weights()[0]

weights1 = weights1.reshape(7, 7, 512, 4096)
weights2 = weights2.reshape(1, 1, 4096, 4096)

dense_convs.layers[0].set_weights([weights1])
dense_convs.layers[2].set_weights([weights2])

训练

原始论文讨论了使用 带动量的 SGD 作为首选优化器。但在实验中注意到,AdamW 在 mIOU 和像素级准确性方面产生了更好的结果。

FCN-32S

fcn32s_optimizer = keras.optimizers.AdamW(
    learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

fcn32s_loss = keras.losses.SparseCategoricalCrossentropy()

# Maintain mIOU and Pixel-wise Accuracy as metrics
fcn32s_model.compile(
    optimizer=fcn32s_optimizer,
    loss=fcn32s_loss,
    metrics=[
        keras.metrics.MeanIoU(num_classes=NUM_CLASSES, sparse_y_pred=False),
        keras.metrics.SparseCategoricalAccuracy(),
    ],
)

fcn32s_history = fcn32s_model.fit(train_ds, epochs=EPOCHS, validation_data=valid_ds)
Epoch 1/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 31s 171ms/step - loss: 0.9853 - mean_io_u: 0.3056 - sparse_categorical_accuracy: 0.6242 - val_loss: 0.7911 - val_mean_io_u: 0.4022 - val_sparse_categorical_accuracy: 0.7011
Epoch 2/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 22s 131ms/step - loss: 0.7463 - mean_io_u: 0.3978 - sparse_categorical_accuracy: 0.7100 - val_loss: 0.7162 - val_mean_io_u: 0.3968 - val_sparse_categorical_accuracy: 0.7157
Epoch 3/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 21s 120ms/step - loss: 0.6939 - mean_io_u: 0.4139 - sparse_categorical_accuracy: 0.7255 - val_loss: 0.6714 - val_mean_io_u: 0.4383 - val_sparse_categorical_accuracy: 0.7379
Epoch 4/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 21s 117ms/step - loss: 0.6694 - mean_io_u: 0.4239 - sparse_categorical_accuracy: 0.7339 - val_loss: 0.6715 - val_mean_io_u: 0.4258 - val_sparse_categorical_accuracy: 0.7332
Epoch 5/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 21s 115ms/step - loss: 0.6556 - mean_io_u: 0.4279 - sparse_categorical_accuracy: 0.7382 - val_loss: 0.6271 - val_mean_io_u: 0.4483 - val_sparse_categorical_accuracy: 0.7514
Epoch 6/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 21s 120ms/step - loss: 0.6501 - mean_io_u: 0.4295 - sparse_categorical_accuracy: 0.7394 - val_loss: 0.6390 - val_mean_io_u: 0.4375 - val_sparse_categorical_accuracy: 0.7442
Epoch 7/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 109ms/step - loss: 0.6464 - mean_io_u: 0.4309 - sparse_categorical_accuracy: 0.7402 - val_loss: 0.6143 - val_mean_io_u: 0.4508 - val_sparse_categorical_accuracy: 0.7553
Epoch 8/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 108ms/step - loss: 0.6363 - mean_io_u: 0.4343 - sparse_categorical_accuracy: 0.7444 - val_loss: 0.6143 - val_mean_io_u: 0.4481 - val_sparse_categorical_accuracy: 0.7541
Epoch 9/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 108ms/step - loss: 0.6367 - mean_io_u: 0.4346 - sparse_categorical_accuracy: 0.7445 - val_loss: 0.6222 - val_mean_io_u: 0.4534 - val_sparse_categorical_accuracy: 0.7510
Epoch 10/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 19s 108ms/step - loss: 0.6398 - mean_io_u: 0.4346 - sparse_categorical_accuracy: 0.7426 - val_loss: 0.6123 - val_mean_io_u: 0.4494 - val_sparse_categorical_accuracy: 0.7541
Epoch 11/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 110ms/step - loss: 0.6361 - mean_io_u: 0.4365 - sparse_categorical_accuracy: 0.7439 - val_loss: 0.6310 - val_mean_io_u: 0.4405 - val_sparse_categorical_accuracy: 0.7461
Epoch 12/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 21s 110ms/step - loss: 0.6325 - mean_io_u: 0.4362 - sparse_categorical_accuracy: 0.7454 - val_loss: 0.6155 - val_mean_io_u: 0.4441 - val_sparse_categorical_accuracy: 0.7509
Epoch 13/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 112ms/step - loss: 0.6335 - mean_io_u: 0.4368 - sparse_categorical_accuracy: 0.7452 - val_loss: 0.6153 - val_mean_io_u: 0.4430 - val_sparse_categorical_accuracy: 0.7504
Epoch 14/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 113ms/step - loss: 0.6289 - mean_io_u: 0.4380 - sparse_categorical_accuracy: 0.7466 - val_loss: 0.6357 - val_mean_io_u: 0.4309 - val_sparse_categorical_accuracy: 0.7382
Epoch 15/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 113ms/step - loss: 0.6267 - mean_io_u: 0.4369 - sparse_categorical_accuracy: 0.7474 - val_loss: 0.5974 - val_mean_io_u: 0.4619 - val_sparse_categorical_accuracy: 0.7617
Epoch 16/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 109ms/step - loss: 0.6309 - mean_io_u: 0.4368 - sparse_categorical_accuracy: 0.7458 - val_loss: 0.6071 - val_mean_io_u: 0.4463 - val_sparse_categorical_accuracy: 0.7533
Epoch 17/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 112ms/step - loss: 0.6285 - mean_io_u: 0.4382 - sparse_categorical_accuracy: 0.7465 - val_loss: 0.5979 - val_mean_io_u: 0.4576 - val_sparse_categorical_accuracy: 0.7602
Epoch 18/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 111ms/step - loss: 0.6250 - mean_io_u: 0.4403 - sparse_categorical_accuracy: 0.7479 - val_loss: 0.6121 - val_mean_io_u: 0.4451 - val_sparse_categorical_accuracy: 0.7507
Epoch 19/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 111ms/step - loss: 0.6307 - mean_io_u: 0.4386 - sparse_categorical_accuracy: 0.7454 - val_loss: 0.6010 - val_mean_io_u: 0.4532 - val_sparse_categorical_accuracy: 0.7577
Epoch 20/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 114ms/step - loss: 0.6199 - mean_io_u: 0.4403 - sparse_categorical_accuracy: 0.7505 - val_loss: 0.6180 - val_mean_io_u: 0.4339 - val_sparse_categorical_accuracy: 0.7465

FCN-16S

fcn16s_optimizer = keras.optimizers.AdamW(
    learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

fcn16s_loss = keras.losses.SparseCategoricalCrossentropy()

# Maintain mIOU and Pixel-wise Accuracy as metrics
fcn16s_model.compile(
    optimizer=fcn16s_optimizer,
    loss=fcn16s_loss,
    metrics=[
        keras.metrics.MeanIoU(num_classes=NUM_CLASSES, sparse_y_pred=False),
        keras.metrics.SparseCategoricalAccuracy(),
    ],
)

fcn16s_history = fcn16s_model.fit(train_ds, epochs=EPOCHS, validation_data=valid_ds)
Epoch 1/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 23s 127ms/step - loss: 6.4519 - mean_io_u_1: 0.3101 - sparse_categorical_accuracy: 0.5649 - val_loss: 5.7052 - val_mean_io_u_1: 0.3842 - val_sparse_categorical_accuracy: 0.6057
Epoch 2/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 19s 110ms/step - loss: 5.2670 - mean_io_u_1: 0.3936 - sparse_categorical_accuracy: 0.6339 - val_loss: 5.8929 - val_mean_io_u_1: 0.3864 - val_sparse_categorical_accuracy: 0.5940
Epoch 3/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 111ms/step - loss: 5.2376 - mean_io_u_1: 0.3945 - sparse_categorical_accuracy: 0.6366 - val_loss: 5.6404 - val_mean_io_u_1: 0.3889 - val_sparse_categorical_accuracy: 0.6079
Epoch 4/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 21s 113ms/step - loss: 5.3014 - mean_io_u_1: 0.3924 - sparse_categorical_accuracy: 0.6323 - val_loss: 5.6516 - val_mean_io_u_1: 0.3874 - val_sparse_categorical_accuracy: 0.6094
Epoch 5/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 112ms/step - loss: 5.3135 - mean_io_u_1: 0.3918 - sparse_categorical_accuracy: 0.6323 - val_loss: 5.6588 - val_mean_io_u_1: 0.3903 - val_sparse_categorical_accuracy: 0.6084
Epoch 6/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 108ms/step - loss: 5.2401 - mean_io_u_1: 0.3938 - sparse_categorical_accuracy: 0.6357 - val_loss: 5.6463 - val_mean_io_u_1: 0.3868 - val_sparse_categorical_accuracy: 0.6097
Epoch 7/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 109ms/step - loss: 5.2277 - mean_io_u_1: 0.3921 - sparse_categorical_accuracy: 0.6371 - val_loss: 5.6272 - val_mean_io_u_1: 0.3796 - val_sparse_categorical_accuracy: 0.6136
Epoch 8/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 112ms/step - loss: 5.2479 - mean_io_u_1: 0.3910 - sparse_categorical_accuracy: 0.6360 - val_loss: 5.6303 - val_mean_io_u_1: 0.3823 - val_sparse_categorical_accuracy: 0.6108
Epoch 9/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 21s 112ms/step - loss: 5.1940 - mean_io_u_1: 0.3913 - sparse_categorical_accuracy: 0.6388 - val_loss: 5.8818 - val_mean_io_u_1: 0.3848 - val_sparse_categorical_accuracy: 0.5912
Epoch 10/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 111ms/step - loss: 5.2457 - mean_io_u_1: 0.3898 - sparse_categorical_accuracy: 0.6358 - val_loss: 5.6423 - val_mean_io_u_1: 0.3880 - val_sparse_categorical_accuracy: 0.6087
Epoch 11/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 110ms/step - loss: 5.1808 - mean_io_u_1: 0.3905 - sparse_categorical_accuracy: 0.6400 - val_loss: 5.6175 - val_mean_io_u_1: 0.3834 - val_sparse_categorical_accuracy: 0.6090
Epoch 12/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 112ms/step - loss: 5.2730 - mean_io_u_1: 0.3907 - sparse_categorical_accuracy: 0.6341 - val_loss: 5.6322 - val_mean_io_u_1: 0.3878 - val_sparse_categorical_accuracy: 0.6109
Epoch 13/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 109ms/step - loss: 5.2501 - mean_io_u_1: 0.3904 - sparse_categorical_accuracy: 0.6359 - val_loss: 5.8711 - val_mean_io_u_1: 0.3859 - val_sparse_categorical_accuracy: 0.5950
Epoch 14/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 107ms/step - loss: 5.2407 - mean_io_u_1: 0.3926 - sparse_categorical_accuracy: 0.6362 - val_loss: 5.6387 - val_mean_io_u_1: 0.3805 - val_sparse_categorical_accuracy: 0.6122
Epoch 15/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 108ms/step - loss: 5.2280 - mean_io_u_1: 0.3909 - sparse_categorical_accuracy: 0.6370 - val_loss: 5.6382 - val_mean_io_u_1: 0.3837 - val_sparse_categorical_accuracy: 0.6112
Epoch 16/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 108ms/step - loss: 5.2232 - mean_io_u_1: 0.3899 - sparse_categorical_accuracy: 0.6369 - val_loss: 5.6285 - val_mean_io_u_1: 0.3818 - val_sparse_categorical_accuracy: 0.6101
Epoch 17/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 107ms/step - loss: 1.4671 - mean_io_u_1: 0.5928 - sparse_categorical_accuracy: 0.8210 - val_loss: 0.7661 - val_mean_io_u_1: 0.6455 - val_sparse_categorical_accuracy: 0.8504
Epoch 18/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 110ms/step - loss: 0.6795 - mean_io_u_1: 0.6508 - sparse_categorical_accuracy: 0.8664 - val_loss: 0.6913 - val_mean_io_u_1: 0.6490 - val_sparse_categorical_accuracy: 0.8562
Epoch 19/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 110ms/step - loss: 0.6498 - mean_io_u_1: 0.6530 - sparse_categorical_accuracy: 0.8663 - val_loss: 0.6834 - val_mean_io_u_1: 0.6559 - val_sparse_categorical_accuracy: 0.8577
Epoch 20/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 110ms/step - loss: 0.6305 - mean_io_u_1: 0.6563 - sparse_categorical_accuracy: 0.8681 - val_loss: 0.6529 - val_mean_io_u_1: 0.6575 - val_sparse_categorical_accuracy: 0.8657

FCN-8S

fcn8s_optimizer = keras.optimizers.AdamW(
    learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

fcn8s_loss = keras.losses.SparseCategoricalCrossentropy()

# Maintain mIOU and Pixel-wise Accuracy as metrics
fcn8s_model.compile(
    optimizer=fcn8s_optimizer,
    loss=fcn8s_loss,
    metrics=[
        keras.metrics.MeanIoU(num_classes=NUM_CLASSES, sparse_y_pred=False),
        keras.metrics.SparseCategoricalAccuracy(),
    ],
)

fcn8s_history = fcn8s_model.fit(train_ds, epochs=EPOCHS, validation_data=valid_ds)
Epoch 1/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 24s 125ms/step - loss: 8.4168 - mean_io_u_2: 0.3116 - sparse_categorical_accuracy: 0.4237 - val_loss: 7.6113 - val_mean_io_u_2: 0.3540 - val_sparse_categorical_accuracy: 0.4682
Epoch 2/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 110ms/step - loss: 8.1030 - mean_io_u_2: 0.3423 - sparse_categorical_accuracy: 0.4401 - val_loss: 7.7038 - val_mean_io_u_2: 0.3335 - val_sparse_categorical_accuracy: 0.4481
Epoch 3/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 110ms/step - loss: 8.0868 - mean_io_u_2: 0.3433 - sparse_categorical_accuracy: 0.4408 - val_loss: 7.5839 - val_mean_io_u_2: 0.3518 - val_sparse_categorical_accuracy: 0.4722
Epoch 4/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 21s 111ms/step - loss: 8.1508 - mean_io_u_2: 0.3414 - sparse_categorical_accuracy: 0.4365 - val_loss: 7.2391 - val_mean_io_u_2: 0.3519 - val_sparse_categorical_accuracy: 0.4805
Epoch 5/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 112ms/step - loss: 8.1621 - mean_io_u_2: 0.3440 - sparse_categorical_accuracy: 0.4361 - val_loss: 7.2805 - val_mean_io_u_2: 0.3474 - val_sparse_categorical_accuracy: 0.4816
Epoch 6/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 110ms/step - loss: 8.1470 - mean_io_u_2: 0.3412 - sparse_categorical_accuracy: 0.4360 - val_loss: 7.5605 - val_mean_io_u_2: 0.3543 - val_sparse_categorical_accuracy: 0.4736
Epoch 7/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 110ms/step - loss: 8.1464 - mean_io_u_2: 0.3430 - sparse_categorical_accuracy: 0.4368 - val_loss: 7.5442 - val_mean_io_u_2: 0.3542 - val_sparse_categorical_accuracy: 0.4702
Epoch 8/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 108ms/step - loss: 8.0812 - mean_io_u_2: 0.3463 - sparse_categorical_accuracy: 0.4403 - val_loss: 7.5565 - val_mean_io_u_2: 0.3471 - val_sparse_categorical_accuracy: 0.4614
Epoch 9/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 109ms/step - loss: 8.0441 - mean_io_u_2: 0.3463 - sparse_categorical_accuracy: 0.4420 - val_loss: 7.5563 - val_mean_io_u_2: 0.3522 - val_sparse_categorical_accuracy: 0.4734
Epoch 10/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 110ms/step - loss: 8.1385 - mean_io_u_2: 0.3432 - sparse_categorical_accuracy: 0.4363 - val_loss: 7.5236 - val_mean_io_u_2: 0.3506 - val_sparse_categorical_accuracy: 0.4660
Epoch 11/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 111ms/step - loss: 8.1114 - mean_io_u_2: 0.3447 - sparse_categorical_accuracy: 0.4381 - val_loss: 7.2068 - val_mean_io_u_2: 0.3518 - val_sparse_categorical_accuracy: 0.4808
Epoch 12/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 107ms/step - loss: 8.0777 - mean_io_u_2: 0.3451 - sparse_categorical_accuracy: 0.4392 - val_loss: 7.2252 - val_mean_io_u_2: 0.3497 - val_sparse_categorical_accuracy: 0.4815
Epoch 13/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 21s 110ms/step - loss: 8.1355 - mean_io_u_2: 0.3446 - sparse_categorical_accuracy: 0.4366 - val_loss: 7.5587 - val_mean_io_u_2: 0.3500 - val_sparse_categorical_accuracy: 0.4671
Epoch 14/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 20s 107ms/step - loss: 8.1828 - mean_io_u_2: 0.3410 - sparse_categorical_accuracy: 0.4330 - val_loss: 7.2464 - val_mean_io_u_2: 0.3557 - val_sparse_categorical_accuracy: 0.4927
Epoch 15/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 108ms/step - loss: 8.1845 - mean_io_u_2: 0.3432 - sparse_categorical_accuracy: 0.4330 - val_loss: 7.2032 - val_mean_io_u_2: 0.3506 - val_sparse_categorical_accuracy: 0.4805
Epoch 16/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 21s 109ms/step - loss: 8.1183 - mean_io_u_2: 0.3449 - sparse_categorical_accuracy: 0.4374 - val_loss: 7.6210 - val_mean_io_u_2: 0.3460 - val_sparse_categorical_accuracy: 0.4751
Epoch 17/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 21s 111ms/step - loss: 8.1766 - mean_io_u_2: 0.3429 - sparse_categorical_accuracy: 0.4329 - val_loss: 7.5361 - val_mean_io_u_2: 0.3489 - val_sparse_categorical_accuracy: 0.4639
Epoch 18/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 109ms/step - loss: 8.0453 - mean_io_u_2: 0.3442 - sparse_categorical_accuracy: 0.4404 - val_loss: 7.1767 - val_mean_io_u_2: 0.3549 - val_sparse_categorical_accuracy: 0.4839
Epoch 19/20

Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

98/98 [==============================] - 20s 109ms/step - loss: 8.0856 - mean_io_u_2: 0.3449 - sparse_categorical_accuracy: 0.4390 - val_loss: 7.1724 - val_mean_io_u_2: 0.3574 - val_sparse_categorical_accuracy: 0.4878
Epoch 20/20

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

98/98 [==============================] - 21s 109ms/step - loss: 8.1378 - mean_io_u_2: 0.3445 - sparse_categorical_accuracy: 0.4358 - val_loss: 7.5449 - val_mean_io_u_2: 0.3521 - val_sparse_categorical_accuracy: 0.4681

可视化

绘制训练运行的指标

我们通过跟踪准确率、损失和平均 IoU 的训练和验证指标,对所有 3 个模型版本进行比较研究。

total_plots = len(fcn32s_history.history)
cols = total_plots // 2

rows = total_plots // cols

if total_plots % cols != 0:
    rows += 1

# Set all history dictionary objects
fcn32s_dict = fcn32s_history.history
fcn16s_dict = fcn16s_history.history
fcn8s_dict = fcn8s_history.history

pos = range(1, total_plots + 1)
plt.figure(figsize=(15, 10))

for i, ((key_32s, value_32s), (key_16s, value_16s), (key_8s, value_8s)) in enumerate(
    zip(fcn32s_dict.items(), fcn16s_dict.items(), fcn8s_dict.items())
):
    plt.subplot(rows, cols, pos[i])
    plt.plot(range(len(value_32s)), value_32s)
    plt.plot(range(len(value_16s)), value_16s)
    plt.plot(range(len(value_8s)), value_8s)
    plt.title(str(key_32s) + " (combined)")
    plt.legend(["FCN-32S", "FCN-16S", "FCN-8S"])

plt.show()

png

可视化预测的分割掩码

为了更好地理解结果并对其进行可视化,我们从测试数据集中选择一个随机图像,并在其上执行推理,以查看每个模型生成的掩码。注意:为了获得更好的结果,模型必须经过更多轮次的训练。

images, masks = next(iter(test_ds))
random_idx = keras.random.uniform([], minval=0, maxval=BATCH_SIZE,seed=10)

# Get random test image and mask
test_image = images[int(random_idx)].numpy().astype("float")
test_mask = masks[int(random_idx)].numpy().astype("float")

pred_image = ops.expand_dims(test_image, axis=0)
pred_image = keras.applications.vgg19.preprocess_input(pred_image)

# Perform inference on FCN-32S
pred_mask_32s = fcn32s_model.predict(pred_image, verbose=0).astype("float")
pred_mask_32s = np.argmax(pred_mask_32s, axis=-1)
pred_mask_32s = pred_mask_32s[0, ...]

# Perform inference on FCN-16S
pred_mask_16s = fcn16s_model.predict(pred_image, verbose=0).astype("float")
pred_mask_16s = np.argmax(pred_mask_16s, axis=-1)
pred_mask_16s = pred_mask_16s[0, ...]

# Perform inference on FCN-8S
pred_mask_8s = fcn8s_model.predict(pred_image, verbose=0).astype("float")
pred_mask_8s = np.argmax(pred_mask_8s, axis=-1)
pred_mask_8s = pred_mask_8s[0, ...]

# Plot all results
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(15, 8))

fig.delaxes(ax[0, 2])

ax[0, 0].set_title("Image")
ax[0, 0].imshow(test_image / 255.0)

ax[0, 1].set_title("Image with ground truth overlay")
ax[0, 1].imshow(test_image / 255.0)
ax[0, 1].imshow(
    test_mask,
    cmap="inferno",
    alpha=0.6,
)

ax[1, 0].set_title("Image with FCN-32S mask overlay")
ax[1, 0].imshow(test_image / 255.0)
ax[1, 0].imshow(pred_mask_32s, cmap="inferno", alpha=0.6)

ax[1, 1].set_title("Image with FCN-16S mask overlay")
ax[1, 1].imshow(test_image / 255.0)
ax[1, 1].imshow(pred_mask_16s, cmap="inferno", alpha=0.6)

ax[1, 2].set_title("Image with FCN-8S mask overlay")
ax[1, 2].imshow(test_image / 255.0)
ax[1, 2].imshow(pred_mask_8s, cmap="inferno", alpha=0.6)

plt.show()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png


结论

全卷积网络是一个非常简单的网络,它在不同基准的图像分割任务中产生了强劲的结果。随着像 注意力 这样的更佳机制的出现(如 SegFormerDeTR 中所使用的那样),该模型可作为一种快速迭代并针对未知数据找到此任务基线的方法。


致谢

感谢 Aritra Roy GosthipatyAyush ThakurRitwik Raha 对此示例进行初步审查。同时感谢 Google 开发者专家 计划。