代码示例 / 自然语言处理 / 大规模多标签文本分类

大规模多标签文本分类

作者: Sayak Paul, Soumik Rakshit
创建日期 2020/09/25
最后修改 2025/02/27
描述: 实现大规模多标签文本分类模型。

ⓘ 此示例使用 Keras 2

在 Colab 中查看 GitHub 源码


简介

在此示例中,我们将构建一个多标签文本分类器,以根据 arXiv 论文的摘要正文预测其主题领域。这种类型的分类器对于像 OpenReview 这样的会议提交门户网站非常有用。给定一篇论文摘要,门户网站可以提供关于该论文最适合哪个领域的建议。

该数据集是使用 arXiv Python 库收集的,该库提供了 原始 arXiv API 的封装器。要了解有关数据收集过程的更多信息,请参阅 此笔记本。此外,您还可以在 Kaggle 上找到该数据集。


导入

import os

os.environ["KERAS_BACKEND"] = "jax"  # or tensorflow, or torch

import keras
from keras import layers, ops

from sklearn.model_selection import train_test_split

from ast import literal_eval
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

执行探索性数据分析

在本节中,我们首先将数据集加载到 pandas 数据帧中,然后执行一些基本的探索性数据分析 (EDA)。

arxiv_data = pd.read_csv(
    "https://github.com/soumik12345/multi-label-text-classification/releases/download/v0.2/arxiv_data.csv"
)
arxiv_data.head()
标题 摘要 术语
0 语义立体匹配/语义... 调查 立体匹配是一种广泛使用的技术... ['cs.CV', 'cs.LG']
1 FUTURE-AI:指导原则和共识... 近年来人工智能的进步... ['cs.CV', 'cs.AI', 'cs.LG']
2 强制执行硬区域... 的互一致性 在本文中,我们提出了一种新颖的互... ['cs.CV', 'cs.AI']
3 半监督... 的参数解耦策略 一致性训练已被证明是一种先进的... ['cs.CV']
4 室内... 的背景-前景分割 为了确保自动驾驶的安全性,... ['cs.CV', 'cs.LG']

我们的文本特征存在于 summaries 列中,其对应的标签在 terms 中。正如您所注意到的,一个特定的条目与多个类别相关联。

print(f"There are {len(arxiv_data)} rows in the dataset.")
There are 2000 rows in the dataset.

真实世界的数据是嘈杂的。最常见的噪声来源之一是数据重复。在这里我们注意到,我们最初的数据集大约有 13k 个重复条目。

total_duplicate_titles = sum(arxiv_data["titles"].duplicated())
print(f"There are {total_duplicate_titles} duplicate titles.")
There are 9 duplicate titles.

在进一步进行之前,我们删除这些条目。

arxiv_data = arxiv_data[~arxiv_data["titles"].duplicated()]
print(f"There are {len(arxiv_data)} rows in the deduplicated dataset.")

# There are some terms with occurrence as low as 1.
print(sum(arxiv_data["terms"].value_counts() == 1))

# How many unique terms?
print(arxiv_data["terms"].nunique())
There are 1991 rows in the deduplicated dataset.
208
275

如上所述,在 3,157 种独特的 terms 组合中,有 2,321 个条目的出现次数最少。为了准备具有分层的训练集、验证集和测试集,我们需要删除这些术语。

# Filtering the rare terms.
arxiv_data_filtered = arxiv_data.groupby("terms").filter(lambda x: len(x) > 1)
arxiv_data_filtered.shape
(1783, 3)

将字符串标签转换为字符串列表

初始标签表示为原始字符串。在这里,我们将它们设为 List[str] 以获得更紧凑的表示。

arxiv_data_filtered["terms"] = arxiv_data_filtered["terms"].apply(
    lambda x: literal_eval(x)
)
arxiv_data_filtered["terms"].values[:5]
array([list(['cs.CV', 'cs.LG']), list(['cs.CV', 'cs.AI', 'cs.LG']),
       list(['cs.CV', 'cs.AI']), list(['cs.CV']),
       list(['cs.CV', 'cs.LG'])], dtype=object)

由于类不平衡而使用分层分割

该数据集存在类不平衡问题。因此,为了获得公平的评估结果,我们需要确保数据集采用分层抽样。要了解有关处理类不平衡问题的不同策略的更多信息,您可以关注本教程。有关使用不平衡数据进行分类的端到端演示,请参阅:不平衡分类:信用卡欺诈检测

test_split = 0.1

# Initial train and test split.
train_df, test_df = train_test_split(
    arxiv_data_filtered,
    test_size=test_split,
    stratify=arxiv_data_filtered["terms"].values,
)

# Splitting the test set further into validation
# and new test sets.
val_df = test_df.sample(frac=0.5)
test_df.drop(val_df.index, inplace=True)

print(f"Number of rows in training set: {len(train_df)}")
print(f"Number of rows in validation set: {len(val_df)}")
print(f"Number of rows in test set: {len(test_df)}")
Number of rows in training set: 1604
Number of rows in validation set: 90
Number of rows in test set: 89

多标签二值化

现在我们使用 StringLookup 层预处理我们的标签。

# For RaggedTensor
import tensorflow as tf

terms = tf.ragged.constant(train_df["terms"].values)
lookup = layers.StringLookup(output_mode="multi_hot")
lookup.adapt(terms)
vocab = lookup.get_vocabulary()


def invert_multi_hot(encoded_labels):
    """Reverse a single multi-hot encoded label to a tuple of vocab terms."""
    hot_indices = np.argwhere(encoded_labels == 1.0)[..., 0]
    return np.take(vocab, hot_indices)


print("Vocabulary:\n")
print(vocab)
Vocabulary:
['[UNK]', 'cs.CV', 'cs.LG', 'cs.AI', 'stat.ML', 'eess.IV', 'cs.NE', 'cs.RO', 'cs.CL', 'cs.SI', 'cs.MM', 'math.NA', 'cs.CG', 'cs.CR', 'I.4.6', 'math.OC', 'cs.GR', 'cs.NA', 'cs.HC', 'cs.DS', '68U10', 'stat.ME', 'q-bio.NC', 'math.AP', 'eess.SP', 'cs.DM', '62H30']

在这里,我们分离了标签池中可用的各个唯一类,然后使用此信息用 0 和 1 表示给定的标签集。以下是一个示例。

sample_label = train_df["terms"].iloc[0]
print(f"Original label: {sample_label}")

label_binarized = lookup([sample_label])
print(f"Label-binarized representation: {label_binarized}")
Original label: ['cs.CV']

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Label-binarized representation: [[0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]

数据预处理和 tf.data.Dataset 对象

我们首先获得序列长度的百分位数估计。目的稍后会明确。

train_df["summaries"].apply(lambda x: len(x.split(" "))).describe()
count    1604.000000
mean      158.151496
std        41.543130
min        25.000000
25%       130.000000
50%       156.000000
75%       184.250000
max       283.000000
Name: summaries, dtype: float64

请注意,50% 的摘要长度为 154(您可能会根据拆分获得不同的数字)。因此,任何接近该值的数字都足以作为最大序列长度的近似值。

现在,我们实现实用程序来准备我们的数据集。

max_seqlen = 150
batch_size = 128
padding_token = "<pad>"
auto = tf.data.AUTOTUNE


def make_dataset(dataframe, is_train=True):
    labels = tf.ragged.constant(dataframe["terms"].values)
    label_binarized = lookup(labels).numpy()
    dataset = tf.data.Dataset.from_tensor_slices(
        (dataframe["summaries"].values, label_binarized)
    )
    dataset = dataset.shuffle(batch_size * 10) if is_train else dataset
    return dataset.batch(batch_size)

现在我们可以准备 tf.data.Dataset 对象。

train_dataset = make_dataset(train_df, is_train=True)
validation_dataset = make_dataset(val_df, is_train=False)
test_dataset = make_dataset(test_df, is_train=False)

数据集预览

text_batch, label_batch = next(iter(train_dataset))

for i, text in enumerate(text_batch[:5]):
    label = label_batch[i].numpy()[None, ...]
    print(f"Abstract: {text}")
    print(f"Label(s): {invert_multi_hot(label[0])}")
    print(" ")
Abstract: b"For the Domain Generalization (DG) problem where the hypotheses are composed\nof a common representation function followed by a labeling function, we point\nout a shortcoming in existing approaches that fail to explicitly optimize for a\nterm, appearing in a well-known and widely adopted upper bound to the risk on\nthe unseen domain, that is dependent on the representation to be learned. To\nthis end, we first derive a novel upper bound to the prediction risk. We show\nthat imposing a mild assumption on the representation to be learned, namely\nmanifold restricted invertibility, is sufficient to deal with this issue.\nFurther, unlike existing approaches, our novel upper bound doesn't require the\nassumption of Lipschitzness of the loss function. In addition, the\ndistributional discrepancy in the representation space is handled via the\nWasserstein-2 barycenter cost. In this context, we creatively leverage old and\nrecent transport inequalities, which link various optimal transport metrics, in\nparticular the $L^1$ distance (also known as the total variation distance) and\nthe Wasserstein-2 distances, with the Kullback-Liebler divergence. These\nanalyses and insights motivate a new representation learning cost for DG that\nadditively balances three competing objectives: 1) minimizing classification\nerror across seen domains via cross-entropy, 2) enforcing domain-invariance in\nthe representation space via the Wasserstein-2 barycenter cost, and 3)\npromoting non-degenerate, nearly-invertible representation via one of two\nmechanisms, viz., an autoencoder-based reconstruction loss or a mutual\ninformation loss. It is to be noted that the proposed algorithms completely\nbypass the use of any adversarial training mechanism that is typical of many\ncurrent domain generalization approaches. Simulation results on several\nstandard datasets demonstrate superior performance compared to several\nwell-known DG algorithms."
Label(s): ['cs.LG' 'stat.ML']

Abstract: b'Image segmentation of touching objects plays a key role in providing accurate\nclassification for computer vision technologies. A new line profile based\nimaging segmentation algorithm has been developed to provide a robust and\naccurate segmentation of a group of touching corns. The performance of the line\nprofile based algorithm has been compared to a watershed based imaging\nsegmentation algorithm. Both algorithms are tested on three different patterns\nof images, which are isolated corns, single-lines, and random distributed\nformations. The experimental results show that the algorithm can segment a\nlarge number of touching corn kernels efficiently and accurately.'
Label(s): ['cs.CV']

Abstract: b'Semantic image segmentation is a principal problem in computer vision, where\nthe aim is to correctly classify each individual pixel of an image into a\nsemantic label. Its widespread use in many areas, including medical imaging and\nautonomous driving, has fostered extensive research in recent years. Empirical\nimprovements in tackling this task have primarily been motivated by successful\nexploitation of Convolutional Neural Networks (CNNs) pre-trained for image\nclassification and object recognition. However, the pixel-wise labelling with\nCNNs has its own unique challenges: (1) an accurate deconvolution, or\nupsampling, of low-resolution output into a higher-resolution segmentation mask\nand (2) an inclusion of global information, or context, within locally\nextracted features. To address these issues, we propose a novel architecture to\nconduct the equivalent of the deconvolution operation globally and acquire\ndense predictions. We demonstrate that it leads to improved performance of\nstate-of-the-art semantic segmentation models on the PASCAL VOC 2012 benchmark,\nreaching 74.0% mean IU accuracy on the test set.'
Label(s): ['cs.CV']

Abstract: b'Modern deep learning models have revolutionized the field of computer vision.\nBut, a significant drawback of most of these models is that they require a\nlarge number of labelled examples to generalize properly. Recent developments\nin few-shot learning aim to alleviate this requirement. In this paper, we\npropose a novel lightweight CNN architecture for 1-shot image segmentation. The\nproposed model is created by taking inspiration from well-performing\narchitectures for semantic segmentation and adapting it to the 1-shot domain.\nWe train our model using 4 meta-learning algorithms that have worked well for\nimage classification and compare the results. For the chosen dataset, our\nproposed model has a 70% lower parameter count than the benchmark, while having\nbetter or comparable mean IoU scores using all 4 of the meta-learning\nalgorithms.'
Label(s): ['cs.CV' 'cs.LG' 'eess.IV']

Abstract: b'In this work, we propose CARLS, a novel framework for augmenting the capacity\nof existing deep learning frameworks by enabling multiple components -- model\ntrainers, knowledge makers and knowledge banks -- to concertedly work together\nin an asynchronous fashion across hardware platforms. The proposed CARLS is\nparticularly suitable for learning paradigms where model training benefits from\nadditional knowledge inferred or discovered during training, such as node\nembeddings for graph neural networks or reliable pseudo labels from model\npredictions. We also describe three learning paradigms -- semi-supervised\nlearning, curriculum learning and multimodal learning -- as examples that can\nbe scaled up efficiently by CARLS. One version of CARLS has been open-sourced\nand available for download at:\nhttps://github.com/tensorflow/neural-structured-learning/tree/master/research/carls'
Label(s): ['cs.LG']

向量化

在我们将数据馈送到我们的模型之前,我们需要对其进行向量化(以数字形式表示)。为此,我们将使用 TextVectorization。它可以作为您的主模型的一部分运行,以便将模型排除在核心预处理逻辑之外。这大大降低了推理期间训练/服务偏差的可能性。

我们首先计算摘要中存在的唯一单词的数量。

# Source: https://stackoverflow.com/a/18937309/7636462
vocabulary = set()
train_df["summaries"].str.lower().str.split().apply(vocabulary.update)
vocabulary_size = len(vocabulary)
print(vocabulary_size)
20498

现在我们创建我们的向量化层并 map() 到之前创建的 tf.data.Dataset

text_vectorizer = layers.TextVectorization(
    max_tokens=vocabulary_size, ngrams=2, output_mode="tf_idf"
)

# `TextVectorization` layer needs to be adapted as per the vocabulary from our
# training set.
with tf.device("/CPU:0"):
    text_vectorizer.adapt(train_dataset.map(lambda text, label: text))

train_dataset = train_dataset.map(
    lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto
).prefetch(auto)
validation_dataset = validation_dataset.map(
    lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto
).prefetch(auto)
test_dataset = test_dataset.map(
    lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto
).prefetch(auto)

一批原始文本将首先通过 TextVectorization 层,它将生成它们的整数表示。在内部,TextVectorization 层将首先从序列中创建双字母组,然后使用 TF-IDF 表示它们。然后,输出表示将传递给负责文本分类的浅层模型。

要了解有关 TextVectorizer 的其他可能配置的更多信息,请查阅官方文档

注意:将 max_tokens 参数设置为预先计算的词汇量大小不是必需的。


创建文本分类模型

我们将保持我们的模型简单——它将是一个小的全连接层堆栈,以 ReLU 作为非线性激活函数。

def make_model():
    shallow_mlp_model = keras.Sequential(
        [
            layers.Dense(512, activation="relu"),
            layers.Dense(256, activation="relu"),
            layers.Dense(lookup.vocabulary_size(), activation="sigmoid"),
        ]  # More on why "sigmoid" has been used here in a moment.
    )
    return shallow_mlp_model

训练模型

我们将使用二元交叉熵损失训练我们的模型。这是因为标签不是不相交的。对于给定的摘要,我们可能有多个类别。因此,我们将预测任务分为一系列多个二元分类问题。这也是为什么我们将模型中分类层的激活函数保持为 sigmoid 的原因。研究人员也使用了损失函数和激活函数的其他组合。例如,在《探索弱监督预训练的局限性》中,Mahajan 等人使用了 softmax 激活函数和交叉熵损失来训练他们的模型。

在多标签分类中可以使用多种指标。为了使此代码示例范围缩小,我们决定使用二元准确率指标。要查看为什么使用此指标的解释,我们参考此拉取请求。还有其他适用于多标签分类的指标,如 F1 分数Hamming 损失

epochs = 20

shallow_mlp_model = make_model()
shallow_mlp_model.compile(
    loss="binary_crossentropy", optimizer="adam", metrics=["binary_accuracy"]
)

history = shallow_mlp_model.fit(
    train_dataset, validation_data=validation_dataset, epochs=epochs
)


def plot_result(item):
    plt.plot(history.history[item], label=item)
    plt.plot(history.history["val_" + item], label="val_" + item)
    plt.xlabel("Epochs")
    plt.ylabel(item)
    plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
    plt.legend()
    plt.grid()
    plt.show()


plot_result("loss")
plot_result("binary_accuracy")

1/13 ━ [37m━━━━━━━━━━━━━━━━━━━ 26s 2s/step - binary_accuracy: 0.4491 - loss: 1.4007



2/13 ━━━ [37m━━━━━━━━━━━━━━━━━ 3s 307ms/step - binary_accuracy: 0.5609 - loss: 1.1359



3/13 ━━━━ [37m━━━━━━━━━━━━━━━━ 2s 290ms/step - binary_accuracy: 0.6315 - loss: 0.9654



4/13 ━━━━━━ [37m━━━━━━━━━━━━━━ 2s 286ms/step - binary_accuracy: 0.6785 - loss: 0.8508



5/13 ━━━━━━━ [37m━━━━━━━━━━━━━ 2s 282ms/step - binary_accuracy: 0.7128 - loss: 0.7661



6/13 ━━━━━━━━━ [37m━━━━━━━━━━━ 1s 283ms/step - binary_accuracy: 0.7391 - loss: 0.7006



7/13 ━━━━━━━━━━ [37m━━━━━━━━━━ 1s 277ms/step - binary_accuracy: 0.7600 - loss: 0.6485



8/13 ━━━━━━━━━━━━ [37m━━━━━━━━ 1s 275ms/step - binary_accuracy: 0.7770 - loss: 0.6054



9/13 ━━━━━━━━━━━━━ [37m━━━━━━━ 1s 272ms/step - binary_accuracy: 0.7913 - loss: 0.5693



10/13 ━━━━━━━━━━━━━━━ [37m━━━━━ 0s 270ms/step - binary_accuracy: 0.8033 - loss: 0.5389



11/13 ━━━━━━━━━━━━━━━━ [37m━━━━ 0s 272ms/step - binary_accuracy: 0.8136 - loss: 0.5127



12/13 ━━━━━━━━━━━━━━━━━━ [37m━━ 0s 273ms/step - binary_accuracy: 0.8225 - loss: 0.4899



13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 363ms/step - binary_accuracy: 0.8303 - loss: 0.4702



13/13 ━━━━━━━━━━━━━━━━━━━━ 7s 402ms/step - binary_accuracy: 0.8369 - loss: 0.4532 - val_binary_accuracy: 0.9782 - val_loss: 0.0867

png

png

在训练过程中,我们注意到损失最初急剧下降,随后逐渐衰减。

评估模型

_, binary_acc = shallow_mlp_model.evaluate(test_dataset)
print(f"Categorical accuracy on the test set: {round(binary_acc * 100, 2)}%.")

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 483ms/step - binary_accuracy: 0.9734 - loss: 0.0927



1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 486ms/step - binary_accuracy: 0.9734 - loss: 0.0927

Categorical accuracy on the test set: 97.34%.

训练后的模型为我们提供了约 99% 的评估准确率。


推理

Keras 提供的预处理层的一个重要特性是它们可以包含在 tf.keras.Model 中。我们将通过在 shallow_mlp_model 之上包含 text_vectorization 层来导出推理模型。这将使我们的推理模型能够直接对原始字符串进行操作。

请注意,在训练期间,始终最好将这些预处理层用作数据输入管道的一部分,而不是模型的一部分,以避免硬件加速器的瓶颈。这也允许异步数据处理。

# We create a custom Model to override the predict method so
# that it first vectorizes text data
class ModelEndtoEnd(keras.Model):

    def predict(self, inputs):
        indices = text_vectorizer(inputs)
        return super().predict(indices)


def get_inference_model(model):
    inputs = shallow_mlp_model.inputs
    outputs = shallow_mlp_model.outputs
    end_to_end_model = ModelEndtoEnd(inputs, outputs, name="end_to_end_model")
    end_to_end_model.compile(
        optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]
    )
    return end_to_end_model


model_for_inference = get_inference_model(shallow_mlp_model)

# Create a small dataset just for demonstrating inference.
inference_dataset = make_dataset(test_df.sample(2), is_train=False)
text_batch, label_batch = next(iter(inference_dataset))
predicted_probabilities = model_for_inference.predict(text_batch)


# Perform inference.
for i, text in enumerate(text_batch[:5]):
    label = label_batch[i].numpy()[None, ...]
    print(f"Abstract: {text}")
    print(f"Label(s): {invert_multi_hot(label[0])}")
    predicted_proba = [proba for proba in predicted_probabilities[i]]
    top_3_labels = [
        x
        for _, x in sorted(
            zip(predicted_probabilities[i], lookup.get_vocabulary()),
            key=lambda pair: pair[0],
            reverse=True,
        )
    ][:3]
    print(f"Predicted Label(s): ({', '.join([label for label in top_3_labels])})")
    print(" ")

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 141ms/step



1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 142ms/step

Abstract: b'High-resolution image segmentation remains challenging and error-prone due to\nthe enormous size of intermediate feature maps. Conventional methods avoid this\nproblem by using patch based approaches where each patch is segmented\nindependently. However, independent patch segmentation induces errors,\nparticularly at the patch boundary due to the lack of contextual information in\nvery high-resolution images where the patch size is much smaller compared to\nthe full image. To overcome these limitations, in this paper, we propose a\nnovel framework to segment a particular patch by incorporating contextual\ninformation from its neighboring patches. This allows the segmentation network\nto see the target patch with a wider field of view without the need of larger\nfeature maps. Comparative analysis from a number of experiments shows that our\nproposed framework is able to segment high resolution images with significantly\nimproved mean Intersection over Union and overall accuracy.'
Label(s): ['cs.CV']
Predicted Label(s): (cs.CV, eess.IV, cs.LG)

Abstract: b"Convolutional neural networks for visual recognition require large amounts of\ntraining samples and usually benefit from data augmentation. This paper\nproposes PatchMix, a data augmentation method that creates new samples by\ncomposing patches from pairs of images in a grid-like pattern. These new\nsamples' ground truth labels are set as proportional to the number of patches\nfrom each image. We then add a set of additional losses at the patch-level to\nregularize and to encourage good representations at both the patch and image\nlevels. A ResNet-50 model trained on ImageNet using PatchMix exhibits superior\ntransfer learning capabilities across a wide array of benchmarks. Although\nPatchMix can rely on random pairings and random grid-like patterns for mixing,\nwe explore evolutionary search as a guiding strategy to discover optimal\ngrid-like patterns and image pairing jointly. For this purpose, we conceive a\nfitness function that bypasses the need to re-train a model to evaluate each\nchoice. In this way, PatchMix outperforms a base model on CIFAR-10 (+1.91),\nCIFAR-100 (+5.31), Tiny Imagenet (+3.52), and ImageNet (+1.16) by significant\nmargins, also outperforming previous state-of-the-art pairwise augmentation\nstrategies."
Label(s): ['cs.CV' 'cs.LG' 'cs.NE']
Predicted Label(s): (cs.CV, cs.LG, stat.ML)


/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:252: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: ['keras_tensor_2']
Received: inputs=Tensor(shape=(2, 20498))
  warnings.warn(msg)

预测结果不是那么好,但对于像我们这样的简单模型来说也不低于平均水平。我们可以通过考虑词序的模型(如 LSTM)甚至使用 Transformer 的模型来提高此性能(Vaswani 等人)。


致谢

我们要感谢 Matt Watson 帮助我们解决了多标签二值化部分以及将处理后的标签反向转换为原始形式的问题。

感谢 Cingis Kratochvil 建议并通过引入二元准确率作为评估指标来扩展此代码示例。