作者: Sayak Paul
创建日期 2021/08/08
上次修改日期 2025/01/03
描述: 训练一个用于预测蕴含关系的多模态模型。
在此示例中,我们将构建并训练一个用于预测多模态蕴含关系的模型。我们将使用 Google Research 最近推出的多模态蕴含数据集。
在社交媒体平台上,为了审核和管理内容,我们可能需要近实时地找到以下问题的答案:
在自然语言处理中,此任务称为分析文本蕴含。但这仅当信息来自文本内容时。在实践中,通常情况下,可用的信息不仅来自文本内容,还来自文本、图像、音频、视频等的多模态组合。多模态蕴含 只是文本蕴含扩展到各种新的输入模态。
此示例需要 TensorFlow 2.5 或更高版本。此外,BERT 模型还需要 TensorFlow Hub 和 TensorFlow Text (Devlin 等人)。可以使用以下命令安装这些库:
!pip install -q tensorflow_text
[[34;49mnotice[1;39;49m][39;49m A new release of pip is available: [31;49m24.0[39;49m -> [32;49m24.3.1
[[34;49mnotice[1;39;49m][39;49m To update, run: [32;49mpip install --upgrade pip
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import random
import math
from skimage.io import imread
from skimage.transform import resize
from PIL import Image
import os
os.environ["KERAS_BACKEND"] = "jax" # or tensorflow, or torch
import keras
import keras_hub
from keras.utils import PyDataset
label_map = {"Contradictory": 0, "Implies": 1, "NoEntailment": 2}
原始数据集可在此处获得:here。它附带了托管在 Twitter 的照片存储系统(称为照片 Blob 存储(简称 PBS))上的图像 URL。我们将使用下载的图像以及原始数据集附带的其他数据。感谢Nilabhra Roy Chowdhury,他负责准备图像数据。
image_base_path = keras.utils.get_file(
"tweet_images",
"https://github.com/sayakpaul/Multimodal-Entailment-Baseline/releases/download/v1.0.0/tweet_images.tar.gz",
untar=True,
)
df = pd.read_csv(
"https://github.com/sayakpaul/Multimodal-Entailment-Baseline/raw/main/csvs/tweets.csv"
).iloc[
0:1000
] # Resources conservation since these are examples and not SOTA
df.sample(10)
id_1 | text_1 | image_1 | id_2 | text_2 | image_2 | 标签 | |
---|---|---|---|---|---|---|---|
815 | 1370730009921343490 | 粘性炸弹带有磁铁,是一种威胁... | http://pbs.twimg.com/media/EwXOFrgVIAEkfjR.jpg | 1370731764906295307 | 粘性炸弹带有磁铁,是一种威胁... | http://pbs.twimg.com/media/EwXRK_3XEAA6Q6F.jpg | 无蕴含关系 |
615 | 1364119737446395905 | #巨蟹座 2.23.21 每日星座运势 ♊️❤️✨ #Hor... | http://pbs.twimg.com/media/Eu5Te44VgAIo1jZ.jpg | 1365218087906078720 | #巨蟹座 2.26.21 每日星座运势 ♊️❤️✨ #Hor... | http://pbs.twimg.com/media/EvI6nW4WQAA4_E_.jpg | 无蕴含关系 |
624 | 1335542260923068417 | 驯鹿跑回来了,今年的比赛是... | http://pbs.twimg.com/media/Eoi99DyXEAE0AFV.jpg | 1335872932267122689 | 戴上你的红鼻子和鹿角,参加 2020 ... | http://pbs.twimg.com/media/Eon5Wk7XUAE-CxN.jpg | 无蕴含关系 |
970 | 1345058844439949312 | 需要参加在线调查的参与者!\n\n主题... | http://pbs.twimg.com/media/Eqqb4_MXcAA-Pvu.jpg | 1361211461792632835 | 需要参加关于苏...的顶尖研究的参与者 | http://pbs.twimg.com/media/EuPz0GwXMAMDklt.jpg | 无蕴含关系 |
456 | 1379831489043521545 | 给 @NanoBiteTSF 的委托\n喜欢兄弟们和 ... | http://pbs.twimg.com/media/EyVf0_VXMAMtRaL.jpg | 1380660763749142531 | 给 @NanoBiteTSF 的另一个委托\n希望你... | http://pbs.twimg.com/media/EykW0iXXAAA2SBC.jpg | 无蕴含关系 |
917 | 1336180735191891968 | (2/10)\n(首尔中区) 市场集群 ->\n... | http://pbs.twimg.com/media/EosRFpGVQAIeuYG.jpg | 1356113330536996866 | (3/11)\n(首尔东大门区)考试院集群... | http://pbs.twimg.com/media/EtHhj7QVcAAibvF.jpg | 无蕴含关系 |
276 | 1339270210029834241 | 今天,自由的信息传到了基索罗,R... | http://pbs.twimg.com/media/EpVK3pfXcAAZ5Du.jpg | 1340881971132698625 | 今天,自由的信息传到了 p... | http://pbs.twimg.com/media/EpvDorkXYAEyz4g.jpg | 暗示 |
35 | 1360186999836200961 | 阿根廷的比特币 - Google Trends https://t... | http://pbs.twimg.com/media/EuBa3UxXYAMb99_.jpg | 1382778703055228929 | 阿根廷想要 #比特币 https://127.0.0.1/9lNxJdxX... | http://pbs.twimg.com/media/EzCbUFNXMAABwPD.jpg | 暗示 |
762 | 1370824756400959491 | $HSBA.L:长期趋势是积极的,并且... | http://pbs.twimg.com/media/EwYl2hPWYAE2niq.png | 1374347458126475269 | 尽管技术评级仅为中等,但... | http://pbs.twimg.com/media/ExKpuwrWgAAktg4.png | 无蕴含关系 |
130 | 1373789433607172097 | 我刚刚看了泰德·拉索的 S01 | E05 集... | http://pbs.twimg.com/media/ExCuNbDXAAQaPiL.jpg | 1374913509662806016 | 我刚刚看了泰德·拉索的 S01 | E06 集... | http://pbs.twimg.com/media/ExSsjRQWgAUVRPz.jpg | 矛盾 |
我们感兴趣的列如下:
text_1
image_1
text_2
image_2
标签
蕴含任务公式化如下:
给定 (text_1
, image_1
) 和 (text_2
, image_2
) 的配对,它们是否互相蕴含(或不蕴含或矛盾)?
我们已经下载了图像。image_1
下载为文件名为 id1
,image2
下载为文件名为 id2
。在下一步中,我们将向 df
添加两个额外的列:image_1
和 image_2
的文件路径。
images_one_paths = []
images_two_paths = []
for idx in range(len(df)):
current_row = df.iloc[idx]
id_1 = current_row["id_1"]
id_2 = current_row["id_2"]
extentsion_one = current_row["image_1"].split(".")[-1]
extentsion_two = current_row["image_2"].split(".")[-1]
image_one_path = os.path.join(image_base_path, str(id_1) + f".{extentsion_one}")
image_two_path = os.path.join(image_base_path, str(id_2) + f".{extentsion_two}")
images_one_paths.append(image_one_path)
images_two_paths.append(image_two_path)
df["image_1_path"] = images_one_paths
df["image_2_path"] = images_two_paths
# Create another column containing the integer ids of
# the string labels.
df["label_idx"] = df["label"].apply(lambda x: label_map[x])
def visualize(idx):
current_row = df.iloc[idx]
image_1 = plt.imread(current_row["image_1_path"])
image_2 = plt.imread(current_row["image_2_path"])
text_1 = current_row["text_1"]
text_2 = current_row["text_2"]
label = current_row["label"]
plt.subplot(1, 2, 1)
plt.imshow(image_1)
plt.axis("off")
plt.title("Image One")
plt.subplot(1, 2, 2)
plt.imshow(image_1)
plt.axis("off")
plt.title("Image Two")
plt.show()
print(f"Text one: {text_1}")
print(f"Text two: {text_2}")
print(f"Label: {label}")
random_idx = random.choice(range(len(df)))
visualize(random_idx)
random_idx = random.choice(range(len(df)))
visualize(random_idx)
Text one: World #water day reminds that we should follow the #guidelines to save water for us. This Day is an #opportunity to learn more about water related issues, be #inspired to tell others and take action to make a difference. Just remember, every #drop counts.
#WorldWaterDay2021 https://127.0.0.1/bQ9Hp53qUj
Text two: Water is an extremely precious resource without which life would be impossible. We need to ensure that water is used judiciously, this #WorldWaterDay, let us pledge to reduce water wastage and conserve it.
#WorldWaterDay2021 https://127.0.0.1/0KWnd8Kn8r
Label: NoEntailment
Text one: 🎧 𝗘𝗣𝗜𝗦𝗢𝗗𝗘 𝟯𝟬: 𝗗𝗬𝗟𝗔𝗡 𝗙𝗜𝗧𝗭𝗦𝗜𝗠𝗢𝗡𝗦
Dylan Fitzsimons is a young passionate greyhound supporter.
He and @Drakesport enjoy a great chat about everything greyhounds!
Listen: https://127.0.0.1/B2XgMp0yaO
#GoGreyhoundRacing #ThisRunsDeep #TalkingDogs https://127.0.0.1/crBiSqHUvp
Text two: 🎧 𝗘𝗣𝗜𝗦𝗢𝗗𝗘 𝟯𝟳: 𝗣𝗜𝗢 𝗕𝗔𝗥𝗥𝗬 🎧
Well known within greyhound circles, Pio Barry shares some wonderful greyhound racing stories with @Drakesport in this podcast episode.
A great chat.
Listen: https://127.0.0.1/mJTVlPHzp0
#TalkingDogs #GoGreyhoundRacing #ThisRunsDeep https://127.0.0.1/QbxtCpLcGm
Label: NoEntailment
数据集存在类别不平衡问题。我们可以在以下单元格中确认这一点。
df["label"].value_counts()
label
NoEntailment 819
Contradictory 92
Implies 89
Name: count, dtype: int64
为了解决这个问题,我们将进行分层拆分。
# 10% for test
train_df, test_df = train_test_split(
df, test_size=0.1, stratify=df["label"].values, random_state=42
)
# 5% for validation
train_df, val_df = train_test_split(
train_df, test_size=0.05, stratify=train_df["label"].values, random_state=42
)
print(f"Total training examples: {len(train_df)}")
print(f"Total validation examples: {len(val_df)}")
print(f"Total test examples: {len(test_df)}")
Total training examples: 855
Total validation examples: 45
Total test examples: 100
Keras Hub 提供了各种 BERT 系列模型。这些模型中的每一个都带有相应的预处理层。您可以从此资源了解有关这些模型及其预处理层的更多信息。
为了使此示例的运行时间相对较短,我们将使用原始 BERT 模型的基本未区分大小写变体。
使用 KerasHub 进行文本预处理
text_preprocessor = keras_hub.models.BertTextClassifierPreprocessor.from_preset(
"bert_base_en_uncased",
sequence_length=128,
)
idx = random.choice(range(len(train_df)))
row = train_df.iloc[idx]
sample_text_1, sample_text_2 = row["text_1"], row["text_2"]
print(f"Text 1: {sample_text_1}")
print(f"Text 2: {sample_text_2}")
test_text = [sample_text_1, sample_text_2]
text_preprocessed = text_preprocessor(test_text)
print("Keys : ", list(text_preprocessed.keys()))
print("Shape Token Ids : ", text_preprocessed["token_ids"].shape)
print("Token Ids : ", text_preprocessed["token_ids"][0, :16])
print(" Shape Padding Mask : ", text_preprocessed["padding_mask"].shape)
print("Padding Mask : ", text_preprocessed["padding_mask"][0, :16])
print("Shape Segment Ids : ", text_preprocessed["segment_ids"].shape)
print("Segment Ids : ", text_preprocessed["segment_ids"][0, :16])
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Text 1: The RPF Lohardaga and Hatia Post of Ranchi Division have recovered 02 bags on 20.02.2021 at Station platform and in T/No.08310 Spl. respectively and handed over to their actual owner correctly. @RPF_INDIA https://127.0.0.1/bdEBl2egIc
Text 2: The RPF Lohardaga and Hatia Post of Ranchi Division have recovered 02 bags on 20.02.2021 at Station platform and in T/No.08310 (JAT-SBP) Spl. respectively and handed over to their actual owner correctly. @RPF_INDIA https://127.0.0.1/Q5l2AtA4uq
Keys : ['token_ids', 'padding_mask', 'segment_ids']
Shape Token Ids : (2, 128)
Token Ids : [ 101 1996 1054 14376 8840 11783 16098 1998 6045 2401 2695 1997
8086 2072 2407 2031]
Shape Padding Mask : (2, 128)
Padding Mask : [ True True True True True True True True True True True True
True True True True]
Shape Segment Ids : (2, 128)
Segment Ids : [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
现在,我们将从数据帧创建tf.data.Dataset
对象。
请注意,文本输入将作为数据输入管道的一部分进行预处理。但是,预处理模块也可以是其相应 BERT 模型的一部分。这有助于减少训练/服务偏差,并使我们的模型能够使用原始文本输入运行。请遵循本教程,以了解有关如何直接将预处理模块合并到模型内部的更多信息。
def dataframe_to_dataset(dataframe):
columns = ["image_1_path", "image_2_path", "text_1", "text_2", "label_idx"]
ds = UnifiedPyDataset(
dataframe,
batch_size=32,
workers=4,
)
return ds
bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
def preprocess_text(text_1, text_2):
output = text_preprocessor([text_1, text_2])
output = {
feature: keras.ops.reshape(output[feature], [-1])
for feature in bert_input_features
}
return output
class UnifiedPyDataset(PyDataset):
"""A Keras-compatible dataset that processes a DataFrame for TensorFlow, JAX, and PyTorch."""
def __init__(
self,
df,
batch_size=32,
workers=4,
use_multiprocessing=False,
max_queue_size=10,
**kwargs,
):
"""
Args:
df: pandas DataFrame with data
batch_size: Batch size for dataset
workers: Number of workers to use for parallel loading (Keras)
use_multiprocessing: Whether to use multiprocessing
max_queue_size: Maximum size of the data queue for parallel loading
"""
super().__init__(**kwargs)
self.dataframe = df
columns = ["image_1_path", "image_2_path", "text_1", "text_2"]
# image files
self.image_x_1 = self.dataframe["image_1_path"]
self.image_x_2 = self.dataframe["image_1_path"]
self.image_y = self.dataframe["label_idx"]
# text files
self.text_x_1 = self.dataframe["text_1"]
self.text_x_2 = self.dataframe["text_2"]
self.text_y = self.dataframe["label_idx"]
# general
self.batch_size = batch_size
self.workers = workers
self.use_multiprocessing = use_multiprocessing
self.max_queue_size = max_queue_size
def __getitem__(self, index):
"""
Fetches a batch of data from the dataset at the given index.
"""
# Return x, y for batch idx.
low = index * self.batch_size
# Cap upper bound at array length; the last batch may be smaller
# if the total number of items is not a multiple of batch size.
high_image_1 = min(low + self.batch_size, len(self.image_x_1))
high_image_2 = min(low + self.batch_size, len(self.image_x_2))
high_text_1 = min(low + self.batch_size, len(self.text_x_1))
high_text_2 = min(low + self.batch_size, len(self.text_x_1))
# images files
batch_image_x_1 = self.image_x_1[low:high_image_1]
batch_image_y_1 = self.image_y[low:high_image_1]
batch_image_x_2 = self.image_x_2[low:high_image_2]
batch_image_y_2 = self.image_y[low:high_image_2]
# text files
batch_text_x_1 = self.text_x_1[low:high_text_1]
batch_text_y_1 = self.text_y[low:high_text_1]
batch_text_x_2 = self.text_x_2[low:high_text_2]
batch_text_y_2 = self.text_y[low:high_text_2]
# image number 1 inputs
image_1 = [
resize(imread(file_name), (128, 128)) for file_name in batch_image_x_1
]
image_1 = [
( # exeperienced some shapes which were different from others.
np.array(Image.fromarray((img.astype(np.uint8))).convert("RGB"))
if img.shape[2] == 4
else img
)
for img in image_1
]
image_1 = np.array(image_1)
# Both text inputs to the model, return a dict for inputs to BertBackbone
text = {
key: np.array(
[
d[key]
for d in [
preprocess_text(file_path1, file_path2)
for file_path1, file_path2 in zip(
batch_text_x_1, batch_text_x_2
)
]
]
)
for key in ["padding_mask", "token_ids", "segment_ids"]
}
# Image number 2 model inputs
image_2 = [
resize(imread(file_name), (128, 128)) for file_name in batch_image_x_2
]
image_2 = [
( # exeperienced some shapes which were different from others
np.array(Image.fromarray((img.astype(np.uint8))).convert("RGB"))
if img.shape[2] == 4
else img
)
for img in image_2
]
# Stack the list comprehension to an nd.array
image_2 = np.array(image_2)
return (
{
"image_1": image_1,
"image_2": image_2,
"padding_mask": text["padding_mask"],
"segment_ids": text["segment_ids"],
"token_ids": text["token_ids"],
},
# Target lables
np.array(batch_image_y_1),
)
def __len__(self):
"""
Returns the number of batches in the dataset.
"""
return math.ceil(len(self.dataframe) / self.batch_size)
创建训练、验证和测试数据集
def prepare_dataset(dataframe):
ds = dataframe_to_dataset(dataframe)
return ds
train_ds = prepare_dataset(train_df)
validation_ds = prepare_dataset(val_df)
test_ds = prepare_dataset(test_df)
我们的最终模型将接受两个图像及其文本对应物。虽然图像将直接馈入模型,但文本输入将首先进行预处理,然后进入模型。以下是此方法的直观说明:
该模型由以下元素组成:
提取出各个嵌入向量后,它们将被投影到相同的空间中。最后,它们的投影将被连接起来,并输入到最终的分类层。
这是一个多类别分类问题,涉及以下类别:
project_embeddings()
, create_vision_encoder()
, 和 create_text_encoder()
工具函数参考自这个示例。
投影工具函数
def project_embeddings(
embeddings, num_projection_layers, projection_dims, dropout_rate
):
projected_embeddings = keras.layers.Dense(units=projection_dims)(embeddings)
for _ in range(num_projection_layers):
x = keras.ops.nn.gelu(projected_embeddings)
x = keras.layers.Dense(projection_dims)(x)
x = keras.layers.Dropout(dropout_rate)(x)
x = keras.layers.Add()([projected_embeddings, x])
projected_embeddings = keras.layers.LayerNormalization()(x)
return projected_embeddings
视觉编码器工具函数
def create_vision_encoder(
num_projection_layers, projection_dims, dropout_rate, trainable=False
):
# Load the pre-trained ResNet50V2 model to be used as the base encoder.
resnet_v2 = keras.applications.ResNet50V2(
include_top=False, weights="imagenet", pooling="avg"
)
# Set the trainability of the base encoder.
for layer in resnet_v2.layers:
layer.trainable = trainable
# Receive the images as inputs.
image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
image_2 = keras.Input(shape=(128, 128, 3), name="image_2")
# Preprocess the input image.
preprocessed_1 = keras.applications.resnet_v2.preprocess_input(image_1)
preprocessed_2 = keras.applications.resnet_v2.preprocess_input(image_2)
# Generate the embeddings for the images using the resnet_v2 model
# concatenate them.
embeddings_1 = resnet_v2(preprocessed_1)
embeddings_2 = resnet_v2(preprocessed_2)
embeddings = keras.layers.Concatenate()([embeddings_1, embeddings_2])
# Project the embeddings produced by the model.
outputs = project_embeddings(
embeddings, num_projection_layers, projection_dims, dropout_rate
)
# Create the vision encoder model.
return keras.Model([image_1, image_2], outputs, name="vision_encoder")
文本编码器工具函数
def create_text_encoder(
num_projection_layers, projection_dims, dropout_rate, trainable=False
):
# Load the pre-trained BERT BackBone using KerasHub.
bert = keras_hub.models.BertBackbone.from_preset(
"bert_base_en_uncased", num_classes=3
)
# Set the trainability of the base encoder.
bert.trainable = trainable
# Receive the text as inputs.
bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
inputs = {
feature: keras.Input(shape=(256,), dtype="int32", name=feature)
for feature in bert_input_features
}
# Generate embeddings for the preprocessed text using the BERT model.
embeddings = bert(inputs)["pooled_output"]
# Project the embeddings produced by the model.
outputs = project_embeddings(
embeddings, num_projection_layers, projection_dims, dropout_rate
)
# Create the text encoder model.
return keras.Model(inputs, outputs, name="text_encoder")
多模态模型工具函数
def create_multimodal_model(
num_projection_layers=1,
projection_dims=256,
dropout_rate=0.1,
vision_trainable=False,
text_trainable=False,
):
# Receive the images as inputs.
image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
image_2 = keras.Input(shape=(128, 128, 3), name="image_2")
# Receive the text as inputs.
bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
text_inputs = {
feature: keras.Input(shape=(256,), dtype="int32", name=feature)
for feature in bert_input_features
}
text_inputs = list(text_inputs.values())
# Create the encoders.
vision_encoder = create_vision_encoder(
num_projection_layers, projection_dims, dropout_rate, vision_trainable
)
text_encoder = create_text_encoder(
num_projection_layers, projection_dims, dropout_rate, text_trainable
)
# Fetch the embedding projections.
vision_projections = vision_encoder([image_1, image_2])
text_projections = text_encoder(text_inputs)
# Concatenate the projections and pass through the classification layer.
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
outputs = keras.layers.Dense(3, activation="softmax")(concatenated)
return keras.Model([image_1, image_2, *text_inputs], outputs)
multimodal_model = create_multimodal_model()
keras.utils.plot_model(multimodal_model, show_shapes=True)
您还可以通过将 plot_model()
的 expand_nested
参数设置为 True
来检查各个编码器的结构。我们鼓励您尝试构建此模型所涉及的不同超参数,并观察最终性能如何受到影响。
multimodal_model.compile(
optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
history = multimodal_model.fit(train_ds, validation_data=validation_ds, epochs=1)
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
1/27 [37m━━━━━━━━━━━━━━━━━━━━ 45:45 106秒/步 - 准确率:0.0625 - 损失:1.6335
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
2/27 ━[37m━━━━━━━━━━━━━━━━━━━ 42:14 101秒/步 - 准确率:0.2422 - 损失:1.9508
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
3/27 ━━[37m━━━━━━━━━━━━━━━━━━ 38:49 97秒/步 - 准确率:0.3524 - 损失:2.0126
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
4/27 ━━[37m━━━━━━━━━━━━━━━━━━ 37:09 97秒/步 - 准确率:0.4284 - 损失:1.9870
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
5/27 ━━━[37m━━━━━━━━━━━━━━━━━ 35:08 96秒/步 - 准确率:0.4815 - 损失:1.9855
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
6/27 ━━━━[37m━━━━━━━━━━━━━━━━ 31:56 91秒/步 - 准确率:0.5210 - 损失:1.9939
7/27 ━━━━━[37m━━━━━━━━━━━━━━━ 29:30 89秒/步 - 准确率:0.5512 - 损失:1.9980
8/27 ━━━━━[37m━━━━━━━━━━━━━━━ 27:12 86秒/步 - 准确率:0.5750 - 损失:2.0061
9/27 ━━━━━━[37m━━━━━━━━━━━━━━ 25:15 84秒/步 - 准确率:0.5956 - 损失:1.9959
10/27 ━━━━━━━[37m━━━━━━━━━━━━━ 23:33 83秒/步 - 准确率:0.6120 - 损失:1.9738
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
11/27 ━━━━━━━━[37m━━━━━━━━━━━━ 22:09 83秒/步 - 准确率:0.6251 - 损失:1.9579
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
12/27 ━━━━━━━━[37m━━━━━━━━━━━━ 20:59 84秒/步 - 准确率:0.6357 - 损失:1.9524
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
13/27 ━━━━━━━━━[37m━━━━━━━━━━━ 19:44 85秒/步 - 准确率:0.6454 - 损失:1.9439
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
14/27 ━━━━━━━━━━[37m━━━━━━━━━━ 18:22 85秒/步 - 准确率:0.6540 - 损失:1.9346
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(23, 256))', 'Tensor(shape=(23, 256))', 'Tensor(shape=(23, 256))']
warnings.warn(msg)
15/27 ━━━━━━━━━━━[37m━━━━━━━━━ 16:52 84秒/步 - 准确率:0.6621 - 损失:1.9213
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
16/27 ━━━━━━━━━━━[37m━━━━━━━━━ 15:29 85秒/步 - 准确率:0.6693 - 损失:1.9101
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
17/27 ━━━━━━━━━━━━[37m━━━━━━━━ 14:08 85秒/步 - 准确率:0.6758 - 损失:1.9021
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
18/27 ━━━━━━━━━━━━━[37m━━━━━━━ 12:45 85秒/步 - 准确率:0.6819 - 损失:1.8916
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
19/27 ━━━━━━━━━━━━━━[37m━━━━━━ 11:24 86秒/步 - 准确率:0.6874 - 损失:1.8851
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
20/27 ━━━━━━━━━━━━━━[37m━━━━━━ 10:00 86秒/步 - 准确率:0.6925 - 损失:1.8791
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
21/27 ━━━━━━━━━━━━━━━[37m━━━━━ 8:36 86秒/步 - 准确率:0.6976 - 损失:1.8699
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
22/27 ━━━━━━━━━━━━━━━━[37m━━━━ 7:11 86秒/步 - 准确率:0.7020 - 损失:1.8623
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
23/27 ━━━━━━━━━━━━━━━━━[37m━━━ 5:46 87秒/步 - 准确率:0.7061 - 损失:1.8573
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
24/27 ━━━━━━━━━━━━━━━━━[37m━━━ 4:20 87秒/步 - 准确率:0.7100 - 损失:1.8534
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
25/27 ━━━━━━━━━━━━━━━━━━[37m━━ 2:54 87秒/步 - 准确率:0.7136 - 损失:1.8494
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
26/27 ━━━━━━━━━━━━━━━━━━━[37m━ 1:27 87秒/步 - 准确率:0.7170 - 损失:1.8449
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
27/27 ━━━━━━━━━━━━━━━━━━━━ 0秒 88秒/步 - 准确率:0.7201 - 损失:1.8414
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/PIL/Image.py:1054: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
warnings.warn(
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/PIL/Image.py:1054: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
warnings.warn(
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(13, 256))', 'Tensor(shape=(13, 256))', 'Tensor(shape=(13, 256))']
warnings.warn(msg)
27/27 ━━━━━━━━━━━━━━━━━━━━ 2508秒 92秒/步 - 准确率:0.7231 - 损失:1.8382 - 验证准确率:0.8222 - 验证损失:1.7304
_, acc = multimodal_model.evaluate(test_ds)
print(f"Accuracy on the test set: {round(acc * 100, 2)}%.")
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/PIL/Image.py:1054: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
warnings.warn(
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/PIL/Image.py:1054: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
warnings.warn(
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))', 'Tensor(shape=(32, 256))']
warnings.warn(msg)
1/4 ━━━━━[37m━━━━━━━━━━━━━━━ 5:32 111秒/步 - 准确率:0.7812 - 损失:1.9384
2/4 ━━━━━━━━━━[37m━━━━━━━━━━ 2:10 65秒/步 - 准确率:0.7969 - 损失:1.8931
3/4 ━━━━━━━━━━━━━━━[37m━━━━━ 1:05 65秒/步 - 准确率:0.8056 - 损失:1.8200
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:248: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: {'padding_mask': 'padding_mask', 'segment_ids': 'segment_ids', 'token_ids': 'token_ids'}
Received: inputs=['Tensor(shape=(4, 256))', 'Tensor(shape=(4, 256))', 'Tensor(shape=(4, 256))']
warnings.warn(msg)
4/4 ━━━━━━━━━━━━━━━━━━━━ 0秒 49秒/步 - 准确率:0.8092 - 损失:1.8075
4/4 ━━━━━━━━━━━━━━━━━━━━ 256秒 49秒/步 - 准确率:0.8113 - 损失:1.8000
Accuracy on the test set: 82.0%.
引入正则化:
训练日志表明模型开始过拟合,并且可能从正则化中受益。Dropout (Srivastava et al.) 是一种简单但功能强大的正则化技术,我们可以在模型中使用它。但是我们应该如何在这里应用它呢?
我们总是可以在模型的不同层之间引入 Dropout (keras.layers.Dropout
)。但是这里有另一个方法。我们的模型期望来自两种不同数据模式的输入。如果在推理过程中任何一种模态不存在怎么办?为了解决这个问题,我们可以在各个投影连接之前引入 Dropout
vision_projections = keras.layers.Dropout(rate)(vision_projections)
text_projections = keras.layers.Dropout(rate)(text_projections)
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
关注重要内容:
图像的所有部分都与它们的文本对应部分相同吗?可能不是这种情况。为了使我们的模型只关注与它们相应的文本部分密切相关的图像的最重要部分,我们可以使用“交叉注意力”
# Embeddings.
vision_projections = vision_encoder([image_1, image_2])
text_projections = text_encoder(text_inputs)
# Cross-attention (Luong-style).
query_value_attention_seq = keras.layers.Attention(use_scale=True, dropout=0.2)(
[vision_projections, text_projections]
)
# Concatenate.
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
contextual = keras.layers.Concatenate()([concatenated, query_value_attention_seq])
要查看实际效果,请参阅这个笔记本。
处理类别不平衡:
数据集存在类别不平衡问题。调查上述模型的混淆矩阵可以发现,它在少数类别上的表现较差。如果我们使用加权损失,那么训练将会更有针对性。您可以查看这个笔记本,它在模型训练期间考虑了类别不平衡问题。
仅使用文本输入:
此外,如果我们只在蕴含任务中使用了文本输入会怎么样?由于社交媒体平台上遇到的文本输入的性质,仅使用文本输入会损害最终性能。在类似的训练设置下,仅使用文本输入,我们在相同的测试集上获得的 top-1 准确率达到 67.14%。 有关详细信息,请参阅这个笔记本。
最后,这里有一个表格,比较了针对蕴含任务采取的不同方法
类型 | 标准 交叉熵 |
损失加权 交叉熵 |
焦点损失 |
---|---|---|---|
多模态 | 77.86% | 67.86% | 86.43% |
仅文本 | 67.14% | 11.43% | 37.86% |
您可以查看这个存储库,以了解更多关于如何进行实验以获得这些数字的信息。
您可以使用托管在Hugging Face Hub上的已训练模型,并在Hugging Face Spaces上尝试演示