一.前言
在Pytorch Geometric中我们经常使用消息传递范式来自定义GNN模型,但是这种方法存在着一些缺陷:在邻域聚合过程中,物化x_i
和x_j
可能会占用大量的内存(尤其是在大图上)。然而,并不是所有的GNN都需要表达成这种消息传递的范式形式,一些GNN是可以直接表达为稀疏矩阵乘法形式的。在1.6.0
版本之后,PyG官方正式引入对稀疏矩阵乘法GNN更有力的支持(torch-sparse
中的SparseTensor
),通过稀疏矩阵乘法能够让内存更高效,同时也加快了执行时间。
二.SparseTensor
详解
2.1 构造方式
PyG通过SparseTensor
类来对稀疏矩阵提供支持,该类位于torch_sparse
模块中,可以通过如下形式导入:
from torch_sparse import SparseTensor
该类的构造函数如下所示:
def __init__(
self,
row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,
is_sorted: bool = False,
trust_data: bool = False,
)
常用参数说明如下:
参数 | 说明 |
---|---|
row | 非零值的行下标,即edge_index[0] |
col | 非零值的列下标,即edge_index[1] |
value | 对应于索引的非零值,可选 |
sparse_sizes | 稀疏矩阵的大小 |
给定图:
其对应的稀疏邻接矩阵构造代码如下:
import torch
from torch_sparse import SparseTensor
from torch_geometric.utils import to_undirected
edge_index = torch.LongTensor(
[[0, 0, 0, 1, 2, 1, 2, 3], [1, 2, 3, 2, 3, 5, 4, 6]])
edge_index = to_undirected(edge_index)
adj = SparseTensor(row=edge_index[0], col=edge_index[1], sparse_sizes=(7, 7))
print(adj)
"""
SparseTensor(row=tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 5, 6]),
col=tensor([1, 2, 3, 0, 2, 5, 0, 1, 3, 4, 0, 2, 6, 2, 1, 3]),
size=(7, 7), nnz=16, density=32.65%)
"""
除了采用上述方式,SparseTensor
还能从其它形式来进行创建:
# 从密集矩阵(常见的普通矩阵)进行创建
adj = SparseTensor.from_dense(mat)
# 创建指定大小的单位阵
adj = SparseTensor.eye(100, 100)
# 从scipy矩阵进行创建
adj = SparseTensor.from_scipy(mat)
2.2 常见用法
与scipy
类似,SparseTensor
也支持COO
、CSR
和CSC
等存储格式的稀疏矩阵:
row, col, value = adj.coo()
rowptr, col, value = adj.csr()
colptr, row, value = adj.csc()
对于几种不同格式的稀疏矩阵,Sparse稀疏矩阵主要存储格式总结一文有更详细的介绍,这里就不展开了。
SparseTensor
支持的常见操作如下:
# 切片
adj = adj[:100, :100]
# 添加对角项, A + I, I为单位阵
adj = adj.set_diag()
# 转置
adj_t = adj.t()
SparseTensor
即可以与密集(dense)矩阵做乘法,也可以与稀疏矩阵做乘法,即:
# Sparse-Dense Matrix Multiplication
x = torch.rand(7, 4)
out = adj.matmul(x)
print(out.shape)
# torch.Size([7, 4])
# Sparse-Sparse Matrix Multiplication
adj = adj.matmul(adj)
在GNN中,我们经常需要对邻接矩阵进行正则化(Normalization),下面是稀疏矩阵的正则化源码:
def norm_adj(adj_t, add_self_loops=True):
"""
normalization adj
"""
if not adj_t.has_value():
adj_t = adj_t.fill_value(1.)
if add_self_loops:
adj_t = fill_diag(adj_t, 1.)
deg = sparsesum(adj_t, dim=1)
deg_inv_sqrt = deg.pow_(-0.5)
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1))
adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1))
return adj_t
adj = norm_adj(adj)
print(adj.to_dense())
"""
tensor([[0.2500, 0.2500, 0.2236, 0.2500, 0.0000, 0.0000, 0.0000],
[0.2500, 0.2500, 0.2236, 0.0000, 0.0000, 0.3536, 0.0000],
[0.2236, 0.2236, 0.2000, 0.2236, 0.3162, 0.0000, 0.0000],
[0.2500, 0.0000, 0.2236, 0.2500, 0.0000, 0.0000, 0.3536],
[0.0000, 0.0000, 0.3162, 0.0000, 0.5000, 0.0000, 0.0000],
[0.0000, 0.3536, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.3536, 0.0000, 0.0000, 0.5000]])
"""
三.稀疏矩阵形式的GCN
GCN的传播公式为:
X
′
=
D
^
−
1
/
2
A
^
D
^
−
1
/
2
X
Θ
\mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}
X′=D^−1/2A^D^−1/2XΘ
其中
A
^
=
A
+
I
\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}
A^=A+I。根据该公式,稀疏矩阵形式的GCN实现、训练以及测评如下所示:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_sparse import fill_diag, mul
from torch_sparse import sum as sparsesum
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from copy import deepcopy
def norm_adj(adj_t, add_self_loops=True):
"""
normalization adj
"""
if not adj_t.has_value():
adj_t = adj_t.fill_value(1.)
if add_self_loops:
adj_t = fill_diag(adj_t, 1.)
deg = sparsesum(adj_t, dim=1)
deg_inv_sqrt = deg.pow_(-0.5)
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1))
adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1))
return adj_t
class GCNConv(nn.Module):
def __init__(self, in_feats, out_feats, bias=False) -> None:
super().__init__()
self.lin = nn.Linear(in_feats, out_feats, bias)
def forward(self, x, adj):
x = self.lin(x)
return adj.matmul(x)
class GCN(nn.Module):
def __init__(self, in_feats, hidden_size, out_feats) -> None:
super().__init__()
self.gcn_conv1 = GCNConv(in_feats, hidden_size)
self.gcn_conv2 = GCNConv(hidden_size, out_feats)
def forward(self, x, adj):
x = self.gcn_conv1(x, adj)
x = F.relu(x)
x = self.gcn_conv2(x, adj)
return F.log_softmax(x, dim=1)
if __name__ == "__main__":
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = Planetoid("temp", name="Cora", transform=T.ToSparseTensor())
data = dataset[0].to(device)
model = GCN(in_feats=dataset.num_features,
hidden_size=16, out_feats=dataset.num_classes)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
best_acc, best_model = 0., None
model.train()
for epoch in range(600):
optimizer.zero_grad()
out = model(data.x, norm_adj(data.adj_t))
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
valid_acc = (out[data.val_mask].argmax(dim=1)
== data.y[data.val_mask]).sum()
if valid_acc > best_acc:
best_acc = valid_acc
best_model = deepcopy(model)
if (epoch + 1) % 50 == 0:
print(f"Epoch {epoch + 1}: loss: {loss.item()}")
loss.backward()
optimizer.step()
best_model.eval()
pred = best_model(data.x, norm_adj(data.adj_t)).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')
"""
Epoch 50: loss: 1.4066219329833984
Epoch 100: loss: 0.7072957754135132
Epoch 150: loss: 0.3014388084411621
Epoch 200: loss: 0.14514322578907013
Epoch 250: loss: 0.08141915500164032
Epoch 300: loss: 0.05128778889775276
Epoch 350: loss: 0.0351191945374012
Epoch 400: loss: 0.02553318254649639
Epoch 450: loss: 0.01940452679991722
Epoch 500: loss: 0.015253005549311638
Epoch 550: loss: 0.012309947051107883
Epoch 600: loss: 0.010146616026759148
Accuracy: 0.8090
"""
结语
参考资料:
稀疏矩阵的形式在大图上是非常有用的,它能节约更多的计算资源,同时矩阵乘法也更高效。因此倘若你的GCN能够用稀疏矩阵乘法的形式进行表达,采用这种方式不失为一种良策。
以上便是本文的全部内容,要是觉得不错的话就点个赞或关注一下博主吧,你们的支持是博主继续创作的不解动力,当然若是有任何问题也敬请批评指正!!!