代码示例 / 计算机视觉 / 使用基于注意力的深度多示例学习 (MIL) 进行分类。

使用基于注意力的深度多示例学习 (MIL) 进行分类。

作者: Mohamad Jaber
创建日期 2021/08/16
最后修改日期 2021/11/25
描述:用于对实例包进行分类并获取其各个实例得分的 MIL 方法。

ⓘ 此示例使用 Keras 3

在 Colab 中查看 GitHub 源代码


简介

什么是多示例学习 (MIL)?

通常,使用监督学习算法时,学习器会收到一组实例的标签。在 MIL 的情况下,学习器会收到一组包的标签,每个包都包含一组实例。如果包中至少包含一个正实例,则该包被标记为正包;如果包中不包含任何正实例,则该包被标记为负包。

动机

在图像分类任务中,通常假设每张图像都清晰地表示一个类别标签。在医学影像(例如计算病理学等)中,整个图像由一个单一的类别标签(癌性/非癌性)表示,或者可以给出感兴趣区域。但是,人们会想知道图像中的哪些模式实际上导致它属于该类别。在这种情况下,图像将被分割,子图像将形成实例包。

因此,目标是

  1. 学习一个模型来预测一组实例的类别标签。
  2. 找出包中哪些实例导致了正类别标签预测。

实现

以下步骤描述了模型的工作原理

  1. 特征提取层提取特征嵌入。
  2. 嵌入被馈送到 MIL 注意力层以获得注意力分数。该层被设计为排列不变。
  3. 输入特征及其对应的注意力分数相乘。
  4. 将结果输出传递到 softmax 函数以进行分类。

参考文献


设置

import numpy as np
import keras
from keras import layers
from keras import ops
from tqdm import tqdm
from matplotlib import pyplot as plt

plt.style.use("ggplot")

创建数据集

我们将创建一组包,并根据它们的内容分配标签。如果一个包中至少存在一个正实例,则该包被视为正包。如果它不包含任何正实例,则该包将被视为负包。

配置参数

  • POSITIVE_CLASS:要保留在正包中的目标类别。
  • BAG_COUNT:训练包的数量。
  • VAL_BAG_COUNT:验证包的数量。
  • BAG_SIZE:一个包中的实例数量。
  • PLOT_SIZE:要绘制的包的数量。
  • ENSEMBLE_AVG_COUNT:要创建并平均在一起的模型数量。(可选:通常会导致更好的性能 - 设置为 1 以用于单个模型)
POSITIVE_CLASS = 1
BAG_COUNT = 1000
VAL_BAG_COUNT = 300
BAG_SIZE = 3
PLOT_SIZE = 3
ENSEMBLE_AVG_COUNT = 1

准备包

由于注意力算子是排列不变算子,因此将具有正类别标签的实例随机放置在正包中的实例中。

def create_bags(input_data, input_labels, positive_class, bag_count, instance_count):
    # Set up bags.
    bags = []
    bag_labels = []

    # Normalize input data.
    input_data = np.divide(input_data, 255.0)

    # Count positive samples.
    count = 0

    for _ in range(bag_count):
        # Pick a fixed size random subset of samples.
        index = np.random.choice(input_data.shape[0], instance_count, replace=False)
        instances_data = input_data[index]
        instances_labels = input_labels[index]

        # By default, all bags are labeled as 0.
        bag_label = 0

        # Check if there is at least a positive class in the bag.
        if positive_class in instances_labels:
            # Positive bag will be labeled as 1.
            bag_label = 1
            count += 1

        bags.append(instances_data)
        bag_labels.append(np.array([bag_label]))

    print(f"Positive bags: {count}")
    print(f"Negative bags: {bag_count - count}")

    return (list(np.swapaxes(bags, 0, 1)), np.array(bag_labels))


# Load the MNIST dataset.
(x_train, y_train), (x_val, y_val) = keras.datasets.mnist.load_data()

# Create training data.
train_data, train_labels = create_bags(
    x_train, y_train, POSITIVE_CLASS, BAG_COUNT, BAG_SIZE
)

# Create validation data.
val_data, val_labels = create_bags(
    x_val, y_val, POSITIVE_CLASS, VAL_BAG_COUNT, BAG_SIZE
)
Positive bags: 283
Negative bags: 717
Positive bags: 104
Negative bags: 196

创建模型

我们现在将构建注意力层,准备一些实用程序,然后构建和训练整个模型。

注意力算子实现

此层的输出大小由单个包的大小决定。

注意力机制使用包中实例的加权平均值,其中权重的总和必须等于 1(与包大小无关)。

权重矩阵(参数)是 wv。为了包含正值和负值,使用双曲正切逐元素非线性。

可以使用门控注意力机制来处理复杂的關係。另一个权重矩阵 u 被添加到计算中。使用 sigmoid 非线性来克服双曲正切非线性对x ∈ [−1, 1] 的近似线性行为。

class MILAttentionLayer(layers.Layer):
    """Implementation of the attention-based Deep MIL layer.

    Args:
      weight_params_dim: Positive Integer. Dimension of the weight matrix.
      kernel_initializer: Initializer for the `kernel` matrix.
      kernel_regularizer: Regularizer function applied to the `kernel` matrix.
      use_gated: Boolean, whether or not to use the gated mechanism.

    Returns:
      List of 2D tensors with BAG_SIZE length.
      The tensors are the attention scores after softmax with shape `(batch_size, 1)`.
    """

    def __init__(
        self,
        weight_params_dim,
        kernel_initializer="glorot_uniform",
        kernel_regularizer=None,
        use_gated=False,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.weight_params_dim = weight_params_dim
        self.use_gated = use_gated

        self.kernel_initializer = keras.initializers.get(kernel_initializer)
        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)

        self.v_init = self.kernel_initializer
        self.w_init = self.kernel_initializer
        self.u_init = self.kernel_initializer

        self.v_regularizer = self.kernel_regularizer
        self.w_regularizer = self.kernel_regularizer
        self.u_regularizer = self.kernel_regularizer

    def build(self, input_shape):
        # Input shape.
        # List of 2D tensors with shape: (batch_size, input_dim).
        input_dim = input_shape[0][1]

        self.v_weight_params = self.add_weight(
            shape=(input_dim, self.weight_params_dim),
            initializer=self.v_init,
            name="v",
            regularizer=self.v_regularizer,
            trainable=True,
        )

        self.w_weight_params = self.add_weight(
            shape=(self.weight_params_dim, 1),
            initializer=self.w_init,
            name="w",
            regularizer=self.w_regularizer,
            trainable=True,
        )

        if self.use_gated:
            self.u_weight_params = self.add_weight(
                shape=(input_dim, self.weight_params_dim),
                initializer=self.u_init,
                name="u",
                regularizer=self.u_regularizer,
                trainable=True,
            )
        else:
            self.u_weight_params = None

        self.input_built = True

    def call(self, inputs):
        # Assigning variables from the number of inputs.
        instances = [self.compute_attention_scores(instance) for instance in inputs]

        # Stack instances into a single tensor.
        instances = ops.stack(instances)

        # Apply softmax over instances such that the output summation is equal to 1.
        alpha = ops.softmax(instances, axis=0)

        # Split to recreate the same array of tensors we had as inputs.
        return [alpha[i] for i in range(alpha.shape[0])]

    def compute_attention_scores(self, instance):
        # Reserve in-case "gated mechanism" used.
        original_instance = instance

        # tanh(v*h_k^T)
        instance = ops.tanh(ops.tensordot(instance, self.v_weight_params, axes=1))

        # for learning non-linear relations efficiently.
        if self.use_gated:
            instance = instance * ops.sigmoid(
                ops.tensordot(original_instance, self.u_weight_params, axes=1)
            )

        # w^T*(tanh(v*h_k^T)) / w^T*(tanh(v*h_k^T)*sigmoid(u*h_k^T))
        return ops.tensordot(instance, self.w_weight_params, axes=1)

可视化工具

绘制相对于类别的包数量(由 PLOT_SIZE 给出)。

此外,如果启用,则可以查看每个包的类别标签预测及其关联的实例得分(在模型经过训练之后)。

def plot(data, labels, bag_class, predictions=None, attention_weights=None):
    """ "Utility for plotting bags and attention weights.

    Args:
      data: Input data that contains the bags of instances.
      labels: The associated bag labels of the input data.
      bag_class: String name of the desired bag class.
        The options are: "positive" or "negative".
      predictions: Class labels model predictions.
      If you don't specify anything, ground truth labels will be used.
      attention_weights: Attention weights for each instance within the input data.
      If you don't specify anything, the values won't be displayed.
    """
    return  ## TODO
    labels = np.array(labels).reshape(-1)

    if bag_class == "positive":
        if predictions is not None:
            labels = np.where(predictions.argmax(1) == 1)[0]
            bags = np.array(data)[:, labels[0:PLOT_SIZE]]

        else:
            labels = np.where(labels == 1)[0]
            bags = np.array(data)[:, labels[0:PLOT_SIZE]]

    elif bag_class == "negative":
        if predictions is not None:
            labels = np.where(predictions.argmax(1) == 0)[0]
            bags = np.array(data)[:, labels[0:PLOT_SIZE]]
        else:
            labels = np.where(labels == 0)[0]
            bags = np.array(data)[:, labels[0:PLOT_SIZE]]

    else:
        print(f"There is no class {bag_class}")
        return

    print(f"The bag class label is {bag_class}")
    for i in range(PLOT_SIZE):
        figure = plt.figure(figsize=(8, 8))
        print(f"Bag number: {labels[i]}")
        for j in range(BAG_SIZE):
            image = bags[j][i]
            figure.add_subplot(1, BAG_SIZE, j + 1)
            plt.grid(False)
            if attention_weights is not None:
                plt.title(np.around(attention_weights[labels[i]][j], 2))
            plt.imshow(image)
        plt.show()


# Plot some of validation data bags per class.
plot(val_data, val_labels, "positive")
plot(val_data, val_labels, "negative")

创建模型

首先,我们将为每个实例创建一些嵌入,调用注意力算子,然后使用 softmax 函数输出类别概率。

def create_model(instance_shape):
    # Extract features from inputs.
    inputs, embeddings = [], []
    shared_dense_layer_1 = layers.Dense(128, activation="relu")
    shared_dense_layer_2 = layers.Dense(64, activation="relu")
    for _ in range(BAG_SIZE):
        inp = layers.Input(instance_shape)
        flatten = layers.Flatten()(inp)
        dense_1 = shared_dense_layer_1(flatten)
        dense_2 = shared_dense_layer_2(dense_1)
        inputs.append(inp)
        embeddings.append(dense_2)

    # Invoke the attention layer.
    alpha = MILAttentionLayer(
        weight_params_dim=256,
        kernel_regularizer=keras.regularizers.L2(0.01),
        use_gated=True,
        name="alpha",
    )(embeddings)

    # Multiply attention weights with the input layers.
    multiply_layers = [
        layers.multiply([alpha[i], embeddings[i]]) for i in range(len(alpha))
    ]

    # Concatenate layers.
    concat = layers.concatenate(multiply_layers, axis=1)

    # Classification output node.
    output = layers.Dense(2, activation="softmax")(concat)

    return keras.Model(inputs, output)

类别权重

由于这种类型的问题可能简单地变成不平衡数据分类问题,因此应该考虑类别加权。

假设有 1000 个包。通常可能存在大约 90% 的包不包含任何正标签,而大约 10% 的包包含正标签。这种数据可以被称为不平衡数据

使用类别权重,模型将倾向于对稀有类别赋予更高的权重。

def compute_class_weights(labels):
    # Count number of postive and negative bags.
    negative_count = len(np.where(labels == 0)[0])
    positive_count = len(np.where(labels == 1)[0])
    total_count = negative_count + positive_count

    # Build class weight dictionary.
    return {
        0: (1 / negative_count) * (total_count / 2),
        1: (1 / positive_count) * (total_count / 2),
    }

构建和训练模型

模型的构建和训练在本节进行。

def train(train_data, train_labels, val_data, val_labels, model):
    # Train model.
    # Prepare callbacks.
    # Path where to save best weights.

    # Take the file name from the wrapper.
    file_path = "/tmp/best_model.weights.h5"

    # Initialize model checkpoint callback.
    model_checkpoint = keras.callbacks.ModelCheckpoint(
        file_path,
        monitor="val_loss",
        verbose=0,
        mode="min",
        save_best_only=True,
        save_weights_only=True,
    )

    # Initialize early stopping callback.
    # The model performance is monitored across the validation data and stops training
    # when the generalization error cease to decrease.
    early_stopping = keras.callbacks.EarlyStopping(
        monitor="val_loss", patience=10, mode="min"
    )

    # Compile model.
    model.compile(
        optimizer="adam",
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )

    # Fit model.
    model.fit(
        train_data,
        train_labels,
        validation_data=(val_data, val_labels),
        epochs=20,
        class_weight=compute_class_weights(train_labels),
        batch_size=1,
        callbacks=[early_stopping, model_checkpoint],
        verbose=0,
    )

    # Load best weights.
    model.load_weights(file_path)

    return model


# Building model(s).
instance_shape = train_data[0][0].shape
models = [create_model(instance_shape) for _ in range(ENSEMBLE_AVG_COUNT)]

# Show single model architecture.
print(models[0].summary())

# Training model(s).
trained_models = [
    train(train_data, train_labels, val_data, val_labels, model)
    for model in tqdm(models)
]
Model: "functional_1"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)         Output Shape       Param #  Connected to         ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer         │ (None, 28, 28)    │       0 │ -                    │
│ (InputLayer)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_layer_1       │ (None, 28, 28)    │       0 │ -                    │
│ (InputLayer)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_layer_2       │ (None, 28, 28)    │       0 │ -                    │
│ (InputLayer)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ flatten (Flatten)   │ (None, 784)       │       0 │ input_layer[0][0]    │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ flatten_1 (Flatten) │ (None, 784)       │       0 │ input_layer_1[0][0]  │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ flatten_2 (Flatten) │ (None, 784)       │       0 │ input_layer_2[0][0]  │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dense (Dense)       │ (None, 128)       │ 100,480 │ flatten[0][0],       │
│                     │                   │         │ flatten_1[0][0],     │
│                     │                   │         │ flatten_2[0][0]      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dense_1 (Dense)     │ (None, 64)        │   8,256 │ dense[0][0],         │
│                     │                   │         │ dense[1][0],         │
│                     │                   │         │ dense[2][0]          │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ alpha               │ [(None, 1),       │  33,024 │ dense_1[0][0],       │
│ (MILAttentionLayer) │ (None, 1), (None, │         │ dense_1[1][0],       │
│                     │ 1)]               │         │ dense_1[2][0]        │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ multiply (Multiply) │ (None, 64)        │       0 │ alpha[0][0],         │
│                     │                   │         │ dense_1[0][0]        │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ multiply_1          │ (None, 64)        │       0 │ alpha[0][1],         │
│ (Multiply)          │                   │         │ dense_1[1][0]        │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ multiply_2          │ (None, 64)        │       0 │ alpha[0][2],         │
│ (Multiply)          │                   │         │ dense_1[2][0]        │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ concatenate         │ (None, 192)       │       0 │ multiply[0][0],      │
│ (Concatenate)       │                   │         │ multiply_1[0][0],    │
│                     │                   │         │ multiply_2[0][0]     │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dense_2 (Dense)     │ (None, 2)         │     386 │ concatenate[0][0]    │
└─────────────────────┴───────────────────┴─────────┴──────────────────────┘
 Total params: 142,146 (555.26 KB)
 Trainable params: 142,146 (555.26 KB)
 Non-trainable params: 0 (0.00 B)
None

100%|██████████████████████████████████████████████████████████████████████████████████| 1/1 [00:36<00:00, 36.67s/it]

模型评估

模型现在准备评估。对于每个模型,我们还创建一个相关的中间模型以获取注意力层的权重。

我们将计算我们 `ENSEMBLE_AVG_COUNT` 个模型中每个模型的预测,并将它们平均起来以获得最终预测。

def predict(data, labels, trained_models):
    # Collect info per model.
    models_predictions = []
    models_attention_weights = []
    models_losses = []
    models_accuracies = []

    for model in trained_models:
        # Predict output classes on data.
        predictions = model.predict(data)
        models_predictions.append(predictions)

        # Create intermediate model to get MIL attention layer weights.
        intermediate_model = keras.Model(model.input, model.get_layer("alpha").output)

        # Predict MIL attention layer weights.
        intermediate_predictions = intermediate_model.predict(data)

        attention_weights = np.squeeze(np.swapaxes(intermediate_predictions, 1, 0))
        models_attention_weights.append(attention_weights)

        loss, accuracy = model.evaluate(data, labels, verbose=0)
        models_losses.append(loss)
        models_accuracies.append(accuracy)

    print(
        f"The average loss and accuracy are {np.sum(models_losses, axis=0) / ENSEMBLE_AVG_COUNT:.2f}"
        f" and {100 * np.sum(models_accuracies, axis=0) / ENSEMBLE_AVG_COUNT:.2f} % resp."
    )

    return (
        np.sum(models_predictions, axis=0) / ENSEMBLE_AVG_COUNT,
        np.sum(models_attention_weights, axis=0) / ENSEMBLE_AVG_COUNT,
    )


# Evaluate and predict classes and attention scores on validation data.
class_predictions, attention_params = predict(val_data, val_labels, trained_models)

# Plot some results from our validation data.
plot(
    val_data,
    val_labels,
    "positive",
    predictions=class_predictions,
    attention_weights=attention_params,
)
plot(
    val_data,
    val_labels,
    "negative",
    predictions=class_predictions,
    attention_weights=attention_params,
)
 10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 53ms/step
 10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 39ms/step
The average loss and accuracy are 0.03 and 99.00 % resp.

结论

从上面的图中,您可以注意到权重总是加起来为 1。在正预测包中,导致正标签的实例将具有比包中其他实例明显更高的注意力分数。然而,在负预测包中,存在两种情况。

  • 所有实例将具有近似相似的分数。
  • 一个实例将具有相对较高的分数(但不像正实例那么高)。这是因为此实例的特征空间接近正实例的特征空间。

备注

  • 如果模型过拟合,则所有包的权重将均匀分布。因此,正则化技术是必要的。
  • 在论文中,包的大小可能因包而异。为了简单起见,这里包的大小是固定的。
  • 为了不依赖于单个模型的随机初始权重,应该考虑平均集成方法。