经过五个月的广泛公开测试,我们很高兴地宣布 Keras 3.0 正式发布。Keras 3 是 Keras 的完整重写,它使您能够在 JAX、TensorFlow 或 PyTorch 之上运行 Keras 工作流程,并解锁全新的大型模型训练和部署功能。您可以选择最适合您的框架,并根据您的当前目标在它们之间切换。您还可以使用 Keras 作为低级跨框架语言来开发自定义组件,例如层、模型或指标,这些组件可以在 JAX、TensorFlow 或 PyTorch 的原生工作流程中使用——只需一个代码库即可。


欢迎来到多框架机器学习时代。

您已经熟悉使用 Keras 的好处——它通过专注于出色的用户体验、API 设计和可调试性,实现了高速开发。它也是一个经过实战检验的框架,已被超过 250 万开发人员选择,并为世界上一些最复杂、最大规模的机器学习系统提供支持,例如 Waymo 自动驾驶车队和 YouTube 推荐引擎。但是,使用新的多后端 Keras 3 还有哪些额外的好处呢?

  • 始终为您的模型获得最佳性能。 在我们的基准测试中,我们发现 JAX 通常在 GPU、TPU 和 CPU 上提供最佳的训练和推理性能——但结果因模型而异,因为非 XLA TensorFlow 有时在 GPU 上更快。能够动态选择能够为您的模型提供最佳性能的后端而无需更改任何代码,意味着您可以保证以最高的效率进行训练和服务。
  • 为您的模型解锁生态系统可选性。 任何 Keras 3 模型都可以实例化为 PyTorch 的Module,可以导出为 TensorFlow 的SavedModel,或者可以实例化为无状态的 JAX 函数。这意味着您可以将 Keras 3 模型与 PyTorch 生态系统包一起使用,与 TensorFlow 的全套部署和生产工具(如 TF-Serving、TF.js 和 TFLite)一起使用,以及与 JAX 大型 TPU 训练基础设施一起使用。使用 Keras 3 API 编写一个model.py,即可访问机器学习领域提供的所有功能。
  • 利用 JAX 进行大规模模型并行和数据并行。 Keras 3 包含一个全新的分布式 API,即keras.distribution 命名空间,目前已针对 JAX 后端实现(TensorFlow 和 PyTorch 后端即将推出)。它可以轻松地进行模型并行、数据并行以及两者的组合——在任意模型规模和集群规模上。因为它将模型定义、训练逻辑和分片配置彼此分开,因此使您的分布式工作流程易于开发和维护。请参阅我们的入门指南
  • 最大化开源模型发布的覆盖范围。 想发布一个预训练模型吗?希望尽可能多的人能够使用它?如果您使用纯 TensorFlow 或 PyTorch 实现它,它将可供大约一半的社区使用。如果您在 Keras 3 中实现它,它将立即可供任何人使用,无论他们选择什么框架(即使他们本身不是 Keras 用户)。在不增加开发成本的情况下,获得双倍的影响。
  • 使用来自任何来源的数据管道。 Keras 3 的fit()/evaluate()/predict()例程与tf.data.Dataset对象、PyTorch 的DataLoader对象、NumPy 数组、Pandas 数据框兼容——无论您使用什么后端。您可以使用 PyTorch 的DataLoader训练 Keras 3 + TensorFlow 模型,或者使用tf.data.Dataset训练 Keras 3 + PyTorch 模型。

适用于 JAX、TensorFlow 和 PyTorch 的完整 Keras API。

Keras 3 实施了完整的 Keras API,并使其可用于 TensorFlow、JAX 和 PyTorch——一百多个层、几十个指标、损失函数、优化器和回调、Keras 训练和评估循环以及 Keras 保存和序列化基础设施。所有您熟悉和喜爱的 API 都在这里。

任何仅使用内置层的 Keras 模型都将立即与所有支持的后端一起使用。事实上,您现有的仅使用内置层的tf.keras模型可以立即开始在 JAX 和 PyTorch 中运行!没错,您的代码库刚刚获得了一整套全新的功能。


编写多框架层、模型、指标……

Keras 3 使您能够创建可在任何框架中以相同方式工作的组件(如任意自定义层或预训练模型)。特别是,Keras 3 使您可以访问适用于所有后端的keras.ops命名空间。它包含

  • NumPy API 的完整实现。 不是“类似 NumPy”的东西——而是真正的 NumPy API,具有相同的函数和相同的参数。您可以使用ops.matmulops.sumops.stackops.einsum等。
  • 一组神经网络特定函数,这些函数在 NumPy 中不存在,例如ops.softmaxops.binary_crossentropyops.conv等。

只要您只使用来自keras.ops的操作,您的自定义层、自定义损失、自定义指标和自定义优化器将与 JAX、PyTorch 和 TensorFlow 一起使用——使用相同的代码。这意味着您只需维护一个组件实现(例如,一个model.py以及一个检查点文件),您就可以在所有框架中使用它,并获得完全相同的数值结果。


……可以与任何 JAX、TensorFlow 和 PyTorch 工作流程无缝协作。

Keras 3 不仅适用于以 Keras 为中心的工作流程,在这种工作流程中,您定义一个 Keras 模型、一个 Keras 优化器、一个 Keras 损失和指标,然后调用fit()evaluate()predict()。它也旨在与低级后端原生工作流程无缝协作:您可以获取 Keras 模型(或任何其他组件,例如损失或指标),并开始在 JAX 训练循环、TensorFlow 训练循环或 PyTorch 训练循环中使用它,或者作为 JAX 或 PyTorch 模型的一部分使用,没有任何摩擦。Keras 3 在 JAX 和 PyTorch 中提供了与tf.keras以前在 TensorFlow 中提供的相同程度的低级实现灵活性。

您可以

  • 编写低级 JAX 训练循环以使用optax优化器、jax.gradjax.jitjax.pmap训练 Keras 模型。
  • 编写低级 TensorFlow 训练循环以使用tf.GradientTapetf.distribute训练 Keras 模型。
  • 编写低级 PyTorch 训练循环以使用torch.optim优化器、torch损失函数和torch.nn.parallel.DistributedDataParallel包装器训练 Keras 模型。
  • 在 PyTorch 的Module中使用 Keras 层(因为它们也是Module实例!)
  • 在 Keras 模型中使用任何 PyTorch 的Module,就像它是一个 Keras 层一样。
  • 等等。


用于大规模数据并行和模型并行的新分布式 API。

我们一直在使用的模型越来越大,因此我们希望为多设备模型分片问题提供一个 Keras 式的解决方案。我们设计的 API 将模型定义、训练逻辑和分片配置完全彼此分开,这意味着您的模型可以像在单个设备上运行一样编写。然后,当您需要训练它们时,可以向任意模型添加任意分片配置。

数据并行(在多个设备上相同地复制小型模型)只需两行代码即可处理

模型并行允许您指定模型变量和中间输出张量的分片布局,跨多个命名维度。在典型情况下,您会将可用设备组织成一个二维网格(称为设备网格),其中第一个维度用于数据并行,第二个维度用于模型并行。然后,您可以将模型配置为沿模型维度分片并在数据维度上复制。

API 允许您通过正则表达式配置每个变量和每个输出张量的布局。这使得轻松快速地为整个变量类别指定相同的布局变得容易。

新的分布式 API 旨在是多后端的,但目前仅适用于 JAX 后端。TensorFlow 和 PyTorch 支持即将推出。请从本指南开始!


预训练模型。

您可以立即开始使用 Keras 3 中的各种预训练模型。

所有 40 个 Keras 应用模型(keras.applications命名空间)都可在所有后端使用。KerasCVKerasHub中大量预训练模型也适用于所有后端。这包括

  • BERT
  • OPT
  • Whisper
  • T5
  • StableDiffusion
  • YOLOv8
  • SegmentAnything
  • 等等。

支持所有后端的跨框架数据管道。

多框架机器学习也意味着多框架数据加载和预处理。Keras 3 模型可以使用各种数据管道进行训练——无论您使用的是 JAX、PyTorch 还是 TensorFlow 后端。它只需正常工作即可。

  • tf.data.Dataset管道:可扩展生产机器学习的参考。
  • torch.utils.data.DataLoader对象。
  • NumPy 数组和 Pandas 数据框。
  • Keras 自身的keras.utils.PyDataset对象。

渐进式复杂度揭示。

渐进式复杂度揭示是 Keras API 核心的设计原则。Keras 不会强迫您遵循构建和训练模型的唯一“正确”方法。相反,它支持各种不同的工作流程,从非常高级别到非常低级别,对应于不同的用户配置文件。

这意味着您可以从简单的工作流程开始——例如使用SequentialFunctional模型并使用fit()训练它们——当您需要更多灵活性时,您可以轻松地自定义不同的组件,同时重用大部分以前的代码。随着您的需求变得更加具体,您不会突然陷入复杂度的悬崖,也不需要切换到不同的工具集。

我们将此原则应用到了我们所有的后端。例如,您可以自定义训练循环中发生的事情,同时仍然利用fit()的功能,而无需从头编写自己的训练循环——只需覆盖train_step方法即可。

以下是它在 PyTorch 和 TensorFlow 中的工作原理

以及JAX 版本的链接


用于层、模型、指标和优化器的新无状态 API。

您是否喜欢函数式编程?您将大饱眼福。

Keras 中的所有有状态对象(即拥有在训练或评估期间更新的数值变量的对象)现在都有一个无状态 API,这使得可以在 JAX 函数中使用它们(这些函数需要完全无状态)

  • 所有层和模型都有一个stateless_call()方法,它反映了__call__()
  • 所有优化器都有一个stateless_apply()方法,它反映了apply()
  • 所有指标都有一个stateless_update_state()方法,它反映了update_state(),以及一个stateless_result()方法,它反映了result()

这些方法没有任何副作用:它们将目标对象的当前状态变量值作为输入,并将更新值作为其输出的一部分返回,例如

outputs, updated_non_trainable_variables = layer.stateless_call(
    trainable_variables,
    non_trainable_variables,
    inputs,
)

您无需自己实现这些方法——只要您实现了有状态版本(例如call()update_state()),它们就会自动可用。


从 Keras 2 迁移到 Keras 3

Keras 3 与 Keras 2 具有高度向后兼容性:它实现了 Keras 2 的完整公共 API 表面,仅有少量例外情况,详见此处。大多数用户无需更改任何代码即可开始在 Keras 3 上运行其 Keras 脚本。

较大的代码库可能需要一些代码更改,因为它们更有可能遇到上述例外情况之一,并且更有可能使用私有 API 或已弃用的 API(tf.compat.v1.keras 命名空间、experimental 命名空间、keras.src 私有命名空间)。为了帮助您迁移到 Keras 3,我们发布了一个完整的迁移指南,其中包含您可能遇到的所有问题的快速修复。

您也可以选择忽略 Keras 3 中的更改,并继续使用 TensorFlow 中的 Keras 2 — 这对于那些没有积极开发但需要保持更新依赖项运行的项目来说是一个不错的选择。您有两种可能性

  1. 如果您将 keras 作为独立包访问,只需切换到使用 Python 包 tf_keras 即可,您可以通过 pip install tf_keras 进行安装。代码和 API 完全没有改变 — 它只是 Keras 2.15,但包名不同。我们将继续修复 tf_keras 中的 bug,并定期发布新版本。但是,不会添加任何新功能或性能改进,因为该包现在处于维护模式。
  2. 如果您通过 tf.keras 访问 keras,则在 TensorFlow 2.16 之前没有立即更改。TensorFlow 2.16 及更高版本将默认使用 Keras 3。在 TensorFlow 2.16 及更高版本中,要继续使用 Keras 2,您可以先安装 tf_keras,然后导出环境变量 TF_USE_LEGACY_KERAS=1。这将指示 TensorFlow 2.16 及更高版本将 tf.keras 解析到本地安装的 tf_keras 包。但是请注意,这可能会影响超出您自己代码的范围:它会影响 Python 进程中任何导入 tf.keras 的包。为了确保您的更改仅影响您自己的代码,您应该使用 tf_keras 包。

享受这个库!

我们很高兴您尝试新的 Keras 并通过利用多框架机器学习来改进您的工作流程。请告诉我们您的体验:问题、摩擦点、功能请求或成功案例 — 我们渴望听到您的反馈!


常见问题

问:Keras 3 是否与旧版 Keras 2 兼容?

使用 tf.keras 开发的代码通常可以与 Keras 3(使用 TensorFlow 后端)一起按原样运行。您应该注意一些有限的兼容性问题,所有这些问题都在此迁移指南中进行了说明。

当谈到并排使用 tf.keras 和 Keras 3 中的 API 时,这是**不可能的** — 它们是不同的包,运行在完全独立的引擎上。

问:在旧版 Keras 2 中开发的预训练模型是否适用于 Keras 3?

通常,是的。任何 tf.keras 模型都应该能够与使用 TensorFlow 后端的 Keras 3 一起开箱即用(确保以 .keras v3 格式保存它)。此外,如果模型仅使用内置 Keras 层,那么它也能够与使用 JAX 和 PyTorch 后端的 Keras 3 一起开箱即用。

如果模型包含使用 TensorFlow API 编写的自定义层,则通常很容易将代码转换为与后端无关的代码。例如,我们仅用几个小时就将 Keras Applications 中的所有 40 个旧版 tf.keras 模型转换为与后端无关的模型。

问:我可以在一个后端中保存 Keras 3 模型并在另一个后端中重新加载它吗?

是的,可以。保存的 .keras 文件中没有任何后端专门化。您保存的 Keras 模型与框架无关,并且可以使用任何后端重新加载。

但是,请注意,使用不同的后端重新加载包含自定义组件的模型需要您的自定义组件使用与后端无关的 API 实现,例如 keras.ops

问:我可以在 tf.data 管道中使用 Keras 3 组件吗?

使用 TensorFlow 后端,Keras 3 与 tf.data 完全兼容(例如,您可以将 Sequential 模型 .map()tf.data 管道中)。

使用不同的后端,Keras 3 对 tf.data 的支持有限。您将无法将任意层或模型 .map()tf.data 管道中。但是,您将能够将特定的 Keras 3 预处理层与 tf.data 一起使用,例如 IntegerLookupCategoryEncoding

当谈到使用 tf.data 管道(不使用 Keras)来馈送您对 .fit().evaluate().predict() 的调用时 — 这在所有后端中都可以开箱即用。

问:Keras 3 模型在使用不同后端运行时行为是否相同?

是的,数值在不同后端之间是相同的。但是,请记住以下注意事项

  • RNG 行为在不同后端之间有所不同(即使在播种后 — 您的结果在每个后端中都是确定性的,但在不同后端之间会存在差异)。因此,随机权重初始化值和 dropout 值在不同后端之间会有所不同。
  • 由于浮点实现的性质,结果在每个函数执行中仅在 float32 中精确到 1e-7。因此,当长时间训练模型时,微小的数值差异会累积,最终可能导致明显的数值差异。
  • 由于 PyTorch 不支持具有非对称填充的平均池化,因此具有 padding="same" 的平均池化层可能导致边界行/列上的数值不同。在实践中这种情况并不常见 — 在 40 个 Keras Applications 视觉模型中,只有一个受到影响。

问:Keras 3 是否支持分布式训练?

数据并行分布在 JAX、TensorFlow 和 PyTorch 中开箱即用地受支持。模型并行分布在 JAX 中使用 keras.distribution API 开箱即用地受支持。

使用 TensorFlow

Keras 3 与 tf.distribute 兼容 — 只需打开一个分布式策略范围并在其中创建/训练您的模型即可。这是一个示例

使用 PyTorch

Keras 3 与 PyTorch 的 DistributedDataParallel 实用程序兼容。这是一个示例

使用 JAX

您可以使用 keras.distribution API 在 JAX 中进行数据并行和模型并行分布。例如,要进行数据并行分布,您只需要以下代码片段

distribution = keras.distribution.DataParallel(devices=keras.distribution.list_devices())
keras.distribution.set_distribution(distribution)

有关模型并行分布,请参阅以下指南

您还可以通过 JAX API(如 jax.sharding)自行分配训练。这是一个示例

问:我的自定义 Keras 层可以在本机 PyTorch Modules 或 Flax Modules 中使用吗?

如果它们仅使用 Keras API(例如 keras.ops 命名空间)编写,那么是的,您的 Keras 层将能够与本机 PyTorch 和 JAX 代码一起开箱即用。在 PyTorch 中,只需像使用任何其他 PyTorch Module 一样使用您的 Keras 层。在 JAX 中,请确保使用无状态层 API,即 layer.stateless_call()

问:将来会添加更多后端吗?框架 XYZ 怎么样?

只要目标框架拥有庞大的用户群或以其他方式带来一些独特的技术优势,我们都乐于添加新的后端。但是,添加和维护新的后端是一项巨大的负担,因此我们将逐案仔细考虑每个新的后端候选者,并且我们不太可能添加许多新的后端。我们不会添加任何尚未成熟的新框架。我们现在正在考虑添加一个用Mojo编写的后端。如果您认为这可能对您有用,请告知 Mojo 团队。