Keras 2 API文档 / 内置小型数据集 / 路透社新闻分类数据集

路透社新闻专线分类数据集

[源代码]

load_data 函数

tf_keras.datasets.reuters.load_data(
    path="reuters.npz",
    num_words=None,
    skip_top=0,
    maxlen=None,
    test_split=0.2,
    seed=113,
    start_char=1,
    oov_char=2,
    index_from=3,
    cache_dir=None,
    **kwargs
)

加载路透社新闻分类数据集。

这是来自路透社的11,228条新闻数据,被标记为46个主题。

这最初是通过解析和预处理经典的Reuters-21578数据集生成的,但预处理代码不再随TF-Keras一起打包。有关更多信息,请参阅此GitHub讨论

每条新闻都被编码为一个单词索引(整数)列表。为了方便起见,单词按数据集中出现的总频率进行索引,例如,整数“3”编码数据中第三个最常见的单词。这允许快速过滤操作,例如:“仅考虑最常见的10,000个单词,但排除最常见的20个单词”。

按照惯例,“0”不代表特定单词,而是用于编码任何未知单词。

参数

  • path:数据缓存位置(相对于 ~/.keras/dataset)。
  • num_words:整数或 None。单词按出现频率(在训练集中)排序,并且只保留 num_words 个最常见的单词。任何不常见的单词将在序列数据中显示为 oov_char 值。如果为 None,则保留所有单词。默认为 None
  • skip_top: 跳过出现频率最高的N个单词(这些单词可能没有信息量)。这些单词将在数据集中显示为oov_char值。0表示不跳过任何单词。默认为0
  • maxlen: int 或 None。最大序列长度。任何更长的序列都将被截断。None表示不截断。默认为None
  • test_split: 介于0.1.之间的浮点数。用于作为测试数据的占数据集的比例。0.2意味着20%的数据集用作测试数据。默认为0.2
  • seed:整数。用于可重现数据混洗的种子。
  • start_char:整数。序列的开头将用此字符标记。0 通常是填充字符。默认为 1
  • oov_char:整数。词汇表外字符。因 num_wordsskip_top 限制而被切除的词语将替换为此字符。
  • index_from:整数。实际词语从此索引及更高索引开始。
  • cache_dir: 用于在本地缓存数据集的目录。如果为None,则默认为~/.keras/datasets
  • **kwargs: 向后兼容使用。

返回

  • Numpy 数组元组(x_train, y_train), (x_test, y_test)

x_train, x_test: 序列列表,每个序列是索引(整数)列表。如果指定了num_words参数,则最大可能索引值为num_words - 1。如果指定了maxlen参数,则最大可能序列长度为maxlen

y_train, y_test: 整数标签列表(1或0)。

注意:“词汇外”(out of vocabulary)字符仅用于在训练集中出现但在num_words限制中未包含的单词。在训练集中未见过但在测试集中出现的单词已被简单地跳过。


[源代码]

get_word_index 函数

tf_keras.datasets.reuters.get_word_index(path="reuters_word_index.json")

检索一个将单词映射到其在路透社数据集中的索引的字典。

实际的单词索引从3开始,其中3个索引保留给:0(填充)、1(开始)、2(OOV)。

例如,'the'的单词索引是1,但在实际的训练数据中,'the'的索引将是1 + 3 = 4。反之,要使用此映射将训练数据中的单词索引翻译回单词,需要减去3。

参数

  • path:数据缓存位置(相对于 ~/.keras/dataset)。

返回

词语索引字典。键是词语字符串,值是其索引。