Keras 3 API 文档 / 内置小型数据集 / IMDB 电影评论情感分类数据集

IMDB 电影评论情感分类数据集

[源文件]

load_data 函数

keras.datasets.imdb.load_data(
    path="imdb.npz",
    num_words=None,
    skip_top=0,
    maxlen=None,
    seed=113,
    start_char=1,
    oov_char=2,
    index_from=3,
    **kwargs
)

加载 IMDB 数据集

这是来自 IMDB 的 25,000 条电影评论数据集,按情感(正面/负面)进行标记。评论已经过预处理,每条评论都被编码为一个词汇索引列表(整数)。为方便起见,词汇是根据其在数据集中的总频率进行索引的,例如整数 "3" 编码的是数据中第 3 个最常见的词汇。这允许进行快速过滤操作,例如:“仅考虑排名前 10,000 的常用词,但排除排名前 20 的最常见词汇”。

按照惯例,"0" 不代表特定词汇,而是用于编码填充标记(pad token)。

参数

  • path: 数据缓存位置(相对于 ~/.keras/dataset)。
  • num_words: 整数或 None。词汇根据其出现频率(在训练集中)进行排序,仅保留 num_words 个最常见的词汇。任何频率较低的词汇将在序列数据中以 oov_char 值表示。如果为 None,则保留所有词汇。默认为 None
  • skip_top: 跳过出现频率最高的 N 个词汇(这些词汇可能信息量不大)。这些词汇将在数据集中以 oov_char 值表示。当为 0 时,不跳过任何词汇。默认为 0
  • maxlen: int 或 None。最大序列长度。任何更长的序列将被截断。None 表示不截断。默认为 None
  • seed: int。用于可复现数据混洗的随机种子。
  • start_char: int。序列的开始将用此字符标记。0 通常是填充字符。默认为 1
  • oov_char: int。超出词汇表(Out-of-vocabulary)的字符。因 num_wordsskip_top 限制而被剔除的词汇将替换为此字符。
  • index_from: int。实际词汇将从此索引及更高的索引开始编码。

返回值

  • Numpy 数组元组: (x_train, y_train), (x_test, y_test)

x_train, x_test: 序列列表,其中序列是索引列表(整数)。如果指定了 num_words 参数,则最大可能的索引值为 num_words - 1。如果指定了 maxlen 参数,则最大可能的序列长度为 maxlen

y_train, y_test: 整数标签列表(1 或 0)。

注意: “超出词汇表”字符仅用于那些存在于训练集中但因未达到 num_words 限制而被排除的词汇。在训练集中未出现但在测试集中的词汇则直接被跳过。


[源文件]

get_word_index 函数

keras.datasets.imdb.get_word_index(path="imdb_word_index.json")

检索一个字典,该字典将词汇映射到它们在 IMDB 数据集中的索引。

参数

  • path: 数据缓存位置(相对于 ~/.keras/dataset)。

返回值

词汇索引字典。键是词汇字符串,值是其索引。

示例

# Use the default parameters to keras.datasets.imdb.load_data
start_char = 1
oov_char = 2
index_from = 3
# Retrieve the training sequences.
(x_train, _), _ = keras.datasets.imdb.load_data(
    start_char=start_char, oov_char=oov_char, index_from=index_from
)
# Retrieve the word index file mapping words to indices
word_index = keras.datasets.imdb.get_word_index()
# Reverse the word index to obtain a dict mapping indices to words
# And add `index_from` to indices to sync with `x_train`
inverted_word_index = dict(
    (i + index_from, word) for (word, i) in word_index.items()
)
# Update `inverted_word_index` to include `start_char` and `oov_char`
inverted_word_index[start_char] = "[START]"
inverted_word_index[oov_char] = "[OOV]"
# Decode the first sequence in the dataset
decoded_sequence = " ".join(inverted_word_index[i] for i in x_train[0])