代码示例 / 计算机视觉 / 基于注意力机制的深度多示例学习 (MIL) 分类。

基于注意力机制的深度多示例学习 (MIL) 分类。

作者: Mohamad Jaber
创建日期 2021/08/16
最后修改日期 2021/11/25
描述: 使用 MIL 方法对实例包进行分类并获取其单个实例分数。

ⓘ 此示例使用 Keras 3

在 Colab 中查看 GitHub 源代码


简介

什么是多示例学习 (MIL)?

通常,在使用监督学习算法时,学习器接收一组实例的标签。在 MIL 的情况下,学习器接收一组包的标签,每个包都包含一组实例。如果包中至少包含一个正实例,则该包被标记为正,如果没有包含任何正实例,则该包被标记为负。

动机

在图像分类任务中,通常假设每张图像都清楚地代表一个类别标签。在医学成像(例如,计算病理学等)中,整个图像由一个类别标签(癌性/非癌性)表示,或者可以给出感兴趣的区域。然而,人们会感兴趣地想知道图像中哪些模式实际上导致它属于该类别。在这种情况下,图像将被分割,子图像将形成实例包。

因此,目标是

  1. 学习一个模型来预测一个实例包的类别标签。
  2. 找出包中哪些实例导致了正类别标签预测。

实现

以下步骤描述了模型的工作方式

  1. 特征提取器层提取特征嵌入。
  2. 将嵌入馈送到 MIL 注意力层以获取注意力分数。该层被设计为排列不变的。
  3. 输入特征及其对应的注意力分数相乘。
  4. 将生成的输出传递给 softmax 函数进行分类。

参考


设置

import numpy as np
import keras
from keras import layers
from keras import ops
from tqdm import tqdm
from matplotlib import pyplot as plt

plt.style.use("ggplot")

创建数据集

我们将创建一组包,并根据其内容分配其标签。如果一个包中至少有一个正实例,则该包被视为正包。如果它不包含任何正实例,则该包将被视为负包。

配置参数

  • POSITIVE_CLASS:要保留在正包中的所需类别。
  • BAG_COUNT:训练包的数量。
  • VAL_BAG_COUNT:验证包的数量。
  • BAG_SIZE:一个包中的实例数量。
  • PLOT_SIZE:要绘制的包的数量。
  • ENSEMBLE_AVG_COUNT:要创建和平均在一起的模型数量。(可选:通常会带来更好的性能 - 对于单个模型设置为 1)
POSITIVE_CLASS = 1
BAG_COUNT = 1000
VAL_BAG_COUNT = 300
BAG_SIZE = 3
PLOT_SIZE = 3
ENSEMBLE_AVG_COUNT = 1

准备包

由于注意力算子是排列不变算子,因此具有正类别标签的实例随机放置在正包中的实例之间。

def create_bags(input_data, input_labels, positive_class, bag_count, instance_count):
    # Set up bags.
    bags = []
    bag_labels = []

    # Normalize input data.
    input_data = np.divide(input_data, 255.0)

    # Count positive samples.
    count = 0

    for _ in range(bag_count):
        # Pick a fixed size random subset of samples.
        index = np.random.choice(input_data.shape[0], instance_count, replace=False)
        instances_data = input_data[index]
        instances_labels = input_labels[index]

        # By default, all bags are labeled as 0.
        bag_label = 0

        # Check if there is at least a positive class in the bag.
        if positive_class in instances_labels:
            # Positive bag will be labeled as 1.
            bag_label = 1
            count += 1

        bags.append(instances_data)
        bag_labels.append(np.array([bag_label]))

    print(f"Positive bags: {count}")
    print(f"Negative bags: {bag_count - count}")

    return (list(np.swapaxes(bags, 0, 1)), np.array(bag_labels))


# Load the MNIST dataset.
(x_train, y_train), (x_val, y_val) = keras.datasets.mnist.load_data()

# Create training data.
train_data, train_labels = create_bags(
    x_train, y_train, POSITIVE_CLASS, BAG_COUNT, BAG_SIZE
)

# Create validation data.
val_data, val_labels = create_bags(
    x_val, y_val, POSITIVE_CLASS, VAL_BAG_COUNT, BAG_SIZE
)
Positive bags: 283
Negative bags: 717
Positive bags: 104
Negative bags: 196

创建模型

我们现在将构建注意力层,准备一些实用程序,然后构建和训练整个模型。

注意力算子实现

该层的输出大小由单个包的大小决定。

注意力机制使用包中实例的加权平均值,其中权重之和必须等于 1(与包大小无关)。

权重矩阵(参数)是 wv。为了包含正值和负值,使用了双曲正切逐元素非线性。

可以使用门控注意力机制来处理复杂的关系。将另一个权重矩阵 u 添加到计算中。使用 sigmoid 非线性来克服双曲正切非线性在 x ∈ [−1, 1] 的近似线性行为。

class MILAttentionLayer(layers.Layer):
    """Implementation of the attention-based Deep MIL layer.

    Args:
      weight_params_dim: Positive Integer. Dimension of the weight matrix.
      kernel_initializer: Initializer for the `kernel` matrix.
      kernel_regularizer: Regularizer function applied to the `kernel` matrix.
      use_gated: Boolean, whether or not to use the gated mechanism.

    Returns:
      List of 2D tensors with BAG_SIZE length.
      The tensors are the attention scores after softmax with shape `(batch_size, 1)`.
    """

    def __init__(
        self,
        weight_params_dim,
        kernel_initializer="glorot_uniform",
        kernel_regularizer=None,
        use_gated=False,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.weight_params_dim = weight_params_dim
        self.use_gated = use_gated

        self.kernel_initializer = keras.initializers.get(kernel_initializer)
        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)

        self.v_init = self.kernel_initializer
        self.w_init = self.kernel_initializer
        self.u_init = self.kernel_initializer

        self.v_regularizer = self.kernel_regularizer
        self.w_regularizer = self.kernel_regularizer
        self.u_regularizer = self.kernel_regularizer

    def build(self, input_shape):
        # Input shape.
        # List of 2D tensors with shape: (batch_size, input_dim).
        input_dim = input_shape[0][1]

        self.v_weight_params = self.add_weight(
            shape=(input_dim, self.weight_params_dim),
            initializer=self.v_init,
            name="v",
            regularizer=self.v_regularizer,
            trainable=True,
        )

        self.w_weight_params = self.add_weight(
            shape=(self.weight_params_dim, 1),
            initializer=self.w_init,
            name="w",
            regularizer=self.w_regularizer,
            trainable=True,
        )

        if self.use_gated:
            self.u_weight_params = self.add_weight(
                shape=(input_dim, self.weight_params_dim),
                initializer=self.u_init,
                name="u",
                regularizer=self.u_regularizer,
                trainable=True,
            )
        else:
            self.u_weight_params = None

        self.input_built = True

    def call(self, inputs):
        # Assigning variables from the number of inputs.
        instances = [self.compute_attention_scores(instance) for instance in inputs]

        # Stack instances into a single tensor.
        instances = ops.stack(instances)

        # Apply softmax over instances such that the output summation is equal to 1.
        alpha = ops.softmax(instances, axis=0)

        # Split to recreate the same array of tensors we had as inputs.
        return [alpha[i] for i in range(alpha.shape[0])]

    def compute_attention_scores(self, instance):
        # Reserve in-case "gated mechanism" used.
        original_instance = instance

        # tanh(v*h_k^T)
        instance = ops.tanh(ops.tensordot(instance, self.v_weight_params, axes=1))

        # for learning non-linear relations efficiently.
        if self.use_gated:
            instance = instance * ops.sigmoid(
                ops.tensordot(original_instance, self.u_weight_params, axes=1)
            )

        # w^T*(tanh(v*h_k^T)) / w^T*(tanh(v*h_k^T)*sigmoid(u*h_k^T))
        return ops.tensordot(instance, self.w_weight_params, axes=1)

可视化工具

根据类别绘制包的数量(由 PLOT_SIZE 指定)。

此外,如果激活,则可以看到每个包的类别标签预测及其相关的实例分数(在模型训练完成后)。

def plot(data, labels, bag_class, predictions=None, attention_weights=None):
    """ "Utility for plotting bags and attention weights.

    Args:
      data: Input data that contains the bags of instances.
      labels: The associated bag labels of the input data.
      bag_class: String name of the desired bag class.
        The options are: "positive" or "negative".
      predictions: Class labels model predictions.
      If you don't specify anything, ground truth labels will be used.
      attention_weights: Attention weights for each instance within the input data.
      If you don't specify anything, the values won't be displayed.
    """
    return  ## TODO
    labels = np.array(labels).reshape(-1)

    if bag_class == "positive":
        if predictions is not None:
            labels = np.where(predictions.argmax(1) == 1)[0]
            bags = np.array(data)[:, labels[0:PLOT_SIZE]]

        else:
            labels = np.where(labels == 1)[0]
            bags = np.array(data)[:, labels[0:PLOT_SIZE]]

    elif bag_class == "negative":
        if predictions is not None:
            labels = np.where(predictions.argmax(1) == 0)[0]
            bags = np.array(data)[:, labels[0:PLOT_SIZE]]
        else:
            labels = np.where(labels == 0)[0]
            bags = np.array(data)[:, labels[0:PLOT_SIZE]]

    else:
        print(f"There is no class {bag_class}")
        return

    print(f"The bag class label is {bag_class}")
    for i in range(PLOT_SIZE):
        figure = plt.figure(figsize=(8, 8))
        print(f"Bag number: {labels[i]}")
        for j in range(BAG_SIZE):
            image = bags[j][i]
            figure.add_subplot(1, BAG_SIZE, j + 1)
            plt.grid(False)
            if attention_weights is not None:
                plt.title(np.around(attention_weights[labels[i]][j], 2))
            plt.imshow(image)
        plt.show()


# Plot some of validation data bags per class.
plot(val_data, val_labels, "positive")
plot(val_data, val_labels, "negative")

创建模型

首先,我们将为每个实例创建一些嵌入,调用注意力运算符,然后使用 softmax 函数输出类别概率。

def create_model(instance_shape):
    # Extract features from inputs.
    inputs, embeddings = [], []
    shared_dense_layer_1 = layers.Dense(128, activation="relu")
    shared_dense_layer_2 = layers.Dense(64, activation="relu")
    for _ in range(BAG_SIZE):
        inp = layers.Input(instance_shape)
        flatten = layers.Flatten()(inp)
        dense_1 = shared_dense_layer_1(flatten)
        dense_2 = shared_dense_layer_2(dense_1)
        inputs.append(inp)
        embeddings.append(dense_2)

    # Invoke the attention layer.
    alpha = MILAttentionLayer(
        weight_params_dim=256,
        kernel_regularizer=keras.regularizers.L2(0.01),
        use_gated=True,
        name="alpha",
    )(embeddings)

    # Multiply attention weights with the input layers.
    multiply_layers = [
        layers.multiply([alpha[i], embeddings[i]]) for i in range(len(alpha))
    ]

    # Concatenate layers.
    concat = layers.concatenate(multiply_layers, axis=1)

    # Classification output node.
    output = layers.Dense(2, activation="softmax")(concat)

    return keras.Model(inputs, output)

类别权重

由于此类问题可能简单地转化为不平衡数据分类问题,因此应考虑类别加权。

假设有 1000 个包。通常会出现大约 90% 的包不包含任何正标签,而大约 10% 的包包含正标签的情况。这种数据可以被称为不平衡数据

使用类别权重,模型会倾向于给稀有类别更高的权重。

def compute_class_weights(labels):
    # Count number of positive and negative bags.
    negative_count = len(np.where(labels == 0)[0])
    positive_count = len(np.where(labels == 1)[0])
    total_count = negative_count + positive_count

    # Build class weight dictionary.
    return {
        0: (1 / negative_count) * (total_count / 2),
        1: (1 / positive_count) * (total_count / 2),
    }

构建和训练模型

本节构建并训练模型。

def train(train_data, train_labels, val_data, val_labels, model):
    # Train model.
    # Prepare callbacks.
    # Path where to save best weights.

    # Take the file name from the wrapper.
    file_path = "/tmp/best_model.weights.h5"

    # Initialize model checkpoint callback.
    model_checkpoint = keras.callbacks.ModelCheckpoint(
        file_path,
        monitor="val_loss",
        verbose=0,
        mode="min",
        save_best_only=True,
        save_weights_only=True,
    )

    # Initialize early stopping callback.
    # The model performance is monitored across the validation data and stops training
    # when the generalization error cease to decrease.
    early_stopping = keras.callbacks.EarlyStopping(
        monitor="val_loss", patience=10, mode="min"
    )

    # Compile model.
    model.compile(
        optimizer="adam",
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )

    # Fit model.
    model.fit(
        train_data,
        train_labels,
        validation_data=(val_data, val_labels),
        epochs=20,
        class_weight=compute_class_weights(train_labels),
        batch_size=1,
        callbacks=[early_stopping, model_checkpoint],
        verbose=0,
    )

    # Load best weights.
    model.load_weights(file_path)

    return model


# Building model(s).
instance_shape = train_data[0][0].shape
models = [create_model(instance_shape) for _ in range(ENSEMBLE_AVG_COUNT)]

# Show single model architecture.
print(models[0].summary())

# Training model(s).
trained_models = [
    train(train_data, train_labels, val_data, val_labels, model)
    for model in tqdm(models)
]
Model: "functional_1"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)         Output Shape       Param #  Connected to         ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer         │ (None, 28, 28)    │       0 │ -                    │
│ (InputLayer)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_layer_1       │ (None, 28, 28)    │       0 │ -                    │
│ (InputLayer)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_layer_2       │ (None, 28, 28)    │       0 │ -                    │
│ (InputLayer)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ flatten (Flatten)   │ (None, 784)       │       0 │ input_layer[0][0]    │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ flatten_1 (Flatten) │ (None, 784)       │       0 │ input_layer_1[0][0]  │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ flatten_2 (Flatten) │ (None, 784)       │       0 │ input_layer_2[0][0]  │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dense (Dense)       │ (None, 128)       │ 100,480 │ flatten[0][0],       │
│                     │                   │         │ flatten_1[0][0],     │
│                     │                   │         │ flatten_2[0][0]      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dense_1 (Dense)     │ (None, 64)        │   8,256 │ dense[0][0],         │
│                     │                   │         │ dense[1][0],         │
│                     │                   │         │ dense[2][0]          │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ alpha               │ [(None, 1),       │  33,024 │ dense_1[0][0],       │
│ (MILAttentionLayer) │ (None, 1), (None, │         │ dense_1[1][0],       │
│                     │ 1)]               │         │ dense_1[2][0]        │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ multiply (Multiply) │ (None, 64)        │       0 │ alpha[0][0],         │
│                     │                   │         │ dense_1[0][0]        │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ multiply_1          │ (None, 64)        │       0 │ alpha[0][1],         │
│ (Multiply)          │                   │         │ dense_1[1][0]        │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ multiply_2          │ (None, 64)        │       0 │ alpha[0][2],         │
│ (Multiply)          │                   │         │ dense_1[2][0]        │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ concatenate         │ (None, 192)       │       0 │ multiply[0][0],      │
│ (Concatenate)       │                   │         │ multiply_1[0][0],    │
│                     │                   │         │ multiply_2[0][0]     │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dense_2 (Dense)     │ (None, 2)         │     386 │ concatenate[0][0]    │
└─────────────────────┴───────────────────┴─────────┴──────────────────────┘
 Total params: 142,146 (555.26 KB)
 Trainable params: 142,146 (555.26 KB)
 Non-trainable params: 0 (0.00 B)
None

100%|██████████████████████████████████████████████████████████████████████████████████| 1/1 [00:36<00:00, 36.67s/it]

模型评估

模型现在可以进行评估了。对于每个模型,我们还会创建一个相关的中间模型,以获取注意力层的权重。

我们将计算每个 ENSEMBLE_AVG_COUNT 模型的预测,并将它们平均在一起,得到我们的最终预测。

def predict(data, labels, trained_models):
    # Collect info per model.
    models_predictions = []
    models_attention_weights = []
    models_losses = []
    models_accuracies = []

    for model in trained_models:
        # Predict output classes on data.
        predictions = model.predict(data)
        models_predictions.append(predictions)

        # Create intermediate model to get MIL attention layer weights.
        intermediate_model = keras.Model(model.input, model.get_layer("alpha").output)

        # Predict MIL attention layer weights.
        intermediate_predictions = intermediate_model.predict(data)

        attention_weights = np.squeeze(np.swapaxes(intermediate_predictions, 1, 0))
        models_attention_weights.append(attention_weights)

        loss, accuracy = model.evaluate(data, labels, verbose=0)
        models_losses.append(loss)
        models_accuracies.append(accuracy)

    print(
        f"The average loss and accuracy are {np.sum(models_losses, axis=0) / ENSEMBLE_AVG_COUNT:.2f}"
        f" and {100 * np.sum(models_accuracies, axis=0) / ENSEMBLE_AVG_COUNT:.2f} % resp."
    )

    return (
        np.sum(models_predictions, axis=0) / ENSEMBLE_AVG_COUNT,
        np.sum(models_attention_weights, axis=0) / ENSEMBLE_AVG_COUNT,
    )


# Evaluate and predict classes and attention scores on validation data.
class_predictions, attention_params = predict(val_data, val_labels, trained_models)

# Plot some results from our validation data.
plot(
    val_data,
    val_labels,
    "positive",
    predictions=class_predictions,
    attention_weights=attention_params,
)
plot(
    val_data,
    val_labels,
    "negative",
    predictions=class_predictions,
    attention_weights=attention_params,
)
 10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 53ms/step
 10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 39ms/step
The average loss and accuracy are 0.03 and 99.00 % resp.

结论

从上面的图中,您可以注意到权重总和始终为 1。在一个被正向预测的包中,导致正向标记的实例将比包中的其余实例具有明显更高的注意力得分。但是,在一个被负向预测的包中,有两种情况:

  • 所有实例将具有大致相似的分数。
  • 一个实例将具有相对较高的分数(但没有正向实例那么高)。这是因为此实例的特征空间与正向实例的特征空间接近。

备注

  • 如果模型过拟合,则所有包的权重将均匀分布。因此,正则化技术是必要的。
  • 在论文中,包的大小可能因包而异。为了简单起见,这里的包大小是固定的。
  • 为了不依赖于单个模型的随机初始权重,应考虑平均集成方法。