Keras 2 API文档 / 内置小型数据集 / MNIST数字分类数据集

MNIST 数字分类数据集

[源代码]

load_data 函数

tf_keras.datasets.mnist.load_data(path="mnist.npz", cache_dir=None)

加载MNIST数据集。

这是一个包含60,000张28x28灰度图像的数据集,涵盖10个数字,以及一个包含10,000张图像的测试集。更多信息可在 MNIST主页 上找到。

参数

  • path:数据集本地缓存的路径,相对于cache_dir。
  • cache_dir:数据集本地缓存的目录位置。如果为None,则默认为 ~/.keras/datasets

返回

  • NumPy 数组元组(x_train, y_train), (x_test, y_test)

x_train:形状为 (60000, 28, 28) 的uint8 NumPy数组,包含灰度图像数据,用于训练。像素值范围为0到255。

y_train:形状为 (60000,) 的uint8 NumPy数组,包含训练数据的数字标签(0-9范围内的整数)。

x_test:形状为 (10000, 28, 28) 的uint8 NumPy数组,包含灰度图像数据,用于测试。像素值范围为0到255。

y_test:形状为 (10000,) 的uint8 NumPy数组,包含测试数据的数字标签(0-9范围内的整数)。

示例

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
assert x_train.shape == (60000, 28, 28)
assert x_test.shape == (10000, 28, 28)
assert y_train.shape == (60000,)
assert y_test.shape == (10000,)

许可证:MNIST数据集的版权归Yann LeCun和Corinna Cortes所有,它是NIST原始数据集的衍生作品。MNIST数据集可根据 Creative Commons Attribution-Share Alike 3.0 许可证 的条款使用。