我们对 Keras 3 的三个后端(TensorFlow、JAX、PyTorch)以及 Keras 2 配合 TensorFlow 进行了基准测试。请在此处 查看用于重现我们结果的代码和设置详情。
我们选择了一系列流行的计算机视觉和自然语言处理模型,涵盖了生成式和非生成式人工智能任务。请参阅下表了解我们的选择。
表 1:基准测试中使用的模型。
| 非生成式 | 生成式 | |
|---|---|---|
| CV | SegmentAnything1 | StableDiffusion2 |
| NLP | BERT3 | Gemma4, Mistral5 |
我们衡量的是常见用户工作流的开箱即用性能,而不是每个框架能达到的最佳性能。本着这一目标,对于 Keras 版本中的模型,我们利用了 KerasCV 和 KerasHub 中现有的实现。
所有基准测试均在 Google Cloud Compute Engine 的 a2-highgpu-1g 机器类型上进行,该类型配备 12 个 vCPU 和 85GB 主机内存,并使用单块 NVIDIA A100 GPU,拥有 40GB GPU 内存。
表 2 显示了每步毫秒为单位的基准测试结果。每一步都涉及对单个数据批次进行训练或预测。结果取 100 步的平均值,不包括第一步(包括模型创建和编译的开销)。
为了公平比较,对于相同的模型和任务(fit 或 predict),我们在不同框架中使用相同的批次大小。然而,对于不同大小和架构的不同模型和任务,我们使用不同的批次大小,以避免内存不足(过大)或 GPU 利用率低下(过小)。
对于大型语言模型(Gemma 和 Mistral),我们也使用了相同的批次大小,因为它们是相同类型的模型,具有相似数量的参数(7B)。我们还对批次大小为 1 的文本生成进行了基准测试,因为它被用户广泛要求。我们使用 bfloat16 精度进行训练和推理,并使用 LoRA6 进行训练(微调)。
为了衡量开箱即用的性能,我们尝试使用所有默认设置。例如,使用高层 API(如 Keras 的 model.fit()),尽量减少配置。
请注意,这与衡量针对特定硬件/框架/模型组合的优化实现有很大不同。请参考 MLPerf 以获取不同框架的最佳优化结果。
表 2:基准测试结果。速度以 ms/step 衡量。越低越好。
| 批次 大小 |
Keras 2 (TensorFlow) |
Keras 3 (TensorFlow) |
Keras 3 (JAX) |
Keras 3 (PyTorch) (eager) |
Keras 3 (最佳) |
|
|---|---|---|---|---|---|---|
| SegmentAnything (fit) |
1 | 386.93 | 355.25 | 361.69 | 1,388.87 | 355.25 |
| SegmentAnything (predict) |
4 | 1,859.27 | 438.50 | 376.34 | 1,720.96 | 376.34 |
| Stable Diffusion (fit) |
8 | 1,023.21 | 392.24 | 391.21 | 823.44 | 391.21 |
| Stable Diffusion (predict) |
13 | 649.71 | 616.04 | 627.27 | 1,337.17 | 616.04 |
| BERT (fit) |
32 | 486.00 | 214.49 | 222.37 | 808.68 | 214.49 |
| BERT (predict) |
256 | 470.12 | 466.01 | 418.72 | 1,865.98 | 418.72 |
| Gemma (fit) |
8 | NA | 232.52 | 273.67 | 525.15 | 232.52 |
| Gemma (generate) |
32 | NA | 1,134.91 | 1,128.21 | 7,952.67* | 1,128.21 |
| Gemma (generate) |
1 | NA | 758.57 | 703.46 | 7,649.40* | 703.46 |
| Mistral (fit) |
8 | NA | 185.92 | 213.22 | 452.12 | 185.92 |
| Mistral (generate) |
32 | NA | 966.06 | 957.25 | 10,932.59* | 957.25 |
| Mistral (generate) |
1 | NA | 743.28 | 679.30 | 11,054.67* | 679.30 |
* 目前,使用 PyTorch 后端的 LLM 推理速度异常缓慢,因为 KerasHub 使用静态序列填充,而 HuggingFace 则不使用。此问题将很快得到解决。
Keras 的三个后端各有优势。至关重要的是,从性能的角度来看,没有一个后端能够持续超越其他后端。最快的后端通常取决于您的具体模型架构。
这凸显了在追求最佳性能时,框架可选性的价值。Keras 3 使您能够无缝切换后端,从而确保找到最适合您模型的匹配项。
我们还计算了 Keras 3(使用其性能最佳的后端)相对于 Keras 2 配合 TensorFlow 的吞吐量(steps/ms)提升(来自表 1)。结果显示在下图。

图 1:Keras 3 相对于 Keras 2 的提速,以吞吐量(steps/ms)衡量
Keras 3 在所有基准测试模型中均持续优于 Keras 2,在许多情况下速度大幅提升。SegmentAnything 推理速度提高了惊人的 380%,StableDiffusion 训练吞吐量提高了 150% 以上,BERT 训练吞吐量提高了 100% 以上。
重要的是,即使您只是升级到 Keras 3 并继续使用 TensorFlow 后端,您仍然会看到性能的提升。这主要是因为 Keras 2 直接使用了更多 TensorFlow 融合算子,这在某些用例中可能不利于 XLA 编译。
框架性能在很大程度上取决于具体模型。Keras 3 使您能够为您的任务选择最快的框架,这几乎总能优于 Keras 2。
1 Kirillov, Alexander, et al. "Segment anything." ICCV (2023).
2 Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models." CVPR (2022).
3 Kenton, Jacob, et al. "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding." NAACL (2019).
4 Banks, Jeanine, et al. "Gemma: Introducing new state-of-the-art open models." The Keyword, Google (2024).
5 Jiang, Albert Q., et al. "Mistral 7B." arXiv preprint arXiv:2310.06825 (2023).
6 Hu, Edward J., et al. "Lora: Low-rank adaptation of large language models." ICLR (2022).