Keras 3 API 文档 / 层 API / 核心层 / InputSpec 对象

InputSpec 对象

[源代码]

InputSpec

keras.InputSpec(
    dtype=None,
    shape=None,
    ndim=None,
    max_ndim=None,
    min_ndim=None,
    axes=None,
    allow_last_axis_squeeze=False,
    name=None,
    optional=False,
)

指定层的每个输入的秩、dtype 和形状。

层可以公开(如果适用)一个 input_spec 属性:InputSpec 的一个实例,或 InputSpec 实例的嵌套结构(每个输入张量一个)。这些对象使层能够对 Layer.__call__ 的第一个参数运行输入兼容性检查,包括输入结构、输入秩、输入形状和输入 dtype。

形状中的 None 条目与任何维度兼容。

参数

  • dtype:输入的预期 dtype。
  • shape:形状元组,输入的预期形状(可以包含 None 以表示动态轴)。包含批次大小。
  • ndim:整数,输入的预期秩。
  • max_ndim:整数,输入的最大秩。
  • min_ndim:整数,输入的最小秩。
  • axes:字典,将整数轴映射到特定的维度值。
  • allow_last_axis_squeeze:如果为 True,则允许秩为 N+1 的输入,只要输入的最后一个轴为 1;也允许秩为 N-1 的输入,只要规范的最后一个轴为 1。
  • name:当以字典形式传递数据时,与此输入对应的预期键。
  • optional:布尔值,指示输入是否为可选。可选输入可以接受 None 值。

示例

class MyLayer(Layer):
    def __init__(self):
        super().__init__()
        # The layer will accept inputs with
        # shape (*, 28, 28) & (*, 28, 28, 1)
        # and raise an appropriate error message otherwise.
        self.input_spec = InputSpec(
            shape=(None, 28, 28, 1),
            allow_last_axis_squeeze=True)