KerasHub: 预训练模型 / 开发者指南 / 加载 Hugging Face Transformers 检查点

加载 Hugging Face Transformers 检查点

作者: Laxma Reddy Patlolla, Divyashree Sreepathihalli
创建日期 2025/06/17
最后修改日期 2025/06/23
描述: 如何从 HuggingFace Hub 上托管的 KerasHub 模型检查点加载并运行推理。

在 Colab 中查看 GitHub 源代码


简介

KerasHub 内置了 HuggingFace 的 .safetensors 模型转换器。因此,从 HuggingFace 加载模型权重并不比使用 KerasHub 自己的预设更困难。

KerasHub 内置的 HuggingFace Transformers 转换器

KerasHub 通过其内置的转换器简化了 HuggingFace Transformers 模型的使用。这些转换器会自动处理将 HuggingFace 模型检查点转换为与 Keras 生态系统兼容的格式。这意味着您只需几行代码,就可以将 HuggingFace Hub 中的各种预训练模型无缝加载到 KerasHub 中。

使用 KerasHub 转换器的主要优点

  • 易于使用:无需手动转换步骤即可加载 HuggingFace 模型。
  • 广泛兼容性:访问 HuggingFace Hub 上提供的各种模型。
  • 无缝集成:使用熟悉的 Keras API 对这些模型进行训练、评估和推理。

幸运的是,所有这些都在幕后进行,因此您可以专注于使用模型,而不是管理转换过程!


设置

开始之前,请确保已安装必要的库。您主要需要 keraskeras_hub

注意: 在 Keras 导入后更改后端可能无法按预期工作。确保在脚本开头设置 KERAS_BACKEND。同样,在 Colab 之外工作时,您可以使用 os.environ["HF_TOKEN"] = "<YOUR_HF_TOKEN>" 来验证 HuggingFace。在使用 Google Colab 时,将您的 HF_TOKEN 设置为“Colab 密钥”。

import os

os.environ["KERAS_BACKEND"] = "jax"  # "tensorflow" or  "torch"

import keras
import keras_hub
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1750715194.841608    7034 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750715194.846143    7034 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1750715194.857298    7034 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750715194.857310    7034 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750715194.857312    7034 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750715194.857313    7034 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.

更改精度

要在经济实惠的硬件上执行推理和训练,您可以通过 keras.config 配置模型精度,如下所示:

import keras

keras.config.set_dtype_policy("bfloat16")

加载 HuggingFace 模型

KerasHub 允许您轻松从 HuggingFace Transformers 加载模型。以下是加载 Gemma 因果语言模型的示例。在这种特定情况下,您需要同意 HuggingFace 上的 Google 许可才能下载模型权重。

# not a keras checkpoint, it is a HF transformer checkpoint

gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("hf://google/gemma-2b")
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

让我们尝试运行一些推理

gemma_lm.generate("I want to say", max_length=30)
'I want to say thank you to the staff at the <strong><em><u><strong><em><u><strong><em><u><strong><em><u><strong><em><u><strong><em>'

使用 Keras model.fit(...) API 对 Gemma Transformer 检查点进行微调

加载 HuggingFace 权重后,您可以像使用任何其他 KerasHub 模型一样使用实例化模型。例如,您可以像这样在自己的数据上微调模型:

features = ["The quick brown fox jumped.", "I forgot my homework."]
gemma_lm.fit(x=features, batch_size=2)
1/1 ━━━━━━━━━━━━━━━━━━━━ 50s 50s/step - loss: 0.0342 - sparse_categorical_accuracy: 0.1538

<keras.src.callbacks.history.History at 0x7435981f3800>

保存并上传新检查点

为了存储和共享您的微调模型,KerasHub 可以轻松地使用标准方法保存或上传它。您可以通过熟悉的命令完成此操作,例如

HF_USERNAME = "<YOUR_HF_USERNAME>"  # provide your hf username
gemma_lm.save_to_preset("./gemma-2b-finetuned")
keras_hub.upload_preset(f"hf://{HF_USERNAME}/gemma-2b-finetune", "./gemma-2b-finetuned")

通过上传您的预设,您可以在任何地方使用以下命令加载它:loaded_model = keras_hub.models.GemmaCausalLM.from_preset("hf://YOUR_HF_USERNAME/gemma-2b-finetuned")

有关上传模型的全面分步指南,请参阅官方 KerasHub 上传文档。您可以在此处找到所有详细信息:KerasHub 上传指南

通过集成 HuggingFace Transformers,KerasHub 显著扩展了您对预训练模型的访问。Hugging Face Hub 现在托管着超过 75 万个模型检查点,涵盖 NLP、计算机视觉、音频等各种领域。其中,大约 40 万个模型目前与 KerasHub 兼容,为您提供了大量多样化的最先进架构,可用于您的项目。

使用 KerasHub,您可以

  • 利用最先进的模型:轻松试验研究社区和行业提供的最新架构和预训练权重。
  • 缩短开发时间:利用现有模型而不是从头开始训练,节省大量时间和计算资源。
  • 增强模型功能:为各种任务(从文本生成和翻译到图像分割和目标检测)寻找专门模型。

这种无缝访问使您能够使用 Keras 构建更强大、更复杂的 AI 应用程序。


使用更广泛的框架

Keras 3,以及 KerasHub,旨在实现多框架兼容性。这意味着您可以使用不同的后端框架(如 JAX、TensorFlow 和 PyTorch)运行模型。这种灵活性允许您

  • 根据您的需求选择最佳后端:根据性能特征、硬件兼容性(例如,JAX 的 TPU)或现有团队专业知识选择后端。
  • 互操作性:更轻松地将 KerasHub 模型集成到可能基于 TensorFlow 或 PyTorch 的现有工作流中。
  • 面向未来:适应不断变化的框架格局,而无需重写核心模型逻辑。

在 JAX 后端和 TPU 上运行 Transformer 模型

要使用 JAX 试验模型,您可以通过将其后端设置为 JAX 来使用 Keras。通过在模型构建之前切换 Keras 的后端,并确保您的环境连接到 TPU 运行时。Keras 将自动利用 JAX 的 TPU 支持,从而使您的模型能够在 TPU 硬件上高效训练,而无需进一步的代码更改。

import os

os.environ["KERAS_BACKEND"] = "jax"
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("hf://google/gemma-2b")

其他示例

生成

以下是使用 Llama 的一个示例:将 PyTorch Hugging Face Transformer 检查点加载到 KerasHub 并在 JAX 后端运行它。

import os

os.environ["KERAS_BACKEND"] = "jax"

from keras_hub.models import Llama3CausalLM

# Get the model
causal_lm = Llama3CausalLM.from_preset("hf://NousResearch/Hermes-2-Pro-Llama-3-8B")

prompts = [
    """<|im_start|>system
You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.<|im_end|>
<|im_start|>user
Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.<|im_end|>
<|im_start|>assistant""",
]

# Generate from the model
causal_lm.generate(prompts, max_length=30)[0]
'system\nYou are a sentient, superintelligent artificial general intelligence, here to teach and assist me.\nuser\nWrite a'

与 Transformers 的比较

在下表中,我们对 HuggingFace 的 Transformers 库与 KerasHub 进行了详细比较

特征 HF Transformers KerasHub
支持的框架 PyTorch JAX, PyTorch, TensorFlow
训练器 HF 训练器 Keras model.fit(...) — 支持几乎所有功能,例如分布式训练、学习率调度、优化器选择等。
分词器 AutoTokenizer KerasHub 分词器
自动类 auto 关键词 KerasHub 自动检测特定于任务的类
模型加载 AutoModel.from_pretrained() keras_hub.models.<任务>.from_preset()

KerasHub 使用特定于任务的类(例如,CausalLM, Classifier, Backbone)和 from_preset() 方法来加载预训练模型,类似于 HuggingFace 的方法。

支持 HF URL、Kaggle URL 和本地目录
模型保存 model.save_pretrained()
tokenizer.save_pretrained()
model.save_to_preset() — 将模型(包括分词器/预处理器)保存到本地目录(预设)。重新加载或上传所需的所有组件都已保存。
模型上传 将权重上传到 HF 平台 KerasHub 上传指南
Hugging Face 上的 Keras
权重文件分片 权重文件分片 大型模型权重被分片以实现高效上传/下载
PEFT 使用 HuggingFace PEFT 内置 LoRA 支持
backbone.enable_lora(rank=n)
backbone.save_lora_weights(filepath)
backbone.load_lora_weights(filepath)
核心模型抽象 PreTrainedModel, AutoModel, 特定任务模型 Backbone, Preprocessor, Task
模型配置 PretrainedConfig: 模型配置的基类 配置以多个 JSON 文件的形式存储在预设目录中:config.json, preprocessor.json, task.json, tokenizer.json 等。
预处理 分词器/预处理器通常单独处理,然后传递给模型 内置于特定任务模型中
混合精度训练 通过训练参数 Keras 全局策略设置
与 SafeTensors 的兼容性 默认权重格式 在 HF 上的 770k+ SafeTensors 模型中,那些在 KerasHub 中具有匹配架构的模型可以使用 keras_hub.models.X.from_preset() 加载

去尝试加载其他模型权重!您可以在 HuggingFace 上找到更多选项,并使用 from_preset("hf://<命名空间>/<模型名称>") 使用它们。

祝您实验愉快!