Keras 2 API 文档 / 模型 API / 保存与序列化 / 模型导出用于推理

模型导出用于推理

[源代码]

ExportArchive

tf_keras.export.ExportArchive()

ExportArchive 用于编写 SavedModel 工件(例如用于推理)。

如果您有一个 TF-Keras 模型或层,想要导出为 SavedModel 以供服务(例如通过 TensorFlow-Serving),您可以使用 ExportArchive 来配置您需要提供的不同服务端点及其签名。只需实例化一个 ExportArchive,使用 track() 注册要使用的层或模型,然后使用 add_endpoint() 方法注册新的服务端点。完成后,使用 write_out() 方法保存工件。

生成的工件是一个 SavedModel,可以通过 tf.saved_model.load 重新加载。

示例

以下是如何导出模型用于推理。

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.write_out("path/to/location")

# Elsewhere, we can reload the artifact and serve it.
# The endpoint we added is available as a method:
serving_model = tf.saved_model.load("path/to/location")
outputs = serving_model.serve(inputs)

以下是如何导出模型,其中一个端点用于推理,另一个端点用于训练模式下的前向传播(例如启用 dropout)。

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="call_inference",
    fn=lambda x: model.call(x, training=False),
    input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.add_endpoint(
    name="call_training",
    fn=lambda x: model.call(x, training=True),
    input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.write_out("path/to/location")

关于资源追踪的注意事项

ExportArchive 能够自动追踪其端点使用的所有 tf.Variables,因此大多数时候并不严格要求调用 .track(model)。但是,如果您的模型使用了查找层,例如 IntegerLookupStringLookupTextVectorization,则需要通过 .track(model) 进行显式追踪。

如果您需要在恢复的归档文件上访问属性 variablestrainable_variablesnon_trainable_variables,也需要进行显式追踪。