Keras 2 API 文档 / 数据加载 / 文本数据加载

文本数据加载

[源文件]

text_dataset_from_directory 函数

tf_keras.utils.text_dataset_from_directory(
    directory,
    labels="inferred",
    label_mode="int",
    class_names=None,
    batch_size=32,
    max_length=None,
    shuffle=True,
    seed=None,
    validation_split=None,
    subset=None,
    follow_links=False,
)

从目录中的文本文件生成一个 tf.data.Dataset 对象。

如果你的目录结构是

main_directory/
...class_a/
......a_text_1.txt
......a_text_2.txt
...class_b/
......b_text_1.txt
......b_text_2.txt

那么调用 text_dataset_from_directory(main_directory, labels='inferred') 将返回一个 tf.data.Dataset,它将从子目录 class_aclass_b 中生成文本批次,以及标签 0 和 1 (0 对应于 class_a,1 对应于 class_b)。

目前只支持 .txt 文件。

参数

  • directory:数据所在的目录。如果 labels"inferred",它应该包含子目录,每个子目录包含一个类的文本文件。否则,目录结构将被忽略。
  • labels:可以是 "inferred"(标签从目录结构生成),None(无标签),或者一个整数标签的列表/元组,其大小与目录中找到的文本文件数量相同。标签应该按照文本文件路径的字母数字顺序排序(通过 Python 中的 os.walk(directory) 获取)。
  • label_mode:描述 labels 编码的字符串。选项包括
    • "int":表示标签被编码为整数(例如,用于 sparse_categorical_crossentropy 损失函数)。
    • "categorical":表示标签被编码为分类向量(例如,用于 categorical_crossentropy 损失函数)。
    • "binary":表示标签(只能是 2 类)被编码为 float32 值为 0 或 1 的标量(例如,用于 binary_crossentropy)。
    • None(无标签)。
  • class_names:仅在 "labels""inferred" 时有效。这是类的名称的明确列表(必须与子目录名称匹配)。用于控制类的顺序(否则使用字母数字顺序)。
  • batch_size:数据批次的大小。默认为 32。如果为 None,数据将不会被分批(数据集将生成单个样本)。
  • max_length:文本字符串的最大长度。超过此长度的文本将被截断到 max_length
  • shuffle:是否打乱数据。默认为 True。如果设置为 False,数据将按字母数字顺序排序。
  • seed:用于打乱和转换的可选随机种子。
  • validation_split:一个可选的 0 到 1 之间的浮点数,表示用于验证的数据比例。
  • subset:要返回的数据子集。"training""validation""both" 之一。仅在设置了 validation_split 时使用。当 subset="both" 时,该工具函数返回一个包含两个数据集的元组(分别是训练数据集和验证数据集)。
  • follow_links:是否访问符号链接指向的子目录。默认为 False

返回值

一个 tf.data.Dataset 对象。

  • 如果 label_modeNone,它将生成形状为 (batch_size,)string 张量,包含一批文本文件的内容。
  • 否则,它将生成一个元组 (texts, labels),其中 texts 的形状为 (batch_size,),并且 labels 的格式如下所述。

关于标签格式的规则

  • 如果 label_modeint,标签是一个形状为 (batch_size,)int32 张量。
  • 如果 label_modebinary,标签是一个形状为 (batch_size, 1) 的由 1 和 0 组成的 float32 张量。
  • 如果 label_modecategorical,标签是一个形状为 (batch_size, num_classes)float32 张量,表示类索引的 one-hot 编码。