Constrained Graph Attention Networks
In their recent paper, Wang et al. propose a few updates to the Graph Attention Network (GAT) neural network algorithm (if you want to skip the technical bit and get to the code, click here). Briefly, GATs are a recently-developed neural network architecture applied to data distributed over a graph domain. We can think of graph convolutional networks as progressively transforming and aggregating signals from within a local neighborhood of a node. At each iteration of this process, we implicitly merge signals from larger and larger neighborhoods of the node of interest, and thereby learn unique representations of nodes that are dependent on their surroundings.
GATs incorporate the seminal idea of “attention” into this learning process. In each message-passing step, rather than updating the features of a source-node via equally-weighted contributions of neighborhood nodes, GAT models learn an attention function – i.e. they learn how to differentially pay attention to various signals in the neighborhood. In this way, the algorithm can learn to focus on imporant signals and disregard superfluous signals. If we consider neural networks as universal funtion approximators, the attention mechanism improves the approximating ability by incorporating multiplicative weight factors into the learning.

However, GATs are not without their pitfals, as noted by Wang et al. Notably, the authors point to two important issues that GATs suffer from: overfitting of attention values and oversmoothing of signals across class boundaries. The authors propose that GATs overfit the attention function because the learning process is driven only by classification error, with complexity
where
Wang et al. propose to incorporate two margin-based constraints into the learning process. The first constraint,
The second constraint,
In both cases,
The authors propose one final addition to alleviate the oversmoothing issue posed by vanilla GATs. Rather than aggregating over all adjacent nodes in a neighborhood, the authors propose to aggregate over the nodes with the
Implementation
I wasn’t able to find an implementation of the Constrained Graph Attention Network for my own purposes, so I’ve implemented the algorithm myself in
Deep Graph Library (DGL) – the source code for this convolutional layer can be found
here. This implementation makes use of the original DGL GATConv
layer structure, with modifications made for the constraints and aggregations. Specifically, the API for CGATConv
has the following modifications:
CGATCONV(in_feats,
out_feats,
num_heads,
feat_drop=0.,
graph_margin=0.1, # graph structure loss slack variable
class_margin=0.1, # class boundary loss slack variable
top_k=3, # number of messages to aggregate over
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False)
Of note is the fact that the attn_drop
parameter has been substituted by the top_k
parameter in order to mitigate oversmoothing, and the two slack variables graph_margin
and class_margin
.
With regards to the loss functions, the authors compute all-pairs differences between all edges incident on a source node, instead of summing over the positive / negative sample attentions (graph_loss
reduction function below:
def graph_loss(nodes):
"""
Loss function on graph structure.
Enforces high attention to adjacent nodes and
lower attention to distant nodes via negative sampling.
"""
msg = nodes.mailbox['m']
pw = msg[:, :, :, 0, :].unsqueeze(1)
nw = msg[:, :, :, 1, :].unsqueeze(2)
loss = (nw + self._graph_margin - pw).clamp(0)
loss = loss.sum(1).sum(1).squeeze()
return {'graph_loss': loss}
.
.
.
graph.srcdata.update({'ft': feat_src, 'el': el})
graph.dstdata.update({'er': er})
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e'))
# construct the negative graph by shuffling edges
# does not assume a single graph or blocked graphs
# see cgatconv.py for ```construct_negative_graph``` function
neg_graph = [construct_negative_graph(i, k=1) for i in dgl.unbatch(graph)]
neg_graph = dgl.batch(neg_graph)
neg_graph.srcdata.update({'ft': feat_src, 'el': el})
neg_graph.dstdata.update({'er': er})
neg_graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
ne = self.leaky_relu(neg_graph.edata.pop('e'))
combined = th.stack([e, ne]).transpose(0, 1).transpose(1, 2)
graph.edata['combined'] = combined
graph.update_all(fn.copy_e('combined', 'm'), graph_loss)
# compute graph structured loss
Lg = graph.ndata['graph_loss'].sum() / (graph.num_nodes() * self._num_heads)
Similarly, the class boundary loss function
def adjacency_message(edges):
"""
Compute binary message on edges.
Compares whether source and destination nodes
have the same or different labels.
"""
l_src = edges.src['l']
l_dst = edges.dst['l']
if l_src.ndim > 1:
adj = th.all(l_src == l_dst, dim=1)
else:
adj = (l_src == l_dst)
return {'adj': adj.detach()}
def class_loss(nodes):
"""
Loss function on class boundaries.
Enforces high attention to adjacent nodes with the same label
and lower attention to adjacent nodes with different labels.
"""
m = nodes.mailbox['m']
w = m[:, :, :-1]
adj = m[:, :, -1].unsqueeze(-1).bool()
same_class = w.masked_fill(adj == 0, np.nan).unsqueeze(2)
diff_class = w.masked_fill(adj == 1, np.nan).unsqueeze(1)
difference = (diff_class + self._class_margin - same_class).clamp(0)
loss = th.nansum(th.nansum(difference, 1), 1)
return {'boundary_loss': loss}
.
.
.
graph.ndata['l'] = label
graph.apply_edges(adjacency_message)
adj = graph.edata.pop('adj').float()
combined = th.cat([e.squeeze(), adj.unsqueeze(-1)], dim=1)
graph.edata['combined'] = combined
graph.update_all(fn.copy_e('combined', 'm'), class_loss)
Lb = graph.ndata['boundary_loss'].sum() / (graph.num_nodes() * self._num_heads)
And finally, the constrained message aggregation is implemented using the following reduction function:
def topk_reduce_func(nodes):
`"""
Aggregate attention-weighted messages over the top-K
attention-valued destination nodes
"""
K = self._top_k
m = nodes.mailbox['m']
[m,_] = th.sort(m, dim=1, descending=True)
m = m[:,:K,:,:].sum(1)
return {'ft': m}
.
.
.
# message passing
if self._top_k is not None:
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
topk_reduce_func)
else:
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))