Keras 3 API 文档 / KerasCV / 模型 / 任务 / BASNet 分割

BASNet 分割

[来源]

BASNet

keras_cv.models.BASNet(
    backbone,
    num_classes,
    input_shape=(None, None, 3),
    input_tensor=None,
    include_rescaling=False,
    projection_filters=64,
    prediction_heads=None,
    refinement_head=None,
    **kwargs
)

一个实现用于语义分割的 BASNet 架构的 Keras 模型。

参考文献

参数

  • backbone: keras.Model。模型的骨干网络,用作 BASNet 预测编码器的特征提取器。当前支持的骨干是 ResNet18 和 ResNet34。默认骨干是 keras_cv.models.ResNet34Backbone()(注意:不要在骨干中指定“input_shape”、“input_tensor”或“include_rescaling”。请在初始化“BASNet”模型时提供这些参数。)
  • num_classes: int,分割模型的类别数。
  • input_shape: 可选的形状元组,默认为 (None, None, 3)。
  • input_tensor: 可选的 Keras 张量(即 layers.Input() 的输出),用作模型的图像输入。
  • include_rescaling: bool,是否重新调整输入。如果设置为 True,输入将通过 Rescaling(1/255.0) 层。
  • projection_filters: int,从 backbone 投影低级特征的卷积层中的过滤器数量。
  • prediction_heads: (可选)keras.layers.Layer 的列表,定义模型的预测模块头。如果未提供,则使用 Conv2D 层后跟调整大小来创建默认头。
  • refinement_head: (可选)keras.layers.Layer,定义模型的细化模块头。如果未提供,则使用 Conv2D 层创建默认头。

示例

import keras_cv

images = np.ones(shape=(1, 288, 288, 3))
labels = np.zeros(shape=(1, 288, 288, 1))

# Note: Do not specify 'input_shape', 'input_tensor', or
# 'include_rescaling' within the backbone.
backbone = keras_cv.models.ResNet34Backbone()
model = keras_cv.models.segmentation.BASNet(
    backbone=backbone,
    num_classes=1,
    input_shape=[288, 288, 3],
    include_rescaling=False
)

# Evaluate model
output = model(images)
pred_labels = output[0]

# Train model
model.compile(
    optimizer="adam",
    loss=keras.losses.BinaryCrossentropy(from_logits=False),
    metrics=["accuracy"],
)
model.fit(images, labels, epochs=3)
    ```


----

<span style="float:right;">[[source]](https://github.com/keras-team/keras-cv/tree/v0.9.0/keras_cv/src/models/task.py#L183)</span>

### `from_preset` method


```python
BASNet.from_preset()

从预设配置和权重实例化 BASNet 模型。

参数

  • preset: 字符串。必须是“resnet18”、“resnet34”、“basnet_resnet18”、“basnet_resnet34”之一。如果要查找具有预训练权重的预设,请选择“”。
  • load_weights: 是否将预训练权重加载到模型中。默认为 None,它遵循预设是否有预训练权重可用。
  • input_shape : 将传递给骨干初始化的输入形状,默认为 None。如果为 None,将使用预设值。

示例

# Load architecture and weights from preset
model = keras_cv.models.BASNet.from_preset(
    "",
)

# Load randomly initialized model from preset architecture with weights
model = keras_cv.models.BASNet.from_preset(
    "",
    load_weights=False,
预设名称 参数 描述
basnet_resnet18 98.78M 具有 ResNet18 v1 骨干的 BASNet。
basnet_resnet34 108.90M 具有 ResNet34 v1 骨干的 BASNet。