Keras 3 API 文档 / 内置小型数据集 / MNIST 数字分类数据集

MNIST 数字分类数据集

[源代码]

load_data 函数

keras.datasets.mnist.load_data(path="mnist.npz")

加载 MNIST 数据集。

这是一个包含 60,000 张 28x28 灰度图像(代表 10 个数字)的数据集,以及一个包含 10,000 张图像的测试集。更多信息请访问 MNIST 主页

参数

  • path: 在本地缓存数据集的路径(相对于 ~/.keras/datasets)。

返回

  • NumPy 数组的元组: (x_train, y_train), (x_test, y_test)

x_train: uint8 NumPy 数组,包含灰度图像数据,形状为 (60000, 28, 28),包含训练数据。像素值范围为 0 到 255。

y_train: uint8 NumPy 数组,包含数字标签(范围 0-9 的整数),形状为 (60000,),用于训练数据。

x_test: uint8 NumPy 数组,包含灰度图像数据,形状为 (10000, 28, 28),包含测试数据。像素值范围为 0 到 255。

y_test: uint8 NumPy 数组,包含数字标签(范围 0-9 的整数),形状为 (10000,),用于测试数据。

示例

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
assert x_train.shape == (60000, 28, 28)
assert x_test.shape == (10000, 28, 28)
assert y_train.shape == (60000,)
assert y_test.shape == (10000,)

许可协议

Yann LeCun 和 Corinna Cortes 拥有 MNIST 数据集的版权,该数据集是原始 NIST 数据集的衍生作品。MNIST 数据集依照 Creative Commons Attribution-Share Alike 3.0 许可协议 提供。