FluxBackbone
类keras_hub.models.FluxBackbone(
input_channels,
hidden_size,
mlp_ratio,
num_heads,
depth,
depth_single_blocks,
axes_dim,
theta,
use_bias,
guidance_embed=False,
image_shape=(None, 768, 3072),
text_shape=(None, 768, 3072),
image_ids_shape=(None, 768, 3072),
text_ids_shape=(None, 768, 3072),
y_shape=(None, 128),
**kwargs
)
用于序列流匹配的 Transformer 模型。
该模型处理图像和文本数据,并关联位置和时间步嵌入,可选择应用 guidance 嵌入。双流块处理独立的图像流和文本流,而单流块则合并这些流。移植自:https://github.com/black-forest-labs/flux
参数
num_heads
整除。调用参数
抛出异常
hidden_size
不能被 num_heads
整除,或者如果 sum(axes_dim)
不等于位置嵌入维度。from_preset
方法FluxBackbone.from_preset(preset, load_weights=True, **kwargs)
从模型预设中实例化一个 keras_hub.models.Backbone
。
预设是用于保存和加载预训练模型的配置、权重及其他文件资产的目录。preset
可以是以下类型之一:
'bert_base_en'
'kaggle://user/bert/keras/bert_base_en'
'hf://user/bert_base_en'
'./bert_base_en'
此构造函数可以通过两种方式之一调用。可以从基类调用,例如 keras_hub.models.Backbone.from_preset()
,或从模型类调用,例如 keras_hub.models.GemmaBackbone.from_preset()
。如果从基类调用,则返回对象的子类将从预设目录中的配置推断出来。
对于任何 Backbone
子类,您可以运行 cls.presets.keys()
来列出该类上所有可用的内置预设。
参数
True
,则将权重加载到模型架构中。如果为 False
,则权重将随机初始化。示例
# Load a Gemma backbone with pre-trained weights.
model = keras_hub.models.Backbone.from_preset(
"gemma_2b_en",
)
# Load a Bert backbone with a pre-trained config and random weights.
model = keras_hub.models.Backbone.from_preset(
"bert_base_en",
load_weights=False,
)