load_data
函数tf_keras.datasets.imdb.load_data(
path="imdb.npz",
num_words=None,
skip_top=0,
maxlen=None,
seed=113,
start_char=1,
oov_char=2,
index_from=3,
cache_dir=None,
**kwargs
)
加载 IMDB 数据集。
该数据集包含来自 IMDB 的 25,000 条电影评论,按情感(积极/消极)标记。评论已预处理,每条评论都编码为一个词索引列表(整数)。为了方便起见,词汇根据其在数据集中的总体出现频率进行索引,例如,整数“3”编码数据中出现频率排名第三的词。这使得可以进行快速过滤操作,例如:“仅考虑前 10,000 个最常见的词,但排除前 20 个最常见的词”。
按照惯例,“0”不代表特定的词,而是用于编码填充标记。
参数
~/.keras/dataset
)。None
。词汇根据其出现频率(在训练集中)进行排名,并且只保留 num_words
个最频繁的词。任何频率较低的词将以 oov_char
的值出现在序列数据中。如果为 None
,则保留所有词。默认为 None
。oov_char
的值出现在数据集中。当为 0 时,不跳过任何词。默认为 0
。None
。最大序列长度。任何更长的序列将被截断。None
表示不截断。默认为 None
。1
。num_words
或 skip_top
限制而被移除的词将被替换为此字符。None
,则默认为 ~/.keras/datasets
。返回值
(x_train, y_train), (x_test, y_test)
。x_train, x_test:序列列表,每个序列是索引(整数)列表。如果指定了 num_words
参数,则最大可能的索引值为 num_words - 1
。如果指定了 maxlen
参数,则最大可能的序列长度为 maxlen
。
y_train, y_test:整数标签列表(1 或 0)。
抛出
maxlen
过低导致无法保留任何输入序列。请注意,'词汇表外'字符仅用于那些出现在训练集中但由于未达到 num_words
限制而未包含进来的词。未在训练集中出现但在测试集中的词已被简单跳过。
get_word_index
函数tf_keras.datasets.imdb.get_word_index(path="imdb_word_index.json")
检索一个字典,该字典将 IMDB 数据集中的词汇映射到其索引。
参数
~/.keras/dataset
)。返回值
词汇索引字典。键为词汇字符串,值为其对应的索引。
示例
# Use the default parameters to keras.datasets.imdb.load_data
start_char = 1
oov_char = 2
index_from = 3
# Retrieve the training sequences.
(x_train, _), _ = keras.datasets.imdb.load_data(
start_char=start_char, oov_char=oov_char, index_from=index_from
)
# Retrieve the word index file mapping words to indices
word_index = keras.datasets.imdb.get_word_index()
# Reverse the word index to obtain a dict mapping indices to words
# And add `index_from` to indices to sync with `x_train`
inverted_word_index = dict(
(i + index_from, word) for (word, i) in word_index.items()
)
# Update `inverted_word_index` to include `start_char` and `oov_char`
inverted_word_index[start_char] = "[START]"
inverted_word_index[oov_char] = "[OOV]"
# Decode the first sequence in the dataset
decoded_sequence = " ".join(inverted_word_index[i] for i in x_train[0])