开发者指南 / Keras 中的 8 位整数量化

Keras 中的 8 位整数量化

作者: Jyotinder Singh
创建日期 2025/10/14
最后修改日期 2025/10/14
描述: Keras 和 KerasHub 中 INT8 量化的完整指南。

在 Colab 中查看 GitHub 源代码


什么是 INT8 量化?

量化将权重和激活的数值精度降低,以减少内存使用并常常加速推理,但会以牺牲少量精度为代价。从 float32 切换到 float16 会使内存需求减半;从 float32 到 INT8 的内存占用大约是 4 倍(与 float16 相比大约是 2 倍)。在具有低精度内核的硬件上(例如 NVIDIA Tensor Cores),这还可以提高吞吐量和降低延迟。实际的收益取决于您的后端和设备。

工作原理

量化将实数值映射到带有比例因子的 8 位整数

  • 整数域:[-128, 127](256 个级别)。
  • 对于一个张量(权重通常按输出通道计算),其值为 w
    • 计算 a_max = max(abs(w))
    • 设置比例因子 s = (2 * a_max) / 256
    • 量化:q = clip(round(w / s), -128, 127)(存储为 INT8)并保留 s
  • 推理使用 qs 来即时重构有效的权重(w ≈ s · q)或将 s 折叠到 matmul/conv 中以提高效率。

优点

  • 内存/带宽受限模型:当实现大部分时间花费在内存 I/O 上时,减少计算时间并不能缩短整体运行时间。INT8 将移动的字节数与 float32 相比减少了约 4 倍,改善了缓存行为并减少了内存停顿;这通常比增加原始 FLOPs 的帮助更大。
  • 受支持硬件上的计算受限层:在 NVIDIA GPU 上,INT8 Tensor Cores 加速了 matmul/conv,从而提高了计算受限层的吞吐量。
  • 精度:许多模型在 float16 下可以保持接近 FP 的精度;INT8 可能会引入轻微的下降(通常取决于任务/模型/数据,约 1-5%)。务必在您自己的数据集上进行验证。

Keras 在 INT8 模式下做什么

  • 映射:对称、线性量化,使用 INT8 加上一个浮点比例因子。
  • 权重:每输出通道的比例因子,以保持精度。
  • 激活动态 AbsMax 缩放,在运行时计算。
  • 图重写:在训练完权重并构建模型后应用量化;图会被重写,以便您可以立即运行或保存。

概述

本指南将介绍如何在 Keras 中使用 8 位整数训练后量化 (PTQ)

  1. 量化一个最小的函数式模型
  2. 保存和重新加载量化模型
  3. 量化一个 KerasHub 模型

量化一个最小的函数式模型。

我们构建一个小的函数式模型,捕获基线输出,将其原地量化为 INT8,然后使用 MSE 度量来比较输出。

import os
import numpy as np
import keras
from keras import layers


# Create a random number generator.
rng = np.random.default_rng()

# Create a simple functional model.
inputs = keras.Input(shape=(10,))
x = layers.Dense(32, activation="relu")(inputs)
outputs = layers.Dense(1, name="target")(x)
model = keras.Model(inputs, outputs)

# Compile and train briefly to materialize meaningful weights.
model.compile(optimizer="adam", loss="mse")
x_train = rng.random((256, 10)).astype("float32")
y_train = rng.random((256, 1)).astype("float32")
model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0)

# Sample inputs for evaluation.
x_eval = rng.random((32, 10)).astype("float32")

# Baseline (FP) outputs.
y_fp32 = model(x_eval)

# Quantize the model in-place to INT8.
model.quantize("int8")

# INT8 outputs after quantization.
y_int8 = model(x_eval)

# Compute a simple MSE between FP and INT8 outputs.
mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int8))
print("Full-Precision vs INT8 MSE:", float(mse))
Full-Precision vs INT8 MSE: 4.982496648153756e-06

正如低 MSE 值所示,INT8 量化模型产生的输出非常接近原始 FP32 模型。


保存和重新加载量化模型

您可以使用标准的 Keras 保存和加载 API 来处理量化模型。保存为 .keras 并重新加载时,量化会得到保留。

# Save the quantized model and reload to verify round-trip.
model.save("int8.keras")
int8_reloaded = keras.saving.load_model("int8.keras")
y_int8_reloaded = int8_reloaded(x_eval)
roundtrip_mse = keras.ops.mean(keras.ops.square(y_int8 - y_int8_reloaded))
print("MSE (INT8 vs reloaded-INT8):", float(roundtrip_mse))
MSE (INT8 vs reloaded-INT8): 0.0

量化一个 KerasHub 模型

所有 KerasHub 模型都支持用于训练后量化的 .quantize(...) API,并且遵循与上述相同的流程。

在此示例中,我们将

  1. 加载 KerasHub 的 gemma3_1b 预训练模型
  2. 使用全精度模型和量化模型生成文本,并比较输出。
  3. 将两个模型保存到磁盘并计算存储节省。
  4. 重新加载 INT8 模型并验证输出与原始量化模型的一致性。
from keras_hub.models import Gemma3CausalLM

# Load from Gemma3 preset
gemma3 = Gemma3CausalLM.from_preset("gemma3_1b")

# Generate text for a single prompt
output = gemma3.generate("Keras is a", max_length=50)
print("Full-precision output:", output)

# Save FP32 Gemma3 model for size comparison.
gemma3.save_to_preset("gemma3_fp32")

# Quantize in-place to INT8 and generate again
gemma3.quantize("int8")

output = gemma3.generate("Keras is a", max_length=50)
print("Quantized output:", output)

# Save INT8 Gemma3 model
gemma3.save_to_preset("gemma3_int8")

# Reload and compare outputs
gemma3_int8 = Gemma3CausalLM.from_preset("gemma3_int8")

output = gemma3_int8.generate("Keras is a", max_length=50)
print("Quantized reloaded output:", output)


# Compute storage savings
def bytes_to_mib(n):
    return n / (1024**2)


gemma_fp32_size = os.path.getsize("gemma3_fp32/model.weights.h5")
gemma_int8_size = os.path.getsize("gemma3_int8/model.weights.h5")

gemma_reduction = 100.0 * (1.0 - (gemma_int8_size / max(gemma_fp32_size, 1)))
print(f"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB")
print(f"Gemma3: INT8 file size: {bytes_to_mib(gemma_int8_size):.2f} MiB")
print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%")
Full-precision output: Keras is a deep learning library for Python. It is a high-level API for neural networks. It is a Python library for deep learning. It is a library for deep learning. It is a library for deep learning. It is a
Quantized output: Keras is a Python library for deep learning. It is a high-level API for building deep learning models. It is designed to be easy
Quantized reloaded output: Keras is a Python library for deep learning. It is a high-level API for building deep learning models. It is designed to be easy
Gemma3: FP32 file size: 3815.32 MiB
Gemma3: INT8 file size: 957.81 MiB
Gemma3: Size reduction: 74.9%

实用技巧

  • 训练后量化 (PTQ) 是一次性操作;量化为 INT8 后,您无法再训练模型。
  • 在量化之前务必使权重具体化(例如,通过 build() 或前向传播)。
  • 预期会有微小的数值差异;使用 MSE 等度量在一个验证批次上进行量化。
  • 存储节省是即时的;速度提升取决于后端/设备的内核。

参考文献