Keras 2 API 文档 / 层 API / 重塑层 / 展平层

展平层

[源代码]

Flatten

tf_keras.layers.Flatten(data_format=None, **kwargs)

展平输入。不影响批次大小。

注意:如果输入形状为 (batch,) 且没有特征轴,则展平会添加一个额外的通道维度,输出形状为 (batch, 1)

参数

  • data_format: 字符串,取值 channels_last(默认)或 channels_first 之一。输入中维度的顺序。channels_last 对应形状为 (batch, ..., channels) 的输入,而 channels_first 对应形状为 (batch, channels, ...) 的输入。如果未指定,则使用 TF-Keras 配置文件 ~/.keras/keras.json(如果存在)中找到的 image_data_format 值,否则为 'channels_last'。默认为 'channels_last'。

示例

>>> model = tf.keras.Sequential()
>>> model.add(tf.keras.layers.Conv2D(64, 3, 3, input_shape=(3, 32, 32)))
>>> model.output_shape
(None, 1, 10, 64)
>>> model.add(Flatten())
>>> model.output_shape
(None, 640)