隆重推出 Keras 3.0

经过五个月广泛的公开 Beta 测试,我们激动地宣布 Keras 3.0 正式发布。Keras 3 是 Keras 的一次彻底重写,它使您能够在 JAX、TensorFlow、PyTorch 或 OpenVINO(仅用于推理)之上运行您的 Keras 工作流程,并解锁全新的大规模模型训练和部署能力。您可以选择最适合您的框架,并根据当前目标在框架之间切换。您还可以将 Keras 用作低级跨框架语言,以开发自定义组件,例如层、模型或度量,这些组件可以在 JAX、TensorFlow 或 PyTorch 的原生工作流程中使用——只需一个代码库。


欢迎来到多框架机器学习时代。

您已经熟悉使用 Keras 的好处——它通过对出色的 UX、API 设计和可调试性的专注,实现了高速开发。它也是一个经过实战检验的框架,已被超过 250 万开发者选用,并为世界上一些最复杂、规模最大的 ML 系统提供支持,例如 Waymo 自动驾驶车队和 YouTube 推荐引擎。但是使用新的多后端 Keras 3 还有哪些额外的好处呢?

  • 始终为您的模型获得最佳性能。在我们的基准测试中,我们发现 JAX 通常在 GPU、TPU 和 CPU 上提供最佳的训练和推理性能——但结果因模型而异,因为非 XLA 的 TensorFlow 有时在 GPU 上更快。能够动态选择能为您的模型带来最佳性能的后端,而无需更改您的代码,这意味着您能够以可实现的最高效率进行训练和服务。
  • 为您的模型解锁生态系统选择性。任何 Keras 3 模型都可以实例化为 PyTorch Module,可以导出为 TensorFlow SavedModel,或者可以实例化为无状态的 JAX 函数。这意味着您可以使用您的 Keras 3 模型与 PyTorch 生态系统包一起使用,使用全套的 TensorFlow 部署&生产工具(如 TF-Serving、TF.js 和 TFLite),以及使用 JAX 的大规模 TPU 训练基础设施。使用 Keras 3 API 编写一个 model.py 文件,即可访问 ML 世界所提供的一切。
  • 利用 JAX 进行大规模模型并行&数据并行。Keras 3 包含一个全新的分布式 API,即 keras.distribution 命名空间,目前已为 JAX 后端实现(即将支持 TensorFlow 和 PyTorch 后端)。它使得模型并行、数据并行以及两者的组合变得容易实现——无论模型规模和集群规模如何。因为它将模型定义、训练逻辑和分片配置完全分开,使得您的分布式工作流程易于开发和维护。请参阅我们的入门指南
  • 最大化您的开源模型发布的覆盖范围。想要发布预训练模型吗?想要尽可能多的人能够使用它吗?如果您使用纯粹的 TensorFlow 或 PyTorch 实现它,大约只有一半的社区能够使用它。如果您使用 Keras 3 实现它,无论他们选择哪个框架(即使他们自己不是 Keras 用户),任何人都可以立即使用它。以零额外的开发成本获得两倍的影响力。
  • 使用来自任何来源的数据管道。Keras 3 的 fit()/evaluate()/predict() 例程与 tf.data.Dataset 对象、PyTorch DataLoader 对象、NumPy 数组、Pandas 数据帧兼容——无论您使用哪个后端。您可以使用 PyTorch DataLoader 训练一个 Keras 3 + TensorFlow 模型,或者使用 tf.data.Dataset 训练一个 Keras 3 + PyTorch 模型。

完整的 Keras API,适用于 JAX、TensorFlow 和 PyTorch。

Keras 3 实现了完整的 Keras API,并使其适用于 TensorFlow、JAX 和 PyTorch——包括一百多个层、数十个度量、损失函数、优化器和回调,以及 Keras 的训练和评估循环,以及 Keras 的保存&序列化基础设施。您熟悉并喜爱的所有 API 都在这里。

任何仅使用内置层的 Keras 模型都可以立即与所有支持的后端一起工作。事实上,您现有的仅使用内置层的 tf.keras 模型可以立即开始在 JAX 和 PyTorch 中运行!没错,您的代码库刚刚获得了一整套全新的能力。


编写多框架层、模型、度量...

Keras 3 使您能够创建在任何框架中都能以相同方式工作的组件(例如任意自定义层或预训练模型)。特别是,Keras 3 提供了跨所有后端工作的 keras.ops 命名空间的访问。它包含

  • NumPy API 的完整实现。不是“类似 NumPy”的东西——而是字面意义上的 NumPy API,具有相同的函数和相同的参数。您可以使用 ops.matmulops.sumops.stackops.einsum 等等。
  • 一组 NumPy 中没有的神经网络特有函数,例如 ops.softmaxops.binary_crossentropyops.conv 等等。

只要您只使用来自 keras.ops 的操作,您的自定义层、自定义损失、自定义度量和自定义优化器就可以使用相同的代码在 JAX、PyTorch 和 TensorFlow 中工作。这意味着您只需维护一个组件实现(例如,一个 model.py 文件和一个检查点文件),就可以在所有框架中使用它,并且数值计算结果完全一致。


...与任何 JAX、TensorFlow 和 PyTorch 工作流程无缝协作。

Keras 3 不仅适用于以 Keras 为中心的工作流程,即您定义 Keras 模型、Keras 优化器、Keras 损失和度量,然后调用 fit()evaluate()predict()。它也旨在与低级后端原生工作流程无缝协作:您可以获取一个 Keras 模型(或任何其他组件,例如损失或度量),并开始在 JAX 训练循环、TensorFlow 训练循环或 PyTorch 训练循环中使用它,或者作为 JAX 或 PyTorch 模型的一部分使用,没有任何阻碍。Keras 3 在 JAX 和 PyTorch 中提供的低级实现灵活性,与之前的 tf.keras 在 TensorFlow 中提供的完全相同。

您可以

  • 编写低级 JAX 训练循环,使用 optax 优化器、jax.gradjax.jitjax.pmap 来训练 Keras 模型。
  • 编写低级 TensorFlow 训练循环,使用 tf.GradientTapetf.distribute 来训练 Keras 模型。
  • 编写低级 PyTorch 训练循环,使用 torch.optim 优化器、torch 损失函数以及 torch.nn.parallel.DistributedDataParallel 包装器来训练 Keras 模型。
  • 在 PyTorch Module 中使用 Keras 层(因为它们本身也是 Module 实例!)
  • 在 Keras 模型中使用任何 PyTorch Module,就像它是一个 Keras 层一样。
  • 等等。


用于大规模数据并行和模型并行的新分布式 API。

我们一直在处理的模型越来越大,因此我们希望为多设备模型分片问题提供一个 Keras 式的解决方案。我们设计的 API 将模型定义、训练逻辑和分片配置完全分开,这意味着您的模型可以像在单个设备上运行一样编写。然后您可以在训练模型时为任意模型添加任意分片配置。

数据并行(在多个设备上完全复制一个小型模型)只需两行代码即可处理

模型并行允许您沿多个命名维度指定模型变量和中间输出张量的分片布局。在典型情况下,您将可用设备组织成一个二维网格(称为设备网格),其中第一个维度用于数据并行,第二个维度用于模型并行。然后,您将模型配置为沿模型维度进行分片,并沿数据维度进行复制。

该 API 允许您通过正则表达式配置每个变量和每个输出张量的布局。这使得快速为整类变量指定相同的布局变得容易。

新的分布式 API 旨在支持多后端,但目前仅适用于 JAX 后端。TensorFlow 和 PyTorch 的支持即将推出。通过此指南开始使用吧!


预训练模型。

有各种各样的预训练模型,您可以立即开始与 Keras 3 一起使用。

所有 40 个 Keras Applications 模型(即 keras.applications 命名空间中的模型)在所有后端均可用。KerasCVKerasHub 中的大量预训练模型也适用于所有后端。这包括

  • BERT
  • OPT
  • Whisper
  • T5
  • StableDiffusion
  • YOLOv8
  • SegmentAnything
  • 等等。

支持所有后端的跨框架数据管道。

多框架 ML 也意味着多框架数据加载和预处理。Keras 3 模型可以使用各种数据管道进行训练——无论您使用的是 JAX、PyTorch 还是 TensorFlow 后端。就是这么方便。

  • tf.data.Dataset 管道:可扩展生产 ML 的参考。
  • torch.utils.data.DataLoader 对象。
  • NumPy 数组和 Pandas 数据帧。
  • Keras 自带的 keras.utils.PyDataset 对象。

渐进式地暴露复杂性。

渐进式地暴露复杂性是 Keras API 的核心设计原则。Keras 不强制您遵循单一的“正确”方法来构建和训练模型。相反,它支持各种不同的工作流程,从非常高级到非常低级,对应于不同的用户画像。

这意味着您可以从简单的工作流程开始——例如使用 SequentialFunctional 模型并通过 fit() 进行训练——当您需要更高的灵活性时,您可以轻松自定义不同的组件,同时重用您之前的大部分代码。随着您的需求变得越来越具体,您不会突然陷入复杂性悬崖,也不需要切换到另一套工具。

我们将这一原则带到了所有后端。例如,您可以在训练循环中自定义发生的事情,同时仍然利用 fit() 的强大功能,而无需从头编写自己的训练循环——只需重写 train_step 方法即可。

以下是它在 PyTorch 和 TensorFlow 中的工作方式

这是 JAX 版本的链接


层、模型、度量和优化器的新无状态 API。

您喜欢函数式编程吗?那您可有福了。

Keras 中所有有状态对象(即拥有在训练或评估期间更新的数值变量的对象)现在都有了无状态 API,使得可以在 JAX 函数中使用它们(JAX 函数要求是完全无状态的)

  • 所有层和模型都有一个 stateless_call() 方法,它镜像了 __call__() 方法。
  • 所有优化器都有一个 stateless_apply() 方法,它镜像了 apply() 方法。
  • 所有度量都有一个 stateless_update_state() 方法,它镜像了 update_state() 方法,以及一个 stateless_result() 方法,它镜像了 result() 方法。

这些方法完全没有副作用:它们将目标对象状态变量的当前值作为输入,并将更新后的值作为其输出的一部分返回,例如

outputs, updated_non_trainable_variables = layer.stateless_call(
    trainable_variables,
    non_trainable_variables,
    inputs,
)

您无需自己实现这些方法——只要您实现了有状态版本(例如 call()update_state()),这些方法就会自动可用。


使用 OpenVINO 后端运行推理。

从 3.8 版本开始,Keras 引入了 OpenVINO 后端,这是一个仅用于推理的后端,意味着它专门设计用于使用 predict() 方法运行模型预测。此后端允许直接在 Keras 工作流程中利用 OpenVINO 的性能优化,从而在支持 OpenVINO 的硬件上实现更快的推理。

要切换到 OpenVINO 后端,请将 KERAS_BACKEND 环境变量设置为 "openvino",或者在本地配置文件 ~/.keras/keras.json 中指定后端。以下是使用 OpenVINO 后端对模型(使用 PyTorch、JAX 或 TensorFlow 后端训练)进行推理的示例

import os
os.environ["KERAS_BACKEND"] = "openvino"
import keras

loaded_model = keras.saving.load_model(...)
predictions = loaded_model.predict(...)

请注意,OpenVINO 后端目前可能缺乏对某些操作的支持。随着操作覆盖范围的扩大,这将在后续的 Keras 版本中得到解决。


从 Keras 2 迁移到 Keras 3

Keras 3 与 Keras 2 高度向后兼容:它实现了 Keras 2 的全部公共 API 表面,只有少数例外,列表见此处。大多数用户无需进行任何代码更改即可在 Keras 3 上运行其 Keras 脚本。

较大的代码库可能需要一些代码更改,因为它们更有可能遇到上面列出的例外情况之一,并且更有可能使用了私有 API 或已弃用的 API(tf.compat.v1.keras 命名空间、experimental 命名空间、keras.src 私有命名空间)。为了帮助您迁移到 Keras 3,我们发布了一份完整的迁移指南,其中包含针对您可能遇到的所有问题的快速修复方法。

您还可以选择忽略 Keras 3 的更改,继续使用 Keras 2 和 TensorFlow——这对于那些没有积极开发但需要使用更新依赖项的项目来说是一个不错的选择。您有两种可能性

  1. 如果您之前将 keras 作为独立包访问,只需切换到使用 Python 包 tf_keras 即可,您可以通过 pip install tf_keras 安装它。代码和 API 完全没有变化——它只是包名不同的 Keras 2.15。我们将继续修复 tf_keras 中的错误并定期发布新版本。然而,由于该包现已处于维护模式,因此不会添加新功能或性能改进。
  2. 如果您通过 tf.keras 访问 keras,那么在 TensorFlow 2.16 之前没有立即变化。TensorFlow 2.16 及更高版本将默认使用 Keras 3。在 TensorFlow 2.16+ 中,要继续使用 Keras 2,您可以首先安装 tf_keras,然后导出环境变量 TF_USE_LEGACY_KERAS=1。这将指示 TensorFlow 2.16+ 将 tf.keras 解析到本地安装的 tf_keras 包。但请注意,这可能会影响到您的代码之外的部分:它将影响到您的 Python 进程中任何导入 tf.keras 的包。为确保您的更改仅影响您自己的代码,您应该使用 tf_keras 包。

尽情享受这个库吧!

我们很高兴您能尝试新的 Keras,并通过利用多框架 ML 来改进您的工作流程。请告诉我们您的使用体验:遇到的问题、困难点、功能请求或成功案例——我们渴望听到您的反馈!


常见问题解答

问:Keras 3 与旧版 Keras 2 兼容吗?

使用 tf.keras 开发的代码通常可以在 Keras 3 中直接运行(使用 TensorFlow 后端)。需要注意的少数不兼容之处已在此迁移指南中得到解决。

至于同时使用来自 tf.keras 和 Keras 3 的 API,这是不可行的——它们是不同的包,运行在完全独立的引擎上。

问:在旧版 Keras 2 中开发的预训练模型可以在 Keras 3 中工作吗?

通常情况下,是的。任何 tf.keras 模型都应该可以直接在 Keras 3(使用 TensorFlow 后端)中工作(请确保以 .keras v3 格式保存)。此外,如果模型仅使用内置的 Keras 层,那么它也可以直接在 Keras 3(使用 JAX 和 PyTorch 后端)中工作。

如果模型包含使用 TensorFlow API 编写的自定义层,通常很容易将代码转换为后端无关的代码。例如,我们仅用了几个小时就将 Keras Applications 中所有 40 个旧版 tf.keras 模型转换为后端无关的模型。

问:我可以在一个后端中保存 Keras 3 模型,然后在另一个后端中重新加载吗?

是的,可以。保存的 .keras 文件中没有任何后端特化内容。您保存的 Keras 模型是框架无关的,可以使用任何后端重新加载。

但是请注意,使用不同后端重新加载包含自定义组件的模型,需要您的自定义组件使用后端无关的 API 实现,例如 keras.ops

问:我可以在 tf.data 管道中使用 Keras 3 组件吗?

使用 TensorFlow 后端时,Keras 3 与 tf.data 完全兼容(例如,您可以将 Sequential 模型 .map()tf.data 管道中)。

使用不同后端时,Keras 3 对 tf.data 的支持有限。您将无法将任意层或模型 .map()tf.data 管道中。但是,您可以使用特定的 Keras 3 预处理层与 tf.data 一起使用,例如 IntegerLookupCategoryEncoding

至于使用不包含 Keras 的 tf.data 管道来为您的 .fit().evaluate().predict() 调用提供数据——这在所有后端中都直接可用。

问:Keras 3 模型在使用不同后端运行时表现一致吗?

是的,数值计算结果在不同后端之间是相同的。但是,请注意以下几点

  • 不同后端之间的 RNG(随机数生成器)行为不同(即使设置了种子——每个后端的结果是确定性的,但不同后端之间的结果会有所不同)。因此,随机权重初始化值和 Dropout 值在不同后端之间也会有所不同。
  • 由于浮点实现本身的特性,每次函数执行的结果在 float32 精度下仅能保证达到 1e-7 的一致性。因此,当长时间训练模型时,微小的数值差异会累积,最终可能导致明显的数值差异。
  • 由于 PyTorch 缺乏对非对称填充的平均池化的支持,使用 padding="same" 的平均池化层可能在边界行/列上产生不同的数值结果。在实践中这种情况并不常见——在 40 个 Keras Applications 视觉模型中,只有一个受到了影响。

问:Keras 3 支持分布式训练吗?

JAX、TensorFlow 和 PyTorch 开箱即用地支持数据并行分布式训练。JAX 使用 keras.distribution API 开箱即用地支持模型并行分布式训练。

使用 TensorFlow

Keras 3 与 tf.distribute 兼容——只需打开一个分布式策略作用域,并在其中创建/训练您的模型。这里有一个示例

使用 PyTorch

Keras 3 与 PyTorch 的 DistributedDataParallel 工具兼容。这里有一个示例

使用 JAX

您可以在 JAX 中使用 keras.distribution API 进行数据并行和模型并行分布式训练。例如,要进行数据并行分布式训练,您只需要以下代码片段

distribution = keras.distribution.DataParallel(devices=keras.distribution.list_devices())
keras.distribution.set_distribution(distribution)

有关模型并行分布式训练,请参阅以下指南

您也可以通过 jax.sharding 等 JAX API 自己实现分布式训练。这里有一个示例

问:我的自定义 Keras 层可以在原生的 PyTorch Modules 或 Flax Modules 中使用吗?

如果它们仅使用 Keras API 编写(例如使用 keras.ops 命名空间),那么是的,您的 Keras 层可以直接与原生 PyTorch 和 JAX 代码一起工作。在 PyTorch 中,只需像使用任何其他 PyTorch Module 一样使用您的 Keras 层即可。在 JAX 中,请确保使用无状态层 API,即 layer.stateless_call()

问:将来会添加更多后端吗?那 XYZ 框架呢?

只要目标框架拥有庞大的用户群或具有独特的技术优势,我们都乐于添加新的后端。然而,添加和维护一个新后端是一项巨大的负担,因此我们将逐个案例地仔细考虑每个新的后端候选者,并且不太可能添加很多新后端。我们不会添加任何尚未成熟的新框架。我们目前正在考虑添加一个用 Mojo 编写的后端。如果您认为这很有用,请告知 Mojo 团队。