作者: Jyotinder Singh
创建日期 2025/10/14
最后修改日期 2025/10/14
描述: Keras 和 KerasHub 中 INT4 量化的完整指南。
量化通过降低权重和激活的数值精度来减少内存使用量并经常加快推理速度,但会以少量精度下降为代价。INT4 训练后量化 (PTQ) 将模型权重存储在 4 位有符号整数中,并在运行时动态地将激活量化为 8 位(W4A8 方案)。与 FP32 相比,这可以将权重存储大小缩小约 8 倍(比 INT8 小 2 倍),同时仍能为许多编码器模型和一些解码器模型保持可接受的精度。计算仍可利用广泛可用的 NVIDIA INT8 Tensor Cores。
4 位比 8 位更具侵略性地压缩,可能会导致更严重的质量下降,尤其是对于大型自回归语言模型。
量化将实数值映射到具有比例因子的 4 位整数
[-8, 7](4 位)范围内的值,并将它们打包成每字节两个。(input_scale * per_channel_kernel_scale) 进行反量化。这与 INT8 指南 中描述的 INT8 路径类似,只是为了实现更强的压缩而增加了一些额外的解包开销。
本指南介绍了如何在 Keras 中使用 4 位 (W4A8) 训练后量化
下面我们将构建一个小的函数式模型,捕获基线输出,原地量化为 INT4,并使用 MSE 指标比较输出。(对于实际评估,请使用您的验证指标。)
import numpy as np
import keras
from keras import layers
# Create a random number generator.
rng = np.random.default_rng()
# Create a simple functional model.
inputs = keras.Input(shape=(10,))
x = layers.Dense(32, activation="relu")(inputs)
outputs = layers.Dense(1, name="target")(x)
model = keras.Model(inputs, outputs)
# Baseline output with full-precision weights.
x_eval = rng.random((32, 10)).astype("float32")
y_fp32 = model(x_eval)
# Quantize the model in-place to INT4 (W4A8).
model.quantize("int4")
# Compare outputs (MSE).
y_int4 = model(x_eval)
mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int4))
print("Full-Precision vs INT4 MSE:", float(mse))
Full-Precision vs INT4 MSE: 0.00028205406852066517
INT4 量化模型通常会产生足够接近的输出,可用于许多下游任务。期望比 INT8 的差异更大,因此请务必在您自己的数据上进行验证。
您可以使用标准的 Keras 保存/加载 API。量化元数据(包括比例因子和打包的权重)将得到保留。
# Save the quantized model and reload to verify round-trip.
model.save("int4.keras")
int4_reloaded = keras.saving.load_model("int4.keras")
y_int4_reloaded = int4_reloaded(x_eval)
# Compare outputs (MSE).
roundtrip_mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int4_reloaded))
print("MSE (INT4 vs reloaded INT4):", float(roundtrip_mse))
MSE (INT4 vs reloaded INT4): 0.00028205406852066517
所有 KerasHub 模型都支持用于训练后量化的 .quantize(...) API,并且与上述工作流程相同。
在此示例中,我们将
import os
from keras_hub.models import Gemma3CausalLM
# Load a Gemma3 preset from KerasHub.
gemma3 = Gemma3CausalLM.from_preset("gemma3_1b")
# Generate with full-precision weights.
fp_output = gemma3.generate("Keras is a", max_length=30)
print("Full-precision output:", fp_output)
# Save the full-precision model to a preset.
gemma3.save_to_preset("gemma3_fp32")
# Quantize to INT4.
gemma3.quantize("int4")
# Generate with INT4 weights.
output = gemma3.generate("Keras is a", max_length=30)
print("Quantized output:", output)
# Save INT4 model to a new preset.
gemma3.save_to_preset("gemma3_int4")
# Reload and compare outputs
gemma3_int4 = Gemma3CausalLM.from_preset("gemma3_int4")
output = gemma3_int4.generate("Keras is a", max_length=30)
print("Quantized reloaded output:", output)
# Compute storage savings
def bytes_to_mib(n):
return n / (1024**2)
gemma_fp32_size = os.path.getsize("gemma3_fp32/model.weights.h5")
gemma_int4_size = os.path.getsize("gemma3_int4/model.weights.h5")
gemma_reduction = 100.0 * (1.0 - (gemma_int4_size / max(gemma_fp32_size, 1)))
print(f"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB")
print(f"Gemma3: INT4 file size: {bytes_to_mib(gemma_int4_size):.2f} MiB")
print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%")
Full-precision output: Keras is a deep learning library for Python. It is a high-level API for neural networks. It is a Python library for deep learning
Quantized output: Keras is a python-based, open-source, and free-to-use, open-source, and a popular, and a
Quantized reloaded output: Keras is a python-based, open-source, and free-to-use, open-source, and a popular, and a
Gemma3: FP32 file size: 3815.32 MiB
Gemma3: INT4 file size: 1488.10 MiB
Gemma3: Size reduction: 61.0%
在单个 NVIDIA L4 (22.5 GB) 上收集的微基准测试。基线是 FP32。
| 指标 | FP32 | INT4 | 变化 |
|---|---|---|---|
| 准确率 (↑) | 91.06% | 90.14% | -0.92pp |
| 模型大小 (MB, ↓) | 255.86 | 159.49 | -37.67% |
| 峰值 GPU 内存 (MiB, ↓) | 1554.00 | 1243.26 | -20.00% |
| 延迟 (ms/样本, ↓) | 6.43 | 5.73 | -10.83% |
| 吞吐量 (样本/秒, ↑) | 155.60 | 174.50 | +12.15% |
分析: 精度下降温和(<1pp),速度和内存增益显著;仅编码器模型在更重的权重压缩下往往能保持保真度。
| 指标 | FP32 | INT4 | 变化 |
|---|---|---|---|
| 困惑度 (↓) | 7.44 | 9.98 | +34.15% |
| 模型大小 (GB, ↓) | 4.8884 | 0.9526 | -80.51% |
| 峰值 GPU 内存 (MiB, ↓) | 8021.12 | 5483.46 | -31.64% |
| 第一个 token 延迟 (ms, ↓) | 128.87 | 122.50 | -4.95% |
| 序列延迟 (ms, ↓) | 338.29 | 181.93 | -46.22% |
| Token 吞吐量 (tokens/秒, ↑) | 174.41 | 256.96 | +47.33% |
分析: INT4 实现了显著的大小(-80%)和内存(-32%)缩减。困惑度有所增加(符合预期,因为压缩更激进),但序列延迟下降,吞吐量提高了约 50%。
| 指标 | FP32 | INT4 | 变化 |
|---|---|---|---|
| 困惑度 (↓) | 6.17 | 10.46 | +69.61% |
| 模型大小 (GB, ↓) | 3.7303 | 1.4576 | -60.92% |
| 峰值 GPU 内存 (MiB, ↓) | 6844.67 | 5008.14 | -26.83% |
| 第一个 token 延迟 (ms, ↓) | 57.42 | 64.21 | +11.83% |
| 序列延迟 (ms, ↓) | 239.78 | 161.18 | -32.78% |
| Token 吞吐量 (tokens/秒, ↑) | 246.06 | 366.05 | +48.76% |
分析: INT4 实现了显著的大小(-61%)和内存(-27%)缩减。困惑度有所增加(符合预期,因为压缩更激进),但序列延迟下降,吞吐量提高了约 50%。
| 指标 | FP32 | INT4 | 变化 |
|---|---|---|---|
| 困惑度 (↓) | 6.38 | 14.16 | +121.78% |
| 模型大小 (GB, ↓) | 5.5890 | 2.4186 | -56.73% |
| 峰值 GPU 内存 (MiB, ↓) | 9509.49 | 6810.26 | -28.38% |
| 第一个 token 延迟 (ms, ↓) | 209.41 | 219.09 | +4.62% |
| 序列延迟 (ms, ↓) | 322.33 | 262.15 | -18.67% |
| Token 吞吐量 (tokens/秒, ↑) | 183.82 | 230.78 | +25.55% |
分析: INT4 实现了显著的大小(-57%)和内存(-28%)缩减。困惑度有所增加(符合预期,因为压缩更激进),但序列延迟下降,吞吐量提高了约 25%。
| 指标 | FP32 | INT4 | 变化 |
|---|---|---|---|
| 困惑度 (↓) | 13.85 | 21.02 | +51.79% |
| 模型大小 (MB, ↓) | 468.3 | 284.0 | -39.37% |
| 峰值 GPU 内存 (MiB, ↓) | 1007.23 | 659.28 | -34.54% |
| 第一个 token 延迟 (ms/样本, ↓) | 95.79 | 97.87 | +2.18% |
| 序列延迟 (ms/样本, ↓) | 60.35 | 54.64 | -9.46% |
| 吞吐量 (样本/秒, ↑) | 973.41 | 1075.15 | +10.45% |
分析: INT4 实现了显著的大小(-39%)和内存(-35%)缩减。困惑度有所增加(符合预期,因为压缩更激进),但序列延迟下降,吞吐量提高了约 10%。
| 目标/约束 | 首选 INT8 | 首选 INT4 (W4A8) |
|---|---|---|
| 最小精度损失至关重要 | ✔︎ | |
| 最大压缩(磁盘/RAM) | ✔︎ | |
| 带宽受限的推理 | 可能 | 通常更好 |
| 解码器 LLM | ✔︎ | 先尝试并评估 |
| 编码器/分类模型 | ✔︎ | ✔︎ |
| 可用内核/工具成熟度 | ✔︎ | 新兴 |
build() 或前向传播)。