代码示例 / 生成式深度学习 / 通过文本反转教授 Stable Diffusion 新概念

通过文本反转教授 Stable Diffusion 新概念

作者:Ian Stenbit,lukewood
创建日期 2022/12/09
上次修改日期 2022/12/09
描述:使用 KerasCV 的 Stable Diffusion 实现学习新的视觉概念。

ⓘ 此示例使用 Keras 2

在 Colab 中查看 GitHub 源代码


文本反转

自发布以来,Stable Diffusion 迅速成为生成式机器学习社区的宠儿。大量的用户流量导致了开源社区的贡献改进、大量的提示工程,甚至发明了新颖的算法。

也许最令人印象深刻的新算法是文本反转,在一张图片胜过千言万语:使用文本反转个性化文本到图像生成中提出。

文本反转是通过微调来教授图像生成器特定视觉概念的过程。在下图中,您可以看到此过程的一个示例,其中作者教授模型新概念,称为“S_*”。

https://i.imgur.com/KqEeBsM.jpg

从概念上讲,文本反转通过学习新文本标记的标记嵌入来工作,同时保持 Stable Diffusion 的其余组件冻结。

本指南向您展示如何使用文本反转算法微调 KerasCV 中提供的 Stable Diffusion 模型。在本指南结束时,您将能够编写“灰袍甘道夫作为 <my-funny-cat-token>”。

https://i.imgur.com/rcb1Yfx.png

首先,让我们导入所需的包,并创建一个 Stable Diffusion 实例,以便我们可以使用其一些子组件进行微调。

!pip install -q git+https://github.com/keras-team/keras-cv.git
!pip install -q tensorflow==2.11.0
import math

import keras_cv
import numpy as np
import tensorflow as tf
from keras_cv import layers as cv_layers
from keras_cv.models.stable_diffusion import NoiseScheduler
from tensorflow import keras
import matplotlib.pyplot as plt

stable_diffusion = keras_cv.models.StableDiffusion()
By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE

接下来,让我们定义一个可视化实用程序来展示生成的图像

def plot_images(images):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        plt.imshow(images[i])
        plt.axis("off")

组装文本-图像对数据集

为了训练我们新标记的嵌入,我们首先必须组装一个由文本-图像对组成的数据集。数据集中的每个样本都必须包含我们正在教授 Stable Diffusion 的概念的图像,以及准确表示图像内容的标题。在本教程中,我们将教授 Stable Diffusion Luke 和 Ian 的 GitHub 头像的概念

gh-avatars

首先,让我们构建一个猫娃娃的图像数据集

def assemble_image_dataset(urls):
    # Fetch all remote files
    files = [tf.keras.utils.get_file(origin=url) for url in urls]

    # Resize images
    resize = keras.layers.Resizing(height=512, width=512, crop_to_aspect_ratio=True)
    images = [keras.utils.load_img(img) for img in files]
    images = [keras.utils.img_to_array(img) for img in images]
    images = np.array([resize(img) for img in images])

    # The StableDiffusion image encoder requires images to be normalized to the
    # [-1, 1] pixel value range
    images = images / 127.5 - 1

    # Create the tf.data.Dataset
    image_dataset = tf.data.Dataset.from_tensor_slices(images)

    # Shuffle and introduce random noise
    image_dataset = image_dataset.shuffle(50, reshuffle_each_iteration=True)
    image_dataset = image_dataset.map(
        cv_layers.RandomCropAndResize(
            target_size=(512, 512),
            crop_area_factor=(0.8, 1.0),
            aspect_ratio_factor=(1.0, 1.0),
        ),
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    image_dataset = image_dataset.map(
        cv_layers.RandomFlip(mode="horizontal"),
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    return image_dataset

接下来,我们组装一个文本数据集

MAX_PROMPT_LENGTH = 77
placeholder_token = "<my-funny-cat-token>"


def pad_embedding(embedding):
    return embedding + (
        [stable_diffusion.tokenizer.end_of_text] * (MAX_PROMPT_LENGTH - len(embedding))
    )


stable_diffusion.tokenizer.add_tokens(placeholder_token)


def assemble_text_dataset(prompts):
    prompts = [prompt.format(placeholder_token) for prompt in prompts]
    embeddings = [stable_diffusion.tokenizer.encode(prompt) for prompt in prompts]
    embeddings = [np.array(pad_embedding(embedding)) for embedding in embeddings]
    text_dataset = tf.data.Dataset.from_tensor_slices(embeddings)
    text_dataset = text_dataset.shuffle(100, reshuffle_each_iteration=True)
    return text_dataset

最后,我们将数据集压缩在一起以生成文本-图像对数据集。

def assemble_dataset(urls, prompts):
    image_dataset = assemble_image_dataset(urls)
    text_dataset = assemble_text_dataset(prompts)
    # the image dataset is quite short, so we repeat it to match the length of the
    # text prompt dataset
    image_dataset = image_dataset.repeat()
    # we use the text prompt dataset to determine the length of the dataset.  Due to
    # the fact that there are relatively few prompts we repeat the dataset 5 times.
    # we have found that this anecdotally improves results.
    text_dataset = text_dataset.repeat(5)
    return tf.data.Dataset.zip((image_dataset, text_dataset))

为了确保我们的提示具有描述性,我们使用极其通用的提示。

让我们用一些示例图像和提示来试试。

train_ds = assemble_dataset(
    urls=[
        "https://i.imgur.com/VIedH1X.jpg",
        "https://i.imgur.com/eBw13hE.png",
        "https://i.imgur.com/oJ3rSg7.png",
        "https://i.imgur.com/5mCL6Df.jpg",
        "https://i.imgur.com/4Q6WWyI.jpg",
    ],
    prompts=[
        "a photo of a {}",
        "a rendering of a {}",
        "a cropped photo of the {}",
        "the photo of a {}",
        "a photo of a clean {}",
        "a dark photo of the {}",
        "a photo of my {}",
        "a photo of the cool {}",
        "a close-up photo of a {}",
        "a bright photo of the {}",
        "a cropped photo of a {}",
        "a photo of the {}",
        "a good photo of the {}",
        "a photo of one {}",
        "a close-up photo of the {}",
        "a rendition of the {}",
        "a photo of the clean {}",
        "a rendition of a {}",
        "a photo of a nice {}",
        "a good photo of a {}",
        "a photo of the nice {}",
        "a photo of the small {}",
        "a photo of the weird {}",
        "a photo of the large {}",
        "a photo of a cool {}",
        "a photo of a small {}",
    ],
)

关于提示准确性的重要性

在我们第一次尝试编写本指南时,我们在数据集中包含了这些猫娃娃组的图像,但继续使用上面列出的通用提示。我们的结果在轶事上很差。例如,这是使用此方法生成的猫娃娃甘道夫

mediocre-wizard

从概念上讲它很接近,但它并不像它可以的那样好。

为了解决这个问题,我们开始尝试将图像分成单个猫娃娃和猫娃娃组的图像。在进行此拆分后,我们为组照片提出了新的提示。

在准确表示内容的文本-图像对上进行训练大大提高了我们结果的质量。这说明了提示准确性的重要性。

除了将图像分成单个和组图像外,我们还删除了一些不准确的提示;例如“{} 的深色照片”

牢记这一点,我们在下面组装了我们的最终训练数据集

single_ds = assemble_dataset(
    urls=[
        "https://i.imgur.com/VIedH1X.jpg",
        "https://i.imgur.com/eBw13hE.png",
        "https://i.imgur.com/oJ3rSg7.png",
        "https://i.imgur.com/5mCL6Df.jpg",
        "https://i.imgur.com/4Q6WWyI.jpg",
    ],
    prompts=[
        "a photo of a {}",
        "a rendering of a {}",
        "a cropped photo of the {}",
        "the photo of a {}",
        "a photo of a clean {}",
        "a photo of my {}",
        "a photo of the cool {}",
        "a close-up photo of a {}",
        "a bright photo of the {}",
        "a cropped photo of a {}",
        "a photo of the {}",
        "a good photo of the {}",
        "a photo of one {}",
        "a close-up photo of the {}",
        "a rendition of the {}",
        "a photo of the clean {}",
        "a rendition of a {}",
        "a photo of a nice {}",
        "a good photo of a {}",
        "a photo of the nice {}",
        "a photo of the small {}",
        "a photo of the weird {}",
        "a photo of the large {}",
        "a photo of a cool {}",
        "a photo of a small {}",
    ],
)

https://i.imgur.com/gQCRjK6.png

看起来很棒!

接下来,我们组装一个我们 GitHub 头像组的数据集

group_ds = assemble_dataset(
    urls=[
        "https://i.imgur.com/yVmZ2Qa.jpg",
        "https://i.imgur.com/JbyFbZJ.jpg",
        "https://i.imgur.com/CCubd3q.jpg",
    ],
    prompts=[
        "a photo of a group of {}",
        "a rendering of a group of {}",
        "a cropped photo of the group of {}",
        "the photo of a group of {}",
        "a photo of a clean group of {}",
        "a photo of my group of {}",
        "a photo of a cool group of {}",
        "a close-up photo of a group of {}",
        "a bright photo of the group of {}",
        "a cropped photo of a group of {}",
        "a photo of the group of {}",
        "a good photo of the group of {}",
        "a photo of one group of {}",
        "a close-up photo of the group of {}",
        "a rendition of the group of {}",
        "a photo of the clean group of {}",
        "a rendition of a group of {}",
        "a photo of a nice group of {}",
        "a good photo of a group of {}",
        "a photo of the nice group of {}",
        "a photo of the small group of {}",
        "a photo of the weird group of {}",
        "a photo of the large group of {}",
        "a photo of a cool group of {}",
        "a photo of a small group of {}",
    ],
)

https://i.imgur.com/GY9Pf3D.png

最后,我们将两个数据集连接起来

train_ds = single_ds.concatenate(group_ds)
train_ds = train_ds.batch(1).shuffle(
    train_ds.cardinality(), reshuffle_each_iteration=True
)

向文本编码器添加新标记

接下来,我们为 Stable Diffusion 模型创建一个新的文本编码器,并将我们新的嵌入添加到 '' 模型中。

tokenized_initializer = stable_diffusion.tokenizer.encode("cat")[1]
new_weights = stable_diffusion.text_encoder.layers[2].token_embedding(
    tf.constant(tokenized_initializer)
)

# Get len of .vocab instead of tokenizer
new_vocab_size = len(stable_diffusion.tokenizer.vocab)

# The embedding layer is the 2nd layer in the text encoder
old_token_weights = stable_diffusion.text_encoder.layers[
    2
].token_embedding.get_weights()
old_position_weights = stable_diffusion.text_encoder.layers[
    2
].position_embedding.get_weights()

old_token_weights = old_token_weights[0]
new_weights = np.expand_dims(new_weights, axis=0)
new_weights = np.concatenate([old_token_weights, new_weights], axis=0)

让我们构建一个新的 TextEncoder 并进行准备。

# Have to set download_weights False so we can init (otherwise tries to load weights)
new_encoder = keras_cv.models.stable_diffusion.TextEncoder(
    keras_cv.models.stable_diffusion.stable_diffusion.MAX_PROMPT_LENGTH,
    vocab_size=new_vocab_size,
    download_weights=False,
)
for index, layer in enumerate(stable_diffusion.text_encoder.layers):
    # Layer 2 is the embedding layer, so we omit it from our weight-copying
    if index == 2:
        continue
    new_encoder.layers[index].set_weights(layer.get_weights())


new_encoder.layers[2].token_embedding.set_weights([new_weights])
new_encoder.layers[2].position_embedding.set_weights(old_position_weights)

stable_diffusion._text_encoder = new_encoder
stable_diffusion._text_encoder.compile(jit_compile=True)

训练

现在我们可以继续进行激动人心的部分:训练!

在 TextualInversion 中,唯一被训练的模型部分是嵌入向量。让我们冻结模型的其余部分。

stable_diffusion.diffusion_model.trainable = False
stable_diffusion.decoder.trainable = False
stable_diffusion.text_encoder.trainable = True

stable_diffusion.text_encoder.layers[2].trainable = True


def traverse_layers(layer):
    if hasattr(layer, "layers"):
        for layer in layer.layers:
            yield layer
    if hasattr(layer, "token_embedding"):
        yield layer.token_embedding
    if hasattr(layer, "position_embedding"):
        yield layer.position_embedding


for layer in traverse_layers(stable_diffusion.text_encoder):
    if isinstance(layer, keras.layers.Embedding) or "clip_embedding" in layer.name:
        layer.trainable = True
    else:
        layer.trainable = False

new_encoder.layers[2].position_embedding.trainable = False

让我们确认设置了正确的权重以进行训练。

all_models = [
    stable_diffusion.text_encoder,
    stable_diffusion.diffusion_model,
    stable_diffusion.decoder,
]
print([[w.shape for w in model.trainable_weights] for model in all_models])
[[TensorShape([49409, 768])], [], []]

训练新的嵌入

为了训练嵌入,我们需要一些实用程序。我们从 KerasCV 导入 NoiseScheduler,并在下面定义以下实用程序

  • sample_from_encoder_outputs 是围绕基本 Stable Diffusion 图像编码器的包装器,它从图像编码器生成的统计分布中采样,而不是仅获取平均值(如许多其他 SD 应用程序一样)
  • get_timestep_embedding 为指定的扩散模型时间步生成嵌入
  • get_position_ids 为文本编码器生成位置 ID 张量(它只是一个从[1, MAX_PROMPT_LENGTH]开始的序列)
# Remove the top layer from the encoder, which cuts off the variance and only returns
# the mean
training_image_encoder = keras.Model(
    stable_diffusion.image_encoder.input,
    stable_diffusion.image_encoder.layers[-2].output,
)


def sample_from_encoder_outputs(outputs):
    mean, logvar = tf.split(outputs, 2, axis=-1)
    logvar = tf.clip_by_value(logvar, -30.0, 20.0)
    std = tf.exp(0.5 * logvar)
    sample = tf.random.normal(tf.shape(mean))
    return mean + std * sample


def get_timestep_embedding(timestep, dim=320, max_period=10000):
    half = dim // 2
    freqs = tf.math.exp(
        -math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half
    )
    args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
    embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
    return embedding


def get_position_ids():
    return tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32)

接下来,我们实现了一个StableDiffusionFineTuner,它是keras.Model的子类,它重写了train_step以训练文本编码器的标记嵌入。这是文本反转算法的核心。

从抽象的角度来说,训练步骤从冻结的 SD 图像编码器潜在分布的训练图像输出中获取样本,向该样本添加噪声,然后将该噪声样本传递给冻结的扩散模型。扩散模型的隐藏状态是对应于图像的提示的文本编码器的输出。

我们的最终目标状态是扩散模型能够使用文本编码作为隐藏状态从样本中分离噪声,因此我们的损失是噪声和扩散模型输出的均方误差(理想情况下,它已经从噪声中去除了图像潜变量)。

我们仅计算文本编码器标记嵌入的梯度,并且在训练步骤中,我们清零除我们正在学习的标记之外的所有标记的梯度。

有关训练步骤的更多详细信息,请参阅内联代码注释。

class StableDiffusionFineTuner(keras.Model):
    def __init__(self, stable_diffusion, noise_scheduler, **kwargs):
        super().__init__(**kwargs)
        self.stable_diffusion = stable_diffusion
        self.noise_scheduler = noise_scheduler

    def train_step(self, data):
        images, embeddings = data

        with tf.GradientTape() as tape:
            # Sample from the predicted distribution for the training image
            latents = sample_from_encoder_outputs(training_image_encoder(images))
            # The latents must be downsampled to match the scale of the latents used
            # in the training of StableDiffusion.  This number is truly just a "magic"
            # constant that they chose when training the model.
            latents = latents * 0.18215

            # Produce random noise in the same shape as the latent sample
            noise = tf.random.normal(tf.shape(latents))
            batch_dim = tf.shape(latents)[0]

            # Pick a random timestep for each sample in the batch
            timesteps = tf.random.uniform(
                (batch_dim,),
                minval=0,
                maxval=noise_scheduler.train_timesteps,
                dtype=tf.int64,
            )

            # Add noise to the latents based on the timestep for each sample
            noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)

            # Encode the text in the training samples to use as hidden state in the
            # diffusion model
            encoder_hidden_state = self.stable_diffusion.text_encoder(
                [embeddings, get_position_ids()]
            )

            # Compute timestep embeddings for the randomly-selected timesteps for each
            # sample in the batch
            timestep_embeddings = tf.map_fn(
                fn=get_timestep_embedding,
                elems=timesteps,
                fn_output_signature=tf.float32,
            )

            # Call the diffusion model
            noise_pred = self.stable_diffusion.diffusion_model(
                [noisy_latents, timestep_embeddings, encoder_hidden_state]
            )

            # Compute the mean-squared error loss and reduce it.
            loss = self.compiled_loss(noise_pred, noise)
            loss = tf.reduce_mean(loss, axis=2)
            loss = tf.reduce_mean(loss, axis=1)
            loss = tf.reduce_mean(loss)

        # Load the trainable weights and compute the gradients for them
        trainable_weights = self.stable_diffusion.text_encoder.trainable_weights
        grads = tape.gradient(loss, trainable_weights)

        # Gradients are stored in indexed slices, so we have to find the index
        # of the slice(s) which contain the placeholder token.
        index_of_placeholder_token = tf.reshape(tf.where(grads[0].indices == 49408), ())
        condition = grads[0].indices == 49408
        condition = tf.expand_dims(condition, axis=-1)

        # Override the gradients, zeroing out the gradients for all slices that
        # aren't for the placeholder token, effectively freezing the weights for
        # all other tokens.
        grads[0] = tf.IndexedSlices(
            values=tf.where(condition, grads[0].values, 0),
            indices=grads[0].indices,
            dense_shape=grads[0].dense_shape,
        )

        self.optimizer.apply_gradients(zip(grads, trainable_weights))
        return {"loss": loss}

在开始训练之前,让我们看看 Stable Diffusion 为我们的标记生成什么。

generated = stable_diffusion.text_to_image(
    f"an oil painting of {placeholder_token}", seed=1337, batch_size=3
)
plot_images(generated)
25/25 [==============================] - 19s 314ms/step

png

如您所见,模型仍然认为我们的标记是猫,因为这是我们用来初始化自定义标记的种子标记。

现在,要开始训练,我们只需像任何其他 Keras 模型一样compile()我们的模型。在此之前,我们还为训练实例化了一个噪声调度程序,并配置了学习率和优化器等训练参数。

noise_scheduler = NoiseScheduler(
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    train_timesteps=1000,
)
trainer = StableDiffusionFineTuner(stable_diffusion, noise_scheduler, name="trainer")
EPOCHS = 50
learning_rate = keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=1e-4, decay_steps=train_ds.cardinality() * EPOCHS
)
optimizer = keras.optimizers.Adam(
    weight_decay=0.004, learning_rate=learning_rate, epsilon=1e-8, global_clipnorm=10
)

trainer.compile(
    optimizer=optimizer,
    # We are performing reduction manually in our train step, so none is required here.
    loss=keras.losses.MeanSquaredError(reduction="none"),
)

为了监控训练,我们可以生成一个keras.callbacks.Callback,以在每个时期使用我们的自定义标记生成一些图像。

我们创建了三个具有不同提示的回调,以便我们可以看到它们在训练过程中是如何发展的。我们使用固定的种子,以便我们可以轻松地看到学习到的标记的进展。

class GenerateImages(keras.callbacks.Callback):
    def __init__(
        self, stable_diffusion, prompt, steps=50, frequency=10, seed=None, **kwargs
    ):
        super().__init__(**kwargs)
        self.stable_diffusion = stable_diffusion
        self.prompt = prompt
        self.seed = seed
        self.frequency = frequency
        self.steps = steps

    def on_epoch_end(self, epoch, logs):
        if epoch % self.frequency == 0:
            images = self.stable_diffusion.text_to_image(
                self.prompt, batch_size=3, num_steps=self.steps, seed=self.seed
            )
            plot_images(
                images,
            )


cbs = [
    GenerateImages(
        stable_diffusion, prompt=f"an oil painting of {placeholder_token}", seed=1337
    ),
    GenerateImages(
        stable_diffusion, prompt=f"gandalf the gray as a {placeholder_token}", seed=1337
    ),
    GenerateImages(
        stable_diffusion,
        prompt=f"two {placeholder_token} getting married, photorealistic, high quality",
        seed=1337,
    ),
]

现在,剩下的就是调用model.fit()了!

trainer.fit(
    train_ds,
    epochs=EPOCHS,
    callbacks=cbs,
)
Epoch 1/50
50/50 [==============================] - 16s 318ms/step
50/50 [==============================] - 16s 318ms/step
50/50 [==============================] - 16s 318ms/step
250/250 [==============================] - 194s 469ms/step - loss: 0.1533
Epoch 2/50
250/250 [==============================] - 68s 269ms/step - loss: 0.1557
Epoch 3/50
250/250 [==============================] - 68s 269ms/step - loss: 0.1359
Epoch 4/50
250/250 [==============================] - 68s 269ms/step - loss: 0.1693
Epoch 5/50
250/250 [==============================] - 68s 269ms/step - loss: 0.1475
Epoch 6/50
250/250 [==============================] - 68s 268ms/step - loss: 0.1472
Epoch 7/50
250/250 [==============================] - 68s 268ms/step - loss: 0.1533
Epoch 8/50
250/250 [==============================] - 68s 268ms/step - loss: 0.1450
Epoch 9/50
250/250 [==============================] - 68s 268ms/step - loss: 0.1639
Epoch 10/50
250/250 [==============================] - 68s 269ms/step - loss: 0.1351
Epoch 11/50
50/50 [==============================] - 16s 316ms/step
50/50 [==============================] - 16s 316ms/step
50/50 [==============================] - 16s 317ms/step
250/250 [==============================] - 116s 464ms/step - loss: 0.1474
Epoch 12/50
250/250 [==============================] - 68s 268ms/step - loss: 0.1737
Epoch 13/50
250/250 [==============================] - 68s 269ms/step - loss: 0.1427
Epoch 14/50
250/250 [==============================] - 68s 269ms/step - loss: 0.1698
Epoch 15/50
250/250 [==============================] - 68s 270ms/step - loss: 0.1424
Epoch 16/50
250/250 [==============================] - 68s 268ms/step - loss: 0.1339
Epoch 17/50
250/250 [==============================] - 68s 268ms/step - loss: 0.1397
Epoch 18/50
250/250 [==============================] - 68s 268ms/step - loss: 0.1469
Epoch 19/50
250/250 [==============================] - 67s 267ms/step - loss: 0.1649
Epoch 20/50
250/250 [==============================] - 68s 268ms/step - loss: 0.1582
Epoch 21/50
50/50 [==============================] - 16s 315ms/step
50/50 [==============================] - 16s 316ms/step
50/50 [==============================] - 16s 316ms/step
250/250 [==============================] - 116s 462ms/step - loss: 0.1331
Epoch 22/50
250/250 [==============================] - 67s 267ms/step - loss: 0.1319
Epoch 23/50
250/250 [==============================] - 68s 267ms/step - loss: 0.1521
Epoch 24/50
250/250 [==============================] - 67s 267ms/step - loss: 0.1486
Epoch 25/50
250/250 [==============================] - 68s 267ms/step - loss: 0.1449
Epoch 26/50
250/250 [==============================] - 67s 266ms/step - loss: 0.1349
Epoch 27/50
250/250 [==============================] - 67s 267ms/step - loss: 0.1454
Epoch 28/50
250/250 [==============================] - 68s 268ms/step - loss: 0.1394
Epoch 29/50
250/250 [==============================] - 67s 267ms/step - loss: 0.1489
Epoch 30/50
250/250 [==============================] - 67s 267ms/step - loss: 0.1338
Epoch 31/50
50/50 [==============================] - 16s 315ms/step
50/50 [==============================] - 16s 320ms/step
50/50 [==============================] - 16s 315ms/step
250/250 [==============================] - 116s 462ms/step - loss: 0.1328
Epoch 32/50
250/250 [==============================] - 67s 267ms/step - loss: 0.1693
Epoch 33/50
250/250 [==============================] - 67s 266ms/step - loss: 0.1420
Epoch 34/50
250/250 [==============================] - 67s 266ms/step - loss: 0.1255
Epoch 35/50
250/250 [==============================] - 67s 266ms/step - loss: 0.1239
Epoch 36/50
250/250 [==============================] - 67s 267ms/step - loss: 0.1558
Epoch 37/50
250/250 [==============================] - 68s 267ms/step - loss: 0.1527
Epoch 38/50
250/250 [==============================] - 67s 267ms/step - loss: 0.1461
Epoch 39/50
250/250 [==============================] - 67s 267ms/step - loss: 0.1555
Epoch 40/50
250/250 [==============================] - 67s 266ms/step - loss: 0.1515
Epoch 41/50
50/50 [==============================] - 16s 315ms/step
50/50 [==============================] - 16s 315ms/step
50/50 [==============================] - 16s 315ms/step
250/250 [==============================] - 116s 461ms/step - loss: 0.1291
Epoch 42/50
250/250 [==============================] - 67s 266ms/step - loss: 0.1474
Epoch 43/50
250/250 [==============================] - 67s 267ms/step - loss: 0.1908
Epoch 44/50
250/250 [==============================] - 67s 267ms/step - loss: 0.1506
Epoch 45/50
250/250 [==============================] - 68s 267ms/step - loss: 0.1424
Epoch 46/50
250/250 [==============================] - 67s 266ms/step - loss: 0.1601
Epoch 47/50
250/250 [==============================] - 67s 266ms/step - loss: 0.1312
Epoch 48/50
250/250 [==============================] - 67s 266ms/step - loss: 0.1524
Epoch 49/50
250/250 [==============================] - 67s 266ms/step - loss: 0.1477
Epoch 50/50
250/250 [==============================] - 67s 267ms/step - loss: 0.1397

<keras.callbacks.History at 0x7f183aea3eb8>

png

png

png

png

png

png

png

png

png

png

png

png

png

png

png

看到模型如何随着时间的推移学习我们的新标记非常有趣。尝试一下,看看如何调整训练参数和训练数据集以生成最佳图像!


试用微调后的模型

现在是最有趣的部分。我们已经学习了自定义标记的标记嵌入,因此现在我们可以像使用任何其他标记一样使用 Stable Diffusion 生成图像!

以下是一些有趣的示例提示,可帮助您入门,以及我们猫娃娃标记的一些示例输出!

generated = stable_diffusion.text_to_image(
    f"Gandalf as a {placeholder_token} fantasy art drawn by disney concept artists, "
    "golden colour, high quality, highly detailed, elegant, sharp focus, concept art, "
    "character concepts, digital painting, mystery, adventure",
    batch_size=3,
)
plot_images(generated)
25/25 [==============================] - 8s 316ms/step

png

generated = stable_diffusion.text_to_image(
    f"A masterpiece of a {placeholder_token} crying out to the heavens. "
    f"Behind the {placeholder_token}, an dark, evil shade looms over it - sucking the "
    "life right out of it.",
    batch_size=3,
)
plot_images(generated)
25/25 [==============================] - 8s 314ms/step

png

generated = stable_diffusion.text_to_image(
    f"An evil {placeholder_token}.", batch_size=3
)
plot_images(generated)
25/25 [==============================] - 8s 322ms/step

png

generated = stable_diffusion.text_to_image(
    f"A mysterious {placeholder_token} approaches the great pyramids of egypt.",
    batch_size=3,
)
plot_images(generated)
25/25 [==============================] - 8s 315ms/step

png


结论

使用文本反转算法,你可以教会Stable Diffusion新的概念!

一些可能的下一步

  • 尝试你自己的提示词
  • 教会模型一种风格
  • 收集你最喜欢的宠物猫或狗的数据集,并教会模型关于它们的信息