Keras 2 API 文档 / 层 API / 核心层 / Input 对象

输入对象

[源代码]

Input 函数

tf_keras.Input(
    shape=None,
    batch_size=None,
    name=None,
    dtype=None,
    sparse=None,
    tensor=None,
    ragged=None,
    type_spec=None,
    **kwargs
)

Input() 用于实例化一个 TF-Keras 张量。

TF-Keras 张量是一个符号化的、类似张量的对象,我们为其添加了特定属性,从而仅通过了解模型的输入和输出来构建 TF-Keras 模型。

例如,如果 abc 是 TF-Keras 张量,则可以执行:model = Model(input=[a, b], output=c)

参数

  • shape: 一个形状元组(整数),不包括批次大小。例如,shape=(32,) 表示预期的输入是 32 维向量的批次。此元组的元素可以为 None;'None' 元素表示形状未知的维度。
  • batch_size: 可选的静态批次大小(整数)。
  • name: 该层可选的名称字符串。在一个模型中应唯一(不要重复使用相同的名称)。如果未提供,则会自动生成。
  • dtype: 输入期望的数据类型,以字符串形式(float32, float64, int32...)
  • sparse: 一个布尔值,指定要创建的占位符是否为稀疏的。'ragged' 和 'sparse' 只能有一个为 True。请注意,如果 sparse 为 False,仍然可以将稀疏张量传递到输入中,它们将被密集化,默认值为 0。
  • tensor: 可选的现有张量,用于包装到 Input 层中。如果设置,该层将使用此张量的 tf.TypeSpec,而不是创建一个新的占位符张量。
  • ragged: 一个布尔值,指定要创建的占位符是否为不规则的。'ragged' 和 'sparse' 只能有一个为 True。在这种情况下,'shape' 参数中的 'None' 值表示不规则维度。有关 RaggedTensors 的更多信息,请参阅 此指南
  • type_spec: 一个 tf.TypeSpec 对象,用于从中创建输入占位符。提供此参数时,除 name 外所有其他参数必须为 None。
  • **kwargs: 已弃用的参数支持。支持 batch_shapebatch_input_shape

返回

一个 tensor

示例

# this is a logistic regression in Keras
x = Input(shape=(32,))
y = Dense(16, activation='softmax')(x)
model = Model(x, y)

请注意,即使启用了 eager 执行,Input 也会生成一个符号化的、类似张量的对象(即一个占位符)。此符号化的、类似张量的对象可与采用张量作为输入的低级 TensorFlow 运算一起使用,例如:

x = Input(shape=(32,))
y = tf.square(x)  # This op will be treated like a layer
model = Model(x, y)

(此行为不适用于高级 TensorFlow API,例如控制流,并且不能直接由 tf.GradientTape 监视)。

但是,生成的模型不会跟踪用作 TensorFlow 运算输入的任何变量。所有变量的使用都必须发生在 TF-Keras 层内部,以确保它们会被模型的权重所跟踪。

TF-Keras Input 还可以从任意 tf.TypeSpec 创建占位符,例如:

x = Input(type_spec=tf.RaggedTensorSpec(shape=[None, None],
                                        dtype=tf.float32, ragged_rank=1))
y = x.values
model = Model(x, y)

传递任意 tf.TypeSpec 时,它必须表示整个批次的签名,而不是单个示例。

引发

  • ValueError: 如果同时提供了 sparseragged
  • ValueError: 如果同时提供了 shape 和(batch_input_shapebatch_shape)。
  • ValueError: 如果 shapetensortype_spec 均为 None。
  • ValueError: 如果在传递 type_spec 时,除 type_spec 以外的其他参数不为 None。
  • ValueError: 如果提供了任何无法识别的参数。