Keras 2 API 文档 / 层 API / 核心层 / Input 对象

输入对象

[来源]

Input 函数

tf_keras.Input(
    shape=None,
    batch_size=None,
    name=None,
    dtype=None,
    sparse=None,
    tensor=None,
    ragged=None,
    type_spec=None,
    **kwargs
)

Input() 用于实例化 TF-Keras 张量。

TF-Keras 张量是一种符号式的张量状对象,我们为其增加了某些属性,这些属性允许我们仅通过了解模型的输入和输出来构建 TF-Keras 模型。

例如,如果 abc 是 TF-Keras 张量,则可以执行以下操作:model = Model(input=[a, b], output=c)

参数

  • shape: 一个形状元组(整数),不包括批次大小。例如,shape=(32,) 表示预期的输入将是大小为 32 的向量的批次。此元组的元素可以是 None;'None' 元素表示形状未知的维度。
  • batch_size: 可选的静态批次大小(整数)。
  • name: 层的可选名称字符串。在模型中应唯一(不要重复使用相同的名称)。如果未提供,将自动生成。
  • dtype: 输入期望的数据类型,作为字符串(float32, float64, int32...)。
  • sparse: 一个布尔值,指定要创建的占位符是否是稀疏的。'ragged' 和 'sparse' 中只能有一个为 True。请注意,如果 sparse 为 False,稀疏张量仍然可以传递给输入 - 它们将使用默认值 0 进行密集化。
  • tensor: 可选的现有张量,用于包装到 Input 层中。如果设置,该层将使用此张量的 tf.TypeSpec,而不是创建新的占位符张量。
  • ragged: 一个布尔值,指定要创建的占位符是否是 Ragged 张量(不规则张量)。'ragged' 和 'sparse' 中只能有一个为 True。在这种情况下,'shape' 参数中的 'None' 值表示不规则维度。有关 RaggedTensors 的更多信息,请参见此指南
  • type_spec: 一个 tf.TypeSpec 对象,用于从中创建输入占位符。提供此参数时,除 name 之外的所有其他参数都必须为 None。
  • **kwargs: 已弃用的参数支持。支持 batch_shapebatch_input_shape

返回值

一个 tensor

示例

# this is a logistic regression in Keras
x = Input(shape=(32,))
y = Dense(16, activation='softmax')(x)
model = Model(x, y)

注意,即使启用了急切执行 (eager execution),Input 也会生成一个符号张量状对象(即占位符)。此符号张量状对象可与接受张量作为输入的较低级别 TensorFlow 操作一起使用,例如:

x = Input(shape=(32,))
y = tf.square(x)  # This op will be treated like a layer
model = Model(x, y)

(此行为不适用于更高级别的 TensorFlow API,例如控制流以及由 tf.GradientTape 直接监视的情况)。

但是,生成的模型不会跟踪任何用作 TensorFlow 操作输入的变量。所有变量的使用都必须发生在 TF-Keras 层内部,以确保它们将被模型的权重跟踪。

TF-Keras Input 还可以从任意 tf.TypeSpec 创建占位符,例如:

x = Input(type_spec=tf.RaggedTensorSpec(shape=[None, None],
                                        dtype=tf.float32, ragged_rank=1))
y = x.values
model = Model(x, y)

传递任意 tf.TypeSpec 时,它必须表示整个批次的签名,而不仅仅是一个示例的签名。

抛出

  • ValueError: 如果同时提供了 sparseragged
  • ValueError: 如果同时提供了 shape 和 (batch_input_shapebatch_shape)。
  • ValueError: 如果 shapetensortype_spec 均为 None。
  • ValueError: 如果传递了 type_spec,而除此以外的参数不为 None。
  • ValueError: 如果提供了任何无法识别的参数。