代码示例 / 生成式深度学习 / 神经风格迁移

神经风格迁移

作者: fchollet
创建日期 2016/01/11
最后修改日期 2020/05/02
描述:使用梯度下降将参考图像的风格迁移到目标图像。

ⓘ 此示例使用 Keras 3

在 Colab 中查看 GitHub 源码


简介

风格迁移是指生成一幅与基础图像具有相同“内容”但具有不同图片(通常是艺术图片)“风格”的图像。这是通过优化一个具有三个组成部分的损失函数来实现的:“风格损失”、“内容损失”和“总变差损失”。

  • 总变差损失强制组合图像的像素之间具有局部空间连续性,从而使其具有视觉连贯性。
  • 风格损失是深度学习发挥作用的地方——它使用深度卷积神经网络来定义。准确地说,它包括基础图像和风格参考图像的表示的 Gram 矩阵之间的 L2 距离之和,这些表示是从卷积神经网络(在 ImageNet 上训练)的不同层提取的。总体思路是在不同的空间尺度(相当大的尺度——由所考虑的层的深度定义)捕获颜色/纹理信息。
  • 内容损失是基础图像(从深层提取)的特征与组合图像的特征之间的 L2 距离,使生成的图像足够接近原始图像。

参考: 艺术风格的神经算法


设置

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
import numpy as np
import tensorflow as tf
from keras.applications import vgg19

base_image_path = keras.utils.get_file("paris.jpg", "https://i.imgur.com/F28w3Ac.jpg")
style_reference_image_path = keras.utils.get_file(
    "starry_night.jpg", "https://i.imgur.com/9ooB60I.jpg"
)
result_prefix = "paris_generated"

# Weights of the different loss components
total_variation_weight = 1e-6
style_weight = 1e-6
content_weight = 2.5e-8

# Dimensions of the generated picture.
width, height = keras.utils.load_img(base_image_path).size
img_nrows = 400
img_ncols = int(width * img_nrows / height)
Downloading data from https://i.imgur.com/F28w3Ac.jpg
 102437/102437 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
Downloading data from https://i.imgur.com/9ooB60I.jpg
 935806/935806 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step

让我们看看我们的基础(内容)图像和我们的风格参考图像

from IPython.display import Image, display

display(Image(base_image_path))
display(Image(style_reference_image_path))

jpeg

jpeg


图像预处理/后处理工具

def preprocess_image(image_path):
    # Util function to open, resize and format pictures into appropriate tensors
    img = keras.utils.load_img(image_path, target_size=(img_nrows, img_ncols))
    img = keras.utils.img_to_array(img)
    img = np.expand_dims(img, axis=0)
    img = vgg19.preprocess_input(img)
    return tf.convert_to_tensor(img)


def deprocess_image(x):
    # Util function to convert a tensor into a valid image
    x = x.reshape((img_nrows, img_ncols, 3))
    # Remove zero-center by mean pixel
    x[:, :, 0] += 103.939
    x[:, :, 1] += 116.779
    x[:, :, 2] += 123.68
    # 'BGR'->'RGB'
    x = x[:, :, ::-1]
    x = np.clip(x, 0, 255).astype("uint8")
    return x

计算风格迁移损失

首先,我们需要定义 4 个实用函数

  • gram_matrix(用于计算风格损失)
  • style_loss 函数,它使生成的图像接近风格参考图像的局部纹理
  • content_loss 函数,它使生成的图像的高级表示接近基础图像的高级表示
  • total_variation_loss 函数,一种正则化损失,使生成的图像在局部保持连贯性
# The gram matrix of an image tensor (feature-wise outer product)


def gram_matrix(x):
    x = tf.transpose(x, (2, 0, 1))
    features = tf.reshape(x, (tf.shape(x)[0], -1))
    gram = tf.matmul(features, tf.transpose(features))
    return gram


# The "style loss" is designed to maintain
# the style of the reference image in the generated image.
# It is based on the gram matrices (which capture style) of
# feature maps from the style reference image
# and from the generated image


def style_loss(style, combination):
    S = gram_matrix(style)
    C = gram_matrix(combination)
    channels = 3
    size = img_nrows * img_ncols
    return tf.reduce_sum(tf.square(S - C)) / (4.0 * (channels**2) * (size**2))


# An auxiliary loss function
# designed to maintain the "content" of the
# base image in the generated image


def content_loss(base, combination):
    return tf.reduce_sum(tf.square(combination - base))


# The 3rd loss function, total variation loss,
# designed to keep the generated image locally coherent


def total_variation_loss(x):
    a = tf.square(
        x[:, : img_nrows - 1, : img_ncols - 1, :] - x[:, 1:, : img_ncols - 1, :]
    )
    b = tf.square(
        x[:, : img_nrows - 1, : img_ncols - 1, :] - x[:, : img_nrows - 1, 1:, :]
    )
    return tf.reduce_sum(tf.pow(a + b, 1.25))

接下来,让我们创建一个特征提取模型,它检索 VGG19 的中间激活(作为字典,按名称)。

# Build a VGG19 model loaded with pre-trained ImageNet weights
model = vgg19.VGG19(weights="imagenet", include_top=False)

# Get the symbolic outputs of each "key" layer (we gave them unique names).
outputs_dict = dict([(layer.name, layer.output) for layer in model.layers])

# Set up a model that returns the activation values for every layer in
# VGG19 (as a dict).
feature_extractor = keras.Model(inputs=model.inputs, outputs=outputs_dict)
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5
 80134624/80134624 ━━━━━━━━━━━━━━━━━━━━ 2s 0us/step

最后,这是计算风格迁移损失的代码。

# List of layers to use for the style loss.
style_layer_names = [
    "block1_conv1",
    "block2_conv1",
    "block3_conv1",
    "block4_conv1",
    "block5_conv1",
]
# The layer to use for the content loss.
content_layer_name = "block5_conv2"


def compute_loss(combination_image, base_image, style_reference_image):
    input_tensor = tf.concat(
        [base_image, style_reference_image, combination_image], axis=0
    )
    features = feature_extractor(input_tensor)

    # Initialize the loss
    loss = tf.zeros(shape=())

    # Add content loss
    layer_features = features[content_layer_name]
    base_image_features = layer_features[0, :, :, :]
    combination_features = layer_features[2, :, :, :]
    loss = loss + content_weight * content_loss(
        base_image_features, combination_features
    )
    # Add style loss
    for layer_name in style_layer_names:
        layer_features = features[layer_name]
        style_reference_features = layer_features[1, :, :, :]
        combination_features = layer_features[2, :, :, :]
        sl = style_loss(style_reference_features, combination_features)
        loss += (style_weight / len(style_layer_names)) * sl

    # Add total variation loss
    loss += total_variation_weight * total_variation_loss(combination_image)
    return loss

向损失和梯度计算添加 tf.function 装饰器

编译它,使其速度更快。

@tf.function
def compute_loss_and_grads(combination_image, base_image, style_reference_image):
    with tf.GradientTape() as tape:
        loss = compute_loss(combination_image, base_image, style_reference_image)
    grads = tape.gradient(loss, combination_image)
    return loss, grads

训练循环

重复运行普通梯度下降步骤以最小化损失,并在每 100 次迭代后保存生成的图像。

我们每 100 步将学习率衰减 0.96。

optimizer = keras.optimizers.SGD(
    keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=100.0, decay_steps=100, decay_rate=0.96
    )
)

base_image = preprocess_image(base_image_path)
style_reference_image = preprocess_image(style_reference_image_path)
combination_image = tf.Variable(preprocess_image(base_image_path))

iterations = 4000
for i in range(1, iterations + 1):
    loss, grads = compute_loss_and_grads(
        combination_image, base_image, style_reference_image
    )
    optimizer.apply_gradients([(grads, combination_image)])
    if i % 100 == 0:
        print("Iteration %d: loss=%.2f" % (i, loss))
        img = deprocess_image(combination_image.numpy())
        fname = result_prefix + "_at_iteration_%d.png" % i
        keras.utils.save_img(fname, img)
Iteration 100: loss=11021.63
Iteration 200: loss=8516.82
Iteration 300: loss=7572.36
Iteration 400: loss=7062.23
Iteration 500: loss=6733.57
Iteration 600: loss=6498.27
Iteration 700: loss=6319.11
Iteration 800: loss=6176.94
Iteration 900: loss=6060.49
Iteration 1000: loss=5963.24
Iteration 1100: loss=5880.51
Iteration 1200: loss=5809.23
Iteration 1300: loss=5747.35
Iteration 1400: loss=5692.95
Iteration 1500: loss=5644.84
Iteration 1600: loss=5601.82
Iteration 1700: loss=5563.18
Iteration 1800: loss=5528.38
Iteration 1900: loss=5496.89
Iteration 2000: loss=5468.20
Iteration 2100: loss=5441.97
Iteration 2200: loss=5418.02
Iteration 2300: loss=5396.11
Iteration 2400: loss=5376.00
Iteration 2500: loss=5357.49
Iteration 2600: loss=5340.36
Iteration 2700: loss=5324.49
Iteration 2800: loss=5309.77
Iteration 2900: loss=5296.08
Iteration 3000: loss=5283.33
Iteration 3100: loss=5271.47
Iteration 3200: loss=5260.39
Iteration 3300: loss=5250.02
Iteration 3400: loss=5240.29
Iteration 3500: loss=5231.18
Iteration 3600: loss=5222.65
Iteration 3700: loss=5214.61
Iteration 3800: loss=5207.08
Iteration 3900: loss=5199.98
Iteration 4000: loss=5193.27

4000 次迭代后,您将获得以下结果

display(Image(result_prefix + "_at_iteration_4000.png"))

png