TFSMLayer
类keras.layers.TFSMLayer(
filepath,
call_endpoint="serve",
call_training_endpoint=None,
trainable=True,
name=None,
dtype=None,
)
重新加载通过 SavedModel / ExportArchive 保存的 Keras 模型/层。
参数
str
或 pathlib.Path
对象。SavedModel 的路径。call()
方法的端点名称。如果 SavedModel 是通过 model.export()
创建的,则默认端点名称为 'serve'
。在其他情况下,它可能被命名为 'serving_default'
。示例
model.export("path/to/artifact")
reloaded_layer = TFSMLayer("path/to/artifact")
outputs = reloaded_layer(inputs)
重新加载的对象可以像常规 Keras 层一样使用,并支持对其可训练权重的训练/微调。请注意,重新加载的对象不会保留原始对象的任何内部结构或自定义方法 - 它是在保存的函数周围创建的一个全新的层。
限制
inputs
张量参数(可以选择是张量的字典/元组/列表)的调用端点。对于具有多个单独输入张量参数的端点,请考虑子类化 TFSMLayer
并使用自定义签名实现 call()
方法。__call__()
中支持 training=True
参数),请确保将训练时调用函数另存为工件中的独立端点,并通过 call_training_endpoint
参数将其名称提供给 TFSMLayer
。