Keras 3 API 文档 / 层 API / 后端特定层 / TensorFlow SavedModel 层

TensorFlow SavedModel 层

[来源]

TFSMLayer

keras.layers.TFSMLayer(
    filepath,
    call_endpoint="serve",
    call_training_endpoint=None,
    trainable=True,
    name=None,
    dtype=None,
)

重新加载通过 SavedModel / ExportArchive 保存的 Keras 模型/层。

参数

  • filepath: strpathlib.Path 对象。SavedModel 的路径。
  • call_endpoint: 用作重新加载层的 call() 方法的端点名称。如果 SavedModel 是通过 model.export() 创建的,则默认端点名称为 'serve'。在其他情况下,它可能被命名为 'serving_default'

示例

model.export("path/to/artifact")
reloaded_layer = TFSMLayer("path/to/artifact")
outputs = reloaded_layer(inputs)

重新加载的对象可以像常规 Keras 层一样使用,并支持对其可训练权重的训练/微调。请注意,重新加载的对象不会保留原始对象的任何内部结构或自定义方法 - 它是在保存的函数周围创建的一个全新的层。

限制

  • 仅支持具有单个 inputs 张量参数(可以选择是张量的字典/元组/列表)的调用端点。对于具有多个单独输入张量参数的端点,请考虑子类化 TFSMLayer 并使用自定义签名实现 call() 方法。
  • 如果您需要训练时行为与推理时行为不同(即,如果您需要重新加载的对象在 __call__() 中支持 training=True 参数),请确保将训练时调用函数另存为工件中的独立端点,并通过 call_training_endpoint 参数将其名称提供给 TFSMLayer