代码示例 / 图数据 / 用于节点分类的图注意力网络 (GAT)

用于节点分类的图注意力网络 (GAT)

作者: akensert
创建日期 2021/09/13
上次修改日期 2021/12/26
描述: 用于节点分类的图注意力网络 (GAT) 的实现。

ⓘ 此示例使用 Keras 2

在 Colab 中查看 GitHub 源码


引言

图神经网络是处理以图结构化数据(例如,社交网络或分子结构)的首选神经网络架构,其效果优于全连接网络或卷积网络。

在本教程中,我们将实现一种特定的图神经网络,称为图注意力网络 (GAT),以根据引用的论文类型(使用 Cora 数据集)来预测科学论文的标签。

参考文献

有关 GAT 的更多信息,请参阅原始论文图注意力网络以及DGL 的图注意力网络文档。

导入软件包

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import pandas as pd
import os
import warnings

warnings.filterwarnings("ignore")
pd.set_option("display.max_columns", 6)
pd.set_option("display.max_rows", 6)
np.random.seed(2)

获取数据集

Cora 数据集的准备工作遵循使用图神经网络进行节点分类教程的做法。有关数据集和探索性数据分析的更多详细信息,请参阅此教程。简而言之,Cora 数据集由两个文件组成:cora.cites,其中包含论文之间的有向链接(引用);以及 cora.content,其中包含相应论文的特征和七个标签之一(论文的主题)。

zip_file = keras.utils.get_file(
    fname="cora.tgz",
    origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",
    extract=True,
)

data_dir = os.path.join(os.path.dirname(zip_file), "cora")

citations = pd.read_csv(
    os.path.join(data_dir, "cora.cites"),
    sep="\t",
    header=None,
    names=["target", "source"],
)

papers = pd.read_csv(
    os.path.join(data_dir, "cora.content"),
    sep="\t",
    header=None,
    names=["paper_id"] + [f"term_{idx}" for idx in range(1433)] + ["subject"],
)

class_values = sorted(papers["subject"].unique())
class_idx = {name: id for id, name in enumerate(class_values)}
paper_idx = {name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))}

papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])
citations["source"] = citations["source"].apply(lambda name: paper_idx[name])
citations["target"] = citations["target"].apply(lambda name: paper_idx[name])
papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])

print(citations)

print(papers)
      target  source
0          0      21
1          0     905
2          0     906
...      ...     ...
5426    1874    2586
5427    1876    1874
5428    1897    2707
[5429 rows x 2 columns]
      paper_id  term_0  term_1  ...  term_1431  term_1432  subject
0          462       0       0  ...          0          0        2
1         1911       0       0  ...          0          0        5
2         2002       0       0  ...          0          0        4
...        ...     ...     ...  ...        ...        ...      ...
2705      2372       0       0  ...          0          0        1
2706       955       0       0  ...          0          0        0
2707       376       0       0  ...          0          0        2
[2708 rows x 1435 columns]

分割数据集

# Obtain random indices
random_indices = np.random.permutation(range(papers.shape[0]))

# 50/50 split
train_data = papers.iloc[random_indices[: len(random_indices) // 2]]
test_data = papers.iloc[random_indices[len(random_indices) // 2 :]]

准备图数据

# Obtain paper indices which will be used to gather node states
# from the graph later on when training the model
train_indices = train_data["paper_id"].to_numpy()
test_indices = test_data["paper_id"].to_numpy()

# Obtain ground truth labels corresponding to each paper_id
train_labels = train_data["subject"].to_numpy()
test_labels = test_data["subject"].to_numpy()

# Define graph, namely an edge tensor and a node feature tensor
edges = tf.convert_to_tensor(citations[["target", "source"]])
node_states = tf.convert_to_tensor(papers.sort_values("paper_id").iloc[:, 1:-1])

# Print shapes of the graph
print("Edges shape:\t\t", edges.shape)
print("Node features shape:", node_states.shape)
Edges shape:         (5429, 2)
Node features shape: (2708, 1433)

构建模型

GAT 将图(即边张量和节点特征张量)作为输入,并输出[更新的]节点状态。对于每个目标节点,节点状态是 N 跳的邻域聚合信息(其中 N 由 GAT 的层数决定)。重要的是,与图卷积网络 (GCN) 相比,GAT 利用注意力机制来聚合来自相邻节点(或源节点)的信息。换句话说,GAT 不是简单地平均/求和来自源节点(源论文)到目标节点(目标论文)的节点状态,而是首先对每个源节点状态应用归一化的注意力分数,然后求和。

(多头)图注意力层

GAT 模型实现了多头图注意力层。MultiHeadGraphAttention 层只是多个图注意力层 (GraphAttention) 的串联(或平均),每个图注意力层都具有单独的可学习权重 WGraphAttention 层执行以下操作

考虑输入节点状态 h^{l},它们通过 W^{l} 进行线性变换,从而产生 z^{l}

对于每个目标节点

  1. 计算所有 j 的成对注意力分数 a^{l}^{T}(z^{l}_{i}||z^{l}_{j}),从而产生 e_{ij} (对于所有 j)。|| 表示串联,_{i} 对应于目标节点,_{j} 对应于给定的 1 跳邻居/源节点。
  2. 通过 softmax 归一化 e_{ij},以便传入目标节点的边的注意力分数之和 (sum_{k}{e_{norm}_{ik}}) 将加起来为 1。
  3. 将注意力分数 e_{norm}_{ij} 应用于 z_{j},并将其添加到新的目标节点状态 h^{l+1}_{i},适用于所有 j
class GraphAttention(layers.Layer):
    def __init__(
        self,
        units,
        kernel_initializer="glorot_uniform",
        kernel_regularizer=None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.units = units
        self.kernel_initializer = keras.initializers.get(kernel_initializer)
        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)

    def build(self, input_shape):

        self.kernel = self.add_weight(
            shape=(input_shape[0][-1], self.units),
            trainable=True,
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            name="kernel",
        )
        self.kernel_attention = self.add_weight(
            shape=(self.units * 2, 1),
            trainable=True,
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            name="kernel_attention",
        )
        self.built = True

    def call(self, inputs):
        node_states, edges = inputs

        # Linearly transform node states
        node_states_transformed = tf.matmul(node_states, self.kernel)

        # (1) Compute pair-wise attention scores
        node_states_expanded = tf.gather(node_states_transformed, edges)
        node_states_expanded = tf.reshape(
            node_states_expanded, (tf.shape(edges)[0], -1)
        )
        attention_scores = tf.nn.leaky_relu(
            tf.matmul(node_states_expanded, self.kernel_attention)
        )
        attention_scores = tf.squeeze(attention_scores, -1)

        # (2) Normalize attention scores
        attention_scores = tf.math.exp(tf.clip_by_value(attention_scores, -2, 2))
        attention_scores_sum = tf.math.unsorted_segment_sum(
            data=attention_scores,
            segment_ids=edges[:, 0],
            num_segments=tf.reduce_max(edges[:, 0]) + 1,
        )
        attention_scores_sum = tf.repeat(
            attention_scores_sum, tf.math.bincount(tf.cast(edges[:, 0], "int32"))
        )
        attention_scores_norm = attention_scores / attention_scores_sum

        # (3) Gather node states of neighbors, apply attention scores and aggregate
        node_states_neighbors = tf.gather(node_states_transformed, edges[:, 1])
        out = tf.math.unsorted_segment_sum(
            data=node_states_neighbors * attention_scores_norm[:, tf.newaxis],
            segment_ids=edges[:, 0],
            num_segments=tf.shape(node_states)[0],
        )
        return out


class MultiHeadGraphAttention(layers.Layer):
    def __init__(self, units, num_heads=8, merge_type="concat", **kwargs):
        super().__init__(**kwargs)
        self.num_heads = num_heads
        self.merge_type = merge_type
        self.attention_layers = [GraphAttention(units) for _ in range(num_heads)]

    def call(self, inputs):
        atom_features, pair_indices = inputs

        # Obtain outputs from each attention head
        outputs = [
            attention_layer([atom_features, pair_indices])
            for attention_layer in self.attention_layers
        ]
        # Concatenate or average the node states from each head
        if self.merge_type == "concat":
            outputs = tf.concat(outputs, axis=-1)
        else:
            outputs = tf.reduce_mean(tf.stack(outputs, axis=-1), axis=-1)
        # Activate and return node states
        return tf.nn.relu(outputs)

使用自定义 train_steptest_steppredict_step 方法实现训练逻辑

请注意,GAT 模型在所有阶段(训练、验证和测试)都对整个图(即 node_statesedges)进行操作。因此,node_statesedges 被传递给 keras.Model 的构造函数并用作属性。各个阶段之间的差异是索引(和标签),这些索引收集某些输出 (tf.gather(outputs, indices))。

class GraphAttentionNetwork(keras.Model):
    def __init__(
        self,
        node_states,
        edges,
        hidden_units,
        num_heads,
        num_layers,
        output_dim,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.node_states = node_states
        self.edges = edges
        self.preprocess = layers.Dense(hidden_units * num_heads, activation="relu")
        self.attention_layers = [
            MultiHeadGraphAttention(hidden_units, num_heads) for _ in range(num_layers)
        ]
        self.output_layer = layers.Dense(output_dim)

    def call(self, inputs):
        node_states, edges = inputs
        x = self.preprocess(node_states)
        for attention_layer in self.attention_layers:
            x = attention_layer([x, edges]) + x
        outputs = self.output_layer(x)
        return outputs

    def train_step(self, data):
        indices, labels = data

        with tf.GradientTape() as tape:
            # Forward pass
            outputs = self([self.node_states, self.edges])
            # Compute loss
            loss = self.compiled_loss(labels, tf.gather(outputs, indices))
        # Compute gradients
        grads = tape.gradient(loss, self.trainable_weights)
        # Apply gradients (update weights)
        optimizer.apply_gradients(zip(grads, self.trainable_weights))
        # Update metric(s)
        self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))

        return {m.name: m.result() for m in self.metrics}

    def predict_step(self, data):
        indices = data
        # Forward pass
        outputs = self([self.node_states, self.edges])
        # Compute probabilities
        return tf.nn.softmax(tf.gather(outputs, indices))

    def test_step(self, data):
        indices, labels = data
        # Forward pass
        outputs = self([self.node_states, self.edges])
        # Compute loss
        loss = self.compiled_loss(labels, tf.gather(outputs, indices))
        # Update metric(s)
        self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))

        return {m.name: m.result() for m in self.metrics}

训练和评估

# Define hyper-parameters
HIDDEN_UNITS = 100
NUM_HEADS = 8
NUM_LAYERS = 3
OUTPUT_DIM = len(class_values)

NUM_EPOCHS = 100
BATCH_SIZE = 256
VALIDATION_SPLIT = 0.1
LEARNING_RATE = 3e-1
MOMENTUM = 0.9

loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = keras.optimizers.SGD(LEARNING_RATE, momentum=MOMENTUM)
accuracy_fn = keras.metrics.SparseCategoricalAccuracy(name="acc")
early_stopping = keras.callbacks.EarlyStopping(
    monitor="val_acc", min_delta=1e-5, patience=5, restore_best_weights=True
)

# Build model
gat_model = GraphAttentionNetwork(
    node_states, edges, HIDDEN_UNITS, NUM_HEADS, NUM_LAYERS, OUTPUT_DIM
)

# Compile model
gat_model.compile(loss=loss_fn, optimizer=optimizer, metrics=[accuracy_fn])

gat_model.fit(
    x=train_indices,
    y=train_labels,
    validation_split=VALIDATION_SPLIT,
    batch_size=BATCH_SIZE,
    epochs=NUM_EPOCHS,
    callbacks=[early_stopping],
    verbose=2,
)

_, test_accuracy = gat_model.evaluate(x=test_indices, y=test_labels, verbose=0)

print("--" * 38 + f"\nTest Accuracy {test_accuracy*100:.1f}%")
Epoch 1/100
5/5 - 26s - loss: 1.8418 - acc: 0.2980 - val_loss: 1.5117 - val_acc: 0.4044 - 26s/epoch - 5s/step
Epoch 2/100
5/5 - 6s - loss: 1.2422 - acc: 0.5640 - val_loss: 1.0407 - val_acc: 0.6471 - 6s/epoch - 1s/step
Epoch 3/100
5/5 - 5s - loss: 0.7092 - acc: 0.7906 - val_loss: 0.8201 - val_acc: 0.7868 - 5s/epoch - 996ms/step
Epoch 4/100
5/5 - 5s - loss: 0.4768 - acc: 0.8604 - val_loss: 0.7451 - val_acc: 0.8088 - 5s/epoch - 934ms/step
Epoch 5/100
5/5 - 5s - loss: 0.2641 - acc: 0.9294 - val_loss: 0.7499 - val_acc: 0.8088 - 5s/epoch - 945ms/step
Epoch 6/100
5/5 - 5s - loss: 0.1487 - acc: 0.9663 - val_loss: 0.6803 - val_acc: 0.8382 - 5s/epoch - 967ms/step
Epoch 7/100
5/5 - 5s - loss: 0.0970 - acc: 0.9811 - val_loss: 0.6688 - val_acc: 0.8088 - 5s/epoch - 960ms/step
Epoch 8/100
5/5 - 5s - loss: 0.0597 - acc: 0.9934 - val_loss: 0.7295 - val_acc: 0.8162 - 5s/epoch - 981ms/step
Epoch 9/100
5/5 - 5s - loss: 0.0398 - acc: 0.9967 - val_loss: 0.7551 - val_acc: 0.8309 - 5s/epoch - 991ms/step
Epoch 10/100
5/5 - 5s - loss: 0.0312 - acc: 0.9984 - val_loss: 0.7666 - val_acc: 0.8309 - 5s/epoch - 987ms/step
Epoch 11/100
5/5 - 5s - loss: 0.0219 - acc: 0.9992 - val_loss: 0.7726 - val_acc: 0.8309 - 5s/epoch - 1s/step
----------------------------------------------------------------------------
Test Accuracy 76.5%

预测(概率)

test_probs = gat_model.predict(x=test_indices)

mapping = {v: k for (k, v) in class_idx.items()}

for i, (probs, label) in enumerate(zip(test_probs[:10], test_labels[:10])):
    print(f"Example {i+1}: {mapping[label]}")
    for j, c in zip(probs, class_idx.keys()):
        print(f"\tProbability of {c: <24} = {j*100:7.3f}%")
    print("---" * 20)
Example 1: Probabilistic_Methods
    Probability of Case_Based               =   0.919%
    Probability of Genetic_Algorithms       =   0.180%
    Probability of Neural_Networks          =  37.896%
    Probability of Probabilistic_Methods    =  59.801%
    Probability of Reinforcement_Learning   =   0.705%
    Probability of Rule_Learning            =   0.044%
    Probability of Theory                   =   0.454%
------------------------------------------------------------
Example 2: Genetic_Algorithms
    Probability of Case_Based               =   0.005%
    Probability of Genetic_Algorithms       =  99.993%
    Probability of Neural_Networks          =   0.001%
    Probability of Probabilistic_Methods    =   0.000%
    Probability of Reinforcement_Learning   =   0.000%
    Probability of Rule_Learning            =   0.000%
    Probability of Theory                   =   0.000%
------------------------------------------------------------
Example 3: Theory
    Probability of Case_Based               =   8.151%
    Probability of Genetic_Algorithms       =   1.021%
    Probability of Neural_Networks          =   0.569%
    Probability of Probabilistic_Methods    =  40.220%
    Probability of Reinforcement_Learning   =   0.792%
    Probability of Rule_Learning            =   6.910%
    Probability of Theory                   =  42.337%
------------------------------------------------------------
Example 4: Neural_Networks
    Probability of Case_Based               =   0.097%
    Probability of Genetic_Algorithms       =   0.026%
    Probability of Neural_Networks          =  93.539%
    Probability of Probabilistic_Methods    =   6.206%
    Probability of Reinforcement_Learning   =   0.028%
    Probability of Rule_Learning            =   0.010%
    Probability of Theory                   =   0.094%
------------------------------------------------------------
Example 5: Theory
    Probability of Case_Based               =  25.259%
    Probability of Genetic_Algorithms       =   4.381%
    Probability of Neural_Networks          =  11.776%
    Probability of Probabilistic_Methods    =  15.053%
    Probability of Reinforcement_Learning   =   1.571%
    Probability of Rule_Learning            =  23.589%
    Probability of Theory                   =  18.370%
------------------------------------------------------------
Example 6: Genetic_Algorithms
    Probability of Case_Based               =   0.000%
    Probability of Genetic_Algorithms       = 100.000%
    Probability of Neural_Networks          =   0.000%
    Probability of Probabilistic_Methods    =   0.000%
    Probability of Reinforcement_Learning   =   0.000%
    Probability of Rule_Learning            =   0.000%
    Probability of Theory                   =   0.000%
------------------------------------------------------------
Example 7: Neural_Networks
    Probability of Case_Based               =   0.296%
    Probability of Genetic_Algorithms       =   0.291%
    Probability of Neural_Networks          =  93.419%
    Probability of Probabilistic_Methods    =   5.696%
    Probability of Reinforcement_Learning   =   0.050%
    Probability of Rule_Learning            =   0.072%
    Probability of Theory                   =   0.177%
------------------------------------------------------------
Example 8: Genetic_Algorithms
    Probability of Case_Based               =   0.000%
    Probability of Genetic_Algorithms       = 100.000%
    Probability of Neural_Networks          =   0.000%
    Probability of Probabilistic_Methods    =   0.000%
    Probability of Reinforcement_Learning   =   0.000%
    Probability of Rule_Learning            =   0.000%
    Probability of Theory                   =   0.000%
------------------------------------------------------------
Example 9: Theory
    Probability of Case_Based               =   4.103%
    Probability of Genetic_Algorithms       =   5.217%
    Probability of Neural_Networks          =  14.532%
    Probability of Probabilistic_Methods    =  66.747%
    Probability of Reinforcement_Learning   =   3.008%
    Probability of Rule_Learning            =   1.782%
    Probability of Theory                   =   4.611%
------------------------------------------------------------
Example 10: Case_Based
    Probability of Case_Based               =  99.566%
    Probability of Genetic_Algorithms       =   0.017%
    Probability of Neural_Networks          =   0.016%
    Probability of Probabilistic_Methods    =   0.155%
    Probability of Reinforcement_Learning   =   0.026%
    Probability of Rule_Learning            =   0.192%
    Probability of Theory                   =   0.028%
------------------------------------------------------------

结论

结果看起来还不错!GAT 模型似乎能够根据论文的引用对象,大约 80% 的时间正确预测论文的主题。可以通过微调 GAT 的超参数来进一步改进。例如,尝试更改层数、隐藏单元数或优化器/学习率;添加正则化(例如,dropout);或修改预处理步骤。我们还可以尝试实现自循环(即,论文 X 引用论文 X)和/或使图无向