Keras 3 API 文档 / 层 API / 核心层 / InputSpec 对象

InputSpec 对象

[源代码]

InputSpec

keras.InputSpec(
    dtype=None,
    shape=None,
    ndim=None,
    max_ndim=None,
    min_ndim=None,
    axes=None,
    allow_last_axis_squeeze=False,
    name=None,
    optional=False,
)

指定层接收的每个输入的秩 (rank)、数据类型 (dtype) 和形状 (shape)。

层可以公开(如果适用)一个 input_spec 属性:它可以是 InputSpec 的一个实例,也可以是一个由 InputSpec 实例组成的嵌套结构(每个输入张量对应一个)。这些对象使层能够对 Layer.__call__ 的第一个参数进行输入兼容性检查,包括输入结构、输入秩、输入形状和输入数据类型。

形状中的 None 表示与任何维度兼容。

参数

  • dtype: 输入的预期数据类型。
  • shape: 形状元组,输入的预期形状(动态轴可以使用 None 表示)。包括批次大小。
  • ndim: 整数,输入的预期秩。
  • max_ndim: 整数,输入的最高秩。
  • min_ndim: 整数,输入的最低秩。
  • axes: 字典,将整数轴映射到特定的维度值。
  • allow_last_axis_squeeze: 如果为 True,则允许秩为 N+1 的输入(只要输入的最后一个轴为 1),也允许秩为 N-1 的输入(只要 spec 的最后一个轴为 1)。
  • name: 当以字典形式传递数据时,与此输入对应的预期键。
  • optional: 布尔值,表示输入是否是可选的。可选输入可以接受 None 值。

示例

class MyLayer(Layer):
    def __init__(self):
        super().__init__()
        # The layer will accept inputs with
        # shape (*, 28, 28) & (*, 28, 28, 1)
        # and raise an appropriate error message otherwise.
        self.input_spec = InputSpec(
            shape=(None, 28, 28, 1),
            allow_last_axis_squeeze=True)