Keras 3 API 文档 / 层 API / 核心层 / InputSpec 对象

InputSpec 对象

[源代码]

InputSpec

keras.InputSpec(
    dtype=None,
    shape=None,
    ndim=None,
    max_ndim=None,
    min_ndim=None,
    axes=None,
    allow_last_axis_squeeze=False,
    name=None,
    optional=False,
)

指定层每个输入的秩、数据类型和形状。

层可以(如果合适)公开一个 input_spec 属性:一个 InputSpec 实例,或一个嵌套的 InputSpec 实例结构(每个输入张量一个)。这些对象使层能够对 Layer.__call__ 的第一个参数的输入结构、输入秩、输入形状和输入数据类型运行输入兼容性检查。

形状中的 None 条目与任何维度兼容。

参数

  • dtype: 预期输入的数据类型。
  • shape: 形状元组,预期输入的形状(可能包含 None 用于动态轴)。包括批次大小。
  • ndim: 整数,预期输入的秩。
  • max_ndim: 整数,输入的最大秩。
  • min_ndim: 整数,输入的最小秩。
  • axes: 字典,将整数轴映射到特定维度值。
  • allow_last_axis_squeeze: 如果为 True,则允许输入秩为 N+1,只要输入的最后一个轴为 1,以及输入秩为 N-1,只要规范的最后一个轴为 1。
  • name: 当将数据作为字典传递时,对应于此输入的预期键。
  • optional: 布尔值,输入是否为可选。可选输入可以接受 None 值。

示例

class MyLayer(Layer):
    def __init__(self):
        super().__init__()
        # The layer will accept inputs with
        # shape (*, 28, 28) & (*, 28, 28, 1)
        # and raise an appropriate error message otherwise.
        self.input_spec = InputSpec(
            shape=(None, 28, 28, 1),
            allow_last_axis_squeeze=True)