Keras 2 API文档 / 内置小型数据集 / CIFAR10小型图像分类数据集

CIFAR10 小图像分类数据集

[源代码]

load_data 函数

tf_keras.datasets.cifar10.load_data(cache_dir=None)

加载CIFAR10数据集。

这是一个包含50,000张32x32彩色训练图像和10,000张测试图像的数据集,分为10个类别。更多信息请参见 CIFAR主页

类别如下

标签 描述
0 飞机
1 汽车
2
3
4 鹿
5
6 青蛙
7
8
9 卡车

参数

  • cache_dir: 用于在本地缓存数据集的目录。如果为None,则默认为~/.keras/datasets

返回

  • NumPy 数组元组(x_train, y_train), (x_test, y_test)

x_train: uint8 NumPy数组,包含形状为 (50000, 32, 32, 3) 的图像数据,是训练数据。像素值范围为0到255。

y_train: uint8 NumPy数组,包含形状为 (50000, 1) 的标签(0-9范围内的整数),是训练数据的标签。

x_test: uint8 NumPy数组,包含形状为 (10000, 32, 32, 3) 的图像数据,是测试数据。像素值范围为0到255。

y_test: uint8 NumPy数组,包含形状为 (10000, 1) 的标签(0-9范围内的整数),是测试数据的标签。

示例

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
assert x_train.shape == (50000, 32, 32, 3)
assert x_test.shape == (10000, 32, 32, 3)
assert y_train.shape == (50000, 1)
assert y_test.shape == (10000, 1)