入门 / Keras 3 基准测试

Keras 3 基准测试

我们对 Keras 3 的三个后端(TensorFlowJAXPyTorch)以及结合 TensorFlow 的 Keras 2 进行了基准测试。您可以在此处找到复现结果的代码和设置详情。

模型

我们选择了一系列流行的计算机视觉和自然语言处理模型,用于生成式和非生成式 AI 任务。请参阅下表了解我们的选择。

表 1:基准测试中使用的模型。

非生成式 生成式
CV SegmentAnything1 StableDiffusion2
NLP BERT3 Gemma4, Mistral5

我们测量的不是每个框架能够达到的最佳性能,而是常见用户工作流程的开箱即用性能。考虑到这一目标,我们利用了 KerasCV 和 KerasHub 中现有的模型 Keras 版本实现。

硬件

所有基准测试均使用 Google Cloud Compute Engine 上具有 12 个 vCPU 和 85GB 主机内存的 a2-highgpu-1g 机器类型,配备单个具有 40GB GPU 内存的 NVIDIA A100 GPU 进行。

结果

表 2 显示了以每步毫秒数(ms/step)为单位的基准测试结果。每一步都涉及在一个数据批次上进行训练或预测。结果是对 100 个步骤取平均值,不包括第一步(第一步包含模型创建和编译开销)。

为了公平比较,如果是相同的模型和任务(fit 或 predict),我们在不同框架中使用了相同的批次大小。然而,对于不同的模型和任务,由于它们的尺寸和架构不同,我们使用了不同的批次大小,以避免内存不足(过大)或 GPU 利用率不足(过小)。

对于大型语言模型(Gemma 和 Mistral),我们也使用了相同的批次大小,因为它们是相同类型的模型,具有相似的参数数量(7B)。我们还对批次大小为 1 的文本生成进行了基准测试,因为这是用户广泛要求的。我们在其训练和推理中使用了 bfloat16 精度,并在其训练(微调)中使用了 LoRA6

为了测量开箱即用性能,我们尽量使用所有默认设置。例如,使用高级 API(例如 Keras 的 model.fit()),配置尽可能少。

请注意,这与测量针对特定硬件/框架/模型组合的优化实现完全不同。有关不同框架的最佳优化结果,请参阅 MLPerf

表 2:基准测试结果。速度以 ms/step 为单位测量。数值越低越好。

批次
大小
Keras 2
(TensorFlow)
Keras 3
(TensorFlow)
Keras 3
(JAX)
Keras 3
(PyTorch)
(eager)
Keras 3
(最佳)
SegmentAnything
(fit)
1 386.93 355.25 361.69 1,388.87 355.25
SegmentAnything
(predict)
4 1,859.27 438.50 376.34 1,720.96 376.34
Stable Diffusion
(fit)
8 1,023.21 392.24 391.21 823.44 391.21
Stable Diffusion
(predict)
13 649.71 616.04 627.27 1,337.17 616.04
BERT
(fit)
32 486.00 214.49 222.37 808.68 214.49
BERT
(predict)
256 470.12 466.01 418.72 1,865.98 418.72
Gemma
(fit)
8 不适用 (NA) 232.52 273.67 525.15 232.52
Gemma
(generate)
32 不适用 (NA) 1,134.91 1,128.21 7,952.67* 1,128.21
Gemma
(generate)
1 不适用 (NA) 758.57 703.46 7,649.40* 703.46
Mistral
(fit)
8 不适用 (NA) 185.92 213.22 452.12 185.92
Mistral
(generate)
32 不适用 (NA) 966.06 957.25 10,932.59* 957.25
Mistral
(generate)
1 不适用 (NA) 743.28 679.30 11,054.67* 679.30

* PyTorch 后端的大型语言模型推理目前异常缓慢,因为 KerasHub 使用静态序列填充,这与 HuggingFace 不同。此问题将很快得到解决。

讨论

主要发现 1:“最佳”后端并不存在

Keras 的三个后端各自拥有独特的优势。重要的是,从性能角度来看,没有哪个后端能始终超越其他后端。最快的后端通常取决于您特定的模型架构。

这突显了在追求最优性能时框架可选性的价值。Keras 3 使您能够无缝切换后端,确保为您的模型找到理想的匹配。

主要发现 2:Keras 3 比 Keras 2 更快

我们还根据表 1 计算了 Keras 3(使用其性能最佳的后端)相对于 Keras 2(使用 TensorFlow)的吞吐量(steps/ms)提升。结果如下图所示。

Figrue 2

图 1:Keras 3 相对于 Keras 2 的加速比,以吞吐量(steps/ms)衡量

Keras 3 在所有基准测试模型中均持续优于 Keras 2,在许多情况下速度显著提升。SegmentAnything 推理性能提升了惊人的 380%,StableDiffusion 训练吞吐量增加了 150% 以上,BERT 训练吞吐量提升了 100% 以上。

重要的是,即使您只是升级到 Keras 3 并继续使用 TensorFlow 后端,您仍然会看到性能提升。这主要是因为 Keras 2 更直接地使用了更多的 TensorFlow 融合操作(fused ops),这在某些用例中可能不利于 XLA 编译。

结论

框架性能很大程度上取决于具体的模型。Keras 3 使您能够为您的任务选择最快的框架——这几乎总是优于 Keras 2 的选项。

参考文献

1 Kirillov, Alexander, 等。“Segment anything。” ICCV (2023)。

2 Rombach, Robin, 等。“High-resolution image synthesis with latent diffusion models。” CVPR (2022)。

3 Kenton, Jacob, 等。“BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding。” NAACL (2019)。

4 Banks, Jeanine, 等。“Gemma: Introducing new state-of-the-art open models。” The Keyword, Google (2024)。

5 Jiang, Albert Q., 等。“Mistral 7B。” arXiv preprint arXiv:2310.06825 (2023)。

6 Hu, Edward J., 等。“Lora: Low-rank adaptation of large language models。” ICLR (2022)。