TFSMLayer
类keras.layers.TFSMLayer(
filepath,
call_endpoint="serve",
call_training_endpoint=None,
trainable=True,
name=None,
dtype=None,
)
重新加载通过 SavedModel / ExportArchive 保存的 Keras 模型/层。
参数
str
或 pathlib.Path
对象。 SavedModel 的路径。call()
方法的端点名称。如果 SavedModel 是通过 model.export()
创建的,则默认端点名称为 'serve'
。在其他情况下,它可能被命名为 'serving_default'
。示例
model.export("path/to/artifact")
reloaded_layer = TFSMLayer("path/to/artifact")
outputs = reloaded_layer(inputs)
重新加载的对象可以像常规 Keras 层一样使用,并支持对其可训练权重的训练/微调。请注意,重新加载的对象不保留原始对象的任何内部结构或自定义方法 – 它是围绕保存的函数创建的全新层。
局限性
inputs
张量参数(可以是张量的 dict/tuple/list)的调用端点。对于具有多个独立输入张量参数的端点,请考虑子类化 TFSMLayer
并实现具有自定义签名的 call()
方法。__call__()
中支持 training=True
参数),请确保训练时调用函数作为工件中的独立端点保存,并通过 call_training_endpoint
参数将其名称提供给 TFSMLayer
。