经过五个月广泛的公开测试,我们很高兴地宣布 Keras 3.0 正式发布。Keras 3 是 Keras 的完全重写版本,它使您能够在 JAX、TensorFlow、PyTorch 或 OpenVINO(仅用于推理)之上运行您的 Keras 工作流程,并解锁全新的大规模模型训练和部署能力。您可以选择最适合您的框架,并根据当前目标在它们之间切换。您还可以将 Keras 用作低级跨框架语言,以开发自定义组件(如层、模型或指标),这些组件可以在 JAX、TensorFlow 或 PyTorch 的本机工作流程中使用 —— 使用一个代码库。


欢迎来到多框架机器学习。

您已经熟悉使用 Keras 的好处 —— 它通过专注于出色的用户体验、API 设计和可调试性来实现高速开发。它也是一个经过实战检验的框架,已被超过 250 万开发者选择,并为世界上一些最复杂、最大规模的机器学习系统提供支持,例如 Waymo 无人驾驶车队和 YouTube 推荐引擎。但是,使用新的多后端 Keras 3 有哪些额外的好处呢?

  • 始终获得模型最佳性能。在我们的基准测试中,我们发现 JAX 通常在 GPU、TPU 和 CPU 上提供最佳的训练和推理性能 —— 但结果因模型而异,因为非 XLA TensorFlow 有时在 GPU 上速度更快。能够动态选择最适合您模型性能的后端,而无需更改任何代码,这意味着您将保证以最高效率进行训练和服务。
  • 解锁模型生态系统可选性。任何 Keras 3 模型都可以实例化为 PyTorch Module,可以导出为 TensorFlow SavedModel,或者可以实例化为无状态 JAX 函数。这意味着您可以使用 Keras 3 模型与 PyTorch 生态系统软件包一起使用,可以使用全套 TensorFlow 部署和生产工具(如 TF-Serving、TF.js 和 TFLite),以及使用 JAX 大规模 TPU 训练基础设施。使用 Keras 3 API 编写一个 model.py,即可访问机器学习世界所能提供的一切。
  • 利用 JAX 进行大规模模型并行和数据并行。Keras 3 包含一个全新的分布 API,即 keras.distribution 命名空间,目前已为 JAX 后端实现(即将支持 TensorFlow 和 PyTorch 后端)。它可以轻松实现模型并行、数据并行以及两者的组合 —— 适用于任意模型规模和集群规模。因为它将模型定义、训练逻辑和分片配置完全分离,因此使您的分布工作流程易于开发和易于维护。请参阅我们的入门指南
  • 最大化开源模型发布的覆盖范围。想要发布预训练模型?希望尽可能多的人能够使用它?如果您使用纯 TensorFlow 或 PyTorch 实现它,则大约一半的社区可以使用它。如果您使用 Keras 3 实现它,则任何人都可以在他们选择的框架中使用它(即使他们本身不是 Keras 用户)。无需增加开发成本即可获得双倍的影响。
  • 使用来自任何来源的数据管道。Keras 3 的 fit()/evaluate()/predict() 例程与 tf.data.Dataset 对象、PyTorch DataLoader 对象、NumPy 数组、Pandas 数据帧兼容 —— 无论您使用哪个后端。您可以在 PyTorch DataLoader 上训练 Keras 3 + TensorFlow 模型,或者在 tf.data.Dataset 上训练 Keras 3 + PyTorch 模型。

完整的 Keras API,适用于 JAX、TensorFlow 和 PyTorch。

Keras 3 实现了完整的 Keras API,并将其提供给 TensorFlow、JAX 和 PyTorch —— 超过一百个层、数十个指标、损失函数、优化器和回调、Keras 训练和评估循环以及 Keras 保存和序列化基础设施。您熟悉并喜爱的所有 API 都可以在这里找到。

任何仅使用内置层的 Keras 模型都将立即与所有支持的后端一起使用。事实上,您现有的仅使用内置层的 tf.keras 模型可以立即开始在 JAX 和 PyTorch 中运行!没错,您的代码库刚刚获得了一整套新的功能。


创作多框架层、模型、指标...

Keras 3 使您能够创建组件(如任意自定义层或预训练模型),这些组件在任何框架中都将以相同的方式工作。特别是,Keras 3 使您可以访问跨所有后端的 keras.ops 命名空间。它包含

  • NumPy API 的完整实现。不是“类似 NumPy”的东西 —— 而是真正的 NumPy API,具有相同的函数和相同的参数。您将获得 ops.matmulops.sumops.stackops.einsum 等。
  • 一组 NumPy 中没有的神经网络特定函数,例如 ops.softmaxops.binary_crossentropyops.conv 等。

只要您仅使用 keras.ops 中的操作,您的自定义层、自定义损失、自定义指标和自定义优化器将与 JAX、PyTorch 和 TensorFlow 一起使用 —— 使用相同的代码。这意味着您只需维护一个组件实现(例如,一个 model.py 和一个检查点文件),您可以在所有框架中使用它,并且具有完全相同的数值。


...与任何 JAX、TensorFlow 和 PyTorch 工作流程无缝协作。

Keras 3 不仅适用于以 Keras 为中心的工作流程,在这些工作流程中,您定义一个 Keras 模型、一个 Keras 优化器、一个 Keras 损失和指标,然后调用 fit()evaluate()predict()。它还旨在与低级后端本机工作流程无缝协作:您可以获取一个 Keras 模型(或任何其他组件,如损失或指标),并开始在 JAX 训练循环、TensorFlow 训练循环或 PyTorch 训练循环中使用它,或将其作为 JAX 或 PyTorch 模型的一部分使用,而不会产生任何摩擦。Keras 3 在 JAX 和 PyTorch 中提供与 tf.keras 之前在 TensorFlow 中提供的完全相同程度的低级实现灵活性。

您可以

  • 编写低级 JAX 训练循环,以使用 optax 优化器、jax.gradjax.jitjax.pmap 训练 Keras 模型。
  • 编写低级 TensorFlow 训练循环,以使用 tf.GradientTapetf.distribute 训练 Keras 模型。
  • 编写低级 PyTorch 训练循环,以使用 torch.optim 优化器、torch 损失函数和 torch.nn.parallel.DistributedDataParallel 包装器训练 Keras 模型。
  • 在 PyTorch Module 中使用 Keras 层(因为它们也是 Module 实例!)
  • 在 Keras 模型中使用任何 PyTorch Module,就像它是一个 Keras 层一样。
  • 等等。


用于大规模数据并行和模型并行的全新分布 API。

我们一直使用的模型越来越大,因此我们希望为多设备模型分片问题提供 Keras 解决方案。我们设计的 API 将模型定义、训练逻辑和分片配置完全分离,这意味着您的模型可以像在单个设备上运行一样编写。然后在需要训练它们时,您可以将任意分片配置添加到任意模型中。

数据并行(在多个设备上完全复制小型模型)只需两行即可处理

模型并行允许您指定模型变量和中间输出张量的分片布局,沿着多个命名维度。在典型情况下,您会将可用设备组织为 2D 网格(称为设备网格),其中第一个维度用于数据并行,第二个维度用于模型并行。然后,您将配置您的模型沿模型维度进行分片,并沿数据维度进行复制。

该 API 允许您通过正则表达式配置每个变量和每个输出张量的布局。这使得很容易快速地为整个类别的变量指定相同的布局。

新的分布 API 旨在支持多后端,但目前仅适用于 JAX 后端。TensorFlow 和 PyTorch 支持即将推出。请通过本指南开始使用!


预训练模型。

您可以使用 Keras 3 开始使用各种预训练模型。

所有 40 个 Keras Applications 模型(keras.applications 命名空间)都可以在所有后端中使用。 KerasCVKerasHub 中的大量预训练模型也可以在所有后端中使用。这包括

  • BERT
  • OPT
  • Whisper
  • T5
  • StableDiffusion
  • YOLOv8
  • SegmentAnything
  • 等等。

支持所有后端的跨框架数据管道。

多框架 ML 也意味着多框架数据加载和预处理。Keras 3 模型可以使用各种数据管道进行训练 —— 无论您使用的是 JAX、PyTorch 还是 TensorFlow 后端。它就是可以工作。

  • tf.data.Dataset 管道:可扩展生产 ML 的参考。
  • torch.utils.data.DataLoader 对象。
  • NumPy 数组和 Pandas 数据帧。
  • Keras 自己的 keras.utils.PyDataset 对象。

逐步揭示复杂性。

逐步揭示复杂性是 Keras API 的核心设计原则。Keras 不会强迫您遵循单一的“真实”模型构建和训练方式。相反,它支持各种不同的工作流程,从非常高级到非常低级,对应于不同的用户配置文件。

这意味着您可以从简单的工作流程开始 —— 例如使用 SequentialFunctional 模型,并使用 fit() 训练它们 —— 当您需要更多灵活性时,您可以轻松自定义不同的组件,同时重复使用您的大部分先前代码。随着您的需求变得更加具体,您不会突然跌入复杂性的深渊,也不需要切换到不同的工具集。

我们将这一原则应用到了我们所有的后端。例如,您可以自定义训练循环中的行为,同时仍然利用 fit() 的强大功能,而无需从头开始编写自己的训练循环——只需重写 train_step 方法即可。

以下是它在 PyTorch 和 TensorFlow 中的工作方式

这是 JAX 版本的链接


层、模型、指标和优化器的新无状态 API。

您喜欢函数式编程吗?您一定会感到惊喜。

Keras 中所有有状态的对象(即拥有在训练或评估期间更新的数值变量的对象)现在都有一个无状态 API,这使得它们可以在 JAX 函数中使用(JAX 函数需要完全无状态)。

  • 所有层和模型都有一个 stateless_call() 方法,该方法与 __call__() 方法相对应。
  • 所有优化器都有一个 stateless_apply() 方法,该方法与 apply() 方法相对应。
  • 所有指标都有一个 stateless_update_state() 方法,该方法与 update_state() 方法相对应,以及一个 stateless_result() 方法,该方法与 result() 方法相对应。

这些方法没有任何副作用:它们将目标对象的状态变量的当前值作为输入,并将更新值作为其输出的一部分返回,例如:

outputs, updated_non_trainable_variables = layer.stateless_call(
    trainable_variables,
    non_trainable_variables,
    inputs,
)

您永远不必自己实现这些方法——只要您已经实现了有状态的版本(例如 call()update_state()),它们就会自动可用。


使用 OpenVINO 后端运行推理。

从 3.8 版本开始,Keras 引入了 OpenVINO 后端,这是一个仅用于推理的后端,这意味着它仅设计用于使用 predict() 方法运行模型预测。此后端可以直接在 Keras 工作流程中利用 OpenVINO 性能优化,从而在 OpenVINO 支持的硬件上实现更快的推理。

要切换到 OpenVINO 后端,请将 KERAS_BACKEND 环境变量设置为 "openvino" 或在 ~/.keras/keras.json 的本地配置文件中指定后端。以下是如何使用 OpenVINO 后端推断模型(使用 PyTorch、JAX 或 TensorFlow 后端训练)的示例

import os
os.environ["KERAS_BACKEND"] = "openvino"
import keras

loaded_model = keras.saving.load_model(...)
predictions = loaded_model.predict(...)

请注意,OpenVINO 后端目前可能缺少对某些操作的支持。随着操作覆盖范围的扩大,这将在即将发布的 Keras 版本中得到解决。


从 Keras 2 迁移到 Keras 3

Keras 3 与 Keras 2 高度向后兼容:它实现了 Keras 2 的完整公共 API 表面,只有少数例外,此处列出了这些例外。大多数用户无需进行任何代码更改即可开始在 Keras 3 上运行他们的 Keras 脚本。

较大的代码库可能需要进行一些代码更改,因为它们更有可能遇到上面列出的例外情况之一,并且更有可能使用私有 API 或已弃用的 API(tf.compat.v1.keras 命名空间、experimental 命名空间、keras.src 私有命名空间)。为了帮助您迁移到 Keras 3,我们发布了一个完整的迁移指南,其中提供了您可能遇到的所有问题的快速修复方法。

您还可以选择忽略 Keras 3 中的更改,而只是继续将 Keras 2 与 TensorFlow 一起使用——对于那些不积极开发但需要保持使用更新依赖项运行的项目来说,这可能是一个不错的选择。您有两种可能性

  1. 如果您将 keras 作为独立软件包访问,只需切换为使用 Python 软件包 tf_keras,您可以通过 pip install tf_keras 安装它。代码和 API 完全不变——它是带有不同软件包名称的 Keras 2.15。我们将继续修复 tf_keras 中的错误,并将定期发布新版本。但是,由于该软件包现在处于维护模式,因此不会添加任何新功能或性能改进。
  2. 如果您通过 tf.keras 访问 keras,则在 TensorFlow 2.16 之前不会有任何直接更改。TensorFlow 2.16+ 将默认使用 Keras 3。在 TensorFlow 2.16+ 中,要继续使用 Keras 2,您可以先安装 tf_keras,然后导出环境变量 TF_USE_LEGACY_KERAS=1。这将指示 TensorFlow 2.16+ 将 tf.keras 解析为本地安装的 tf_keras 软件包。请注意,这可能会影响您自己的代码之外的其他内容:它会影响 Python 进程中导入 tf.keras 的任何软件包。为了确保您的更改仅影响您自己的代码,您应该使用 tf_keras 软件包。

享受这个库吧!

我们很高兴您能尝试新的 Keras 并通过利用多框架机器学习来改进您的工作流程。请告诉我们进展如何:问题、摩擦点、功能请求或成功案例——我们渴望收到您的反馈!


常见问题

问:Keras 3 与旧版 Keras 2 兼容吗?

使用 tf.keras 开发的代码通常可以直接在 Keras 3(使用 TensorFlow 后端)中运行。您应该注意一些不兼容的情况,所有这些都在本迁移指南中解决。

当涉及到并行使用来自 tf.keras 和 Keras 3 的 API 时,这是不可能的——它们是不同的软件包,在完全独立的引擎上运行。

问:在旧版 Keras 2 中开发的预训练模型可以在 Keras 3 中工作吗?

一般来说,是的。任何 tf.keras 模型都应该可以在 Keras 3 中使用 TensorFlow 后端开箱即用(请确保以 .keras v3 格式保存它)。此外,如果模型仅使用内置的 Keras 层,那么它也可以在 Keras 3 中使用 JAX 和 PyTorch 后端开箱即用。

如果模型包含使用 TensorFlow API 编写的自定义层,则通常很容易将代码转换为与后端无关。例如,我们只花了几个小时就将 Keras Applications 中的所有 40 个旧版 tf.keras 模型转换为与后端无关的模型。

问:我可以在一个后端中保存 Keras 3 模型,然后在另一个后端中重新加载它吗?

是的,您可以。保存的 .keras 文件中没有任何后端专业化。您保存的 Keras 模型与框架无关,可以使用任何后端重新加载。

但是,请注意,使用不同的后端重新加载包含自定义组件的模型需要使用与后端无关的 API(例如 keras.ops)来实现您的自定义组件。

问:我可以在 tf.data 管道中使用 Keras 3 组件吗?

使用 TensorFlow 后端,Keras 3 与 tf.data 完全兼容(例如,您可以将 Sequential 模型 .map()tf.data 管道中)。

使用不同的后端,Keras 3 对 tf.data 的支持有限。您将无法将任意层或模型 .map()tf.data 管道中。但是,您将能够将特定的 Keras 3 预处理层与 tf.data 一起使用,例如 IntegerLookupCategoryEncoding

当涉及到使用 tf.data 管道(不使用 Keras)来馈送您对 .fit().evaluate().predict() 的调用时——这可以在所有后端开箱即用。

问:使用不同后端运行时,Keras 3 模型是否行为相同?

是的,数字在不同后端之间是相同的。但是,请记住以下注意事项

  • RNG 行为在不同后端之间是不同的(即使在播种之后——您的结果在每个后端中都是确定的,但在不同后端之间会有所不同)。因此,随机权重初始化值和 dropout 值在不同后端之间会有所不同。
  • 由于浮点实现的性质,每次函数执行,结果仅在 float32 中具有高达 1e-7 的精度是相同的。因此,当长时间训练模型时,微小的数值差异会累积,最终可能会导致明显的数值差异。
  • 由于 PyTorch 中不支持具有不对称填充的平均池化,因此 padding="same" 的平均池化层可能会导致边界行/列上的数值不同。这种情况在实践中并不经常发生——在 40 个 Keras Applications 视觉模型中,只有一个受到影响。

问:Keras 3 支持分布式训练吗?

数据并行分布在 JAX、TensorFlow 和 PyTorch 中开箱即用。模型并行分布通过 keras.distribution API 在 JAX 中开箱即用。

使用 TensorFlow

Keras 3 与 tf.distribute 兼容——只需打开一个 Distribution Strategy 作用域并在其中创建/训练您的模型即可。这是一个示例

使用 PyTorch

Keras 3 与 PyTorch 的 DistributedDataParallel 实用程序兼容。这是一个示例

使用 JAX

您可以使用 keras.distribution API 在 JAX 中进行数据并行和模型并行分布。例如,要进行数据并行分布,您只需要以下代码片段

distribution = keras.distribution.DataParallel(devices=keras.distribution.list_devices())
keras.distribution.set_distribution(distribution)

有关模型并行分布,请参阅以下指南

您还可以通过 JAX API(例如 jax.sharding)自行分发训练。这是一个示例

问:我的自定义 Keras 层可以在原生 PyTorch Modules 中使用或与 Flax Modules 一起使用吗?

如果它们仅使用 Keras API(例如 keras.ops 命名空间)编写,那么是的,您的 Keras 层可以在原生 PyTorch 和 JAX 代码中开箱即用。在 PyTorch 中,只需像使用任何其他 PyTorch Module 一样使用您的 Keras 层。在 JAX 中,请确保使用无状态层 API,即 layer.stateless_call()

问:未来会添加更多后端吗?框架 XYZ 呢?

我们欢迎添加新的后端,只要目标框架拥有庞大的用户群或在其他方面具有某些独特的技术优势。但是,添加和维护新的后端是一项巨大的负担,因此我们将根据具体情况仔细考虑每个新的后端候选者,并且我们不太可能添加许多新的后端。我们不会添加任何尚未建立良好的新框架。我们现在可能正在考虑添加一个用Mojo编写的后端。如果您觉得这可能对您有用,请告知 Mojo 团队。