作者: Sayak Paul
创建日期 2021/08/08
最后修改日期 2021/08/15
在本示例中,我们将构建并训练一个用于预测多模态蕴含的模型。我们将使用 Google Research 最近推出的 多模态蕴含数据集。
在 NLP 中,此任务称为分析文本蕴含。但是,这仅限于信息来自文本内容的情况。在实践中,信息通常不仅来自文本内容,还来自文本、图像、音频、视频等的多种组合。多模态蕴含只是将文本蕴含扩展到各种新的输入模态。
此示例需要 TensorFlow 2.5 或更高版本。此外,BERT 模型(Devlin 等人)需要 TensorFlow Hub 和 TensorFlow Text。可以使用以下命令安装这些库
!pip install -q tensorflow_text
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from tensorflow import keras
label_map = {"Contradictory": 0, "Implies": 1, "NoEntailment": 2}
原始数据集可在此处获得 此处。它包含托管在 Twitter 的照片存储系统(称为Photo Blob Storage(简称 PBS))上的图像 URL。我们将使用下载的图像以及原始数据集附带的其他数据。感谢Nilabhra Roy Chowdhury 完成图像数据准备工作。
image_base_path = keras.utils.get_file(
df = pd.read_csv(
id_1 | text_1 | image_1 | id_2 | text_2 | image_2 | label | |
291 | 1330800194863190016 | #KLM1167 (B738): #AMS (Amsterdam) to #HEL (Van... | http://pbs.twimg.com/media/EnfzuZAW4AE236p.png | 1378695438480588802 | #CKK205 (B77L): #PVG (Shanghai) to #AMS (Amste... | http://pbs.twimg.com/media/EyIcMexXEAE6gia.png | NoEntailment |
37 | 1366581728312057856 | Friends, interested all go to have a look!\n@j... | http://pbs.twimg.com/media/EvcS1v4UcAEEXPO.jpg | 1373810535066570759 | Friends, interested all go to have a look!\n@f... | http://pbs.twimg.com/media/ExDBZqwVIAQ4LWk.jpg | Contradictory |
315 | 1352551603258052608 | #WINk Drops I have earned today🚀\n\nToday:1/22... | http://pbs.twimg.com/media/EsTdcLLVcAIiFKT.jpg | 1354636016234098688 | #WINk Drops I have earned today☀\n\nToday:1/28... | http://pbs.twimg.com/media/EsyhK-qU0AgfMAH.jpg | NoEntailment |
761 | 1379795999493853189 | #buythedip Ready to FLY even HIGHER #pennysto... | http://pbs.twimg.com/media/EyYFJCzWgAMfTrT.jpg | 1380190250144792576 | #buythedip Ready to FLY even HIGHER #pennysto... | http://pbs.twimg.com/media/Eydrt0ZXAAMmbfv.jpg | NoEntailment |
146 | 1340185132293099523 | I know sometimes I am weird to you.\n\nBecause... | http://pbs.twimg.com/media/EplLRriWwAAJ2AE.jpg | 1359755419883814913 | I put my sword down and get on my knees to swe... | http://pbs.twimg.com/media/Et7SWWeWYAICK-c.jpg | NoEntailment |
1351 | 1381256604926967813 | Finally completed the skin rendering. Will sta... | http://pbs.twimg.com/media/Eys1j7NVIAgF-YF.jpg | 1381630932092784641 | Hair rendering. Will finish the hair by tomorr... | http://pbs.twimg.com/media/EyyKAoaUUAElm-e.jpg | NoEntailment |
368 | 1371883298805403649 | 📉 $LINK Number of Receiving Addresses (7d MA) ... | http://pbs.twimg.com/media/EwnoltOWEAAS4mG.jpg | 1373216720974979072 | 📉 $LINK Number of Receiving Addresses (7d MA) ... | http://pbs.twimg.com/media/Ew6lVGYXEAE6Ugi.jpg | NoEntailment |
1112 | 1377679115159887873 | April is National Distracted Driving Awareness... | http://pbs.twimg.com/media/Ex5_u7UVIAARjQ2.jpg | 1379075258448281608 | April is Distracted Driving Awareness Month. ... | http://pbs.twimg.com/media/EyN1YjpWUAMc5ak.jpg | NoEntailment |
264 | 1330727515741167619 | ♥️Verse Of The Day♥️\n.\n#VerseOfTheDay #Quran... | http://pbs.twimg.com/media/EnexnydXIAYuI11.jpg | 1332623263495819264 | ♥️Verse Of The Day♥️\n.\n#VerseOfTheDay #Quran... | http://pbs.twimg.com/media/En5ty1VXUAATALP.jpg | NoEntailment |
865 | 1377784616275296261 | No white picket fence can keep us in. #TBT 200... | http://pbs.twimg.com/media/Ex7fzouWQAITAq8.jpg | 1380175915804672012 | Sometimes you just need to change your altitud... | http://pbs.twimg.com/media/EydernQXIAk2g5v.jpg | NoEntailment |
给定 (text_1
, image_1
) 和 (text_2
, image_2
) 对,它们是否蕴含(或不蕴含或矛盾)?
下载为 id1
下载为 id2
作为其文件名。在下一步中,我们将向 df
添加两列 - image_1
和 image_2
images_one_paths = []
images_two_paths = []
for idx in range(len(df)):
current_row = df.iloc[idx]
id_1 = current_row["id_1"]
id_2 = current_row["id_2"]
extentsion_one = current_row["image_1"].split(".")[-1]
extentsion_two = current_row["image_2"].split(".")[-1]
image_one_path = os.path.join(image_base_path, str(id_1) + f".{extentsion_one}")
image_two_path = os.path.join(image_base_path, str(id_2) + f".{extentsion_two}")
df["image_1_path"] = images_one_paths
df["image_2_path"] = images_two_paths
# Create another column containing the integer ids of
# the string labels.
df["label_idx"] = df["label"].apply(lambda x: label_map[x])
def visualize(idx):
current_row = df.iloc[idx]
image_1 = plt.imread(current_row["image_1_path"])
image_2 = plt.imread(current_row["image_2_path"])
text_1 = current_row["text_1"]
text_2 = current_row["text_2"]
label = current_row["label"]
plt.subplot(1, 2, 1)
plt.title("Image One")
plt.subplot(1, 2, 2)
plt.title("Image Two")
print(f"Text one: {text_1}")
print(f"Text two: {text_2}")
print(f"Label: {label}")
random_idx = np.random.choice(len(df))
random_idx = np.random.choice(len(df))
Text one: Friends, interested all go to have a look!
@ThePartyGoddess @OurLadyAngels @BJsWholesale @Richard_Jeni @FashionLavidaG @RapaRooski @DMVTHING @DeMarcoReports @LobidaFo @DeMarcoMorgan https://t.co/cStULl7y7G
Text two: Friends, interested all go to have a look!
@smittyses @CYosabel @crum_7 @CrumDarrell @ElymalikU @jenloarn @SoCodiePrevost @roblowry82 @Crummy_14 @CSchmelzenbach https://t.co/IZphLTNzgl
Label: Contradictory
We're underway in the Round 6 opener!
📺: @Foxtel, @kayosports
📱: My Football Live app https://t.co/wHSpvQaoGC
#WLeague #ADLvMVC #AUFC #MVFC https://t.co/3Smp8KXm8W
We're underway in sunny Adelaide!
📺: @Foxtel, @kayosports
📱: My Football Live app https://t.co/wHSpvQaoGC
#ADLvCBR #WLeague #AUFC #UnitedAlways https://t.co/fG1PyLQXM4
Label: NoEntailment
数据集存在 类别不平衡问题。我们可以在下面的单元格中确认这一点。
NoEntailment 1182
Implies 109
Contradictory 109
Name: label, dtype: int64
# 10% for test
train_df, test_df = train_test_split(
df, test_size=0.1, stratify=df["label"].values, random_state=42
# 5% for validation
train_df, val_df = train_test_split(
train_df, test_size=0.05, stratify=train_df["label"].values, random_state=42
print(f"Total training examples: {len(train_df)}")
print(f"Total validation examples: {len(val_df)}")
print(f"Total test examples: {len(test_df)}")
Total training examples: 1197
Total validation examples: 63
Total test examples: 140
TensorFlow Hub 提供了 各种 BERT 系列模型。这些模型中的每一个都带有一个相应的预处理层。您可以从 此资源 中了解有关这些模型及其预处理层的更多信息。
为了使本示例的运行时间相对较短,我们将使用原始 BERT 模型的一个较小的变体。
# Define TF Hub paths to the BERT encoder and its preprocessor
bert_model_path = (
bert_preprocess_path = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
我们的文本预处理代码主要来自 本教程。强烈建议您查看本教程以了解有关输入预处理的更多信息。
def make_bert_preprocessing_model(sentence_features, seq_length=128):
"""Returns Model mapping string features to BERT inputs.
sentence_features: A list with the names of string-valued features.
seq_length: An integer that defines the sequence length of BERT inputs.
A Keras Model that can be called on a list or dict of string Tensors
(with the order or names, resp., given by sentence_features) and
returns a dict of tensors for input to BERT.
input_segments = [
tf.keras.layers.Input(shape=(), dtype=tf.string, name=ft)
for ft in sentence_features
# Tokenize the text to word pieces.
bert_preprocess = hub.load(bert_preprocess_path)
tokenizer = hub.KerasLayer(bert_preprocess.tokenize, name="tokenizer")
segments = [tokenizer(s) for s in input_segments]
# Optional: Trim segments in a smart way to fit seq_length.
# Simple cases (like this example) can skip this step and let
# the next step apply a default truncation to approximately equal lengths.
truncated_segments = segments
# Pack inputs. The details (start/end token ids, dict of output tensors)
# are model-dependent, so this gets loaded from the SavedModel.
packer = hub.KerasLayer(
model_inputs = packer(truncated_segments)
return keras.Model(input_segments, model_inputs)
bert_preprocess_model = make_bert_preprocessing_model(["text_1", "text_2"])
keras.utils.plot_model(bert_preprocess_model, show_shapes=True, show_dtype=True)
idx = np.random.choice(len(train_df))
row = train_df.iloc[idx]
sample_text_1, sample_text_2 = row["text_1"], row["text_2"]
print(f"Text 1: {sample_text_1}")
print(f"Text 2: {sample_text_2}")
test_text = [np.array([sample_text_1]), np.array([sample_text_2])]
text_preprocessed = bert_preprocess_model(test_text)
print("Keys : ", list(text_preprocessed.keys()))
print("Shape Word Ids : ", text_preprocessed["input_word_ids"].shape)
print("Word Ids : ", text_preprocessed["input_word_ids"][0, :16])
print("Shape Mask : ", text_preprocessed["input_mask"].shape)
print("Input Mask : ", text_preprocessed["input_mask"][0, :16])
print("Shape Type Ids : ", text_preprocessed["input_type_ids"].shape)
print("Type Ids : ", text_preprocessed["input_type_ids"][0, :16])
Text 1: Renewables met 97% of Scotland's electricity demand in 2020!!!!
https://t.co/wi5c9UFAUF https://t.co/arcuBgh0BP
Text 2: Renewables met 97% of Scotland's electricity demand in 2020 https://t.co/SrhyqPnIkU https://t.co/LORgvTM7Sn
Keys : ['input_mask', 'input_word_ids', 'input_type_ids']
Shape Word Ids : (1, 128)
Word Ids : tf.Tensor(
[ 101 13918 2015 2777 5989 1003 1997 3885 1005 1055 6451 5157
1999 12609 999 999], shape=(16,), dtype=int32)
Shape Mask : (1, 128)
Input Mask : tf.Tensor([1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1], shape=(16,), dtype=int32)
Shape Type Ids : (1, 128)
Type Ids : tf.Tensor([0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], shape=(16,), dtype=int32)
我们现在将从数据帧创建 tf.data.Dataset
请注意,文本输入将在数据输入管道的一部分中进行预处理。但是,预处理模块也可以是其相应 BERT 模型的一部分。这有助于减少训练/服务偏差,并允许我们的模型使用原始文本输入进行操作。请遵循 本教程 了解有关如何将预处理模块直接合并到模型中的更多信息。
def dataframe_to_dataset(dataframe):
columns = ["image_1_path", "image_2_path", "text_1", "text_2", "label_idx"]
dataframe = dataframe[columns].copy()
labels = dataframe.pop("label_idx")
ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
ds = ds.shuffle(buffer_size=len(dataframe))
return ds
resize = (128, 128)
bert_input_features = ["input_word_ids", "input_type_ids", "input_mask"]
def preprocess_image(image_path):
extension = tf.strings.split(image_path)[-1]
image = tf.io.read_file(image_path)
if extension == b"jpg":
image = tf.image.decode_jpeg(image, 3)
image = tf.image.decode_png(image, 3)
image = tf.image.resize(image, resize)
return image
def preprocess_text(text_1, text_2):
text_1 = tf.convert_to_tensor([text_1])
text_2 = tf.convert_to_tensor([text_2])
output = bert_preprocess_model([text_1, text_2])
output = {feature: tf.squeeze(output[feature]) for feature in bert_input_features}
return output
def preprocess_text_and_image(sample):
image_1 = preprocess_image(sample["image_1_path"])
image_2 = preprocess_image(sample["image_2_path"])
text = preprocess_text(sample["text_1"], sample["text_2"])
return {"image_1": image_1, "image_2": image_2, "text": text}
batch_size = 32
auto = tf.data.AUTOTUNE
def prepare_dataset(dataframe, training=True):
ds = dataframe_to_dataset(dataframe)
if training:
ds = ds.shuffle(len(train_df))
ds = ds.map(lambda x, y: (preprocess_text_and_image(x), y)).cache()
ds = ds.batch(batch_size).prefetch(auto)
return ds
train_ds = prepare_dataset(train_df)
validation_ds = prepare_dataset(val_df, False)
test_ds = prepare_dataset(test_df, False)
和 create_text_encoder()
实用程序参考自 此示例。
def project_embeddings(
embeddings, num_projection_layers, projection_dims, dropout_rate
projected_embeddings = keras.layers.Dense(units=projection_dims)(embeddings)
for _ in range(num_projection_layers):
x = tf.nn.gelu(projected_embeddings)
x = keras.layers.Dense(projection_dims)(x)
x = keras.layers.Dropout(dropout_rate)(x)
x = keras.layers.Add()([projected_embeddings, x])
projected_embeddings = keras.layers.LayerNormalization()(x)
return projected_embeddings
def create_vision_encoder(
num_projection_layers, projection_dims, dropout_rate, trainable=False
# Load the pre-trained ResNet50V2 model to be used as the base encoder.
resnet_v2 = keras.applications.ResNet50V2(
include_top=False, weights="imagenet", pooling="avg"
# Set the trainability of the base encoder.
for layer in resnet_v2.layers:
layer.trainable = trainable
# Receive the images as inputs.
image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
image_2 = keras.Input(shape=(128, 128, 3), name="image_2")
# Preprocess the input image.
preprocessed_1 = keras.applications.resnet_v2.preprocess_input(image_1)
preprocessed_2 = keras.applications.resnet_v2.preprocess_input(image_2)
# Generate the embeddings for the images using the resnet_v2 model
# concatenate them.
embeddings_1 = resnet_v2(preprocessed_1)
embeddings_2 = resnet_v2(preprocessed_2)
embeddings = keras.layers.Concatenate()([embeddings_1, embeddings_2])
# Project the embeddings produced by the model.
outputs = project_embeddings(
embeddings, num_projection_layers, projection_dims, dropout_rate
# Create the vision encoder model.
return keras.Model([image_1, image_2], outputs, name="vision_encoder")
def create_text_encoder(
num_projection_layers, projection_dims, dropout_rate, trainable=False
# Load the pre-trained BERT model to be used as the base encoder.
bert = hub.KerasLayer(bert_model_path, name="bert",)
# Set the trainability of the base encoder.
bert.trainable = trainable
# Receive the text as inputs.
bert_input_features = ["input_type_ids", "input_mask", "input_word_ids"]
inputs = {
feature: keras.Input(shape=(128,), dtype=tf.int32, name=feature)
for feature in bert_input_features
# Generate embeddings for the preprocessed text using the BERT model.
embeddings = bert(inputs)["pooled_output"]
# Project the embeddings produced by the model.
outputs = project_embeddings(
embeddings, num_projection_layers, projection_dims, dropout_rate
# Create the text encoder model.
return keras.Model(inputs, outputs, name="text_encoder")
def create_multimodal_model(
# Receive the images as inputs.
image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
image_2 = keras.Input(shape=(128, 128, 3), name="image_2")
# Receive the text as inputs.
bert_input_features = ["input_type_ids", "input_mask", "input_word_ids"]
text_inputs = {
feature: keras.Input(shape=(128,), dtype=tf.int32, name=feature)
for feature in bert_input_features
# Create the encoders.
vision_encoder = create_vision_encoder(
num_projection_layers, projection_dims, dropout_rate, vision_trainable
text_encoder = create_text_encoder(
num_projection_layers, projection_dims, dropout_rate, text_trainable
# Fetch the embedding projections.
vision_projections = vision_encoder([image_1, image_2])
text_projections = text_encoder(text_inputs)
# Concatenate the projections and pass through the classification layer.
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
outputs = keras.layers.Dense(3, activation="softmax")(concatenated)
return keras.Model([image_1, image_2, text_inputs], outputs)
multimodal_model = create_multimodal_model()
keras.utils.plot_model(multimodal_model, show_shapes=True)
您还可以通过将 plot_model()
的 expand_nested
参数设置为 True
optimizer="adam", loss="sparse_categorical_crossentropy", metrics="accuracy"
history = multimodal_model.fit(train_ds, validation_data=validation_ds, epochs=10)
Epoch 1/10
38/38 [==============================] - 49s 789ms/step - loss: 1.0014 - accuracy: 0.8229 - val_loss: 0.5514 - val_accuracy: 0.8571
Epoch 2/10
38/38 [==============================] - 3s 90ms/step - loss: 0.4019 - accuracy: 0.8814 - val_loss: 0.5866 - val_accuracy: 0.8571
Epoch 3/10
38/38 [==============================] - 3s 90ms/step - loss: 0.3557 - accuracy: 0.8897 - val_loss: 0.5929 - val_accuracy: 0.8571
Epoch 4/10
38/38 [==============================] - 3s 91ms/step - loss: 0.2877 - accuracy: 0.9006 - val_loss: 0.6272 - val_accuracy: 0.8571
Epoch 5/10
38/38 [==============================] - 3s 91ms/step - loss: 0.1796 - accuracy: 0.9398 - val_loss: 0.8545 - val_accuracy: 0.8254
Epoch 6/10
38/38 [==============================] - 3s 91ms/step - loss: 0.1292 - accuracy: 0.9566 - val_loss: 1.2276 - val_accuracy: 0.8413
Epoch 7/10
38/38 [==============================] - 3s 91ms/step - loss: 0.1015 - accuracy: 0.9666 - val_loss: 1.2914 - val_accuracy: 0.7778
Epoch 8/10
38/38 [==============================] - 3s 92ms/step - loss: 0.1253 - accuracy: 0.9524 - val_loss: 1.1944 - val_accuracy: 0.8413
Epoch 9/10
38/38 [==============================] - 3s 92ms/step - loss: 0.3064 - accuracy: 0.9131 - val_loss: 1.2162 - val_accuracy: 0.8095
Epoch 10/10
38/38 [==============================] - 3s 92ms/step - loss: 0.2212 - accuracy: 0.9248 - val_loss: 1.1080 - val_accuracy: 0.8413
_, acc = multimodal_model.evaluate(test_ds)
print(f"Accuracy on the test set: {round(acc * 100, 2)}%.")
5/5 [==============================] - 6s 1s/step - loss: 0.8390 - accuracy: 0.8429
Accuracy on the test set: 84.29%.
训练日志表明模型开始过拟合,并且可能受益于正则化。Dropout(Srivastava 等人)是一种简单但强大的正则化技术,我们可以在模型中使用它。但是我们应该如何在这里应用它呢?
我们始终可以在模型的不同层之间引入 Dropout(keras.layers.Dropout
)。但这里还有另一种方法。我们的模型期望来自两种不同数据模态的输入。如果在推理期间任何一种模态都不存在怎么办?为了解决这个问题,我们可以在连接之前引入 Dropout 到各个投影中
vision_projections = keras.layers.Dropout(rate)(vision_projections)
text_projections = keras.layers.Dropout(rate)(text_projections)
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
# Embeddings.
vision_projections = vision_encoder([image_1, image_2])
text_projections = text_encoder(text_inputs)
# Cross-attention (Luong-style).
query_value_attention_seq = keras.layers.Attention(use_scale=True, dropout=0.2)(
[vision_projections, text_projections]
# Concatenate.
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
contextual = keras.layers.Concatenate()([concatenated, query_value_attention_seq])
要查看实际效果,请参考 此笔记本。
数据集存在类别不平衡问题。研究上述模型的混淆矩阵表明,它在少数类上的表现不佳。如果我们使用了加权损失,那么训练将更有指导意义。您可以查看 此笔记本,它在模型训练期间考虑了类别不平衡。
此外,如果我们仅将文本输入用于蕴含任务会怎样?由于在社交媒体平台上遇到的文本输入的性质,仅文本输入会损害最终性能。在类似的训练设置下,仅使用文本输入,我们在相同的测试集上获得了 67.14% 的 top-1 准确率。有关详细信息,请参考 此笔记本。
类型 | 标准 交叉熵 |
损失加权 交叉熵 |
Focal Loss |
多模态 | 77.86% | 67.86% | 86.43% |
仅文本 | 67.14% | 11.43% | 37.86% |
您可以查看 此存储库,以了解如何进行实验以获得这些数字。
您可以使用托管在 Hugging Face Hub 上的训练模型,并在 Hugging Face Spaces 上尝试演示