代码示例 / 自然语言处理 / 多模态蕴含

多模态蕴含

作者: Sayak Paul
创建日期 2021/08/08
最后修改日期 2021/08/15
描述:训练用于预测蕴含的多模态模型。

ⓘ 此示例使用 Keras 2

在 Colab 中查看 GitHub 源代码


引言

在本示例中,我们将构建并训练一个用于预测多模态蕴含的模型。我们将使用 Google Research 最近推出的 多模态蕴含数据集

什么是多模态蕴含?

在社交媒体平台上,为了审计和审核内容,我们可能希望近乎实时地找到以下问题的答案

  • 给定的信息是否与另一条信息矛盾?
  • 给定的信息是否暗示另一条信息?

在 NLP 中,此任务称为分析文本蕴含。但是,这仅限于信息来自文本内容的情况。在实践中,信息通常不仅来自文本内容,还来自文本、图像、音频、视频等的多种组合。多模态蕴含只是将文本蕴含扩展到各种新的输入模态。

要求

此示例需要 TensorFlow 2.5 或更高版本。此外,BERT 模型(Devlin 等人)需要 TensorFlow Hub 和 TensorFlow Text。可以使用以下命令安装这些库

!pip install -q tensorflow_text

导入

from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from tensorflow import keras

定义标签映射

label_map = {"Contradictory": 0, "Implies": 1, "NoEntailment": 2}

收集数据集

原始数据集可在此处获得 此处。它包含托管在 Twitter 的照片存储系统(称为Photo Blob Storage(简称 PBS))上的图像 URL。我们将使用下载的图像以及原始数据集附带的其他数据。感谢Nilabhra Roy Chowdhury 完成图像数据准备工作。

image_base_path = keras.utils.get_file(
    "tweet_images",
    "https://github.com/sayakpaul/Multimodal-Entailment-Baseline/releases/download/v1.0.0/tweet_images.tar.gz",
    untar=True,
)

读取数据集并应用基本预处理

df = pd.read_csv(
    "https://github.com/sayakpaul/Multimodal-Entailment-Baseline/raw/main/csvs/tweets.csv"
)
df.sample(10)
id_1 text_1 image_1 id_2 text_2 image_2 label
291 1330800194863190016 #KLM1167 (B738): #AMS (Amsterdam) to #HEL (Van... http://pbs.twimg.com/media/EnfzuZAW4AE236p.png 1378695438480588802 #CKK205 (B77L): #PVG (Shanghai) to #AMS (Amste... http://pbs.twimg.com/media/EyIcMexXEAE6gia.png NoEntailment
37 1366581728312057856 Friends, interested all go to have a look!\n@j... http://pbs.twimg.com/media/EvcS1v4UcAEEXPO.jpg 1373810535066570759 Friends, interested all go to have a look!\n@f... http://pbs.twimg.com/media/ExDBZqwVIAQ4LWk.jpg Contradictory
315 1352551603258052608 #WINk Drops I have earned today🚀\n\nToday:1/22... http://pbs.twimg.com/media/EsTdcLLVcAIiFKT.jpg 1354636016234098688 #WINk Drops I have earned today☀\n\nToday:1/28... http://pbs.twimg.com/media/EsyhK-qU0AgfMAH.jpg NoEntailment
761 1379795999493853189 #buythedip Ready to FLY even HIGHER #pennysto... http://pbs.twimg.com/media/EyYFJCzWgAMfTrT.jpg 1380190250144792576 #buythedip Ready to FLY even HIGHER #pennysto... http://pbs.twimg.com/media/Eydrt0ZXAAMmbfv.jpg NoEntailment
146 1340185132293099523 I know sometimes I am weird to you.\n\nBecause... http://pbs.twimg.com/media/EplLRriWwAAJ2AE.jpg 1359755419883814913 I put my sword down and get on my knees to swe... http://pbs.twimg.com/media/Et7SWWeWYAICK-c.jpg NoEntailment
1351 1381256604926967813 Finally completed the skin rendering. Will sta... http://pbs.twimg.com/media/Eys1j7NVIAgF-YF.jpg 1381630932092784641 Hair rendering. Will finish the hair by tomorr... http://pbs.twimg.com/media/EyyKAoaUUAElm-e.jpg NoEntailment
368 1371883298805403649 📉 $LINK Number of Receiving Addresses (7d MA) ... http://pbs.twimg.com/media/EwnoltOWEAAS4mG.jpg 1373216720974979072 📉 $LINK Number of Receiving Addresses (7d MA) ... http://pbs.twimg.com/media/Ew6lVGYXEAE6Ugi.jpg NoEntailment
1112 1377679115159887873 April is National Distracted Driving Awareness... http://pbs.twimg.com/media/Ex5_u7UVIAARjQ2.jpg 1379075258448281608 April is Distracted Driving Awareness Month. ... http://pbs.twimg.com/media/EyN1YjpWUAMc5ak.jpg NoEntailment
264 1330727515741167619 ♥️Verse Of The Day♥️\n.\n#VerseOfTheDay #Quran... http://pbs.twimg.com/media/EnexnydXIAYuI11.jpg 1332623263495819264 ♥️Verse Of The Day♥️\n.\n#VerseOfTheDay #Quran... http://pbs.twimg.com/media/En5ty1VXUAATALP.jpg NoEntailment
865 1377784616275296261 No white picket fence can keep us in. #TBT 200... http://pbs.twimg.com/media/Ex7fzouWQAITAq8.jpg 1380175915804672012 Sometimes you just need to change your altitud... http://pbs.twimg.com/media/EydernQXIAk2g5v.jpg NoEntailment

我们感兴趣的列如下

  • text_1
  • image_1
  • text_2
  • image_2
  • label

蕴含任务表述如下

给定 (text_1, image_1) 和 (text_2, image_2) 对,它们是否蕴含(或不蕴含或矛盾)?

我们已经下载了图像。image_1 下载为 id1 作为其文件名,image2 下载为 id2 作为其文件名。在下一步中,我们将向 df 添加两列 - image_1image_2 的文件路径。

images_one_paths = []
images_two_paths = []

for idx in range(len(df)):
    current_row = df.iloc[idx]
    id_1 = current_row["id_1"]
    id_2 = current_row["id_2"]
    extentsion_one = current_row["image_1"].split(".")[-1]
    extentsion_two = current_row["image_2"].split(".")[-1]

    image_one_path = os.path.join(image_base_path, str(id_1) + f".{extentsion_one}")
    image_two_path = os.path.join(image_base_path, str(id_2) + f".{extentsion_two}")

    images_one_paths.append(image_one_path)
    images_two_paths.append(image_two_path)

df["image_1_path"] = images_one_paths
df["image_2_path"] = images_two_paths

# Create another column containing the integer ids of
# the string labels.
df["label_idx"] = df["label"].apply(lambda x: label_map[x])

数据集可视化

def visualize(idx):
    current_row = df.iloc[idx]
    image_1 = plt.imread(current_row["image_1_path"])
    image_2 = plt.imread(current_row["image_2_path"])
    text_1 = current_row["text_1"]
    text_2 = current_row["text_2"]
    label = current_row["label"]

    plt.subplot(1, 2, 1)
    plt.imshow(image_1)
    plt.axis("off")
    plt.title("Image One")
    plt.subplot(1, 2, 2)
    plt.imshow(image_1)
    plt.axis("off")
    plt.title("Image Two")
    plt.show()

    print(f"Text one: {text_1}")
    print(f"Text two: {text_2}")
    print(f"Label: {label}")


random_idx = np.random.choice(len(df))
visualize(random_idx)

random_idx = np.random.choice(len(df))
visualize(random_idx)

png

Text one: Friends, interested all go to have a look!
@ThePartyGoddess @OurLadyAngels @BJsWholesale @Richard_Jeni @FashionLavidaG @RapaRooski @DMVTHING @DeMarcoReports @LobidaFo @DeMarcoMorgan https://t.co/cStULl7y7G
Text two: Friends, interested all go to have a look!
@smittyses @CYosabel @crum_7 @CrumDarrell @ElymalikU @jenloarn @SoCodiePrevost @roblowry82 @Crummy_14 @CSchmelzenbach https://t.co/IZphLTNzgl
Label: Contradictory

png

Text one: 👟 KICK OFF @ MARDEN SPORTS COMPLEX
We're underway in the Round 6 opener!
📺: @Foxtel, @kayosports
📱: My Football Live app https://t.co/wHSpvQaoGC
#WLeague #ADLvMVC #AUFC #MVFC https://t.co/3Smp8KXm8W
Text two: 👟 KICK OFF @ MARSDEN SPORTS COMPLEX
We're underway in sunny Adelaide!
📺: @Foxtel, @kayosports
📱: My Football Live app https://t.co/wHSpvQaoGC
#ADLvCBR #WLeague #AUFC #UnitedAlways https://t.co/fG1PyLQXM4
Label: NoEntailment

训练/测试拆分

数据集存在 类别不平衡问题。我们可以在下面的单元格中确认这一点。

df["label"].value_counts()
NoEntailment     1182
Implies           109
Contradictory     109
Name: label, dtype: int64

为了解决这个问题,我们将进行分层拆分。

# 10% for test
train_df, test_df = train_test_split(
    df, test_size=0.1, stratify=df["label"].values, random_state=42
)
# 5% for validation
train_df, val_df = train_test_split(
    train_df, test_size=0.05, stratify=train_df["label"].values, random_state=42
)

print(f"Total training examples: {len(train_df)}")
print(f"Total validation examples: {len(val_df)}")
print(f"Total test examples: {len(test_df)}")
Total training examples: 1197
Total validation examples: 63
Total test examples: 140

数据输入管道

TensorFlow Hub 提供了 各种 BERT 系列模型。这些模型中的每一个都带有一个相应的预处理层。您可以从 此资源 中了解有关这些模型及其预处理层的更多信息。

为了使本示例的运行时间相对较短,我们将使用原始 BERT 模型的一个较小的变体。

# Define TF Hub paths to the BERT encoder and its preprocessor
bert_model_path = (
    "https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-256_A-4/1"
)
bert_preprocess_path = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"

我们的文本预处理代码主要来自 本教程。强烈建议您查看本教程以了解有关输入预处理的更多信息。

def make_bert_preprocessing_model(sentence_features, seq_length=128):
    """Returns Model mapping string features to BERT inputs.

  Args:
    sentence_features: A list with the names of string-valued features.
    seq_length: An integer that defines the sequence length of BERT inputs.

  Returns:
    A Keras Model that can be called on a list or dict of string Tensors
    (with the order or names, resp., given by sentence_features) and
    returns a dict of tensors for input to BERT.
  """

    input_segments = [
        tf.keras.layers.Input(shape=(), dtype=tf.string, name=ft)
        for ft in sentence_features
    ]

    # Tokenize the text to word pieces.
    bert_preprocess = hub.load(bert_preprocess_path)
    tokenizer = hub.KerasLayer(bert_preprocess.tokenize, name="tokenizer")
    segments = [tokenizer(s) for s in input_segments]

    # Optional: Trim segments in a smart way to fit seq_length.
    # Simple cases (like this example) can skip this step and let
    # the next step apply a default truncation to approximately equal lengths.
    truncated_segments = segments

    # Pack inputs. The details (start/end token ids, dict of output tensors)
    # are model-dependent, so this gets loaded from the SavedModel.
    packer = hub.KerasLayer(
        bert_preprocess.bert_pack_inputs,
        arguments=dict(seq_length=seq_length),
        name="packer",
    )
    model_inputs = packer(truncated_segments)
    return keras.Model(input_segments, model_inputs)


bert_preprocess_model = make_bert_preprocessing_model(["text_1", "text_2"])
keras.utils.plot_model(bert_preprocess_model, show_shapes=True, show_dtype=True)

png

在示例输入上运行预处理器

idx = np.random.choice(len(train_df))
row = train_df.iloc[idx]
sample_text_1, sample_text_2 = row["text_1"], row["text_2"]
print(f"Text 1: {sample_text_1}")
print(f"Text 2: {sample_text_2}")

test_text = [np.array([sample_text_1]), np.array([sample_text_2])]
text_preprocessed = bert_preprocess_model(test_text)

print("Keys           : ", list(text_preprocessed.keys()))
print("Shape Word Ids : ", text_preprocessed["input_word_ids"].shape)
print("Word Ids       : ", text_preprocessed["input_word_ids"][0, :16])
print("Shape Mask     : ", text_preprocessed["input_mask"].shape)
print("Input Mask     : ", text_preprocessed["input_mask"][0, :16])
print("Shape Type Ids : ", text_preprocessed["input_type_ids"].shape)
print("Type Ids       : ", text_preprocessed["input_type_ids"][0, :16])
Text 1: Renewables met 97% of Scotland's electricity demand in 2020!!!!
https://t.co/wi5c9UFAUF https://t.co/arcuBgh0BP
Text 2: Renewables met 97% of Scotland's electricity demand in 2020 https://t.co/SrhyqPnIkU https://t.co/LORgvTM7Sn
Keys           :  ['input_mask', 'input_word_ids', 'input_type_ids']
Shape Word Ids :  (1, 128)
Word Ids       :  tf.Tensor(
[  101 13918  2015  2777  5989  1003  1997  3885  1005  1055  6451  5157
  1999 12609   999   999], shape=(16,), dtype=int32)
Shape Mask     :  (1, 128)
Input Mask     :  tf.Tensor([1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1], shape=(16,), dtype=int32)
Shape Type Ids :  (1, 128)
Type Ids       :  tf.Tensor([0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], shape=(16,), dtype=int32)

我们现在将从数据帧创建 tf.data.Dataset 对象。

请注意,文本输入将在数据输入管道的一部分中进行预处理。但是,预处理模块也可以是其相应 BERT 模型的一部分。这有助于减少训练/服务偏差,并允许我们的模型使用原始文本输入进行操作。请遵循 本教程 了解有关如何将预处理模块直接合并到模型中的更多信息。

def dataframe_to_dataset(dataframe):
    columns = ["image_1_path", "image_2_path", "text_1", "text_2", "label_idx"]
    dataframe = dataframe[columns].copy()
    labels = dataframe.pop("label_idx")
    ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
    ds = ds.shuffle(buffer_size=len(dataframe))
    return ds

预处理实用程序

resize = (128, 128)
bert_input_features = ["input_word_ids", "input_type_ids", "input_mask"]


def preprocess_image(image_path):
    extension = tf.strings.split(image_path)[-1]

    image = tf.io.read_file(image_path)
    if extension == b"jpg":
        image = tf.image.decode_jpeg(image, 3)
    else:
        image = tf.image.decode_png(image, 3)
    image = tf.image.resize(image, resize)
    return image


def preprocess_text(text_1, text_2):
    text_1 = tf.convert_to_tensor([text_1])
    text_2 = tf.convert_to_tensor([text_2])
    output = bert_preprocess_model([text_1, text_2])
    output = {feature: tf.squeeze(output[feature]) for feature in bert_input_features}
    return output


def preprocess_text_and_image(sample):
    image_1 = preprocess_image(sample["image_1_path"])
    image_2 = preprocess_image(sample["image_2_path"])
    text = preprocess_text(sample["text_1"], sample["text_2"])
    return {"image_1": image_1, "image_2": image_2, "text": text}

创建最终数据集

batch_size = 32
auto = tf.data.AUTOTUNE


def prepare_dataset(dataframe, training=True):
    ds = dataframe_to_dataset(dataframe)
    if training:
        ds = ds.shuffle(len(train_df))
    ds = ds.map(lambda x, y: (preprocess_text_and_image(x), y)).cache()
    ds = ds.batch(batch_size).prefetch(auto)
    return ds


train_ds = prepare_dataset(train_df)
validation_ds = prepare_dataset(val_df, False)
test_ds = prepare_dataset(test_df, False)

模型构建实用程序

我们的最终模型将接受两张图像及其相应的文本。图像将直接馈送到模型,而文本输入将首先进行预处理,然后进入模型。以下是这种方法的可视化说明

该模型包含以下元素

  • 图像的独立编码器。我们将使用在 ImageNet-1k 数据集上预训练的 ResNet50V2 来实现这一点。
  • 图像的独立编码器。将使用预训练的 BERT 来实现这一点。

提取各个嵌入后,它们将在相同的空间中进行投影。最后,它们的投影将被连接起来并馈送到最终的分类层。

这是一个涉及以下类别的多类分类问题

  • NoEntailment
  • 蕴含
  • Contradictory

project_embeddings()create_vision_encoder()create_text_encoder() 实用程序参考自 此示例

投影实用程序

def project_embeddings(
    embeddings, num_projection_layers, projection_dims, dropout_rate
):
    projected_embeddings = keras.layers.Dense(units=projection_dims)(embeddings)
    for _ in range(num_projection_layers):
        x = tf.nn.gelu(projected_embeddings)
        x = keras.layers.Dense(projection_dims)(x)
        x = keras.layers.Dropout(dropout_rate)(x)
        x = keras.layers.Add()([projected_embeddings, x])
        projected_embeddings = keras.layers.LayerNormalization()(x)
    return projected_embeddings

视觉编码器实用程序

def create_vision_encoder(
    num_projection_layers, projection_dims, dropout_rate, trainable=False
):
    # Load the pre-trained ResNet50V2 model to be used as the base encoder.
    resnet_v2 = keras.applications.ResNet50V2(
        include_top=False, weights="imagenet", pooling="avg"
    )
    # Set the trainability of the base encoder.
    for layer in resnet_v2.layers:
        layer.trainable = trainable

    # Receive the images as inputs.
    image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
    image_2 = keras.Input(shape=(128, 128, 3), name="image_2")

    # Preprocess the input image.
    preprocessed_1 = keras.applications.resnet_v2.preprocess_input(image_1)
    preprocessed_2 = keras.applications.resnet_v2.preprocess_input(image_2)

    # Generate the embeddings for the images using the resnet_v2 model
    # concatenate them.
    embeddings_1 = resnet_v2(preprocessed_1)
    embeddings_2 = resnet_v2(preprocessed_2)
    embeddings = keras.layers.Concatenate()([embeddings_1, embeddings_2])

    # Project the embeddings produced by the model.
    outputs = project_embeddings(
        embeddings, num_projection_layers, projection_dims, dropout_rate
    )
    # Create the vision encoder model.
    return keras.Model([image_1, image_2], outputs, name="vision_encoder")

文本编码器实用程序

def create_text_encoder(
    num_projection_layers, projection_dims, dropout_rate, trainable=False
):
    # Load the pre-trained BERT model to be used as the base encoder.
    bert = hub.KerasLayer(bert_model_path, name="bert",)
    # Set the trainability of the base encoder.
    bert.trainable = trainable

    # Receive the text as inputs.
    bert_input_features = ["input_type_ids", "input_mask", "input_word_ids"]
    inputs = {
        feature: keras.Input(shape=(128,), dtype=tf.int32, name=feature)
        for feature in bert_input_features
    }

    # Generate embeddings for the preprocessed text using the BERT model.
    embeddings = bert(inputs)["pooled_output"]

    # Project the embeddings produced by the model.
    outputs = project_embeddings(
        embeddings, num_projection_layers, projection_dims, dropout_rate
    )
    # Create the text encoder model.
    return keras.Model(inputs, outputs, name="text_encoder")

多模态模型实用程序

def create_multimodal_model(
    num_projection_layers=1,
    projection_dims=256,
    dropout_rate=0.1,
    vision_trainable=False,
    text_trainable=False,
):
    # Receive the images as inputs.
    image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
    image_2 = keras.Input(shape=(128, 128, 3), name="image_2")

    # Receive the text as inputs.
    bert_input_features = ["input_type_ids", "input_mask", "input_word_ids"]
    text_inputs = {
        feature: keras.Input(shape=(128,), dtype=tf.int32, name=feature)
        for feature in bert_input_features
    }

    # Create the encoders.
    vision_encoder = create_vision_encoder(
        num_projection_layers, projection_dims, dropout_rate, vision_trainable
    )
    text_encoder = create_text_encoder(
        num_projection_layers, projection_dims, dropout_rate, text_trainable
    )

    # Fetch the embedding projections.
    vision_projections = vision_encoder([image_1, image_2])
    text_projections = text_encoder(text_inputs)

    # Concatenate the projections and pass through the classification layer.
    concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
    outputs = keras.layers.Dense(3, activation="softmax")(concatenated)
    return keras.Model([image_1, image_2, text_inputs], outputs)


multimodal_model = create_multimodal_model()
keras.utils.plot_model(multimodal_model, show_shapes=True)

png

您还可以通过将 plot_model()expand_nested 参数设置为 True 来检查各个编码器的结构。建议您尝试构建此模型时涉及的不同超参数,并观察最终性能如何受到影响。


编译和训练模型

multimodal_model.compile(
    optimizer="adam", loss="sparse_categorical_crossentropy", metrics="accuracy"
)

history = multimodal_model.fit(train_ds, validation_data=validation_ds, epochs=10)
Epoch 1/10
38/38 [==============================] - 49s 789ms/step - loss: 1.0014 - accuracy: 0.8229 - val_loss: 0.5514 - val_accuracy: 0.8571
Epoch 2/10
38/38 [==============================] - 3s 90ms/step - loss: 0.4019 - accuracy: 0.8814 - val_loss: 0.5866 - val_accuracy: 0.8571
Epoch 3/10
38/38 [==============================] - 3s 90ms/step - loss: 0.3557 - accuracy: 0.8897 - val_loss: 0.5929 - val_accuracy: 0.8571
Epoch 4/10
38/38 [==============================] - 3s 91ms/step - loss: 0.2877 - accuracy: 0.9006 - val_loss: 0.6272 - val_accuracy: 0.8571
Epoch 5/10
38/38 [==============================] - 3s 91ms/step - loss: 0.1796 - accuracy: 0.9398 - val_loss: 0.8545 - val_accuracy: 0.8254
Epoch 6/10
38/38 [==============================] - 3s 91ms/step - loss: 0.1292 - accuracy: 0.9566 - val_loss: 1.2276 - val_accuracy: 0.8413
Epoch 7/10
38/38 [==============================] - 3s 91ms/step - loss: 0.1015 - accuracy: 0.9666 - val_loss: 1.2914 - val_accuracy: 0.7778
Epoch 8/10
38/38 [==============================] - 3s 92ms/step - loss: 0.1253 - accuracy: 0.9524 - val_loss: 1.1944 - val_accuracy: 0.8413
Epoch 9/10
38/38 [==============================] - 3s 92ms/step - loss: 0.3064 - accuracy: 0.9131 - val_loss: 1.2162 - val_accuracy: 0.8095
Epoch 10/10
38/38 [==============================] - 3s 92ms/step - loss: 0.2212 - accuracy: 0.9248 - val_loss: 1.1080 - val_accuracy: 0.8413

评估模型

_, acc = multimodal_model.evaluate(test_ds)
print(f"Accuracy on the test set: {round(acc * 100, 2)}%.")
5/5 [==============================] - 6s 1s/step - loss: 0.8390 - accuracy: 0.8429
Accuracy on the test set: 84.29%.

有关训练的其他说明

整合正则化:

训练日志表明模型开始过拟合,并且可能受益于正则化。Dropout(Srivastava 等人)是一种简单但强大的正则化技术,我们可以在模型中使用它。但是我们应该如何在这里应用它呢?

我们始终可以在模型的不同层之间引入 Dropout(keras.layers.Dropout)。但这里还有另一种方法。我们的模型期望来自两种不同数据模态的输入。如果在推理期间任何一种模态都不存在怎么办?为了解决这个问题,我们可以在连接之前引入 Dropout 到各个投影中

vision_projections = keras.layers.Dropout(rate)(vision_projections)
text_projections = keras.layers.Dropout(rate)(text_projections)
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])

关注重点:

图像的所有部分是否都与其文本对应部分一样重要?情况可能并非如此。为了使我们的模型只关注与相应文本部分相关的图像的最重要部分,我们可以使用“交叉注意力”

# Embeddings.
vision_projections = vision_encoder([image_1, image_2])
text_projections = text_encoder(text_inputs)

# Cross-attention (Luong-style).
query_value_attention_seq = keras.layers.Attention(use_scale=True, dropout=0.2)(
    [vision_projections, text_projections]
)
# Concatenate.
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
contextual = keras.layers.Concatenate()([concatenated, query_value_attention_seq])

要查看实际效果,请参考 此笔记本

处理类别不平衡:

数据集存在类别不平衡问题。研究上述模型的混淆矩阵表明,它在少数类上的表现不佳。如果我们使用了加权损失,那么训练将更有指导意义。您可以查看 此笔记本,它在模型训练期间考虑了类别不平衡。

仅使用文本输入:

此外,如果我们仅将文本输入用于蕴含任务会怎样?由于在社交媒体平台上遇到的文本输入的性质,仅文本输入会损害最终性能。在类似的训练设置下,仅使用文本输入,我们在相同的测试集上获得了 67.14% 的 top-1 准确率。有关详细信息,请参考 此笔记本

最后,这是一个比较蕴含任务中不同方法的表格

类型 标准
交叉熵
损失加权
交叉熵
Focal Loss
多模态 77.86% 67.86% 86.43%
仅文本 67.14% 11.43% 37.86%

您可以查看 此存储库,以了解如何进行实验以获得这些数字。


最后说明

  • 我们在本示例中使用的架构对于可用于训练的数据点数来说太大了。它将受益于更多数据。
  • 我们使用了原始 BERT 模型的较小变体。很有可能,使用更大的变体,性能将得到提升。TensorFlow Hub 提供了许多不同的 BERT 模型,您可以尝试使用它们。
  • 我们保持预训练模型冻结。在多模态蕴含任务上对它们进行微调可能会带来更好的性能。
  • 我们为多模态蕴含任务构建了一个简单的基线模型。已经提出了各种方法来解决蕴含问题。此演示文稿来自 识别多模态蕴含 教程,提供了全面的概述。

您可以使用托管在 Hugging Face Hub 上的训练模型,并在 Hugging Face Spaces 上尝试演示