text_dataset_from_directory 函数keras.utils.text_dataset_from_directory(
directory,
labels="inferred",
label_mode="int",
class_names=None,
batch_size=32,
max_length=None,
shuffle=True,
seed=None,
validation_split=None,
subset=None,
follow_links=False,
format="tf",
verbose=True,
)
从目录中的文本文件生成数据集。
如果你的目录结构是
main_directory/
...class_a/
......a_text_1.txt
......a_text_2.txt
...class_b/
......b_text_1.txt
......b_text_2.txt
调用 text_dataset_from_directory(main_directory, labels='inferred') 将返回一个数据集,该数据集会生成来自子目录 class_a 和 class_b 的文本批次,以及标签 0 和 1(0 对应 class_a,1 对应 class_b)。
目前仅支持 .txt 文件。
默认情况下,此函数将返回一个 tf.data.Dataset 对象。您可以设置 format="grain" 来改返回 grain.IterDataset 对象,这将移除 TensorFlow 依赖。
参数
labels 为 "inferred",则该目录应包含子目录,每个子目录包含一个类的文本文件。否则,将忽略目录结构。"inferred"(标签从目录结构推断生成)、None(无标签)或一个与目录中找到的文本文件数量相等的整数标签列表/元组。标签应根据文本文件路径的字母数字顺序进行排序(在 Python 中通过 os.walk(directory) 获取)。labels 编码的字符串。选项有:"int":表示标签编码为整数(例如,用于 sparse_categorical_crossentropy 损失)。"categorical" 表示标签被编码为分类向量(例如,用于 categorical_crossentropy 损失)。"binary":表示标签(只能有两个)编码为 float32 标量,值为 0 或 1(例如,用于 binary_crossentropy)。None(无标签)。"labels" 为 "inferred" 时有效。这是类名的显式列表(必须与子目录名称匹配)。用于控制类的顺序(否则使用字母数字顺序)。None,则数据将不进行批处理(数据集将产生单个样本)。默认为 32。max_length。False,则按字母数字顺序对数据进行排序。默认为 True。"training"、"validation" 或 "both" 之一。仅在设置了 validation_split 时使用。当 subset="both" 时,该实用程序返回一个包含两个数据集的元组(分别是训练集和验证集)。False。"tf"。可用选项包括:"tf":返回一个 tf.data.Dataset 对象。需要安装 TensorFlow。"grain":返回一个 grain.IterDataset 对象。需要安装 Grain。True。返回
一个 tf.data.Dataset(format="tf")或 grain.IterDataset(format="grain")对象。
当 format="tf" 时: - 如果 label_mode 为 None,则生成形状为 (batch_size,) 的 string 张量,包含一批文本文件的内容。 - 否则,生成一个元组 (texts, labels),其中 texts 的形状为 (batch_size,),labels 的格式如下所述。
当 format="grain" 时: - 如果 label_mode 为 None,则生成一个包含一批文本文件内容的 Python 字符串列表。 - 否则,生成一个元组 (texts, labels),其中 texts 是一个 Python 字符串列表,labels 的格式如下所述。
标签格式规则
label_mode 是 int,则标签是形状为 (batch_size,) 的 int32 张量。label_mode 是 binary,则标签是形状为 (batch_size, 1) 的 float32 张量,包含 1 和 0。label_mode 是 categorical,则标签是形状为 (batch_size, num_classes) 的 float32 张量,表示类索引的独热编码。