ExportArchive
类tf_keras.export.ExportArchive()
ExportArchive 用于编写 SavedModel 工件(例如用于推理)。
如果您有一个 TF-Keras 模型或层,想要导出为 SavedModel 以供服务(例如通过 TensorFlow-Serving),您可以使用 ExportArchive
来配置您需要提供的不同服务端点及其签名。只需实例化一个 ExportArchive
,使用 track()
注册要使用的层或模型,然后使用 add_endpoint()
方法注册新的服务端点。完成后,使用 write_out()
方法保存工件。
生成的工件是一个 SavedModel,可以通过 tf.saved_model.load
重新加载。
示例
以下是如何导出模型用于推理。
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.write_out("path/to/location")
# Elsewhere, we can reload the artifact and serve it.
# The endpoint we added is available as a method:
serving_model = tf.saved_model.load("path/to/location")
outputs = serving_model.serve(inputs)
以下是如何导出模型,其中一个端点用于推理,另一个端点用于训练模式下的前向传播(例如启用 dropout)。
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="call_inference",
fn=lambda x: model.call(x, training=False),
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.add_endpoint(
name="call_training",
fn=lambda x: model.call(x, training=True),
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.write_out("path/to/location")
关于资源追踪的注意事项
ExportArchive
能够自动追踪其端点使用的所有 tf.Variables
,因此大多数时候并不严格要求调用 .track(model)
。但是,如果您的模型使用了查找层,例如 IntegerLookup
、StringLookup
或 TextVectorization
,则需要通过 .track(model)
进行显式追踪。
如果您需要在恢复的归档文件上访问属性 variables
、trainable_variables
或 non_trainable_variables
,也需要进行显式追踪。