Keras 3 API 文档 / KerasCV / 模型 / 主干网络 / MixTransformer 主干网络

MixTransformer 主干网络

[源代码]

MiTBackbone

keras_cv.models.MiTBackbone(
    include_rescaling,
    depths,
    input_shape=(224, 224, 3),
    input_tensor=None,
    embedding_dims=None,
    **kwargs
)

主干网络模型的基类。

主干网络是模型的可重用层,在标准任务(如 ImageNet 分类)上训练,可以在其他任务中重用。


[源代码]

from_preset 方法

MiTBackbone.from_preset()

从预设配置和权重实例化 MiTBackbone 模型。

参数

  • preset: 字符串。必须是以下之一:"mit_b0","mit_b1","mit_b2","mit_b3","mit_b4","mit_b5","mit_b0_imagenet"。 如果要查找带有预训练权重的预设,请选择 "mit_b0_imagenet"。
  • load_weights: 是否将预训练权重加载到模型中。默认为 None,遵循预设是否具有可用的预训练权重。

示例

# Load architecture and weights from preset
model = keras_cv.models.MiTBackbone.from_preset(
    "mit_b0_imagenet",
)

# Load randomly initialized model from preset architecture with weights
model = keras_cv.models.MiTBackbone.from_preset(
    "mit_b0_imagenet",
    load_weights=False,
预设名称 参数 描述
mit_b0 3.32M 具有 8 个 Transformer 块的 MiT(MixTransformer)模型。
mit_b1 13.16M 具有 8 个 Transformer 块的 MiT(MixTransformer)模型。
mit_b2 24.20M 具有 16 个 Transformer 块的 MiT(MixTransformer)模型。
mit_b3 44.08M 具有 28 个 Transformer 块的 MiT(MixTransformer)模型。
mit_b4 60.85M 具有 41 个 Transformer 块的 MiT(MixTransformer)模型。
mit_b5 81.45M 具有 52 个 Transformer 块的 MiT(MixTransformer)模型。
mit_b0_imagenet 3.32M 具有 8 个 Transformer 块的 MiT(MixTransformer)模型。在 ImageNet-1K 上预训练,验证集上的 top-1 准确率为 69%。

[源代码]

MiTB0Backbone

keras_cv.models.MiTB0Backbone(
    include_rescaling,
    depths,
    input_shape=(224, 224, 3),
    input_tensor=None,
    embedding_dims=None,
    **kwargs
)

MiT 模型。

对于迁移学习用例,请务必阅读 迁移学习和微调指南

参数

  • include_rescaling: bool,是否对输入进行重新缩放。如果设置为 True,则输入将通过 Rescaling(scale=1 / 255) 层。默认为 True。
  • input_shape: 可选的形状元组,默认为 (None, None, 3)。
  • input_tensor: 可选的 Keras 张量(即 layers.Input() 的输出),用作模型的图像输入。

示例

input_data = tf.ones(shape=(8, 224, 224, 3))

# Randomly initialized backbone
model = MiTB0Backbone()
output = model(input_data)

[源代码]

MiTB1Backbone

keras_cv.models.MiTB1Backbone(
    include_rescaling,
    depths,
    input_shape=(224, 224, 3),
    input_tensor=None,
    embedding_dims=None,
    **kwargs
)

MiT 模型。

对于迁移学习用例,请务必阅读 迁移学习和微调指南

参数

  • include_rescaling: bool,是否对输入进行重新缩放。如果设置为 True,则输入将通过 Rescaling(scale=1 / 255) 层。默认为 True。
  • input_shape: 可选的形状元组,默认为 (None, None, 3)。
  • input_tensor: 可选的 Keras 张量(即 layers.Input() 的输出),用作模型的图像输入。

示例

input_data = tf.ones(shape=(8, 224, 224, 3))

# Randomly initialized backbone
model = MiTB1Backbone()
output = model(input_data)

[源代码]

MiTB2Backbone

keras_cv.models.MiTB2Backbone(
    include_rescaling,
    depths,
    input_shape=(224, 224, 3),
    input_tensor=None,
    embedding_dims=None,
    **kwargs
)

MiT 模型。

对于迁移学习用例,请务必阅读 迁移学习和微调指南

参数

  • include_rescaling: bool,是否对输入进行重新缩放。如果设置为 True,则输入将通过 Rescaling(scale=1 / 255) 层。默认为 True。
  • input_shape: 可选的形状元组,默认为 (None, None, 3)。
  • input_tensor: 可选的 Keras 张量(即 layers.Input() 的输出),用作模型的图像输入。

示例

input_data = tf.ones(shape=(8, 224, 224, 3))

# Randomly initialized backbone
model = MiTB2Backbone()
output = model(input_data)

[源代码]

MiTB3Backbone

keras_cv.models.MiTB3Backbone(
    include_rescaling,
    depths,
    input_shape=(224, 224, 3),
    input_tensor=None,
    embedding_dims=None,
    **kwargs
)

MiT 模型。

对于迁移学习用例,请务必阅读 迁移学习和微调指南

参数

  • include_rescaling: bool,是否对输入进行重新缩放。如果设置为 True,则输入将通过 Rescaling(scale=1 / 255) 层。默认为 True。
  • input_shape: 可选的形状元组,默认为 (None, None, 3)。
  • input_tensor: 可选的 Keras 张量(即 layers.Input() 的输出),用作模型的图像输入。

示例

input_data = tf.ones(shape=(8, 224, 224, 3))

# Randomly initialized backbone
model = MiTB3Backbone()
output = model(input_data)

[源代码]

MiTB4Backbone

keras_cv.models.MiTB4Backbone(
    include_rescaling,
    depths,
    input_shape=(224, 224, 3),
    input_tensor=None,
    embedding_dims=None,
    **kwargs
)

MiT 模型。

对于迁移学习用例,请务必阅读 迁移学习和微调指南

参数

  • include_rescaling: bool,是否对输入进行重新缩放。如果设置为 True,则输入将通过 Rescaling(scale=1 / 255) 层。默认为 True。
  • input_shape: 可选的形状元组,默认为 (None, None, 3)。
  • input_tensor: 可选的 Keras 张量(即 layers.Input() 的输出),用作模型的图像输入。

示例

input_data = tf.ones(shape=(8, 224, 224, 3))

# Randomly initialized backbone
model = MiTB4Backbone()
output = model(input_data)

[源代码]

MiTB5Backbone

keras_cv.models.MiTB5Backbone(
    include_rescaling,
    depths,
    input_shape=(224, 224, 3),
    input_tensor=None,
    embedding_dims=None,
    **kwargs
)

MiT 模型。

对于迁移学习用例,请务必阅读 迁移学习和微调指南

参数

  • include_rescaling: bool,是否对输入进行重新缩放。如果设置为 True,则输入将通过 Rescaling(scale=1 / 255) 层。默认为 True。
  • input_shape: 可选的形状元组,默认为 (None, None, 3)。
  • input_tensor: 可选的 Keras 张量(即 layers.Input() 的输出),用作模型的图像输入。

示例

input_data = tf.ones(shape=(8, 224, 224, 3))

# Randomly initialized backbone
model = MiTB5Backbone()
output = model(input_data)