Jumping-Knowledge Representation Learning With LSTMs
Background
As I mentioned in my previous post on constrained graph attention networks, graph neural networks suffer from overfitting and oversmoothing as network depth increases. These issues can ultimately be linked to the local topologies of the graph.
If we consider a 2d image as a graph (i.e. pixels become nodes), we see that images are highly regular – that is, each node has the same number of neighbors, except for those at the image periphery. When we apply convolution kernels over node signals, filters at any given layer are aggregating information from the same sized neighborhods irrespective of their location.
However, if we consider a graph, there is no guarantee that the graph will be regular. In fact, in many situations, graphs are highly irregular, and are characterized by unique topological neighborhood properties such as tree-like structures or expander graphs, that are sparse yet highly connected. If we compare an expander node to a node whose local topology is more regular, we would find that the number of signals that each node implicitly convolves at each network layer would vary considerably. These topological discrepancies have important implications when we consider problems like node and graph classification, as well as edge prediction. The problem ultimately boils down to one of flexibility: can we account for unique local topologies of a graph in order to dynamically aggregate local information on a node-by-node basis?

In a recent paper, the authors propose one approach to address this question, which they call “jumping knowledge representation learning”1. Instead of utilizing the output of the last convolution layer to inform the prediction, jumping-knowledge networks aggregate the embeddings from all hidden layers to inform the final prediction. The authors develop an approach to study the “influence distribution” of nodes: for a given node
They show that influence distribution
The jumping knowledge network architecture is conceptually similar to other graph neural networks, and we can, in fact, simply incorporate the jumping knowledge mechanism as an additional layer. The goal is to adaptively learn the effective neighborhood size on a node-by-node basis, rather than enforcing the same aggregation radius for every node (remember, we want to account for local topological and feature variations). The authors suggest three possible aggregation functions: concatentation, max-pooling, and an LSTM-attention mechanism 1 3. Each aggregator learns an optimal combination of the hidden embeddings, which is then pushed through a linear layer to generate the final network output. Concatenation determines the optimal linear combination of hidden embeddings for the entire dataset simultaneously, so it is not a node-specific aggregator. Max-pooling selects the most important hidden layer for each feature element on a node-by-node basis – however, empirically, I found that max-pooling was highly unstable in terms of model learning. The LSTM-attention aggregator treats the hidden embeddings as a sequence of elements, and assigns a unique attention score to each hidden embedding 4.
Long-Short Term Memory
Briefly, given a sequence of samples

where
The embeddings for each element learned by the LSTM cell are represented by
The softmax-normalized attention weights represent a probability distribution over attention weights
where
An Application of Jumping Knowledge Networks to Cortical Segmentation
I’ve implemented the jumping knowledge network using DGL here. Below, I’ll demonstrate the application of jumping knowledge representation learning to a cortical segmentation task. Neuroscientifically, we have reason to believe that the influence radius will vary along the cortical manifold, even if the mesh structure is highly regular. As such, I am specifically interested in examining the importance that each node assigns to the embeddings of each hidden layer. To that end, I utilize the LSTM-attention aggregator. Similarly, as the jumping-knowledge mechanism can be incorporated as an additional layer to any general graph neural network, I will use graph attention networks (GAT) as the base network architecture, and compare vanilla GAT performance to GATs with a jumping knowledge mechanism (JKGAT).
Below, I show the prediction generated by a 9-layer JKGAT model, with 4 attention heads and 32 hidden channels per layer, with respect to the “known” or “true” cortical map. We find slight differences in the performance of our JKGAT model with respect to the ground truth map, notably in the lateral occipital cortex and the medial prefrontal cortex.

When we consider the accuracies for various parameterizations of our models, we see that the JKGAT performs quite well. Notably, it performs better than the GAT model in most cases. Likewise, as hypothesized, the JKGAT performs better than the GAT model as network depth increases, specifically because we are able to dynamically learn the optimal influence radii for each node, rather than constraining the same radius size for the entire graph. This allows us to learn more abstract representations of the input features by mitigating oversmoothing and by accounting for node topological variability, which is important for additional use-cases like graph classification.

Similarly, we find that JKGAT networks generate segmentation predictions that are more reproducible and consistent across resampled datasets. This is important, especially in the case where we might acquire data on an individual multiple times, and want to generate a cortical map for each acquisition instance. Unless an individual suffers from an accelerating neurological disorder, experiences a traumatic neurological injury, or the time between consecutive scans is very long (on the order of years), we expect the cortical map of any given individual to remain quite static (though examining how the “map” of an individual changes over time is still an open-ended topic).

Finally, when we consider the importance that each cortical node assigns to the unique embedding at the
Let us consider the attention map for layer 4. We can interpret the maps as follows: for a given network architecture (in this case, a network with 9 layers), we find that areas in the primary motor (i.e. Brodmann areas 3a and banks of area 4) and primary auditory cortex (Broddmann areas A1 and R1) preferentially attend to the embedding of hidden layer 4, relative to the rest of the cortex – this indicates that the implicit aggregation over an influence radius of 4 layers is deemed more informative for the classification of nodes in the primary motor and auditory regions than for orther cortical areas. However, whether this says anything about the implicit complexitiy of the cortical signals of these areas remains to be studied.
Xu et al. Representation Learning on Graphs with Jumping Knowledge Networks. 2018. ↩︎
Dinitz et a. Large Low-Diameter Graphs are Good Expanders. 2017. ↩︎
Lutzeyer et al. Comparing Graph Spectra of Adjacency and Laplacian Matrices. 2017. ↩︎
Gers, Felix. Long Short-Term Memory in Recurrent Neural Networks. 2001. ↩︎
Fan et al. Comparison of Long Short Term Memory Networks and the Hydrological Model in Runoff Simulation. 2020. ↩︎