CausalLM 类keras_hub.models.CausalLM()
用于生成式语言建模任务的基类。
CausalLM 任务封装了一个 keras_hub.models.Backbone 和一个 keras_hub.models.Preprocessor,以创建一个可用于生成和生成式微调的模型。
CausalLM 任务提供了一个额外的、高级的 generate() 函数,该函数可用于以字符串输入、字符串输出的方式,逐个 token 自回归地采样模型。所有 CausalLM 类的 compile() 方法都包含一个额外的 sampler 参数,可用于传递一个 keras_hub.samplers.Sampler 来控制如何对预测分布进行采样。
调用 fit() 时,将应用因果掩码逐个 token 预测分词后的输入,这为控制推理时间生成提供了预训练和监督式微调的设置。
所有 CausalLM 任务都包含一个 from_preset() 构造函数,可用于加载预训练的配置和权重。
示例
# Load a GPT2 backbone with pre-trained weights.
causal_lm = keras_hub.models.CausalLM.from_preset(
"gpt2_base_en",
)
causal_lm.compile(sampler="top_k")
causal_lm.generate("Keras is a", max_length=64)
# Load a Mistral instruction tuned checkpoint at bfloat16 precision.
causal_lm = keras_hub.models.CausalLM.from_preset(
"mistral_instruct_7b_en",
dtype="bfloat16",
)
causal_lm.compile(sampler="greedy")
causal_lm.generate("Keras is a", max_length=64)
from_preset 方法CausalLM.from_preset(preset, load_weights=True, **kwargs)
从模型预设实例化一个 keras_hub.models.Task。
预设是一个包含配置、权重和其他文件资产的目录,用于保存和加载预训练模型。preset 可以作为以下之一传递:
'bert_base_en''kaggle://user/bert/keras/bert_base_en''hf://user/bert_base_en''./bert_base_en'对于任何 Task 子类,您都可以运行 cls.presets.keys() 来列出该类上所有可用的内置预设。
此构造函数可以通过两种方式调用。一种方式是从特定任务的基类(如 keras_hub.models.CausalLM.from_preset())调用,另一种方式是从模型类(如 keras_hub.models.BertTextClassifier.from_preset())调用。如果从基类调用,返回对象的子类将从预设目录中的配置推断出来。
参数
True,已保存的权重将被加载到模型架构中。如果为 False,所有权重将被随机初始化。示例
# Load a Gemma generative task.
causal_lm = keras_hub.models.CausalLM.from_preset(
"gemma_2b_en",
)
# Load a Bert classification task.
model = keras_hub.models.TextClassifier.from_preset(
"bert_base_en",
num_classes=2,
)
| 预设 | 参数 | 描述 |
|---|---|---|
| bart_base_en | 139.42M | 6 层 BART 模型,大小写保持不变。在 BookCorpus、英文维基百科和 CommonCrawl 上训练。 |
| bart_large_en | 406.29M | 12 层 BART 模型,大小写保持不变。在 BookCorpus、英文维基百科和 CommonCrawl 上训练。 |
| bart_large_en_cnn | 406.29M | 在 CNN+DM 摘要数据集上微调的 bart_large_en 骨干模型。 |
| bloom_560m_multi | 559.21M | 24 层 Bloom 模型,隐藏维度为 1024。在 45 种自然语言和 12 种编程语言上训练。 |
| bloomz_560m_multi | 559.21M | 24 层 Bloom 模型,隐藏维度为 1024。在跨语言任务混合 (xP3) 数据集上进行微调。 |
| bloom_1.1b_multi | 1.07B | 24 层 Bloom 模型,隐藏维度为 1536。在 45 种自然语言和 12 种编程语言上训练。 |
| bloomz_1.1b_multi | 1.07B | 24 层 Bloom 模型,隐藏维度为 1536。在跨语言任务混合 (xP3) 数据集上进行微调。 |
| bloom_1.7b_multi | 1.72B | 24 层 Bloom 模型,隐藏维度为 2048。在 45 种自然语言和 12 种编程语言上训练。 |
| bloomz_1.7b_multi | 1.72B | 24 层 Bloom 模型,隐藏维度为 2048。在跨语言任务混合 (xP3) 数据集上进行微调。 |
| bloom_3b_multi | 3.00B | 30 层 Bloom 模型,隐藏维度为 2560。在 45 种自然语言和 12 种编程语言上训练。 |
| bloomz_3b_multi | 3.00B | 30 层 Bloom 模型,隐藏维度为 2560。在跨语言任务混合 (xP3) 数据集上进行微调。 |
| falcon_refinedweb_1b_en | 1.31B | 24 层 Falcon 模型(参数为 10 亿的 Falcon),在 RefinedWeb 数据集的 3500 亿个标记上训练。 |
| vault_gemma_1b_en | 1.04B | 10 亿参数,26 层,VaultGemma 模型。 |
| gemma_2b_en | 25.1 亿 | 20 亿参数、18 层的 Gemma 基础模型。 |
| gemma_instruct_2b_en | 25.1 亿 | 20 亿参数、18 层的 Gemma 指令微调模型。 |
| gemma_1.1_instruct_2b_en | 25.1 亿 | 20 亿参数、18 层的 Gemma 指令微调模型。1.1 版本更新提高了模型质量。 |
| code_gemma_1.1_2b_en | 25.1 亿 | 20 亿参数、18 层的 CodeGemma 模型。此模型针对代码补全的“填充中间”(FIM) 任务进行了训练。1.1 版本更新提高了模型质量。 |
| code_gemma_2b_en | 25.1 亿 | 20 亿参数、18 层的 CodeGemma 模型。此模型针对代码补全的“填充中间”(FIM) 任务进行了训练。 |
| gemma2_2b_en | 26.1 亿 | 20 亿参数、26 层的 Gemma 基础模型。 |
| gemma2_instruct_2b_en | 26.1 亿 | 20 亿参数、26 层的 Gemma 指令微调模型。 |
| shieldgemma_2b_en | 26.1 亿 | 20 亿参数、26 层的 ShieldGemma 模型。 |
| c2s_scale_gemma_2_2b_en | 26.1 亿 | 一个 20 亿参数,基于 Gemma-2 架构构建的单细胞生物学感知模型。 |
| gemma_7b_en | 85.4 亿 | 70 亿参数、28 层的 Gemma 基础模型。 |
| gemma_instruct_7b_en | 85.4 亿 | 70 亿参数、28 层的 Gemma 指令微调模型。 |
| gemma_1.1_instruct_7b_en | 85.4 亿 | 70 亿参数、28 层的 Gemma 指令微调模型。1.1 版本更新提高了模型质量。 |
| code_gemma_7b_en | 85.4 亿 | 70 亿参数、28 层的 CodeGemma 模型。此模型针对代码补全的“填充中间”(FIM) 任务进行了训练。 |
| code_gemma_instruct_7b_en | 85.4 亿 | 70 亿参数、28 层的 CodeGemma 指令微调模型。此模型针对与代码相关的聊天用例进行了训练。 |
| code_gemma_1.1_instruct_7b_en | 85.4 亿 | 70 亿参数、28 层的 CodeGemma 指令微调模型。此模型针对与代码相关的聊天用例进行了训练。1.1 版本更新提高了模型质量。 |
| gemma2_9b_en | 92.4 亿 | 90 亿参数、42 层的 Gemma 基础模型。 |
| gemma2_instruct_9b_en | 92.4 亿 | 90 亿参数、42 层的 Gemma 指令微调模型。 |
| shieldgemma_9b_en | 92.4 亿 | 90 亿参数、42 层的 ShieldGemma 模型。 |
| gemma2_27b_en | 272.3 亿 | 270 亿参数、42 层的 Gemma 基础模型。 |
| gemma2_instruct_27b_en | 272.3 亿 | 270 亿参数、42 层的 Gemma 指令微调模型。 |
| shieldgemma_27b_en | 272.3 亿 | 270 亿参数、42 层的 ShieldGemma 模型。 |
| c2s_scale_gemma_2_27b_en | 272.3 亿 | 一个 270 亿参数,基于 Gemma-2 架构构建的单细胞生物学感知模型。 |
| gemma3_270m | 268.10M | 2.7 亿参数(1.7 亿嵌入参数,1 亿 Transformer 参数)模型,18 层,纯文本模型,专为超高效 AI 设计,特别适合任务特定微调。 |
| gemma3_instruct_270m | 268.10M | 2.7 亿参数(1.7 亿嵌入参数,1 亿 Transformer 参数)模型,18 层,纯文本模型,指令微调模型,专为超高效 AI 设计,特别适合任务特定微调。 |
| gemma3_1b | 999.89M | 10 亿参数,26 层,仅文本预训练 Gemma3 模型。 |
| gemma3_instruct_1b | 999.89M | 10 亿参数,26 层,仅文本指令微调 Gemma3 模型。 |
| gemma3_4b_text | 3.88B | 40 亿参数,34 层,仅文本预训练 Gemma3 模型。 |
| gemma3_instruct_4b_text | 3.88B | 40 亿参数,34 层,仅文本指令微调 Gemma3 模型。 |
| gemma3_4b | 4.30B | 40 亿参数,34 层,视觉+文本预训练 Gemma3 模型。 |
| gemma3_instruct_4b | 4.30B | 40 亿参数,34 层,视觉+文本指令微调 Gemma3 模型。 |
| gemma3_12b_text | 11.77B | 120 亿参数,48 层,仅文本预训练 Gemma3 模型。 |
| gemma3_instruct_12b_text | 11.77B | 120 亿参数,48 层,仅文本指令微调 Gemma3 模型。 |
| gemma3_12b | 12.19B | 120 亿参数,48 层,视觉+文本预训练 Gemma3 模型。 |
| gemma3_instruct_12b | 12.19B | 120 亿参数,48 层,视觉+文本指令微调 Gemma3 模型。 |
| gemma3_27b_text | 27.01B | 270 亿参数,62 层,仅文本预训练 Gemma3 模型。 |
| gemma3_instruct_27b_text | 27.01B | 270 亿参数,62 层,仅文本指令微调 Gemma3 模型。 |
| gemma3_27b | 27.43B | 270 亿参数,62 层,视觉+文本预训练 Gemma3 模型。 |
| gemma3_instruct_27b | 27.43B | 270 亿参数,62 层,视觉+文本指令微调 Gemma3 模型。 |
| gpt2_base_en | 124.44M | 12 层 GPT-2 模型,大小写保持不变。在 WebText 上训练。 |
| gpt2_base_en_cnn_dailymail | 124.44M | 12 层 GPT-2 模型,大小写保持不变。在 CNN/DailyMail 摘要数据集上微调。 |
| gpt2_medium_en | 354.82M | 24 层 GPT-2 模型,大小写保持不变。在 WebText 上训练。 |
| gpt2_large_en | 774.03M | 36 层 GPT-2 模型,大小写保持不变。在 WebText 上训练。 |
| gpt2_extra_large_en | 1.56B | 48 层 GPT-2 模型,大小写保持不变。在 WebText 上训练。 |
| llama2_7b_en | 6.74B | 70 亿参数,32 层,基础 LLaMA 2 模型。 |
| llama2_instruct_7b_en | 6.74B | 70 亿参数,32 层,指令微调 LLaMA 2 模型。 |
| vicuna_1.5_7b_en | 6.74B | 70 亿参数,32 层,指令微调 Vicuna v1.5 模型。 |
| llama2_7b_en_int8 | 6.74B | 70 亿参数,32 层,基础 LLaMA 2 模型,激活和权重均量化为 int8。 |
| llama2_instruct_7b_en_int8 | 6.74B | 70 亿参数,32 层,指令微调 LLaMA 2 模型,激活和权重均量化为 int8。 |
| llama3.2_1b | 15.0 亿 | 10 亿参数、16 层的 LLaMA 3.2 基础模型。 |
| llama3.2_instruct_1b | 15.0 亿 | 10 亿参数、16 层的经过指令调优的 LLaMA 3.2 模型。 |
| llama3.2_guard_1b | 15.0 亿 | 10 亿参数、16 层的 LLaMA 3.2 基础模型,为同意安全分类进行了微调。 |
| llama3.2_3b | 36.1 亿 | 30 亿参数、26 层的 LLaMA 3.2 基础模型。 |
| llama3.2_instruct_3b | 36.1 亿 | 30 亿参数、28 层的经过指令调优的 LLaMA 3.2 模型。 |
| llama3_8b_en | 80.3 亿 | 80 亿参数、32 层的 LLaMA 3 基础模型。 |
| llama3_instruct_8b_en | 80.3 亿 | 80 亿参数、32 层的经过指令调优的 LLaMA 3 模型。 |
| llama3.1_8b | 80.3 亿 | 80 亿参数、32 层的 LLaMA 3.1 基础模型。 |
| llama3.1_instruct_8b | 80.3 亿 | 80 亿参数、32 层的经过指令调优的 LLaMA 3.1 模型。 |
| llama3.1_guard_8b | 80.3 亿 | 80 亿参数、32 层的 LLaMA 3.1 模型,为同意安全分类进行了微调。 |
| llama3_8b_en_int8 | 80.3 亿 | 80 亿参数、32 层的 LLaMA 3 基础模型,其激活和权重被量化为 int8。 |
| llama3_instruct_8b_en_int8 | 80.3 亿 | 80 亿参数、32 层的经过指令调优的 LLaMA 3 模型,其激活和权重被量化为 int8。 |
| mistral_7b_en | 72.4 亿 | Mistral 7B 基础模型 |
| mistral_instruct_7b_en | 72.4 亿 | Mistral 7B 指令模型 |
| mistral_0.2_instruct_7b_en | 72.4 亿 | Mistral 7B instruct version 0.2 model |
| mistral_0.3_7b_en | 7.25B | Mistral 7B base version 0.3 model |
| mistral_0.3_instruct_7b_en | 7.25B | Mistral 7B instruct version 0.3 model |
| mixtral_8_7b_en | 46.70B | 32 层 Mixtral MoE 模型,具有 70 亿个活动参数和每个 MoE 层 8 个专家。 |
| mixtral_8_instruct_7b_en | 46.70B | 指令微调 32 层 Mixtral MoE 模型,具有 70 亿个活动参数和每个 MoE 层 8 个专家。 |
| moonshine_tiny_en | 27.09M | 用于英语语音识别的 Moonshine tiny 模型。由 Useful Sensors 开发,用于实时转录。 |
| moonshine_base_en | 61.51M | 用于英语语音识别的 Moonshine base 模型。由 Useful Sensors 开发,用于实时转录。 |
| opt_125m_en | 125.24M | 12 层 OPT 模型,大小写保持不变。在 BookCorpus、CommonCrawl、Pile 和 PushShift.io 语料库上训练。 |
| opt_1.3b_en | 1.32B | 24 层 OPT 模型,大小写保持不变。在 BookCorpus、CommonCrawl、Pile 和 PushShift.io 语料库上训练。 |
| opt_2.7b_en | 2.70B | 32 层 OPT 模型,大小写保持不变。在 BookCorpus、CommonCrawl、Pile 和 PushShift.io 语料库上训练。 |
| opt_6.7b_en | 6.70B | 32 层 OPT 模型,大小写保持不变。在 BookCorpus、CommonCrawl、Pile 和 PushShift.io 语料库上训练。 |
| pali_gemma_3b_mix_224 | 2.92B | 图像大小 224,混合微调,文本序列长度为 256 |
| pali_gemma_3b_224 | 2.92B | 图像大小 224,预训练,文本序列长度为 128 |
| pali_gemma_3b_mix_448 | 2.92B | 图像大小 448,混合微调,文本序列长度为 512 |
| pali_gemma_3b_448 | 2.92B | 图像大小 448,预训练,文本序列长度为 512 |
| pali_gemma_3b_896 | 2.93B | 图像大小 896,预训练,文本序列长度为 512 |
| pali_gemma2_mix_3b_224 | 3.03B | 30 亿参数,图像大小 224,SigLIP-So400m 视觉编码器 27 层,Gemma2 2B 语言模型 26 层。此模型已在各种视觉语言任务和领域上进行微调。 |
| pali_gemma2_pt_3b_224 | 3.03B | 30 亿参数,图像大小 224,SigLIP-So400m 视觉编码器 27 层,Gemma2 2B 语言模型 26 层。此模型已在混合数据集上进行预训练。 |
| pali_gemma_2_ft_docci_3b_448 | 3.03B | 30 亿参数,图像大小 448,SigLIP-So400m 视觉编码器 27 层,Gemma2 2B 语言模型 26 层。此模型已在 DOCCI 数据集上进行微调,以改进具有细粒度细节的描述。 |
| pali_gemma2_mix_3b_448 | 3.03B | 30 亿参数,图像大小 448,SigLIP-So400m 视觉编码器 27 层,Gemma2 2B 语言模型 26 层。此模型已在各种视觉语言任务和领域上进行微调。 |
| pali_gemma2_pt_3b_448 | 3.03B | 30 亿参数,图像大小 448,SigLIP-So400m 视觉编码器 27 层,Gemma2 2B 语言模型 26 层。此模型已在混合数据集上进行预训练。 |
| pali_gemma2_pt_3b_896 | 3.04B | 30 亿参数,图像大小 896,SigLIP-So400m 视觉编码器 27 层,Gemma2 2B 语言模型 26 层。此模型已在混合数据集上进行预训练。 |
| pali_gemma2_mix_10b_224 | 9.66B | 100 亿参数,图像大小 224,SigLIP-So400m 视觉编码器 27 层,Gemma2 9B 语言模型 42 层。此模型已在各种视觉语言任务和领域上进行微调。 |
| pali_gemma2_pt_10b_224 | 9.66B | 100 亿参数,图像大小 224,SigLIP-So400m 视觉编码器 27 层,Gemma2 9B 语言模型 42 层。此模型已在混合数据集上进行预训练。 |
| pali_gemma2_ft_docci_10b_448 | 9.66B | 100 亿参数,SigLIP-So400m 视觉编码器 27 层,Gemma2 9B 语言模型 42 层。此模型已在 DOCCI 数据集上进行微调,以改进具有细粒度细节的描述。 |
| pali_gemma2_mix_10b_448 | 9.66B | 100 亿参数,图像大小 448,SigLIP-So400m 视觉编码器 27 层,Gemma2 9B 语言模型 42 层。此模型已在各种视觉语言任务和领域上进行微调。 |
| pali_gemma2_pt_10b_448 | 9.66B | 100 亿参数,图像大小 448,SigLIP-So400m 视觉编码器 27 层,Gemma2 9B 语言模型 42 层。此模型已在混合数据集上进行预训练。 |
| pali_gemma2_pt_10b_896 | 9.67B | 100 亿参数,图像大小 896,SigLIP-So400m 视觉编码器 27 层,Gemma2 9B 语言模型 42 层。此模型已在混合数据集上进行预训练。 |
| pali_gemma2_mix_28b_224 | 27.65B | 280 亿参数,图像大小 224,SigLIP-So400m 视觉编码器 27 层,Gemma2 27B 语言模型 46 层。此模型已在各种视觉语言任务和领域上进行微调。 |
| pali_gemma2_mix_28b_448 | 27.65B | 280 亿参数,图像大小 448,SigLIP-So400m 视觉编码器 27 层,Gemma2 27B 语言模型 46 层。此模型已在各种视觉语言任务和领域上进行微调。 |
| pali_gemma2_pt_28b_224 | 27.65B | 280 亿参数,图像大小 224,SigLIP-So400m 视觉编码器 27 层,Gemma2 27B 语言模型 46 层。此模型已在混合数据集上进行预训练。 |
| pali_gemma2_pt_28b_448 | 27.65B | 280 亿参数,图像大小 448,SigLIP-So400m 视觉编码器 27 层,Gemma2 27B 语言模型 46 层。此模型已在混合数据集上进行预训练。 |
| pali_gemma2_pt_28b_896 | 27.65B | 280 亿参数,图像大小 896,SigLIP-So400m 视觉编码器 27 层,Gemma2 27B 语言模型 46 层。此模型已在混合数据集上进行预训练。 |
| parseq | 23.83M | 用于场景文本识别的排列自回归序列 (PARSeq) 基础模型。 |
| phi3_mini_4k_instruct_en | 3.82B | 38 亿参数,32 层,4k 上下文长度,Phi-3 模型。该模型使用 Phi-3 数据集训练。该数据集包含合成数据和经过筛选的公开可用网站数据,重点关注高质量和推理密集型特性。 |
| phi3_mini_128k_instruct_en | 3.82B | 38 亿参数,32 层,128k 上下文长度,Phi-3 模型。该模型使用 Phi-3 数据集训练。该数据集包含合成数据和经过筛选的公开可用网站数据,重点关注高质量和推理密集型特性。 |
| qwen2.5_0.5b_en | 494.03M | 24 层 Qwen 模型,参数为 5 亿。 |
| qwen2.5_instruct_0.5b_en | 494.03M | 指令微调 24 层 Qwen 模型,参数为 5 亿。 |
| qwen2.5_3b_en | 3.09B | 36 层 Qwen 模型,参数为 31 亿。 |
| qwen2.5_7b_en | 6.99B | 48 层 Qwen 模型,参数为 70 亿。 |
| qwen2.5_instruct_32b_en | 32.76B | 指令微调 64 层 Qwen 模型,参数为 320 亿。 |
| qwen2.5_instruct_72b_en | 72.71B | 指令微调 80 层 Qwen 模型,参数为 720 亿。 |
| qwen3_0.6b_en | 596.05M | 28 层 Qwen3 模型,具有 5.96 亿参数,针对资源受限设备的效率和快速推理进行了优化。 |
| qwen3_1.7b_en | 1.72B | 28 层 Qwen3 模型,具有 17.2 亿参数,在性能和资源使用之间取得了良好的平衡。 |
| qwen3_4b_en | 4.02B | 36 层 Qwen3 模型,具有 40.2 亿参数,提供了比小型变体更强的推理能力和更好的性能。 |
| qwen3_8b_en | 8.19B | 36 层 Qwen3 模型,具有 81.9 亿参数,具备增强的推理、编码和指令遵循能力。 |
| qwen3_14b_en | 14.77B | 40 层 Qwen3 模型,具有 147.7 亿参数,具备先进的推理、编码和多语言能力。 |
| qwen3_32b_en | 32.76B | 64 层 Qwen3 模型,具有 327.6 亿参数,在推理、编码和通用语言任务方面均达到最先进的性能。 |
| qwen3_moe_30b_a3b_en | 30.53B | 混合专家 (MoE) 模型拥有 305 亿总参数,激活 33 亿参数,基于 48 层构建,并利用 32 个查询头和 4 个键/值注意力头,拥有 128 个专家(激活 8 个)。 |
| qwen3_moe_235b_a22b_en | 235.09B | 混合专家 (MoE) 模型拥有 2350 亿总参数,激活 220 亿参数,基于 94 层构建,并利用 64 个查询头和 4 个键/值注意力头,拥有 128 个专家(激活 8 个)。 |
| qwen1.5_moe_2.7b_en | 14.32B | 24 层 Qwen MoE 模型,具有 27 亿个活动参数和每个 MoE 层 8 个专家。 |
| t5gemma_s_s_ul2 | 312.52M | T5Gemma S/S 模型,具有小型编码器和小型解码器,被适配为 UL2 模型。 |
| t5gemma_s_s_prefixlm | 312.52M | T5Gemma S/S 模型,具有小型编码器和小型解码器,被适配为前缀语言模型。 |
| t5gemma_s_s_ul2_it | 312.52M | T5Gemma S/S 模型,具有小型编码器和小型解码器,被适配为 UL2 模型,并进行了指令遵循微调。 |
| t5gemma_s_s_prefixlm_it | 312.52M | T5Gemma S/S 模型,具有小型编码器和小型解码器,被适配为前缀语言模型,并进行了指令遵循微调。 |
| t5gemma_b_b_ul2 | 591.49M | T5Gemma B/B 模型,具有基础编码器和基础解码器,被适配为 UL2 模型。 |
| t5gemma_b_b_prefixlm | 591.49M | T5Gemma B/B 模型,具有基础编码器和基础解码器,被适配为前缀语言模型。 |
| t5gemma_b_b_ul2_it | 591.49M | T5Gemma B/B 模型,具有基础编码器和基础解码器,被适配为 UL2 模型,并进行了指令遵循微调。 |
| t5gemma_b_b_prefixlm_it | 591.49M | T5Gemma B/B 模型,具有基础编码器和基础解码器,被适配为前缀语言模型,并进行了指令遵循微调。 |
| t5gemma_l_l_ul2 | 1.24B | T5Gemma L/L 模型,具有大型编码器和大型解码器,被适配为 UL2 模型。 |
| t5gemma_l_l_prefixlm | 1.24B | T5Gemma L/L 模型,具有大型编码器和大型解码器,被适配为前缀语言模型。 |
| t5gemma_l_l_ul2_it | 1.24B | T5Gemma L/L 模型,具有大型编码器和大型解码器,被适配为 UL2 模型,并进行了指令遵循微调。 |
| t5gemma_l_l_prefixlm_it | 1.24B | T5Gemma L/L 模型,具有大型编码器和大型解码器,被适配为前缀语言模型,并进行了指令遵循微调。 |
| t5gemma_ml_ml_ul2 | 2.20B | T5Gemma ML/ML 模型,具有中大型编码器和中大型解码器,被适配为 UL2 模型。 |
| t5gemma_ml_ml_prefixlm | 2.20B | T5Gemma ML/ML 模型,具有中大型编码器和中大型解码器,被适配为前缀语言模型。 |
| t5gemma_ml_ml_ul2_it | 2.20B | T5Gemma ML/ML 模型,具有中大型编码器和中大型解码器,被适配为 UL2 模型,并进行了指令遵循微调。 |
| t5gemma_ml_ml_prefixlm_it | 2.20B | T5Gemma ML/ML 模型,具有中大型编码器和中大型解码器,被适配为前缀语言模型,并进行了指令遵循微调。 |
| t5gemma_xl_xl_ul2 | 3.77B | T5Gemma XL/XL 模型,具有超大型编码器和超大型解码器,被适配为 UL2 模型。 |
| t5gemma_xl_xl_prefixlm | 3.77B | T5Gemma XL/XL 模型,具有超大型编码器和超大型解码器,被适配为前缀语言模型。 |
| t5gemma_xl_xl_ul2_it | 3.77B | T5Gemma XL/XL 模型,具有超大型编码器和超大型解码器,被适配为 UL2 模型,并进行了指令遵循微调。 |
| t5gemma_xl_xl_prefixlm_it | 3.77B | T5Gemma XL/XL 模型,具有超大型编码器和超大型解码器,被适配为前缀语言模型,并进行了指令遵循微调。 |
| t5gemma_2b_2b_ul2 | 5.60B | T5Gemma 2B/2B 模型,具有 20 亿参数的编码器和 20 亿参数的解码器,被适配为 UL2 模型。 |
| t5gemma_2b_2b_prefixlm | 5.60B | T5Gemma 2B/2B 模型,具有 20 亿参数的编码器和 20 亿参数的解码器,被适配为前缀语言模型。 |
| t5gemma_2b_2b_ul2_it | 5.60B | T5Gemma 2B/2B 模型,具有 20 亿参数的编码器和 20 亿参数的解码器,被适配为 UL2 模型,并进行了指令遵循微调。 |
| t5gemma_2b_2b_prefixlm_it | 5.60B | T5Gemma 2B/2B 模型,具有 20 亿参数的编码器和 20 亿参数的解码器,被适配为前缀语言模型,并进行了指令遵循微调。 |
| t5gemma_9b_2b_ul2 | 12.29B | T5Gemma 9B/2B 模型,具有 90 亿参数的编码器和 20 亿参数的解码器,被适配为 UL2 模型。 |
| t5gemma_9b_2b_prefixlm | 12.29B | T5Gemma 9B/2B 模型,具有 90 亿参数的编码器和 20 亿参数的解码器,被适配为前缀语言模型。 |
| t5gemma_9b_2b_ul2_it | 12.29B | T5Gemma 9B/2B 模型,具有 90 亿参数的编码器和 20 亿参数的解码器,被适配为 UL2 模型,并进行了指令遵循微调。 |
| t5gemma_9b_2b_prefixlm_it | 12.29B | T5Gemma 9B/2B 模型,具有 90 亿参数的编码器和 20 亿参数的解码器,被适配为前缀语言模型,并进行了指令遵循微调。 |
| t5gemma_9b_9b_ul2 | 20.33B | T5Gemma 9B/9B 模型,具有 90 亿参数的编码器和 90 亿参数的解码器,被适配为 UL2 模型。 |
| t5gemma_9b_9b_prefixlm | 20.33B | T5Gemma 9B/9B 模型,具有 90 亿参数的编码器和 90 亿参数的解码器,被适配为前缀语言模型。 |
| t5gemma_9b_9b_ul2_it | 20.33B | T5Gemma 9B/9B 模型,具有 90 亿参数的编码器和 90 亿参数的解码器,被适配为 UL2 模型,并进行了指令遵循微调。 |
| t5gemma_9b_9b_prefixlm_it | 20.33B | T5Gemma 9B/9B 模型,具有 90 亿参数的编码器和 90 亿参数的解码器,被适配为前缀语言模型,并进行了指令遵循微调。 |
compile 方法CausalLM.compile(
optimizer="auto", loss="auto", weighted_metrics="auto", sampler="top_k", **kwargs
)
配置 CausalLM 任务以进行训练和生成。
CausalLM 任务通过为 optimizer、loss 和 weighted_metrics 设置默认值,扩展了 keras.Model.compile 的默认编译签名。要覆盖这些默认值,请在编译期间为这些参数传递任何值。
CausalLM 任务向 compile 添加了一个新的 sampler 参数,可用于控制 generate 函数使用的采样策略。
请注意,由于训练输入包含从损失中排除的填充标记,因此使用 weighted_metrics 而不是 metrics 进行编译几乎总是一个好主意。
参数
"auto"、优化器名称或 keras.Optimizer 实例。默认为 "auto",它会为给定的模型和任务使用默认优化器。有关可能的 optimizer 值的更多信息,请参阅 keras.Model.compile 和 keras.optimizers。"auto"、损失名称或 keras.losses.Loss 实例。默认为 "auto",对于 token 分类 CausalLM 任务,将应用 keras.losses.SparseCategoricalCrossentropy 损失。有关可能的 loss 值,请参阅 keras.Model.compile 和 keras.losses。"auto" 或要在训练和测试期间由模型评估的指标列表。默认为 "auto",将应用 keras.metrics.SparseCategoricalAccuracy 来跟踪模型在猜测掩码标记值方面的准确性。有关可能的 weighted_metrics 值,请参阅 keras.Model.compile 和 keras.metrics。keras_hub.samplers.Sampler 实例。配置 generate() 调用期间使用的采样方法。有关内置采样策略的完整列表,请参阅 keras_hub.samplers。keras.Model.compile。generate 方法CausalLM.generate(inputs, max_length=None, stop_token_ids="auto", strip_prompt=False)
根据提示 inputs 生成文本。
此方法根据给定的 inputs 生成文本。用于生成的采样方法可以通过 compile() 方法设置。
如果 inputs 是一个 tf.data.Dataset,输出将“逐批”生成并连接起来。否则,所有输入将被视为单个批次处理。
如果模型附加了 preprocessor,inputs 将在 generate() 函数内部进行预处理,并且应与 preprocessor 层期望的结构匹配(通常是原始字符串)。如果未附加 preprocessor,则 inputs 应与 backbone 期望的结构匹配。请参阅上面的示例用法以了解每个的演示。
参数
tf.data.Dataset。如果模型附加了 preprocessor,inputs 应与 preprocessor 层期望的结构匹配。如果未附加 preprocessor,inputs 应与 backbone 模型期望的结构匹配。preprocessor 配置的 sequence_length 最大值。如果 preprocessor 为 None,则 inputs 应填充到所需的最大长度,并且此参数将被忽略。None、"auto" 或 token ID 的元组。默认为 "auto",它使用 preprocessor.tokenizer.end_token_id。不指定处理器将产生错误。None 在生成 max_length 个 token 后停止生成。您也可以指定一个 token ID 列表,模型应在此停止。请注意,token 序列中的每个序列都将被解释为一个停止 token,不支持多 token 停止序列。save_to_preset 方法CausalLM.save_to_preset(preset_dir, max_shard_size=10)
将任务保存到预设目录。
参数
int 或 float。每个分片文件的最大大小(以 GB 为单位)。如果为 None,则不进行分片。默认为 10。preprocessor 属性keras_hub.models.CausalLM.preprocessor
用于预处理输入的 keras_hub.models.Preprocessor 层。
backbone 属性keras_hub.models.CausalLM.backbone
一个具有核心架构的 keras_hub.models.Backbone 模型。