BASNet
类keras_cv.models.BASNet(
backbone,
num_classes,
input_shape=(None, None, 3),
input_tensor=None,
include_rescaling=False,
projection_filters=64,
prediction_heads=None,
refinement_head=None,
**kwargs
)
一个实现用于语义分割的 BASNet 架构的 Keras 模型。
参考文献
参数
keras.Model
。模型的骨干网络,用作 BASNet 预测编码器的特征提取器。当前支持的骨干是 ResNet18 和 ResNet34。默认骨干是 keras_cv.models.ResNet34Backbone()
(注意:不要在骨干中指定“input_shape”、“input_tensor”或“include_rescaling”。请在初始化“BASNet”模型时提供这些参数。)layers.Input()
的输出),用作模型的图像输入。True
,输入将通过 Rescaling(1/255.0)
层。backbone
投影低级特征的卷积层中的过滤器数量。keras.layers.Layer
的列表,定义模型的预测模块头。如果未提供,则使用 Conv2D 层后跟调整大小来创建默认头。keras.layers.Layer
,定义模型的细化模块头。如果未提供,则使用 Conv2D 层创建默认头。示例
import keras_cv
images = np.ones(shape=(1, 288, 288, 3))
labels = np.zeros(shape=(1, 288, 288, 1))
# Note: Do not specify 'input_shape', 'input_tensor', or
# 'include_rescaling' within the backbone.
backbone = keras_cv.models.ResNet34Backbone()
model = keras_cv.models.segmentation.BASNet(
backbone=backbone,
num_classes=1,
input_shape=[288, 288, 3],
include_rescaling=False
)
# Evaluate model
output = model(images)
pred_labels = output[0]
# Train model
model.compile(
optimizer="adam",
loss=keras.losses.BinaryCrossentropy(from_logits=False),
metrics=["accuracy"],
)
model.fit(images, labels, epochs=3)
```
----
<span style="float:right;">[[source]](https://github.com/keras-team/keras-cv/tree/v0.9.0/keras_cv/src/models/task.py#L183)</span>
### `from_preset` method
```python
BASNet.from_preset()
从预设配置和权重实例化 BASNet 模型。
参数
None
,它遵循预设是否有预训练权重可用。None
。如果为 None
,将使用预设值。示例
# Load architecture and weights from preset
model = keras_cv.models.BASNet.from_preset(
"",
)
# Load randomly initialized model from preset architecture with weights
model = keras_cv.models.BASNet.from_preset(
"",
load_weights=False,
预设名称 | 参数 | 描述 |
---|---|---|
basnet_resnet18 | 98.78M | 具有 ResNet18 v1 骨干的 BASNet。 |
basnet_resnet34 | 108.90M | 具有 ResNet34 v1 骨干的 BASNet。 |