代码示例 / 自然语言处理 / 多模态蕴含

多模态蕴含

作者: Sayak Paul
创建日期 2021/08/08
最后修改日期 2021/08/15
描述: 训练一个用于预测多模态蕴含的多模态模型。

ⓘ 此示例使用 Keras 2

在 Colab 中查看 GitHub 源代码


简介

在本示例中,我们将构建并训练一个用于预测多模态蕴含的模型。我们将使用 Google Research 最近推出的 多模态蕴含数据集

什么是多模态蕴含?

在社交媒体平台上,为了审核和监管内容,我们可能希望实时找到以下问题的答案

  • 给定的信息是否与其他信息相矛盾?
  • 给定的信息是否暗示其他信息?

在 NLP 中,此任务称为分析文本蕴含。然而,这仅限于信息来自文本内容的情况。在实践中,信息通常不仅来自文本内容,还来自文本、图像、音频、视频等的多种组合。多模态蕴含 只是将文本蕴含扩展到各种新的输入模态。

要求

此示例需要 TensorFlow 2.5 或更高版本。此外,BERT 模型 (Devlin 等人) 需要 TensorFlow Hub 和 TensorFlow Text。可以使用以下命令安装这些库

!pip install -q tensorflow_text

导入

from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from tensorflow import keras

定义标签映射

label_map = {"Contradictory": 0, "Implies": 1, "NoEntailment": 2}

收集数据集

原始数据集可以在 这里 获取。它附带了 Twitter 照片存储系统(简称 PBS)上托管的图像 URL。我们将使用下载的图像以及原始数据集附带的其他数据。感谢 Nilabhra Roy Chowdhury 为准备图像数据做出的贡献。

image_base_path = keras.utils.get_file(
    "tweet_images",
    "https://github.com/sayakpaul/Multimodal-Entailment-Baseline/releases/download/v1.0.0/tweet_images.tar.gz",
    untar=True,
)

读取数据集并应用基本预处理

df = pd.read_csv(
    "https://github.com/sayakpaul/Multimodal-Entailment-Baseline/raw/main/csvs/tweets.csv"
)
df.sample(10)
id_1 text_1 image_1 id_2 text_2 image_2 label
291 1330800194863190016 #KLM1167 (B738): #AMS (Amsterdam) to #HEL (Van... http://pbs.twimg.com/media/EnfzuZAW4AE236p.png 1378695438480588802 #CKK205 (B77L): #PVG (Shanghai) to #AMS (Amste... http://pbs.twimg.com/media/EyIcMexXEAE6gia.png NoEntailment
37 1366581728312057856 朋友们,有兴趣的都可以去看一看!\n@j... http://pbs.twimg.com/media/EvcS1v4UcAEEXPO.jpg 1373810535066570759 朋友们,有兴趣的都可以去看一看!\n@f... http://pbs.twimg.com/media/ExDBZqwVIAQ4LWk.jpg Contradictory
315 1352551603258052608 #WINk Drops I have earned today🚀\n\nToday:1/22... http://pbs.twimg.com/media/EsTdcLLVcAIiFKT.jpg 1354636016234098688 #WINk Drops I have earned today☀\n\nToday:1/28... http://pbs.twimg.com/media/EsyhK-qU0AgfMAH.jpg NoEntailment
761 1379795999493853189 #buythedip Ready to FLY even HIGHER #pennysto... http://pbs.twimg.com/media/EyYFJCzWgAMfTrT.jpg 1380190250144792576 #buythedip Ready to FLY even HIGHER #pennysto... http://pbs.twimg.com/media/Eydrt0ZXAAMmbfv.jpg NoEntailment
146 1340185132293099523 我知道有时候我对你是奇怪的。\n\n因为... http://pbs.twimg.com/media/EplLRriWwAAJ2AE.jpg 1359755419883814913 我放下剑,跪下来祈祷。... http://pbs.twimg.com/media/Et7SWWeWYAICK-c.jpg NoEntailment
1351 1381256604926967813 终于完成了皮肤渲染。将开始... http://pbs.twimg.com/media/Eys1j7NVIAgF-YF.jpg 1381630932092784641 头发渲染。明天将完成头发。... http://pbs.twimg.com/media/EyyKAoaUUAElm-e.jpg NoEntailment
368 1371883298805403649 📉 $LINK 收件地址数量(7 天移动平均值)... http://pbs.twimg.com/media/EwnoltOWEAAS4mG.jpg 1373216720974979072 📉 $LINK 收件地址数量(7 天移动平均值)... http://pbs.twimg.com/media/Ew6lVGYXEAE6Ugi.jpg NoEntailment
1112 1377679115159887873 四月是全国分心驾驶意识月... http://pbs.twimg.com/media/Ex5_u7UVIAARjQ2.jpg 1379075258448281608 四月是分心驾驶意识月。... http://pbs.twimg.com/media/EyN1YjpWUAMc5ak.jpg NoEntailment
264 1330727515741167619 ♥️每日金句♥️\n.\n#VerseOfTheDay #Quran... http://pbs.twimg.com/media/EnexnydXIAYuI11.jpg 1332623263495819264 ♥️每日金句♥️\n.\n#VerseOfTheDay #Quran... http://pbs.twimg.com/media/En5ty1VXUAATALP.jpg NoEntailment
865 1377784616275296261 没有白栅栏可以困住我们。#TBT 200... http://pbs.twimg.com/media/Ex7fzouWQAITAq8.jpg 1380175915804672012 有时候你只需要改变你的心态。... http://pbs.twimg.com/media/EydernQXIAk2g5v.jpg NoEntailment

我们感兴趣的列如下

  • text_1
  • image_1
  • text_2
  • image_2
  • label

蕴含任务被表述为以下

给定 (text_1, image_1) 和 (text_2, image_2) 对,它们是否相互蕴含(或不蕴含或矛盾)?

我们已经下载了图像。image_1id1 为文件名下载,image2id2 为文件名下载。在下一步中,我们将为 df 添加另外两列 - image_1image_2 的文件路径。

images_one_paths = []
images_two_paths = []

for idx in range(len(df)):
    current_row = df.iloc[idx]
    id_1 = current_row["id_1"]
    id_2 = current_row["id_2"]
    extentsion_one = current_row["image_1"].split(".")[-1]
    extentsion_two = current_row["image_2"].split(".")[-1]

    image_one_path = os.path.join(image_base_path, str(id_1) + f".{extentsion_one}")
    image_two_path = os.path.join(image_base_path, str(id_2) + f".{extentsion_two}")

    images_one_paths.append(image_one_path)
    images_two_paths.append(image_two_path)

df["image_1_path"] = images_one_paths
df["image_2_path"] = images_two_paths

# Create another column containing the integer ids of
# the string labels.
df["label_idx"] = df["label"].apply(lambda x: label_map[x])

数据集可视化

def visualize(idx):
    current_row = df.iloc[idx]
    image_1 = plt.imread(current_row["image_1_path"])
    image_2 = plt.imread(current_row["image_2_path"])
    text_1 = current_row["text_1"]
    text_2 = current_row["text_2"]
    label = current_row["label"]

    plt.subplot(1, 2, 1)
    plt.imshow(image_1)
    plt.axis("off")
    plt.title("Image One")
    plt.subplot(1, 2, 2)
    plt.imshow(image_1)
    plt.axis("off")
    plt.title("Image Two")
    plt.show()

    print(f"Text one: {text_1}")
    print(f"Text two: {text_2}")
    print(f"Label: {label}")


random_idx = np.random.choice(len(df))
visualize(random_idx)

random_idx = np.random.choice(len(df))
visualize(random_idx)

png

Text one: Friends, interested all go to have a look!
@ThePartyGoddess @OurLadyAngels @BJsWholesale @Richard_Jeni @FashionLavidaG @RapaRooski @DMVTHING @DeMarcoReports @LobidaFo @DeMarcoMorgan https://t.co/cStULl7y7G
Text two: Friends, interested all go to have a look!
@smittyses @CYosabel @crum_7 @CrumDarrell @ElymalikU @jenloarn @SoCodiePrevost @roblowry82 @Crummy_14 @CSchmelzenbach https://t.co/IZphLTNzgl
Label: Contradictory

png

Text one: 👟 KICK OFF @ MARDEN SPORTS COMPLEX
We're underway in the Round 6 opener!
📺: @Foxtel, @kayosports
📱: My Football Live app https://t.co/wHSpvQaoGC
#WLeague #ADLvMVC #AUFC #MVFC https://t.co/3Smp8KXm8W
Text two: 👟 KICK OFF @ MARSDEN SPORTS COMPLEX
We're underway in sunny Adelaide!
📺: @Foxtel, @kayosports
📱: My Football Live app https://t.co/wHSpvQaoGC
#ADLvCBR #WLeague #AUFC #UnitedAlways https://t.co/fG1PyLQXM4
Label: NoEntailment

训练/测试拆分

数据集存在 类别不平衡问题。我们可以在以下单元格中确认这一点。

df["label"].value_counts()
NoEntailment     1182
Implies           109
Contradictory     109
Name: label, dtype: int64

为了解决这个问题,我们将进行分层拆分。

# 10% for test
train_df, test_df = train_test_split(
    df, test_size=0.1, stratify=df["label"].values, random_state=42
)
# 5% for validation
train_df, val_df = train_test_split(
    train_df, test_size=0.05, stratify=train_df["label"].values, random_state=42
)

print(f"Total training examples: {len(train_df)}")
print(f"Total validation examples: {len(val_df)}")
print(f"Total test examples: {len(test_df)}")
Total training examples: 1197
Total validation examples: 63
Total test examples: 140

数据输入管道

TensorFlow Hub 提供了 各种 BERT 模型系列。每个模型都带有一个对应的预处理层。您可以从 此资源 中了解有关这些模型及其预处理层的更多信息。

为了使本示例的运行时间相对较短,我们将使用原始 BERT 模型的较小变体。

# Define TF Hub paths to the BERT encoder and its preprocessor
bert_model_path = (
    "https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-256_A-4/1"
)
bert_preprocess_path = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"

我们的文本预处理代码主要来自 本教程。强烈建议您查看本教程,以详细了解输入预处理。

def make_bert_preprocessing_model(sentence_features, seq_length=128):
    """Returns Model mapping string features to BERT inputs.

  Args:
    sentence_features: A list with the names of string-valued features.
    seq_length: An integer that defines the sequence length of BERT inputs.

  Returns:
    A Keras Model that can be called on a list or dict of string Tensors
    (with the order or names, resp., given by sentence_features) and
    returns a dict of tensors for input to BERT.
  """

    input_segments = [
        tf.keras.layers.Input(shape=(), dtype=tf.string, name=ft)
        for ft in sentence_features
    ]

    # Tokenize the text to word pieces.
    bert_preprocess = hub.load(bert_preprocess_path)
    tokenizer = hub.KerasLayer(bert_preprocess.tokenize, name="tokenizer")
    segments = [tokenizer(s) for s in input_segments]

    # Optional: Trim segments in a smart way to fit seq_length.
    # Simple cases (like this example) can skip this step and let
    # the next step apply a default truncation to approximately equal lengths.
    truncated_segments = segments

    # Pack inputs. The details (start/end token ids, dict of output tensors)
    # are model-dependent, so this gets loaded from the SavedModel.
    packer = hub.KerasLayer(
        bert_preprocess.bert_pack_inputs,
        arguments=dict(seq_length=seq_length),
        name="packer",
    )
    model_inputs = packer(truncated_segments)
    return keras.Model(input_segments, model_inputs)


bert_preprocess_model = make_bert_preprocessing_model(["text_1", "text_2"])
keras.utils.plot_model(bert_preprocess_model, show_shapes=True, show_dtype=True)

png

在示例输入上运行预处理器

idx = np.random.choice(len(train_df))
row = train_df.iloc[idx]
sample_text_1, sample_text_2 = row["text_1"], row["text_2"]
print(f"Text 1: {sample_text_1}")
print(f"Text 2: {sample_text_2}")

test_text = [np.array([sample_text_1]), np.array([sample_text_2])]
text_preprocessed = bert_preprocess_model(test_text)

print("Keys           : ", list(text_preprocessed.keys()))
print("Shape Word Ids : ", text_preprocessed["input_word_ids"].shape)
print("Word Ids       : ", text_preprocessed["input_word_ids"][0, :16])
print("Shape Mask     : ", text_preprocessed["input_mask"].shape)
print("Input Mask     : ", text_preprocessed["input_mask"][0, :16])
print("Shape Type Ids : ", text_preprocessed["input_type_ids"].shape)
print("Type Ids       : ", text_preprocessed["input_type_ids"][0, :16])
Text 1: Renewables met 97% of Scotland's electricity demand in 2020!!!!
https://t.co/wi5c9UFAUF https://t.co/arcuBgh0BP
Text 2: Renewables met 97% of Scotland's electricity demand in 2020 https://t.co/SrhyqPnIkU https://t.co/LORgvTM7Sn
Keys           :  ['input_mask', 'input_word_ids', 'input_type_ids']
Shape Word Ids :  (1, 128)
Word Ids       :  tf.Tensor(
[  101 13918  2015  2777  5989  1003  1997  3885  1005  1055  6451  5157
  1999 12609   999   999], shape=(16,), dtype=int32)
Shape Mask     :  (1, 128)
Input Mask     :  tf.Tensor([1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1], shape=(16,), dtype=int32)
Shape Type Ids :  (1, 128)
Type Ids       :  tf.Tensor([0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], shape=(16,), dtype=int32)

现在我们将从数据帧创建 tf.data.Dataset 对象。

请注意,文本输入将作为数据输入管道的一部分进行预处理。但预处理模块也可以是相应 BERT 模型的一部分。这有助于减少训练/服务偏差,并使我们的模型能够使用原始文本输入运行。请按照 本教程 详细了解如何将预处理模块直接集成到模型中。

def dataframe_to_dataset(dataframe):
    columns = ["image_1_path", "image_2_path", "text_1", "text_2", "label_idx"]
    dataframe = dataframe[columns].copy()
    labels = dataframe.pop("label_idx")
    ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
    ds = ds.shuffle(buffer_size=len(dataframe))
    return ds

预处理实用程序

resize = (128, 128)
bert_input_features = ["input_word_ids", "input_type_ids", "input_mask"]


def preprocess_image(image_path):
    extension = tf.strings.split(image_path)[-1]

    image = tf.io.read_file(image_path)
    if extension == b"jpg":
        image = tf.image.decode_jpeg(image, 3)
    else:
        image = tf.image.decode_png(image, 3)
    image = tf.image.resize(image, resize)
    return image


def preprocess_text(text_1, text_2):
    text_1 = tf.convert_to_tensor([text_1])
    text_2 = tf.convert_to_tensor([text_2])
    output = bert_preprocess_model([text_1, text_2])
    output = {feature: tf.squeeze(output[feature]) for feature in bert_input_features}
    return output


def preprocess_text_and_image(sample):
    image_1 = preprocess_image(sample["image_1_path"])
    image_2 = preprocess_image(sample["image_2_path"])
    text = preprocess_text(sample["text_1"], sample["text_2"])
    return {"image_1": image_1, "image_2": image_2, "text": text}

创建最终数据集

batch_size = 32
auto = tf.data.AUTOTUNE


def prepare_dataset(dataframe, training=True):
    ds = dataframe_to_dataset(dataframe)
    if training:
        ds = ds.shuffle(len(train_df))
    ds = ds.map(lambda x, y: (preprocess_text_and_image(x), y)).cache()
    ds = ds.batch(batch_size).prefetch(auto)
    return ds


train_ds = prepare_dataset(train_df)
validation_ds = prepare_dataset(val_df, False)
test_ds = prepare_dataset(test_df, False)

模型构建实用程序

我们的最终模型将接受两个图像及其文本对应物。图像将直接馈送到模型,而文本输入将首先进行预处理,然后才进入模型。以下是这种方法的视觉说明

该模型包含以下元素

  • 用于图像的独立编码器。我们将为此使用在 ImageNet-1k 数据集上预训练的 ResNet50V2
  • 用于图像的独立编码器。我们将为此使用预训练的 BERT。

提取单个嵌入后,它们将在相同的空间中进行投影。最后,它们的投影将被连接起来,并馈送到最终的分类层。

这是一个多类分类问题,涉及以下类别

  • NoEntailment
  • Implies
  • Contradictory

project_embeddings()create_vision_encoder()create_text_encoder() 实用程序来自 此示例

投影实用程序

def project_embeddings(
    embeddings, num_projection_layers, projection_dims, dropout_rate
):
    projected_embeddings = keras.layers.Dense(units=projection_dims)(embeddings)
    for _ in range(num_projection_layers):
        x = tf.nn.gelu(projected_embeddings)
        x = keras.layers.Dense(projection_dims)(x)
        x = keras.layers.Dropout(dropout_rate)(x)
        x = keras.layers.Add()([projected_embeddings, x])
        projected_embeddings = keras.layers.LayerNormalization()(x)
    return projected_embeddings

视觉编码器实用程序

def create_vision_encoder(
    num_projection_layers, projection_dims, dropout_rate, trainable=False
):
    # Load the pre-trained ResNet50V2 model to be used as the base encoder.
    resnet_v2 = keras.applications.ResNet50V2(
        include_top=False, weights="imagenet", pooling="avg"
    )
    # Set the trainability of the base encoder.
    for layer in resnet_v2.layers:
        layer.trainable = trainable

    # Receive the images as inputs.
    image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
    image_2 = keras.Input(shape=(128, 128, 3), name="image_2")

    # Preprocess the input image.
    preprocessed_1 = keras.applications.resnet_v2.preprocess_input(image_1)
    preprocessed_2 = keras.applications.resnet_v2.preprocess_input(image_2)

    # Generate the embeddings for the images using the resnet_v2 model
    # concatenate them.
    embeddings_1 = resnet_v2(preprocessed_1)
    embeddings_2 = resnet_v2(preprocessed_2)
    embeddings = keras.layers.Concatenate()([embeddings_1, embeddings_2])

    # Project the embeddings produced by the model.
    outputs = project_embeddings(
        embeddings, num_projection_layers, projection_dims, dropout_rate
    )
    # Create the vision encoder model.
    return keras.Model([image_1, image_2], outputs, name="vision_encoder")

文本编码器实用程序

def create_text_encoder(
    num_projection_layers, projection_dims, dropout_rate, trainable=False
):
    # Load the pre-trained BERT model to be used as the base encoder.
    bert = hub.KerasLayer(bert_model_path, name="bert",)
    # Set the trainability of the base encoder.
    bert.trainable = trainable

    # Receive the text as inputs.
    bert_input_features = ["input_type_ids", "input_mask", "input_word_ids"]
    inputs = {
        feature: keras.Input(shape=(128,), dtype=tf.int32, name=feature)
        for feature in bert_input_features
    }

    # Generate embeddings for the preprocessed text using the BERT model.
    embeddings = bert(inputs)["pooled_output"]

    # Project the embeddings produced by the model.
    outputs = project_embeddings(
        embeddings, num_projection_layers, projection_dims, dropout_rate
    )
    # Create the text encoder model.
    return keras.Model(inputs, outputs, name="text_encoder")

多模态模型实用程序

def create_multimodal_model(
    num_projection_layers=1,
    projection_dims=256,
    dropout_rate=0.1,
    vision_trainable=False,
    text_trainable=False,
):
    # Receive the images as inputs.
    image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
    image_2 = keras.Input(shape=(128, 128, 3), name="image_2")

    # Receive the text as inputs.
    bert_input_features = ["input_type_ids", "input_mask", "input_word_ids"]
    text_inputs = {
        feature: keras.Input(shape=(128,), dtype=tf.int32, name=feature)
        for feature in bert_input_features
    }

    # Create the encoders.
    vision_encoder = create_vision_encoder(
        num_projection_layers, projection_dims, dropout_rate, vision_trainable
    )
    text_encoder = create_text_encoder(
        num_projection_layers, projection_dims, dropout_rate, text_trainable
    )

    # Fetch the embedding projections.
    vision_projections = vision_encoder([image_1, image_2])
    text_projections = text_encoder(text_inputs)

    # Concatenate the projections and pass through the classification layer.
    concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
    outputs = keras.layers.Dense(3, activation="softmax")(concatenated)
    return keras.Model([image_1, image_2, text_inputs], outputs)


multimodal_model = create_multimodal_model()
keras.utils.plot_model(multimodal_model, show_shapes=True)

png

您也可以通过将 plot_model()expand_nested 参数设置为 True 来检查各个编码器的结构。鼓励您使用构建此模型所涉及的不同超参数进行尝试,并观察最终性能的影响。


编译并训练模型

multimodal_model.compile(
    optimizer="adam", loss="sparse_categorical_crossentropy", metrics="accuracy"
)

history = multimodal_model.fit(train_ds, validation_data=validation_ds, epochs=10)
Epoch 1/10
38/38 [==============================] - 49s 789ms/step - loss: 1.0014 - accuracy: 0.8229 - val_loss: 0.5514 - val_accuracy: 0.8571
Epoch 2/10
38/38 [==============================] - 3s 90ms/step - loss: 0.4019 - accuracy: 0.8814 - val_loss: 0.5866 - val_accuracy: 0.8571
Epoch 3/10
38/38 [==============================] - 3s 90ms/step - loss: 0.3557 - accuracy: 0.8897 - val_loss: 0.5929 - val_accuracy: 0.8571
Epoch 4/10
38/38 [==============================] - 3s 91ms/step - loss: 0.2877 - accuracy: 0.9006 - val_loss: 0.6272 - val_accuracy: 0.8571
Epoch 5/10
38/38 [==============================] - 3s 91ms/step - loss: 0.1796 - accuracy: 0.9398 - val_loss: 0.8545 - val_accuracy: 0.8254
Epoch 6/10
38/38 [==============================] - 3s 91ms/step - loss: 0.1292 - accuracy: 0.9566 - val_loss: 1.2276 - val_accuracy: 0.8413
Epoch 7/10
38/38 [==============================] - 3s 91ms/step - loss: 0.1015 - accuracy: 0.9666 - val_loss: 1.2914 - val_accuracy: 0.7778
Epoch 8/10
38/38 [==============================] - 3s 92ms/step - loss: 0.1253 - accuracy: 0.9524 - val_loss: 1.1944 - val_accuracy: 0.8413
Epoch 9/10
38/38 [==============================] - 3s 92ms/step - loss: 0.3064 - accuracy: 0.9131 - val_loss: 1.2162 - val_accuracy: 0.8095
Epoch 10/10
38/38 [==============================] - 3s 92ms/step - loss: 0.2212 - accuracy: 0.9248 - val_loss: 1.1080 - val_accuracy: 0.8413

评估模型

_, acc = multimodal_model.evaluate(test_ds)
print(f"Accuracy on the test set: {round(acc * 100, 2)}%.")
5/5 [==============================] - 6s 1s/step - loss: 0.8390 - accuracy: 0.8429
Accuracy on the test set: 84.29%.

有关训练的补充说明

合并正则化:

训练日志表明模型开始过拟合,并且可能会从正则化中受益。Dropout (Srivastava 等人) 是一种简单但功能强大的正则化技术,我们可以在模型中使用它。但是,我们应该如何在这里应用它呢?

我们总是可以在模型的不同层之间引入 Dropout (keras.layers.Dropout)。但这里还有一个方法。我们的模型期望来自两种不同数据模态的输入。如果在推理期间任一模态不存在,该怎么办?为了解决这个问题,我们可以在连接投影之前,将其引入到各个投影中

vision_projections = keras.layers.Dropout(rate)(vision_projections)
text_projections = keras.layers.Dropout(rate)(text_projections)
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])

关注重要的事物:

图像的所有部分是否都与它们对应的文本部分完全一致?这可能并非如此。为了使我们的模型只关注与文本部分密切相关的图像中最重要部分,我们可以使用“交叉注意力”。

# Embeddings.
vision_projections = vision_encoder([image_1, image_2])
text_projections = text_encoder(text_inputs)

# Cross-attention (Luong-style).
query_value_attention_seq = keras.layers.Attention(use_scale=True, dropout=0.2)(
    [vision_projections, text_projections]
)
# Concatenate.
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
contextual = keras.layers.Concatenate()([concatenated, query_value_attention_seq])

要查看实际操作,请参阅此笔记本.

处理类不平衡:

该数据集存在类不平衡问题。调查上面模型的混淆矩阵表明,它在少数类上的表现不佳。如果我们使用加权损失,那么训练会更有指导意义。您可以查看此笔记本,它在模型训练期间考虑了类不平衡问题。

仅使用文本输入:

此外,如果我们仅将文本输入纳入蕴含任务中会怎样?由于社交媒体平台上遇到的文本输入的性质,仅使用文本输入会损害最终性能。在类似的训练设置下,仅使用文本输入,我们在同一测试集上的 top-1 准确率为 67.14%。有关详细信息,请参阅此笔记本.

最后,以下是比较用于蕴含任务的不同方法的表格

类型 标准
交叉熵
损失加权
交叉熵
Focal Loss
多模态 77.86% 67.86% 86.43%
仅文本 67.14% 11.43% 37.86%

您可以查看此仓库,以详细了解如何进行实验以获得这些数字。


最后说明

  • 我们在本示例中使用的架构对于可用于训练的数据点数量来说太大。它将从更多数据中获益。
  • 我们使用了原始 BERT 模型的较小变体。很有可能,使用更大的变体,这种性能将会得到改善。TensorFlow Hub 提供许多不同的 BERT 模型,您可以用它们进行实验。
  • 我们保持了预训练模型的冻结状态。在多模态蕴含任务上对它们进行微调可能会导致更好的性能。
  • 我们为多模态蕴含任务构建了一个简单的基线模型。已经提出了各种方法来解决蕴含问题。来自识别多模态蕴含教程的此演示文稿提供了一个全面的概述。

您可以使用托管在Hugging Face Hub 上的训练后的模型,并在Hugging Face Spaces 上尝试演示。