作者: Mohamad Jaber
创建日期 2021/08/16
最后修改日期 2021/11/25
描述:用于对实例包进行分类并获取其各个实例得分的 MIL 方法。
通常,使用监督学习算法时,学习器会收到一组实例的标签。在 MIL 的情况下,学习器会收到一组包的标签,每个包都包含一组实例。如果包中至少包含一个正实例,则该包被标记为正包;如果包中不包含任何正实例,则该包被标记为负包。
在图像分类任务中,通常假设每张图像都清晰地表示一个类别标签。在医学影像(例如计算病理学等)中,整个图像由一个单一的类别标签(癌性/非癌性)表示,或者可以给出感兴趣区域。但是,人们会想知道图像中的哪些模式实际上导致它属于该类别。在这种情况下,图像将被分割,子图像将形成实例包。
因此,目标是
以下步骤描述了模型的工作原理
import numpy as np
import keras
from keras import layers
from keras import ops
from tqdm import tqdm
from matplotlib import pyplot as plt
plt.style.use("ggplot")
我们将创建一组包,并根据它们的内容分配标签。如果一个包中至少存在一个正实例,则该包被视为正包。如果它不包含任何正实例,则该包将被视为负包。
POSITIVE_CLASS
:要保留在正包中的目标类别。BAG_COUNT
:训练包的数量。VAL_BAG_COUNT
:验证包的数量。BAG_SIZE
:一个包中的实例数量。PLOT_SIZE
:要绘制的包的数量。ENSEMBLE_AVG_COUNT
:要创建并平均在一起的模型数量。(可选:通常会导致更好的性能 - 设置为 1 以用于单个模型)POSITIVE_CLASS = 1
BAG_COUNT = 1000
VAL_BAG_COUNT = 300
BAG_SIZE = 3
PLOT_SIZE = 3
ENSEMBLE_AVG_COUNT = 1
由于注意力算子是排列不变算子,因此将具有正类别标签的实例随机放置在正包中的实例中。
def create_bags(input_data, input_labels, positive_class, bag_count, instance_count):
# Set up bags.
bags = []
bag_labels = []
# Normalize input data.
input_data = np.divide(input_data, 255.0)
# Count positive samples.
count = 0
for _ in range(bag_count):
# Pick a fixed size random subset of samples.
index = np.random.choice(input_data.shape[0], instance_count, replace=False)
instances_data = input_data[index]
instances_labels = input_labels[index]
# By default, all bags are labeled as 0.
bag_label = 0
# Check if there is at least a positive class in the bag.
if positive_class in instances_labels:
# Positive bag will be labeled as 1.
bag_label = 1
count += 1
bags.append(instances_data)
bag_labels.append(np.array([bag_label]))
print(f"Positive bags: {count}")
print(f"Negative bags: {bag_count - count}")
return (list(np.swapaxes(bags, 0, 1)), np.array(bag_labels))
# Load the MNIST dataset.
(x_train, y_train), (x_val, y_val) = keras.datasets.mnist.load_data()
# Create training data.
train_data, train_labels = create_bags(
x_train, y_train, POSITIVE_CLASS, BAG_COUNT, BAG_SIZE
)
# Create validation data.
val_data, val_labels = create_bags(
x_val, y_val, POSITIVE_CLASS, VAL_BAG_COUNT, BAG_SIZE
)
Positive bags: 283
Negative bags: 717
Positive bags: 104
Negative bags: 196
我们现在将构建注意力层,准备一些实用程序,然后构建和训练整个模型。
此层的输出大小由单个包的大小决定。
注意力机制使用包中实例的加权平均值,其中权重的总和必须等于 1(与包大小无关)。
权重矩阵(参数)是 w 和 v。为了包含正值和负值,使用双曲正切逐元素非线性。
可以使用门控注意力机制来处理复杂的關係。另一个权重矩阵 u 被添加到计算中。使用 sigmoid 非线性来克服双曲正切非线性对x ∈ [−1, 1] 的近似线性行为。
class MILAttentionLayer(layers.Layer):
"""Implementation of the attention-based Deep MIL layer.
Args:
weight_params_dim: Positive Integer. Dimension of the weight matrix.
kernel_initializer: Initializer for the `kernel` matrix.
kernel_regularizer: Regularizer function applied to the `kernel` matrix.
use_gated: Boolean, whether or not to use the gated mechanism.
Returns:
List of 2D tensors with BAG_SIZE length.
The tensors are the attention scores after softmax with shape `(batch_size, 1)`.
"""
def __init__(
self,
weight_params_dim,
kernel_initializer="glorot_uniform",
kernel_regularizer=None,
use_gated=False,
**kwargs,
):
super().__init__(**kwargs)
self.weight_params_dim = weight_params_dim
self.use_gated = use_gated
self.kernel_initializer = keras.initializers.get(kernel_initializer)
self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
self.v_init = self.kernel_initializer
self.w_init = self.kernel_initializer
self.u_init = self.kernel_initializer
self.v_regularizer = self.kernel_regularizer
self.w_regularizer = self.kernel_regularizer
self.u_regularizer = self.kernel_regularizer
def build(self, input_shape):
# Input shape.
# List of 2D tensors with shape: (batch_size, input_dim).
input_dim = input_shape[0][1]
self.v_weight_params = self.add_weight(
shape=(input_dim, self.weight_params_dim),
initializer=self.v_init,
name="v",
regularizer=self.v_regularizer,
trainable=True,
)
self.w_weight_params = self.add_weight(
shape=(self.weight_params_dim, 1),
initializer=self.w_init,
name="w",
regularizer=self.w_regularizer,
trainable=True,
)
if self.use_gated:
self.u_weight_params = self.add_weight(
shape=(input_dim, self.weight_params_dim),
initializer=self.u_init,
name="u",
regularizer=self.u_regularizer,
trainable=True,
)
else:
self.u_weight_params = None
self.input_built = True
def call(self, inputs):
# Assigning variables from the number of inputs.
instances = [self.compute_attention_scores(instance) for instance in inputs]
# Stack instances into a single tensor.
instances = ops.stack(instances)
# Apply softmax over instances such that the output summation is equal to 1.
alpha = ops.softmax(instances, axis=0)
# Split to recreate the same array of tensors we had as inputs.
return [alpha[i] for i in range(alpha.shape[0])]
def compute_attention_scores(self, instance):
# Reserve in-case "gated mechanism" used.
original_instance = instance
# tanh(v*h_k^T)
instance = ops.tanh(ops.tensordot(instance, self.v_weight_params, axes=1))
# for learning non-linear relations efficiently.
if self.use_gated:
instance = instance * ops.sigmoid(
ops.tensordot(original_instance, self.u_weight_params, axes=1)
)
# w^T*(tanh(v*h_k^T)) / w^T*(tanh(v*h_k^T)*sigmoid(u*h_k^T))
return ops.tensordot(instance, self.w_weight_params, axes=1)
绘制相对于类别的包数量(由 PLOT_SIZE
给出)。
此外,如果启用,则可以查看每个包的类别标签预测及其关联的实例得分(在模型经过训练之后)。
def plot(data, labels, bag_class, predictions=None, attention_weights=None):
""" "Utility for plotting bags and attention weights.
Args:
data: Input data that contains the bags of instances.
labels: The associated bag labels of the input data.
bag_class: String name of the desired bag class.
The options are: "positive" or "negative".
predictions: Class labels model predictions.
If you don't specify anything, ground truth labels will be used.
attention_weights: Attention weights for each instance within the input data.
If you don't specify anything, the values won't be displayed.
"""
return ## TODO
labels = np.array(labels).reshape(-1)
if bag_class == "positive":
if predictions is not None:
labels = np.where(predictions.argmax(1) == 1)[0]
bags = np.array(data)[:, labels[0:PLOT_SIZE]]
else:
labels = np.where(labels == 1)[0]
bags = np.array(data)[:, labels[0:PLOT_SIZE]]
elif bag_class == "negative":
if predictions is not None:
labels = np.where(predictions.argmax(1) == 0)[0]
bags = np.array(data)[:, labels[0:PLOT_SIZE]]
else:
labels = np.where(labels == 0)[0]
bags = np.array(data)[:, labels[0:PLOT_SIZE]]
else:
print(f"There is no class {bag_class}")
return
print(f"The bag class label is {bag_class}")
for i in range(PLOT_SIZE):
figure = plt.figure(figsize=(8, 8))
print(f"Bag number: {labels[i]}")
for j in range(BAG_SIZE):
image = bags[j][i]
figure.add_subplot(1, BAG_SIZE, j + 1)
plt.grid(False)
if attention_weights is not None:
plt.title(np.around(attention_weights[labels[i]][j], 2))
plt.imshow(image)
plt.show()
# Plot some of validation data bags per class.
plot(val_data, val_labels, "positive")
plot(val_data, val_labels, "negative")
首先,我们将为每个实例创建一些嵌入,调用注意力算子,然后使用 softmax 函数输出类别概率。
def create_model(instance_shape):
# Extract features from inputs.
inputs, embeddings = [], []
shared_dense_layer_1 = layers.Dense(128, activation="relu")
shared_dense_layer_2 = layers.Dense(64, activation="relu")
for _ in range(BAG_SIZE):
inp = layers.Input(instance_shape)
flatten = layers.Flatten()(inp)
dense_1 = shared_dense_layer_1(flatten)
dense_2 = shared_dense_layer_2(dense_1)
inputs.append(inp)
embeddings.append(dense_2)
# Invoke the attention layer.
alpha = MILAttentionLayer(
weight_params_dim=256,
kernel_regularizer=keras.regularizers.L2(0.01),
use_gated=True,
name="alpha",
)(embeddings)
# Multiply attention weights with the input layers.
multiply_layers = [
layers.multiply([alpha[i], embeddings[i]]) for i in range(len(alpha))
]
# Concatenate layers.
concat = layers.concatenate(multiply_layers, axis=1)
# Classification output node.
output = layers.Dense(2, activation="softmax")(concat)
return keras.Model(inputs, output)
由于这种类型的问题可能简单地变成不平衡数据分类问题,因此应该考虑类别加权。
假设有 1000 个包。通常可能存在大约 90% 的包不包含任何正标签,而大约 10% 的包包含正标签。这种数据可以被称为不平衡数据。
使用类别权重,模型将倾向于对稀有类别赋予更高的权重。
def compute_class_weights(labels):
# Count number of postive and negative bags.
negative_count = len(np.where(labels == 0)[0])
positive_count = len(np.where(labels == 1)[0])
total_count = negative_count + positive_count
# Build class weight dictionary.
return {
0: (1 / negative_count) * (total_count / 2),
1: (1 / positive_count) * (total_count / 2),
}
模型的构建和训练在本节进行。
def train(train_data, train_labels, val_data, val_labels, model):
# Train model.
# Prepare callbacks.
# Path where to save best weights.
# Take the file name from the wrapper.
file_path = "/tmp/best_model.weights.h5"
# Initialize model checkpoint callback.
model_checkpoint = keras.callbacks.ModelCheckpoint(
file_path,
monitor="val_loss",
verbose=0,
mode="min",
save_best_only=True,
save_weights_only=True,
)
# Initialize early stopping callback.
# The model performance is monitored across the validation data and stops training
# when the generalization error cease to decrease.
early_stopping = keras.callbacks.EarlyStopping(
monitor="val_loss", patience=10, mode="min"
)
# Compile model.
model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
# Fit model.
model.fit(
train_data,
train_labels,
validation_data=(val_data, val_labels),
epochs=20,
class_weight=compute_class_weights(train_labels),
batch_size=1,
callbacks=[early_stopping, model_checkpoint],
verbose=0,
)
# Load best weights.
model.load_weights(file_path)
return model
# Building model(s).
instance_shape = train_data[0][0].shape
models = [create_model(instance_shape) for _ in range(ENSEMBLE_AVG_COUNT)]
# Show single model architecture.
print(models[0].summary())
# Training model(s).
trained_models = [
train(train_data, train_labels, val_data, val_labels, model)
for model in tqdm(models)
]
Model: "functional_1"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer │ (None, 28, 28) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ input_layer_1 │ (None, 28, 28) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ input_layer_2 │ (None, 28, 28) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ flatten (Flatten) │ (None, 784) │ 0 │ input_layer[0][0] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ flatten_1 (Flatten) │ (None, 784) │ 0 │ input_layer_1[0][0] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ flatten_2 (Flatten) │ (None, 784) │ 0 │ input_layer_2[0][0] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ dense (Dense) │ (None, 128) │ 100,480 │ flatten[0][0], │ │ │ │ │ flatten_1[0][0], │ │ │ │ │ flatten_2[0][0] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ dense_1 (Dense) │ (None, 64) │ 8,256 │ dense[0][0], │ │ │ │ │ dense[1][0], │ │ │ │ │ dense[2][0] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ alpha │ [(None, 1), │ 33,024 │ dense_1[0][0], │ │ (MILAttentionLayer) │ (None, 1), (None, │ │ dense_1[1][0], │ │ │ 1)] │ │ dense_1[2][0] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ multiply (Multiply) │ (None, 64) │ 0 │ alpha[0][0], │ │ │ │ │ dense_1[0][0] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ multiply_1 │ (None, 64) │ 0 │ alpha[0][1], │ │ (Multiply) │ │ │ dense_1[1][0] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ multiply_2 │ (None, 64) │ 0 │ alpha[0][2], │ │ (Multiply) │ │ │ dense_1[2][0] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ concatenate │ (None, 192) │ 0 │ multiply[0][0], │ │ (Concatenate) │ │ │ multiply_1[0][0], │ │ │ │ │ multiply_2[0][0] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ dense_2 (Dense) │ (None, 2) │ 386 │ concatenate[0][0] │ └─────────────────────┴───────────────────┴─────────┴──────────────────────┘
Total params: 142,146 (555.26 KB)
Trainable params: 142,146 (555.26 KB)
Non-trainable params: 0 (0.00 B)
None
100%|██████████████████████████████████████████████████████████████████████████████████| 1/1 [00:36<00:00, 36.67s/it]
模型现在准备评估。对于每个模型,我们还创建一个相关的中间模型以获取注意力层的权重。
我们将计算我们 `ENSEMBLE_AVG_COUNT` 个模型中每个模型的预测,并将它们平均起来以获得最终预测。
def predict(data, labels, trained_models):
# Collect info per model.
models_predictions = []
models_attention_weights = []
models_losses = []
models_accuracies = []
for model in trained_models:
# Predict output classes on data.
predictions = model.predict(data)
models_predictions.append(predictions)
# Create intermediate model to get MIL attention layer weights.
intermediate_model = keras.Model(model.input, model.get_layer("alpha").output)
# Predict MIL attention layer weights.
intermediate_predictions = intermediate_model.predict(data)
attention_weights = np.squeeze(np.swapaxes(intermediate_predictions, 1, 0))
models_attention_weights.append(attention_weights)
loss, accuracy = model.evaluate(data, labels, verbose=0)
models_losses.append(loss)
models_accuracies.append(accuracy)
print(
f"The average loss and accuracy are {np.sum(models_losses, axis=0) / ENSEMBLE_AVG_COUNT:.2f}"
f" and {100 * np.sum(models_accuracies, axis=0) / ENSEMBLE_AVG_COUNT:.2f} % resp."
)
return (
np.sum(models_predictions, axis=0) / ENSEMBLE_AVG_COUNT,
np.sum(models_attention_weights, axis=0) / ENSEMBLE_AVG_COUNT,
)
# Evaluate and predict classes and attention scores on validation data.
class_predictions, attention_params = predict(val_data, val_labels, trained_models)
# Plot some results from our validation data.
plot(
val_data,
val_labels,
"positive",
predictions=class_predictions,
attention_weights=attention_params,
)
plot(
val_data,
val_labels,
"negative",
predictions=class_predictions,
attention_weights=attention_params,
)
10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 53ms/step
10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 39ms/step
The average loss and accuracy are 0.03 and 99.00 % resp.
从上面的图中,您可以注意到权重总是加起来为 1。在正预测包中,导致正标签的实例将具有比包中其他实例明显更高的注意力分数。然而,在负预测包中,存在两种情况。