代码示例 / 自然语言处理 / 多模态蕴涵

多模态蕴涵

作者:Sayak Paul
创建日期 2021/08/08
最后修改日期 2025/01/03
描述: 训练一个多模态模型来预测蕴涵。

ⓘ 此示例使用 Keras 2

在 Colab 中查看 GitHub 源


简介

在此示例中,我们将构建并训练一个用于预测多模态蕴涵的模型。我们将使用 Google Research 最近引入的多模态蕴涵数据集

什么是多模态蕴涵?

在社交媒体平台上,为了审核和管理内容,我们可能希望近乎实时地找到以下问题的答案:

  • 一条信息是否与另一条信息矛盾?
  • 一条信息是否暗示另一条信息?

在自然语言处理中,此任务称为分析文本蕴涵。然而,这仅限于信息来自文本内容的情况。实际上,信息往往不仅来自文本内容,还来自文本、图像、音频、视频等的多模态组合。多模态蕴涵只是将文本蕴涵扩展到各种新的输入模态。

要求

本示例需要 TensorFlow 2.5 或更高版本。此外,BERT 模型 (Devlin 等人) 需要 TensorFlow Hub 和 TensorFlow Text。这些库可以使用以下命令安装:

!pip install -q tensorflow_text
 [notice] A new release of pip is available: 24.0 -> 24.3.1
 [notice] To update, run: pip install --upgrade pip

导入

from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import random
import math
from skimage.io import imread
from skimage.transform import resize
from PIL import Image
import os

os.environ["KERAS_BACKEND"] = "jax"  # or tensorflow, or torch

import keras
import keras_hub
from keras.utils import PyDataset

定义标签映射

label_map = {"Contradictory": 0, "Implies": 1, "NoEntailment": 2}

收集数据集

原始数据集可在此处获取。它附带图像的 URL,这些图像托管在 Twitter 的照片存储系统(简称 Photo Blob Storage (PBS))上。我们将使用已下载的图像以及原始数据集附带的其他数据。感谢 Nilabhra Roy Chowdhury 参与了图像数据的准备工作。

image_base_path = keras.utils.get_file(
    "tweet_images",
    "https://github.com/sayakpaul/Multimodal-Entailment-Baseline/releases/download/v1.0.0/tweet_images.tar.gz",
    untar=True,
)

读取数据集并应用基本预处理

df = pd.read_csv(
    "https://github.com/sayakpaul/Multimodal-Entailment-Baseline/raw/main/csvs/tweets.csv"
).iloc[
    0:1000
]  # Resources conservation since these are examples and not SOTA
df.sample(10)
id_1 text_1 image_1 id_2 text_2 image_2 标签
815 1370730009921343490 粘性炸弹是威胁,因为它们有磁铁…… http://pbs.twimg.com/media/EwXOFrgVIAEkfjR.jpg 1370731764906295307 粘性炸弹是威胁,因为它们有磁铁…… http://pbs.twimg.com/media/EwXRK_3XEAA6Q6F.jpg 无蕴涵
615 1364119737446395905 #巨蟹座 2.23.21 ♊️❤️✨ #Hor... 每日星座运势 http://pbs.twimg.com/media/Eu5Te44VgAIo1jZ.jpg 1365218087906078720 #巨蟹座 2.26.21 ♊️❤️✨ #Hor... 每日星座运势 http://pbs.twimg.com/media/EvI6nW4WQAA4_E_.jpg 无蕴涵
624 1335542260923068417 驯鹿跑回来了,今年的跑步…… http://pbs.twimg.com/media/Eoi99DyXEAE0AFV.jpg 1335872932267122689 为 2020 年的活动戴上你的红鼻子和鹿角…… http://pbs.twimg.com/media/Eon5Wk7XUAE-CxN.jpg 无蕴涵
970 1345058844439949312 需要参与者进行在线调查!\n\n主题…… http://pbs.twimg.com/media/Eqqb4_MXcAA-Pvu.jpg 1361211461792632835 需要参与者进行关于 Sur... 的顶级研究 http://pbs.twimg.com/media/EuPz0GwXMAMDklt.jpg 无蕴涵
456 1379831489043521545 为@NanoBiteTSF 委托\n享受兄弟们和…… http://pbs.twimg.com/media/EyVf0_VXMAMtRaL.jpg 1380660763749142531 为@NanoBiteTSF 再次委托\n希望你…… http://pbs.twimg.com/media/EykW0iXXAAA2SBC.jpg 无蕴涵
917 1336180735191891968 (2/10)\n(首尔中区)市场集群 ->\n…… http://pbs.twimg.com/media/EosRFpGVQAIeuYG.jpg 1356113330536996866 (3/11)\n(首尔东大门区)高士泰集群…… http://pbs.twimg.com/media/EtHhj7QVcAAibvF.jpg 无蕴涵
276 1339270210029834241 今天,自由的信息传到了卢旺达的基索罗…… http://pbs.twimg.com/media/EpVK3pfXcAAZ5Du.jpg 1340881971132698625 今天,自由的信息正在传达给人民…… http://pbs.twimg.com/media/EpvDorkXYAEyz4g.jpg 蕴含
35 1360186999836200961 阿根廷的比特币 - Google Trends https://t... http://pbs.twimg.com/media/EuBa3UxXYAMb99_.jpg 1382778703055228929 阿根廷想要#比特币 https://#/9lNxJdxX... http://pbs.twimg.com/media/EzCbUFNXMAABwPD.jpg 蕴含
762 1370824756400959491 $HSBA.L:长期趋势是积极的,并且…… http://pbs.twimg.com/media/EwYl2hPWYAE2niq.png 1374347458126475269 尽管技术评级仅为中等,但…… http://pbs.twimg.com/media/ExKpuwrWgAAktg4.png 无蕴涵
130 1373789433607172097 我刚刚看了《泰德·拉索》S01 | E05 集…… http://pbs.twimg.com/media/ExCuNbDXAAQaPiL.jpg 1374913509662806016 我刚刚看了《泰德·拉索》S01 | E06 集…… http://pbs.twimg.com/media/ExSsjRQWgAUVRPz.jpg 矛盾

我们感兴趣的列如下:

  • text_1
  • image_1
  • text_2
  • image_2
  • 标签

蕴涵任务的表述如下:

给定(text_1image_1)和(text_2image_2)的对,它们是否相互蕴涵(或不蕴涵或矛盾)?

我们已经下载了图像。`image_1` 以 `id1` 作为其文件名下载,`image2` 以 `id2` 作为其文件名下载。在下一步中,我们将向 `df` 添加两列——`image_1` 和 `image_2` 的文件路径。

images_one_paths = []
images_two_paths = []

for idx in range(len(df)):
    current_row = df.iloc[idx]
    id_1 = current_row["id_1"]
    id_2 = current_row["id_2"]
    extentsion_one = current_row["image_1"].split(".")[-1]
    extentsion_two = current_row["image_2"].split(".")[-1]

    image_one_path = os.path.join(image_base_path, str(id_1) + f".{extentsion_one}")
    image_two_path = os.path.join(image_base_path, str(id_2) + f".{extentsion_two}")

    images_one_paths.append(image_one_path)
    images_two_paths.append(image_two_path)

df["image_1_path"] = images_one_paths
df["image_2_path"] = images_two_paths

# Create another column containing the integer ids of
# the string labels.
df["label_idx"] = df["label"].apply(lambda x: label_map[x])

数据集可视化

def visualize(idx):
    current_row = df.iloc[idx]
    image_1 = plt.imread(current_row["image_1_path"])
    image_2 = plt.imread(current_row["image_2_path"])
    text_1 = current_row["text_1"]
    text_2 = current_row["text_2"]
    label = current_row["label"]

    plt.subplot(1, 2, 1)
    plt.imshow(image_1)
    plt.axis("off")
    plt.title("Image One")
    plt.subplot(1, 2, 2)
    plt.imshow(image_1)
    plt.axis("off")
    plt.title("Image Two")
    plt.show()

    print(f"Text one: {text_1}")
    print(f"Text two: {text_2}")
    print(f"Label: {label}")


random_idx = random.choice(range(len(df)))
visualize(random_idx)

random_idx = random.choice(range(len(df)))
visualize(random_idx)

png

Text one: World #water day reminds that we should follow the #guidelines to save water for us. This Day is an #opportunity to learn more about water related issues, be #inspired to tell others and take action to make a difference. Just remember, every #drop counts.
#WorldWaterDay2021 https://#/bQ9Hp53qUj
Text two: Water is an extremely precious resource without which life would be impossible. We need to ensure that water is used judiciously, this #WorldWaterDay, let us pledge to reduce water wastage and conserve it.
#WorldWaterDay2021 https://#/0KWnd8Kn8r
Label: NoEntailment

png

Text one: 🎧 𝗘𝗣𝗜𝗦𝗢𝗗𝗘 𝟯𝟬: 𝗗𝗬𝗟𝗔𝗡 𝗙𝗜𝗧𝗭𝗦𝗜𝗠𝗢𝗡𝗦
Dylan Fitzsimons is a young passionate greyhound supporter. 
He and @Drakesport enjoy a great chat about everything greyhounds!
Listen: https://#/B2XgMp0yaO
#GoGreyhoundRacing #ThisRunsDeep #TalkingDogs https://#/crBiSqHUvp
Text two: 🎧 𝗘𝗣𝗜𝗦𝗢𝗗𝗘 𝟯𝟳: 𝗣𝗜𝗢 𝗕𝗔𝗥𝗥𝗬 🎧
Well known within greyhound circles, Pio Barry shares some wonderful greyhound racing stories with @Drakesport in this podcast episode.
A great chat. 
Listen: https://#/mJTVlPHzp0
#TalkingDogs #GoGreyhoundRacing #ThisRunsDeep https://#/QbxtCpLcGm
Label: NoEntailment

训练/测试拆分

数据集存在类别不平衡问题。我们可以在以下单元格中证实这一点。

df["label"].value_counts()
label
NoEntailment     819
Contradictory     92
Implies           89
Name: count, dtype: int64

为了解决这个问题,我们将采用分层抽样。

# 10% for test
train_df, test_df = train_test_split(
    df, test_size=0.1, stratify=df["label"].values, random_state=42
)
# 5% for validation
train_df, val_df = train_test_split(
    train_df, test_size=0.05, stratify=train_df["label"].values, random_state=42
)

print(f"Total training examples: {len(train_df)}")
print(f"Total validation examples: {len(val_df)}")
print(f"Total test examples: {len(test_df)}")
Total training examples: 855
Total validation examples: 45
Total test examples: 100

数据输入管道

Keras Hub 提供了各种 BERT 系列模型。每个模型都带有一个相应的预处理层。您可以从此资源了解有关这些模型及其预处理层的更多信息。

为了使本示例的运行时间相对较短,我们将使用原始 BERT 模型的 base_unacased 变体。

使用 KerasHub 进行文本预处理

text_preprocessor = keras_hub.models.BertTextClassifierPreprocessor.from_preset(
    "bert_base_en_uncased",
    sequence_length=128,
)

对示例输入运行预处理器

idx = random.choice(range(len(train_df)))
row = train_df.iloc[idx]
sample_text_1, sample_text_2 = row["text_1"], row["text_2"]
print(f"Text 1: {sample_text_1}")
print(f"Text 2: {sample_text_2}")

test_text = [sample_text_1, sample_text_2]
text_preprocessed = text_preprocessor(test_text)

print("Keys           : ", list(text_preprocessed.keys()))
print("Shape Token Ids : ", text_preprocessed["token_ids"].shape)
print("Token Ids       : ", text_preprocessed["token_ids"][0, :16])
print(" Shape Padding Mask     : ", text_preprocessed["padding_mask"].shape)
print("Padding Mask     : ", text_preprocessed["padding_mask"][0, :16])
print("Shape Segment Ids : ", text_preprocessed["segment_ids"].shape)
print("Segment Ids       : ", text_preprocessed["segment_ids"][0, :16])
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Text 1: The RPF Lohardaga and Hatia Post of Ranchi Division have recovered  02 bags on 20.02.2021 at Station platform and in T/No.08310 Spl. respectively and  handed over to their actual owner correctly. @RPF_INDIA https://#/bdEBl2egIc
Text 2: The RPF Lohardaga and Hatia Post of Ranchi Division have recovered  02 bags on 20.02.2021 at Station platform and in T/No.08310 (JAT-SBP) Spl. respectively and  handed over to their actual owner correctly. @RPF_INDIA https://#/Q5l2AtA4uq
Keys           :  ['token_ids', 'padding_mask', 'segment_ids']
Shape Token Ids :  (2, 128)
Token Ids       :  [  101  1996  1054 14376  8840 11783 16098  1998  6045  2401  2695  1997
  8086  2072  2407  2031]
 Shape Padding Mask     :  (2, 128)
Padding Mask     :  [ True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True]
Shape Segment Ids :  (2, 128)
Segment Ids       :  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

我们现在将从数据帧创建 tf.data.Dataset 对象。

请注意,文本输入将作为数据输入管道的一部分进行预处理。但预处理模块也可以是其相应 BERT 模型的一部分。这有助于减少训练/服务偏差,并使我们的模型能够使用原始文本输入。请按照本教程了解如何将预处理模块直接集成到模型中。

def dataframe_to_dataset(dataframe):
    columns = ["image_1_path", "image_2_path", "text_1", "text_2", "label_idx"]
    ds = UnifiedPyDataset(
        dataframe,
        batch_size=32,
        workers=4,
    )
    return ds

预处理工具

bert_input_features = ["padding_mask", "segment_ids", "token_ids"]


def preprocess_text(text_1, text_2):
    output = text_preprocessor([text_1, text_2])
    output = {
        feature: keras.ops.reshape(output[feature], [-1])
        for feature in bert_input_features
    }
    return output

创建最终数据集,方法改编自 PyDataset 文档字符串。

class UnifiedPyDataset(PyDataset):
    """A Keras-compatible dataset that processes a DataFrame for TensorFlow, JAX, and PyTorch."""

    def __init__(
        self,
        df,
        batch_size=32,
        workers=4,
        use_multiprocessing=False,
        max_queue_size=10,
        **kwargs,
    ):
        """
        Args:
            df: pandas DataFrame with data
            batch_size: Batch size for dataset
            workers: Number of workers to use for parallel loading (Keras)
            use_multiprocessing: Whether to use multiprocessing
            max_queue_size: Maximum size of the data queue for parallel loading
        """
        super().__init__(**kwargs)
        self.dataframe = df
        columns = ["image_1_path", "image_2_path", "text_1", "text_2"]

        # image files
        self.image_x_1 = self.dataframe["image_1_path"]
        self.image_x_2 = self.dataframe["image_1_path"]
        self.image_y = self.dataframe["label_idx"]

        # text files
        self.text_x_1 = self.dataframe["text_1"]
        self.text_x_2 = self.dataframe["text_2"]
        self.text_y = self.dataframe["label_idx"]

        # general
        self.batch_size = batch_size
        self.workers = workers
        self.use_multiprocessing = use_multiprocessing
        self.max_queue_size = max_queue_size

    def __getitem__(self, index):
        """
        Fetches a batch of data from the dataset at the given index.
        """

        # Return x, y for batch idx.
        low = index * self.batch_size
        # Cap upper bound at array length; the last batch may be smaller
        # if the total number of items is not a multiple of batch size.

        high_image_1 = min(low + self.batch_size, len(self.image_x_1))
        high_image_2 = min(low + self.batch_size, len(self.image_x_2))

        high_text_1 = min(low + self.batch_size, len(self.text_x_1))
        high_text_2 = min(low + self.batch_size, len(self.text_x_1))

        # images files
        batch_image_x_1 = self.image_x_1[low:high_image_1]
        batch_image_y_1 = self.image_y[low:high_image_1]

        batch_image_x_2 = self.image_x_2[low:high_image_2]
        batch_image_y_2 = self.image_y[low:high_image_2]

        # text files
        batch_text_x_1 = self.text_x_1[low:high_text_1]
        batch_text_y_1 = self.text_y[low:high_text_1]

        batch_text_x_2 = self.text_x_2[low:high_text_2]
        batch_text_y_2 = self.text_y[low:high_text_2]

        # image number 1 inputs
        image_1 = [
            resize(imread(file_name), (128, 128)) for file_name in batch_image_x_1
        ]
        image_1 = [
            (  # exeperienced some shapes which were different from others.
                np.array(Image.fromarray((img.astype(np.uint8))).convert("RGB"))
                if img.shape[2] == 4
                else img
            )
            for img in image_1
        ]
        image_1 = np.array(image_1)

        # Both text inputs to the model, return a dict for inputs to BertBackbone
        text = {
            key: np.array(
                [
                    d[key]
                    for d in [
                        preprocess_text(file_path1, file_path2)
                        for file_path1, file_path2 in zip(
                            batch_text_x_1, batch_text_x_2
                        )
                    ]
                ]
            )
            for key in ["padding_mask", "token_ids", "segment_ids"]
        }

        # Image number 2 model inputs
        image_2 = [
            resize(imread(file_name), (128, 128)) for file_name in batch_image_x_2
        ]
        image_2 = [
            (  # exeperienced some shapes which were different from others
                np.array(Image.fromarray((img.astype(np.uint8))).convert("RGB"))
                if img.shape[2] == 4
                else img
            )
            for img in image_2
        ]
        # Stack the list comprehension to an nd.array
        image_2 = np.array(image_2)

        return (
            {
                "image_1": image_1,
                "image_2": image_2,
                "padding_mask": text["padding_mask"],
                "segment_ids": text["segment_ids"],
                "token_ids": text["token_ids"],
            },
            # Target lables
            np.array(batch_image_y_1),
        )

    def __len__(self):
        """
        Returns the number of batches in the dataset.
        """
        return math.ceil(len(self.dataframe) / self.batch_size)

创建训练集、验证集和测试集

def prepare_dataset(dataframe):
    ds = dataframe_to_dataset(dataframe)
    return ds


train_ds = prepare_dataset(train_df)
validation_ds = prepare_dataset(val_df)
test_ds = prepare_dataset(test_df)

模型构建工具

我们的最终模型将接受两张图像及其对应的文本。图像将直接输入模型,而文本输入将首先进行预处理,然后才能进入模型。以下是此方法的视觉说明:

该模型由以下元素组成:

  • 一个独立的图像编码器。我们将为此使用在 ImageNet-1k 数据集上预训练的 ResNet50V2
  • 一个独立的图像编码器。为此将使用预训练的 BERT。

提取单独的嵌入后,它们将被投影到相同的空间中。最后,它们的投影将被连接并馈送到最终的分类层。

这是一个多分类问题,涉及以下类别:

  • 无蕴涵
  • 蕴含
  • 矛盾

`project_embeddings()`、`create_vision_encoder()` 和 `create_text_encoder()` 工具参考自此示例

投影工具

def project_embeddings(
    embeddings, num_projection_layers, projection_dims, dropout_rate
):
    projected_embeddings = keras.layers.Dense(units=projection_dims)(embeddings)
    for _ in range(num_projection_layers):
        x = keras.ops.nn.gelu(projected_embeddings)
        x = keras.layers.Dense(projection_dims)(x)
        x = keras.layers.Dropout(dropout_rate)(x)
        x = keras.layers.Add()([projected_embeddings, x])
        projected_embeddings = keras.layers.LayerNormalization()(x)
    return projected_embeddings

视觉编码器工具

def create_vision_encoder(
    num_projection_layers, projection_dims, dropout_rate, trainable=False
):
    # Load the pre-trained ResNet50V2 model to be used as the base encoder.
    resnet_v2 = keras.applications.ResNet50V2(
        include_top=False, weights="imagenet", pooling="avg"
    )
    # Set the trainability of the base encoder.
    for layer in resnet_v2.layers:
        layer.trainable = trainable

    # Receive the images as inputs.
    image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
    image_2 = keras.Input(shape=(128, 128, 3), name="image_2")

    # Preprocess the input image.
    preprocessed_1 = keras.applications.resnet_v2.preprocess_input(image_1)
    preprocessed_2 = keras.applications.resnet_v2.preprocess_input(image_2)

    # Generate the embeddings for the images using the resnet_v2 model
    # concatenate them.
    embeddings_1 = resnet_v2(preprocessed_1)
    embeddings_2 = resnet_v2(preprocessed_2)
    embeddings = keras.layers.Concatenate()([embeddings_1, embeddings_2])

    # Project the embeddings produced by the model.
    outputs = project_embeddings(
        embeddings, num_projection_layers, projection_dims, dropout_rate
    )
    # Create the vision encoder model.
    return keras.Model([image_1, image_2], outputs, name="vision_encoder")

文本编码器工具

def create_text_encoder(
    num_projection_layers, projection_dims, dropout_rate, trainable=False
):
    # Load the pre-trained BERT BackBone using KerasHub.
    bert = keras_hub.models.BertBackbone.from_preset(
        "bert_base_en_uncased", num_classes=3
    )

    # Set the trainability of the base encoder.
    bert.trainable = trainable

    # Receive the text as inputs.
    bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
    inputs = {
        feature: keras.Input(shape=(256,), dtype="int32", name=feature)
        for feature in bert_input_features
    }

    # Generate embeddings for the preprocessed text using the BERT model.
    embeddings = bert(inputs)["pooled_output"]

    # Project the embeddings produced by the model.
    outputs = project_embeddings(
        embeddings, num_projection_layers, projection_dims, dropout_rate
    )
    # Create the text encoder model.
    return keras.Model(inputs, outputs, name="text_encoder")

多模态模型工具

def create_multimodal_model(
    num_projection_layers=1,
    projection_dims=256,
    dropout_rate=0.1,
    vision_trainable=False,
    text_trainable=False,
):
    # Receive the images as inputs.
    image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
    image_2 = keras.Input(shape=(128, 128, 3), name="image_2")

    # Receive the text as inputs.
    bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
    text_inputs = {
        feature: keras.Input(shape=(256,), dtype="int32", name=feature)
        for feature in bert_input_features
    }
    text_inputs = list(text_inputs.values())
    # Create the encoders.
    vision_encoder = create_vision_encoder(
        num_projection_layers, projection_dims, dropout_rate, vision_trainable
    )
    text_encoder = create_text_encoder(
        num_projection_layers, projection_dims, dropout_rate, text_trainable
    )

    # Fetch the embedding projections.
    vision_projections = vision_encoder([image_1, image_2])
    text_projections = text_encoder(text_inputs)

    # Concatenate the projections and pass through the classification layer.
    concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
    outputs = keras.layers.Dense(3, activation="softmax")(concatenated)
    return keras.Model([image_1, image_2, *text_inputs], outputs)


multimodal_model = create_multimodal_model()
keras.utils.plot_model(multimodal_model, show_shapes=True)

png

您也可以通过将 `plot_model()` 的 `expand_nested` 参数设置为 `True` 来检查各个编码器的结构。我们鼓励您尝试构建此模型中涉及的不同超参数,并观察最终性能如何受到影响。


编译并训练模型

multimodal_model.compile(
    optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)

history = multimodal_model.fit(train_ds, validation_data=validation_ds, epochs=1)
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)

1/27 [37m━━━━━━━━━━━━━━━━━━━━ 45:45 106s/步 - 准确率:0.0625 - 损失:1.6335

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



2/27 ━ [37m━━━━━━━━━━━━━━━━━━━ 42:14 101s/步 - 准确率:0.2422 - 损失:1.9508

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



3/27 ━━ [37m━━━━━━━━━━━━━━━━━━ 38:49 97s/步 - 准确率:0.3524 - 损失:2.0126

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



4/27 ━━ [37m━━━━━━━━━━━━━━━━━━ 37:09 97s/步 - 准确率:0.4284 - 损失:1.9870

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



5/27 ━━━ [37m━━━━━━━━━━━━━━━━━ 35:08 96s/步 - 准确率:0.4815 - 损失:1.9855

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



6/27 ━━━━ [37m━━━━━━━━━━━━━━━━ 31:56 91s/步 - 准确率:0.5210 - 损失:1.9939



7/27 ━━━━━ [37m━━━━━━━━━━━━━━━ 29:30 89s/步 - 准确率:0.5512 - 损失:1.9980



8/27 ━━━━━ [37m━━━━━━━━━━━━━━━ 27:12 86s/步 - 准确率:0.5750 - 损失:2.0061



9/27 ━━━━━━ [37m━━━━━━━━━━━━━━ 25:15 84s/步 - 准确率:0.5956 - 损失:1.9959



10/27 ━━━━━━━ [37m━━━━━━━━━━━━━ 23:33 83s/步 - 准确率:0.6120 - 损失:1.9738

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



11/27 ━━━━━━━━ [37m━━━━━━━━━━━━ 22:09 83s/步 - 准确率:0.6251 - 损失:1.9579

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



12/27 ━━━━━━━━ [37m━━━━━━━━━━━━ 20:59 84s/步 - 准确率:0.6357 - 损失:1.9524

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



13/27 ━━━━━━━━━ [37m━━━━━━━━━━━ 19:44 85s/步 - 准确率:0.6454 - 损失:1.9439

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



14/27 ━━━━━━━━━━ [37m━━━━━━━━━━ 18:22 85s/步 - 准确率:0.6540 - 损失:1.9346

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(23, 256))', 'Tensor(shape=(23, 256))', 'Tensor(shape=(23, 256))']
  warnings.warn(msg)



15/27 ━━━━━━━━━━━ [37m━━━━━━━━━ 16:52 84s/步 - 准确率:0.6621 - 损失:1.9213

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



16/27 ━━━━━━━━━━━ [37m━━━━━━━━━ 15:29 85s/步 - 准确率:0.6693 - 损失:1.9101

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



17/27 ━━━━━━━━━━━━ [37m━━━━━━━━ 14:08 85s/步 - 准确率:0.6758 - 损失:1.9021

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



18/27 ━━━━━━━━━━━━━ [37m━━━━━━━ 12:45 85s/步 - 准确率:0.6819 - 损失:1.8916

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



19/27 ━━━━━━━━━━━━━━ [37m━━━━━━ 11:24 86s/步 - 准确率:0.6874 - 损失:1.8851

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



20/27 ━━━━━━━━━━━━━━ [37m━━━━━━ 10:00 86s/步 - 准确率:0.6925 - 损失:1.8791

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



21/27 ━━━━━━━━━━━━━━━ [37m━━━━━ 8:36 86s/步 - 准确率:0.6976 - 损失:1.8699

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



22/27 ━━━━━━━━━━━━━━━━ [37m━━━━ 7:11 86s/步 - 准确率:0.7020 - 损失:1.8623

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



23/27 ━━━━━━━━━━━━━━━━━ [37m━━━ 5:46 87s/步 - 准确率:0.7061 - 损失:1.8573

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



24/27 ━━━━━━━━━━━━━━━━━ [37m━━━ 4:20 87s/步 - 准确率:0.7100 - 损失:1.8534

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



25/27 ━━━━━━━━━━━━━━━━━━ [37m━━ 2:54 87s/步 - 准确率:0.7136 - 损失:1.8494

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



26/27 ━━━━━━━━━━━━━━━━━━━ [37m━ 1:27 87s/步 - 准确率:0.7170 - 损失:1.8449

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)



27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 88s/步 - 准确率:0.7201 - 损失:1.8414

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/PIL/Image.py:1054: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  warnings.warn(

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/PIL/Image.py:1054: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  warnings.warn(

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(13, 256))', 'Tensor(shape=(13, 256))', 'Tensor(shape=(13, 256))']
  warnings.warn(msg)



27/27 ━━━━━━━━━━━━━━━━━━━━ 2508s 92s/步 - 准确率:0.7231 - 损失:1.8382 - val_accuracy: 0.8222 - val_loss: 1.7304


评估模型

_, acc = multimodal_model.evaluate(test_ds)
print(f"Accuracy on the test set: {round(acc * 100, 2)}%.")
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/PIL/Image.py:1054: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  warnings.warn(

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/PIL/Image.py:1054: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  warnings.warn(

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
  warnings.warn(msg)

1/4 ━━━━━ [37m━━━━━━━━━━━━━━━ 5:32 111s/步 - 准确率:0.7812 - 损失:1.9384



2/4 ━━━━━━━━━━ [37m━━━━━━━━━━ 2:10 65s/步 - 准确率:0.7969 - 损失:1.8931



3/4 ━━━━━━━━━━━━━━━ [37m━━━━━ 1:05 65s/步 - 准确率:0.8056 - 损失:1.8200

/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(4, 256))', 'Tensor(shape=(4, 256))', 'Tensor(shape=(4, 256))']
  warnings.warn(msg)



4/4 ━━━━━━━━━━━━━━━━━━━━ 0s 49s/步 - 准确率:0.8092 - 损失:1.8075



4/4 ━━━━━━━━━━━━━━━━━━━━ 256s 49s/步 - 准确率:0.8113 - 损失:1.8000

Accuracy on the test set: 82.0%.

关于训练的补充说明

加入正则化:

训练日志表明模型开始过拟合,可能受益于正则化。Dropout (Srivastava 等人) 是一种简单而强大的正则化技术,我们可以在模型中使用它。但是我们应该如何在这里应用它呢?

我们总是可以在模型的不同层之间引入 Dropout (keras.layers.Dropout)。但这里还有另一种方法。我们的模型需要来自两种不同数据模态的输入。如果推理期间其中一种模态不存在怎么办?为了解决这个问题,我们可以在单独的投影中引入 Dropout,就在它们被连接之前:

vision_projections = keras.layers.Dropout(rate)(vision_projections)
text_projections = keras.layers.Dropout(rate)(text_projections)
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])

关注重要之处:

图像的所有部分是否与其文本对应部分同样重要?可能并非如此。为了让我们的模型只关注图像中与相应文本部分密切相关的最重要部分,我们可以使用“交叉注意力”。

# Embeddings.
vision_projections = vision_encoder([image_1, image_2])
text_projections = text_encoder(text_inputs)

# Cross-attention (Luong-style).
query_value_attention_seq = keras.layers.Attention(use_scale=True, dropout=0.2)(
    [vision_projections, text_projections]
)
# Concatenate.
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
contextual = keras.layers.Concatenate()([concatenated, query_value_attention_seq])

要查看实际效果,请参阅此笔记本

处理类别不平衡:

数据集存在类别不平衡问题。调查上述模型的混淆矩阵显示,它在少数类别上的表现不佳。如果我们使用加权损失,训练会更有指导性。您可以查看此笔记本,该笔记本在模型训练期间考虑了类别不平衡。

仅使用文本输入:

此外,如果我们只使用文本输入进行蕴涵任务会怎样?由于社交媒体平台上遇到的文本输入的性质,仅使用文本输入会损害最终性能。在类似的训练设置下,仅使用文本输入,我们在同一测试集上获得了 67.14% 的 top-1 准确率。有关详细信息,请参阅此笔记本

最后,这里有一个比较蕴涵任务不同方法的表格:

类型 标准
交叉熵
损失加权
交叉熵
焦点损失
多模态 77.86% 67.86% 86.43%
仅文本 67.14% 11.43% 37.86%

您可以查看此仓库,了解更多关于如何进行实验以获得这些数字的信息。


最后说明

  • 我们在此示例中使用的架构对于可用于训练的数据点数量来说太大了。它将受益于更多数据。
  • 我们使用了原始 BERT 模型的一个较小变体。很有可能使用较大的变体,性能会得到改善。TensorFlow Hub 提供了许多不同的 BERT 模型,您可以进行实验。
  • 我们保持预训练模型冻结。在多模态蕴涵任务上对它们进行微调可能会带来更好的性能。
  • 我们为多模态蕴涵任务构建了一个简单的基线模型。已经提出了各种方法来解决蕴涵问题。识别多模态蕴涵教程的此演示文稿提供了全面的概述。

您可以使用托管在 Hugging Face Hub 上的训练模型,并在 Hugging Face Spaces 上尝试演示。