Keras 3 API 文档 / 内置小型数据集 / 路透社新闻专线分类数据集

路透社新闻专线分类数据集

[源代码]

load_data 函数

keras.datasets.reuters.load_data(
    path="reuters.npz",
    num_words=None,
    skip_top=0,
    maxlen=None,
    test_split=0.2,
    seed=113,
    start_char=1,
    oov_char=2,
    index_from=3,
)

加载路透社新闻专线分类数据集。

这是一个包含来自路透社的 11,228 条新闻专线的数据集,标记了 46 个主题。

这最初是通过解析和预处理经典的 Reuters-21578 数据集生成的,但预处理代码不再与 Keras 打包在一起。有关更多信息,请参阅此GitHub 讨论

每条新闻专线都被编码为单词索引(整数)的列表。为了方便起见,单词是按数据集中出现的总体频率进行索引的,因此例如,整数“3”编码了数据中第三个最常见的单词。这允许快速进行过滤操作,例如:“只考虑最常见的 10,000 个单词,但消除最常见的 20 个单词”。

按照惯例,“0”不代表特定的单词,而是用于编码任何未知单词。

参数

  • path: 缓存数据的位置(相对于 ~/.keras/dataset)。
  • num_words: 整数或 None。单词按其出现频率(在训练集中)进行排名,并且只保留 num_words 个最常见的单词。任何不常见的单词都会在序列数据中显示为 oov_char 值。如果为 None,则保留所有单词。默认为 None
  • skip_top: 跳过 N 个最常出现的单词(可能不提供信息)。这些单词在数据集中将显示为 oov_char 值。0 表示不跳过任何单词。默认为 0
  • maxlen: int 或 None。最大序列长度。任何更长的序列都将被截断。None 表示不截断。默认为 None
  • test_split: 介于 0.1. 之间的浮点数。用作测试数据集的数据集比例。0.2 表示 20% 的数据集用作测试数据。默认为 0.2
  • seed: int。用于可重复数据洗牌的种子。
  • start_char: int。序列的开始将用此字符标记。0 通常是填充字符。默认为 1
  • oov_char: int。词汇表外字符。由于 num_wordsskip_top 限制而被删除的单词将被替换为此字符。
  • index_from: int。使用此索引及更高的索引对实际单词进行索引。

返回值

  • Numpy 数组的元组: (x_train, y_train), (x_test, y_test)

x_train, x_test: 序列列表,它们是索引(整数)的列表。如果指定了 num_words 参数,则最大可能的索引值是 num_words - 1。如果指定了 maxlen 参数,则最大可能的序列长度是 maxlen

y_train, y_test: 整数标签的列表 (1 或 0)。

注意: “词汇表外”字符仅用于训练集中存在但由于不符合此处的 num_words 条件而未包含的单词。在训练集中未看到但在测试集中出现的单词已被简单地跳过。


[源代码]

get_word_index 函数

keras.datasets.reuters.get_word_index(path="reuters_word_index.json")

检索将单词映射到其在路透社数据集中的索引的字典。

实际的单词索引从 3 开始,其中 3 个索引保留给:0(填充)、1(开始)、2(oov)。

例如,单词“the”的索引为 1,但在实际的训练数据中,“the”的索引将为 1 + 3 = 4。反之亦然,要使用此映射将训练数据中的单词索引转换回单词,索引需要减去 3。

参数

  • path: 缓存数据的位置(相对于 ~/.keras/dataset)。

返回值

单词索引字典。键是单词字符串,值是它们的索引。