代码示例 / 自然语言处理 / 大规模多标签文本分类

大规模多标签文本分类

作者: Sayak PaulSoumik Rakshit
创建日期 2020/09/25
上次修改 2020/12/23
描述: 实现一个大规模的多标签文本分类模型。

ⓘ 此示例使用 Keras 2

在 Colab 中查看 GitHub 源代码


简介

在此示例中,我们将构建一个多标签文本分类器,以根据其摘要内容预测 arXiv 论文的主题领域。这种类型的分类器对于像 OpenReview 这样的会议投稿门户网站很有用。给定一篇论文的摘要,该门户网站可以提供建议,说明这篇论文最适合属于哪些领域。

数据集是使用 arXiv Python 库 收集的,该库提供了围绕 原始 arXiv API 的包装器。要了解有关数据收集过程的更多信息,请参阅 此笔记本。此外,您还可以在 Kaggle 上找到数据集。


导入

from tensorflow.keras import layers
from tensorflow import keras
import tensorflow as tf

from sklearn.model_selection import train_test_split
from ast import literal_eval

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

执行探索性数据分析

在本节中,我们将首先将数据集加载到 pandas 数据框中,然后执行一些基本的探索性数据分析 (EDA)。

arxiv_data = pd.read_csv(
    "https://github.com/soumik12345/multi-label-text-classification/releases/download/v0.2/arxiv_data.csv"
)
arxiv_data.head()
标题 摘要 术语
0 关于语义立体匹配的调查/语义... 立体匹配是一种广泛使用的技术... ['cs.CV', 'cs.LG']
1 FUTURE-AI:指导原则和共识重... 人工智能的最新进展... ['cs.CV', 'cs.AI', 'cs.LG']
2 强制硬区域的相互一致性... 在本文中,我们提出了一种新颖的相互一致性方法... ['cs.CV', 'cs.AI']
3 半监督学习的参数解耦策略... 一致性训练已被证明是一种优势... ['cs.CV']
4 室内场景的背景前景分割... 为了确保自动驾驶的安全性,必须... ['cs.CV', 'cs.LG']

我们的文本特征存在于 summaries 列中,它们对应的标签在 terms 中。正如您所看到的,每个条目都与多个类别相关联。

print(f"There are {len(arxiv_data)} rows in the dataset.")
There are 51774 rows in the dataset.

现实世界的数据是有噪声的。最常见的一种噪声来源是数据重复。在这里,我们注意到我们的初始数据集大约有 13k 个重复条目。

total_duplicate_titles = sum(arxiv_data["titles"].duplicated())
print(f"There are {total_duplicate_titles} duplicate titles.")
There are 12802 duplicate titles.

在继续之前,我们将删除这些条目。

arxiv_data = arxiv_data[~arxiv_data["titles"].duplicated()]
print(f"There are {len(arxiv_data)} rows in the deduplicated dataset.")

# There are some terms with occurrence as low as 1.
print(sum(arxiv_data["terms"].value_counts() == 1))

# How many unique terms?
print(arxiv_data["terms"].nunique())
There are 38972 rows in the deduplicated dataset.
2321
3157

如上所述,在 3,157 个 terms 的唯一组合中,有 2,321 个条目的出现次数最少。为了准备我们带有 分层 的训练、验证和测试集,我们需要删除这些术语。

# Filtering the rare terms.
arxiv_data_filtered = arxiv_data.groupby("terms").filter(lambda x: len(x) > 1)
arxiv_data_filtered.shape
(36651, 3)

将字符串标签转换为字符串列表

初始标签表示为原始字符串。在这里,我们将其转换为 List[str] 以实现更紧凑的表示。

arxiv_data_filtered["terms"] = arxiv_data_filtered["terms"].apply(
    lambda x: literal_eval(x)
)
arxiv_data_filtered["terms"].values[:5]
array([list(['cs.CV', 'cs.LG']), list(['cs.CV', 'cs.AI', 'cs.LG']),
       list(['cs.CV', 'cs.AI']), list(['cs.CV']),
       list(['cs.CV', 'cs.LG'])], dtype=object)

由于类别不平衡,使用分层分割

数据集存在 类别不平衡问题。因此,为了获得公平的评估结果,我们需要确保数据集使用分层进行采样。要了解有关处理类别不平衡问题的不同策略的更多信息,您可以阅读 此教程。有关使用不平衡数据进行分类的端到端演示,请参阅 不平衡分类:信用卡欺诈检测

test_split = 0.1

# Initial train and test split.
train_df, test_df = train_test_split(
    arxiv_data_filtered,
    test_size=test_split,
    stratify=arxiv_data_filtered["terms"].values,
)

# Splitting the test set further into validation
# and new test sets.
val_df = test_df.sample(frac=0.5)
test_df.drop(val_df.index, inplace=True)

print(f"Number of rows in training set: {len(train_df)}")
print(f"Number of rows in validation set: {len(val_df)}")
print(f"Number of rows in test set: {len(test_df)}")
Number of rows in training set: 32985
Number of rows in validation set: 1833
Number of rows in test set: 1833

多标签二值化

现在,我们使用 StringLookup 层对我们的标签进行预处理。

terms = tf.ragged.constant(train_df["terms"].values)
lookup = tf.keras.layers.StringLookup(output_mode="multi_hot")
lookup.adapt(terms)
vocab = lookup.get_vocabulary()


def invert_multi_hot(encoded_labels):
    """Reverse a single multi-hot encoded label to a tuple of vocab terms."""
    hot_indices = np.argwhere(encoded_labels == 1.0)[..., 0]
    return np.take(vocab, hot_indices)


print("Vocabulary:\n")
print(vocab)
Vocabulary:
['[UNK]', 'cs.CV', 'cs.LG', 'stat.ML', 'cs.AI', 'eess.IV', 'cs.RO', 'cs.CL', 'cs.NE', 'cs.CR', 'math.OC', 'eess.SP', 'cs.GR', 'cs.SI', 'cs.MM', 'cs.SY', 'cs.IR', 'cs.MA', 'eess.SY', 'cs.HC', 'math.IT', 'cs.IT', 'cs.DC', 'cs.CY', 'stat.AP', 'stat.TH', 'math.ST', 'stat.ME', 'eess.AS', 'cs.SD', 'q-bio.QM', 'q-bio.NC', 'cs.DS', 'cs.GT', 'cs.CG', 'cs.SE', 'cs.NI', 'I.2.6', 'stat.CO', 'math.NA', 'cs.NA', 'physics.chem-ph', 'cs.DB', 'q-bio.BM', 'cs.PL', 'cs.LO', 'cond-mat.dis-nn', '68T45', 'math.PR', 'physics.comp-ph', 'I.2.10', 'cs.CE', 'cs.AR', 'q-fin.ST', 'cond-mat.stat-mech', '68T05', 'quant-ph', 'math.DS', 'physics.data-an', 'cs.CC', 'I.4.6', 'physics.soc-ph', 'physics.ao-ph', 'cs.DM', 'econ.EM', 'q-bio.GN', 'physics.med-ph', 'astro-ph.IM', 'I.4.8', 'math.AT', 'cs.PF', 'cs.FL', 'I.4', 'q-fin.TR', 'I.5.4', 'I.2', '68U10', 'hep-ex', 'cond-mat.mtrl-sci', '68T10', 'physics.optics', 'physics.geo-ph', 'physics.flu-dyn', 'math.CO', 'math.AP', 'I.4; I.5', 'I.4.9', 'I.2.6; I.2.8', '68T01', '65D19', 'q-fin.CP', 'nlin.CD', 'cs.MS', 'I.2.6; I.5.1', 'I.2.10; I.4; I.5', 'I.2.0; I.2.6', '68T07', 'q-fin.GN', 'cs.SC', 'cs.ET', 'K.3.2', 'I.2.8', '68U01', '68T30', 'q-fin.EC', 'q-bio.MN', 'econ.GN', 'I.4.9; I.5.4', 'I.4.5', 'I.2; I.5', 'I.2; I.4; I.5', 'I.2.6; I.2.7', 'I.2.10; I.4.8', '68T99', '68Q32', '68', '62H30', 'q-fin.RM', 'q-fin.PM', 'q-bio.TO', 'q-bio.OT', 'physics.bio-ph', 'nlin.AO', 'math.LO', 'math.FA', 'hep-ph', 'cond-mat.soft', 'I.4.6; I.4.8', 'I.4.4', 'I.4.3', 'I.4.0', 'I.2; J.2', 'I.2; I.2.6; I.2.7', 'I.2.7', 'I.2.6; I.5.4', 'I.2.6; I.2.9', 'I.2.6; I.2.7; H.3.1; H.3.3', 'I.2.6; I.2.10', 'I.2.6, I.5.4', 'I.2.1; J.3', 'I.2.10; I.5.1; I.4.8', 'I.2.10; I.4.8; I.5.4', 'I.2.10; I.2.6', 'I.2.1', 'H.3.1; I.2.6; I.2.7', 'H.3.1; H.3.3; I.2.6; I.2.7', 'G.3', 'F.2.2; I.2.7', 'E.5; E.4; E.2; H.1.1; F.1.1; F.1.3', '68Txx', '62H99', '62H35', '14J60 (Primary) 14F05, 14J26 (Secondary)']

在这里,我们从标签池中分离出可用的各个唯一类别,然后使用此信息用 0 和 1 来表示给定的标签集。下面是一个例子。

sample_label = train_df["terms"].iloc[0]
print(f"Original label: {sample_label}")

label_binarized = lookup([sample_label])
print(f"Label-binarized representation: {label_binarized}")
Original label: ['cs.LG', 'cs.CV', 'eess.IV']
Label-binarized representation: [[0. 1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0.]]

数据预处理和 tf.data.Dataset 对象

我们首先获得序列长度的百分比估计值。目的稍后会变得清晰。

train_df["summaries"].apply(lambda x: len(x.split(" "))).describe()
count    32985.000000
mean       156.497105
std         41.528225
min          5.000000
25%        128.000000
50%        154.000000
75%        183.000000
max        462.000000
Name: summaries, dtype: float64

请注意,50% 的摘要长度为 154(您可能根据分割获得不同的数字)。因此,任何接近该值的数字都是最大序列长度的足够近似值。

现在,我们实现实用程序来准备我们的数据集。

max_seqlen = 150
batch_size = 128
padding_token = "<pad>"
auto = tf.data.AUTOTUNE


def make_dataset(dataframe, is_train=True):
    labels = tf.ragged.constant(dataframe["terms"].values)
    label_binarized = lookup(labels).numpy()
    dataset = tf.data.Dataset.from_tensor_slices(
        (dataframe["summaries"].values, label_binarized)
    )
    dataset = dataset.shuffle(batch_size * 10) if is_train else dataset
    return dataset.batch(batch_size)

现在,我们可以准备 tf.data.Dataset 对象。

train_dataset = make_dataset(train_df, is_train=True)
validation_dataset = make_dataset(val_df, is_train=False)
test_dataset = make_dataset(test_df, is_train=False)

数据集预览

text_batch, label_batch = next(iter(train_dataset))

for i, text in enumerate(text_batch[:5]):
    label = label_batch[i].numpy()[None, ...]
    print(f"Abstract: {text}")
    print(f"Label(s): {invert_multi_hot(label[0])}")
    print(" ")
Abstract: b"In this paper we show how using satellite images can improve the accuracy of\nhousing price estimation models. Using Los Angeles County's property assessment\ndataset, by transferring learning from an Inception-v3 model pretrained on\nImageNet, we could achieve an improvement of ~10% in R-squared score compared\nto two baseline models that only use non-image features of the house."
Label(s): ['cs.LG' 'stat.ML']

Abstract: b'Learning from data streams is an increasingly important topic in data mining,\nmachine learning, and artificial intelligence in general. A major focus in the\ndata stream literature is on designing methods that can deal with concept\ndrift, a challenge where the generating distribution changes over time. A\ngeneral assumption in most of this literature is that instances are\nindependently distributed in the stream. In this work we show that, in the\ncontext of concept drift, this assumption is contradictory, and that the\npresence of concept drift necessarily implies temporal dependence; and thus\nsome form of time series. This has important implications on model design and\ndeployment. We explore and highlight the these implications, and show that\nHoeffding-tree based ensembles, which are very popular for learning in streams,\nare not naturally suited to learning \\emph{within} drift; and can perform in\nthis scenario only at significant computational cost of destructive adaptation.\nOn the other hand, we develop and parameterize gradient-descent methods and\ndemonstrate how they can perform \\emph{continuous} adaptation with no explicit\ndrift-detection mechanism, offering major advantages in terms of accuracy and\nefficiency. As a consequence of our theoretical discussion and empirical\nobservations, we outline a number of recommendations for deploying methods in\nconcept-drifting streams.'
Label(s): ['cs.LG' 'stat.ML']

Abstract: b"As reinforcement learning (RL) achieves more success in solving complex\ntasks, more care is needed to ensure that RL research is reproducible and that\nalgorithms herein can be compared easily and fairly with minimal bias. RL\nresults are, however, notoriously hard to reproduce due to the algorithms'\nintrinsic variance, the environments' stochasticity, and numerous (potentially\nunreported) hyper-parameters. In this work we investigate the many issues\nleading to irreproducible research and how to manage those. We further show how\nto utilise a rigorous and standardised evaluation approach for easing the\nprocess of documentation, evaluation and fair comparison of different\nalgorithms, where we emphasise the importance of choosing the right measurement\nmetrics and conducting proper statistics on the results, for unbiased reporting\nof the results."
Label(s): ['cs.LG' 'stat.ML' 'cs.AI' 'cs.RO']

Abstract: b'Estimating dense correspondences between images is a long-standing image\nunder-standing task. Recent works introduce convolutional neural networks\n(CNNs) to extract high-level feature maps and find correspondences through\nfeature matching. However,high-level feature maps are in low spatial resolution\nand therefore insufficient to provide accurate and fine-grained features to\ndistinguish intra-class variations for correspondence matching. To address this\nproblem, we generate robust features by dynamically selecting features at\ndifferent scales. To resolve two critical issues in feature selection,i.e.,how\nmany and which scales of features to be selected, we frame the feature\nselection process as a sequential Markov decision-making process (MDP) and\nintroduce an optimal selection strategy using reinforcement learning (RL). We\ndefine an RL environment for image matching in which each individual action\neither requires new features or terminates the selection episode by referring a\nmatching score. Deep neural networks are incorporated into our method and\ntrained for decision making. Experimental results show that our method achieves\ncomparable/superior performance with state-of-the-art methods on three\nbenchmarks, demonstrating the effectiveness of our feature selection strategy.'
Label(s): ['cs.CV']

Abstract: b'Dense reconstructions often contain errors that prior work has so far\nminimised using high quality sensors and regularising the output. Nevertheless,\nerrors still persist. This paper proposes a machine learning technique to\nidentify errors in three dimensional (3D) meshes. Beyond simply identifying\nerrors, our method quantifies both the magnitude and the direction of depth\nestimate errors when viewing the scene. This enables us to improve the\nreconstruction accuracy.\n  We train a suitably deep network architecture with two 3D meshes: a\nhigh-quality laser reconstruction, and a lower quality stereo image\nreconstruction. The network predicts the amount of error in the lower quality\nreconstruction with respect to the high-quality one, having only view the\nformer through its input. We evaluate our approach by correcting\ntwo-dimensional (2D) inverse-depth images extracted from the 3D model, and show\nthat our method improves the quality of these depth reconstructions by up to a\nrelative 10% RMSE.'
Label(s): ['cs.CV' 'cs.RO']

向量化

在将数据馈送到我们的模型之前,我们需要对其进行向量化(将其表示为数值形式)。为此,我们将使用 TextVectorization。它可以作为主模型的一部分进行操作,以便将模型从核心预处理逻辑中排除。这极大地降低了推理期间训练/服务偏差的可能性。

我们首先计算摘要中存在的唯一词语数量。

# Source: https://stackoverflow.com/a/18937309/7636462
vocabulary = set()
train_df["summaries"].str.lower().str.split().apply(vocabulary.update)
vocabulary_size = len(vocabulary)
print(vocabulary_size)
153338

我们现在创建我们的向量化层并将 map() 应用于之前创建的 tf.data.Dataset

text_vectorizer = layers.TextVectorization(
    max_tokens=vocabulary_size, ngrams=2, output_mode="tf_idf"
)

# `TextVectorization` layer needs to be adapted as per the vocabulary from our
# training set.
with tf.device("/CPU:0"):
    text_vectorizer.adapt(train_dataset.map(lambda text, label: text))

train_dataset = train_dataset.map(
    lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto
).prefetch(auto)
validation_dataset = validation_dataset.map(
    lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto
).prefetch(auto)
test_dataset = test_dataset.map(
    lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto
).prefetch(auto)

一批原始文本将首先通过 TextVectorization 层,它将生成它们的整数表示。在内部,TextVectorization 层将首先从序列中创建二元组,然后使用 TF-IDF 来表示它们。输出表示将随后传递给负责文本分类的浅层模型。

要了解有关使用 TextVectorizer 的其他可能配置的更多信息,请参阅 官方文档

注意:将 max_tokens 参数设置为预先计算的词汇量大小不是必需的。


创建文本分类模型

我们将保持模型简单 - 它将是一堆小的全连接层,ReLU 作为非线性。

def make_model():
    shallow_mlp_model = keras.Sequential(
        [
            layers.Dense(512, activation="relu"),
            layers.Dense(256, activation="relu"),
            layers.Dense(lookup.vocabulary_size(), activation="sigmoid"),
        ]  # More on why "sigmoid" has been used here in a moment.
    )
    return shallow_mlp_model

训练模型

我们将使用二元交叉熵损失来训练我们的模型。这是因为标签不是不相交的。对于给定的摘要,我们可能有多个类别。因此,我们将预测任务划分为一系列多个二元分类问题。这也是我们在模型中将分类层的激活函数保留为 sigmoid 的原因。研究人员也使用了损失函数和激活函数的其他组合。例如,在 探索弱监督预训练的极限 中,Mahajan 等人使用 softmax 激活函数和交叉熵损失来训练他们的模型。

在多标签分类中可以使用多种指标。为了使本代码示例保持简洁,我们决定使用 二元准确率指标。要了解为什么使用此指标的解释,请参阅此 拉取请求。还有其他适合多标签分类的指标,例如 F1 分数汉明损失

epochs = 20

shallow_mlp_model = make_model()
shallow_mlp_model.compile(
    loss="binary_crossentropy", optimizer="adam", metrics=["binary_accuracy"]
)

history = shallow_mlp_model.fit(
    train_dataset, validation_data=validation_dataset, epochs=epochs
)


def plot_result(item):
    plt.plot(history.history[item], label=item)
    plt.plot(history.history["val_" + item], label="val_" + item)
    plt.xlabel("Epochs")
    plt.ylabel(item)
    plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
    plt.legend()
    plt.grid()
    plt.show()


plot_result("loss")
plot_result("binary_accuracy")
Epoch 1/20
258/258 [==============================] - 87s 332ms/step - loss: 0.0326 - binary_accuracy: 0.9893 - val_loss: 0.0189 - val_binary_accuracy: 0.9943
Epoch 2/20
258/258 [==============================] - 100s 387ms/step - loss: 0.0033 - binary_accuracy: 0.9990 - val_loss: 0.0271 - val_binary_accuracy: 0.9940
Epoch 3/20
258/258 [==============================] - 99s 384ms/step - loss: 7.8393e-04 - binary_accuracy: 0.9999 - val_loss: 0.0328 - val_binary_accuracy: 0.9939
Epoch 4/20
258/258 [==============================] - 109s 421ms/step - loss: 3.0132e-04 - binary_accuracy: 1.0000 - val_loss: 0.0366 - val_binary_accuracy: 0.9939
Epoch 5/20
258/258 [==============================] - 105s 405ms/step - loss: 1.6006e-04 - binary_accuracy: 1.0000 - val_loss: 0.0399 - val_binary_accuracy: 0.9939
Epoch 6/20
258/258 [==============================] - 107s 414ms/step - loss: 1.2400e-04 - binary_accuracy: 1.0000 - val_loss: 0.0412 - val_binary_accuracy: 0.9939
Epoch 7/20
258/258 [==============================] - 110s 425ms/step - loss: 7.7131e-05 - binary_accuracy: 1.0000 - val_loss: 0.0439 - val_binary_accuracy: 0.9940
Epoch 8/20
258/258 [==============================] - 105s 405ms/step - loss: 5.5611e-05 - binary_accuracy: 1.0000 - val_loss: 0.0446 - val_binary_accuracy: 0.9940
Epoch 9/20
258/258 [==============================] - 103s 397ms/step - loss: 4.5994e-05 - binary_accuracy: 1.0000 - val_loss: 0.0454 - val_binary_accuracy: 0.9940
Epoch 10/20
258/258 [==============================] - 105s 405ms/step - loss: 3.5126e-05 - binary_accuracy: 1.0000 - val_loss: 0.0472 - val_binary_accuracy: 0.9939
Epoch 11/20
258/258 [==============================] - 109s 422ms/step - loss: 2.9927e-05 - binary_accuracy: 1.0000 - val_loss: 0.0466 - val_binary_accuracy: 0.9940
Epoch 12/20
258/258 [==============================] - 133s 516ms/step - loss: 2.5748e-05 - binary_accuracy: 1.0000 - val_loss: 0.0484 - val_binary_accuracy: 0.9940
Epoch 13/20
258/258 [==============================] - 129s 497ms/step - loss: 4.3529e-05 - binary_accuracy: 1.0000 - val_loss: 0.0500 - val_binary_accuracy: 0.9940
Epoch 14/20
258/258 [==============================] - 158s 611ms/step - loss: 8.1068e-04 - binary_accuracy: 0.9998 - val_loss: 0.0377 - val_binary_accuracy: 0.9936
Epoch 15/20
258/258 [==============================] - 144s 558ms/step - loss: 0.0016 - binary_accuracy: 0.9995 - val_loss: 0.0418 - val_binary_accuracy: 0.9935
Epoch 16/20
258/258 [==============================] - 131s 506ms/step - loss: 0.0018 - binary_accuracy: 0.9995 - val_loss: 0.0479 - val_binary_accuracy: 0.9931
Epoch 17/20
258/258 [==============================] - 127s 491ms/step - loss: 0.0012 - binary_accuracy: 0.9997 - val_loss: 0.0521 - val_binary_accuracy: 0.9931
Epoch 18/20
258/258 [==============================] - 153s 594ms/step - loss: 6.3144e-04 - binary_accuracy: 0.9998 - val_loss: 0.0549 - val_binary_accuracy: 0.9934
Epoch 19/20
258/258 [==============================] - 142s 550ms/step - loss: 3.1753e-04 - binary_accuracy: 0.9999 - val_loss: 0.0589 - val_binary_accuracy: 0.9934
Epoch 20/20
258/258 [==============================] - 153s 594ms/step - loss: 2.0258e-04 - binary_accuracy: 1.0000 - val_loss: 0.0585 - val_binary_accuracy: 0.9933

png

png

在训练期间,我们注意到损失最初急剧下降,然后逐渐衰减。

评估模型

_, binary_acc = shallow_mlp_model.evaluate(test_dataset)
print(f"Categorical accuracy on the test set: {round(binary_acc * 100, 2)}%.")
15/15 [==============================] - 3s 196ms/step - loss: 0.0580 - binary_accuracy: 0.9933
Categorical accuracy on the test set: 99.33%.

经过训练的模型使我们获得了大约 99% 的评估准确率。


推理

Keras 提供的 预处理层 的一个重要特性是它们可以包含在 tf.keras.Model 中。我们将通过在 shallow_mlp_model 之上包含 text_vectorization 层来导出推理模型。这将使我们的推理模型能够直接对原始字符串进行操作。

注意,在训练期间,最好将这些预处理层作为数据输入管道的一部分而不是模型的一部分,以避免对硬件加速器造成瓶颈。这也允许异步数据处理。

# Create a model for inference.
model_for_inference = keras.Sequential([text_vectorizer, shallow_mlp_model])

# Create a small dataset just for demoing inference.
inference_dataset = make_dataset(test_df.sample(100), is_train=False)
text_batch, label_batch = next(iter(inference_dataset))
predicted_probabilities = model_for_inference.predict(text_batch)

# Perform inference.
for i, text in enumerate(text_batch[:5]):
    label = label_batch[i].numpy()[None, ...]
    print(f"Abstract: {text}")
    print(f"Label(s): {invert_multi_hot(label[0])}")
    predicted_proba = [proba for proba in predicted_probabilities[i]]
    top_3_labels = [
        x
        for _, x in sorted(
            zip(predicted_probabilities[i], lookup.get_vocabulary()),
            key=lambda pair: pair[0],
            reverse=True,
        )
    ][:3]
    print(f"Predicted Label(s): ({', '.join([label for label in top_3_labels])})")
    print(" ")
4/4 [==============================] - 0s 62ms/step
Abstract: b'We investigate the training of sparse layers that use different parameters\nfor different inputs based on hashing in large Transformer models.\nSpecifically, we modify the feedforward layer to hash to different sets of\nweights depending on the current token, over all tokens in the sequence. We\nshow that this procedure either outperforms or is competitive with\nlearning-to-route mixture-of-expert methods such as Switch Transformers and\nBASE Layers, while requiring no routing parameters or extra terms in the\nobjective function such as a load balancing loss, and no sophisticated\nassignment algorithm. We study the performance of different hashing techniques,\nhash sizes and input features, and show that balanced and random hashes focused\non the most local features work best, compared to either learning clusters or\nusing longer-range context. We show our approach works well both on large\nlanguage modeling and dialogue tasks, and on downstream fine-tuning tasks.'
Label(s): ['cs.LG' 'cs.CL']
Predicted Label(s): (cs.LG, cs.CL, stat.ML)

Abstract: b'We present the first method capable of photorealistically reconstructing\ndeformable scenes using photos/videos captured casually from mobile phones. Our\napproach augments neural radiance fields (NeRF) by optimizing an additional\ncontinuous volumetric deformation field that warps each observed point into a\ncanonical 5D NeRF. We observe that these NeRF-like deformation fields are prone\nto local minima, and propose a coarse-to-fine optimization method for\ncoordinate-based models that allows for more robust optimization. By adapting\nprinciples from geometry processing and physical simulation to NeRF-like\nmodels, we propose an elastic regularization of the deformation field that\nfurther improves robustness. We show that our method can turn casually captured\nselfie photos/videos into deformable NeRF models that allow for photorealistic\nrenderings of the subject from arbitrary viewpoints, which we dub "nerfies." We\nevaluate our method by collecting time-synchronized data using a rig with two\nmobile phones, yielding train/validation images of the same pose at different\nviewpoints. We show that our method faithfully reconstructs non-rigidly\ndeforming scenes and reproduces unseen views with high fidelity.'
Label(s): ['cs.CV' 'cs.GR']
Predicted Label(s): (cs.CV, cs.GR, cs.RO)

Abstract: b'We propose to jointly learn multi-view geometry and warping between views of\nthe same object instances for robust cross-view object detection. What makes\nmulti-view object instance detection difficult are strong changes in viewpoint,\nlighting conditions, high similarity of neighbouring objects, and strong\nvariability in scale. By turning object detection and instance\nre-identification in different views into a joint learning task, we are able to\nincorporate both image appearance and geometric soft constraints into a single,\nmulti-view detection process that is learnable end-to-end. We validate our\nmethod on a new, large data set of street-level panoramas of urban objects and\nshow superior performance compared to various baselines. Our contribution is\nthreefold: a large-scale, publicly available data set for multi-view instance\ndetection and re-identification; an annotation tool custom-tailored for\nmulti-view instance detection; and a novel, holistic multi-view instance\ndetection and re-identification method that jointly models geometry and\nappearance across views.'
Label(s): ['cs.CV' 'cs.LG' 'stat.ML']
Predicted Label(s): (cs.CV, cs.RO, cs.MM)

Abstract: b'Learning graph convolutional networks (GCNs) is an emerging field which aims\nat generalizing deep learning to arbitrary non-regular domains. Most of the\nexisting GCNs follow a neighborhood aggregation scheme, where the\nrepresentation of a node is recursively obtained by aggregating its neighboring\nnode representations using averaging or sorting operations. However, these\noperations are either ill-posed or weak to be discriminant or increase the\nnumber of training parameters and thereby the computational complexity and the\nrisk of overfitting. In this paper, we introduce a novel GCN framework that\nachieves spatial graph convolution in a reproducing kernel Hilbert space\n(RKHS). The latter makes it possible to design, via implicit kernel\nrepresentations, convolutional graph filters in a high dimensional and more\ndiscriminating space without increasing the number of training parameters. The\nparticularity of our GCN model also resides in its ability to achieve\nconvolutions without explicitly realigning nodes in the receptive fields of the\nlearned graph filters with those of the input graphs, thereby making\nconvolutions permutation agnostic and well defined. Experiments conducted on\nthe challenging task of skeleton-based action recognition show the superiority\nof the proposed method against different baselines as well as the related work.'
Label(s): ['cs.CV']
Predicted Label(s): (cs.LG, cs.CV, cs.NE)

Abstract: b'Recurrent meta reinforcement learning (meta-RL) agents are agents that employ\na recurrent neural network (RNN) for the purpose of "learning a learning\nalgorithm". After being trained on a pre-specified task distribution, the\nlearned weights of the agent\'s RNN are said to implement an efficient learning\nalgorithm through their activity dynamics, which allows the agent to quickly\nsolve new tasks sampled from the same distribution. However, due to the\nblack-box nature of these agents, the way in which they work is not yet fully\nunderstood. In this study, we shed light on the internal working mechanisms of\nthese agents by reformulating the meta-RL problem using the Partially\nObservable Markov Decision Process (POMDP) framework. We hypothesize that the\nlearned activity dynamics is acting as belief states for such agents. Several\nillustrative experiments suggest that this hypothesis is true, and that\nrecurrent meta-RL agents can be viewed as agents that learn to act optimally in\npartially observable environments consisting of multiple related tasks. This\nview helps in understanding their failure cases and some interesting\nmodel-based results reported in the literature.'
Label(s): ['cs.LG' 'cs.AI']
Predicted Label(s): (stat.ML, cs.LG, cs.AI)

预测结果并不理想,但对于像我们这样的简单模型来说并不低于平均水平。我们可以使用考虑词序的模型(如 LSTM)甚至那些使用 Transformer 的模型 (Vaswani 等人) 来提高这种性能。


致谢

我们要感谢 Matt Watson 帮助我们解决了多标签二值化部分以及将处理后的标签反变换回原始形式的问题。

感谢 Cingis Kratochvil 为此代码示例提供建议并通过二进制准确率扩展了它。