Keras 3 API 文档 / 数据加载

数据加载

Keras 的数据加载工具位于 keras.utils 模块中,它们可以帮助您将磁盘上的原始数据转换为 tf.data.Dataset 对象,以便高效地训练模型。

在训练之前,您还可以将这些加载工具与预处理层结合使用,以进一步转换输入数据集。

下面是一个简单的例子:假设您有 10 个文件夹,每个文件夹包含 10,000 张来自不同类别的图像,您想训练一个将图像映射到其类别的分类器。

您的训练数据文件夹看起来会像这样

training_data/
...class_a/
......a_image_1.jpg
......a_image_2.jpg
...class_b/
......b_image_1.jpg
......b_image_2.jpg
etc.

您可能还有一个 validation_data/ 验证数据文件夹,结构与训练数据文件夹相同。

您可以简单地这样做

import keras

train_ds = keras.utils.image_dataset_from_directory(
    directory='training_data/',
    labels='inferred',
    label_mode='categorical',
    batch_size=32,
    image_size=(256, 256))
validation_ds = keras.utils.image_dataset_from_directory(
    directory='validation_data/',
    labels='inferred',
    label_mode='categorical',
    batch_size=32,
    image_size=(256, 256))

model = keras.applications.Xception(
    weights=None, input_shape=(256, 256, 3), classes=10)
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
model.fit(train_ds, epochs=10, validation_data=validation_ds)

可用的数据集加载工具

图像数据加载

时间序列数据加载

文本数据加载

音频数据加载