How to Build Graph Transformers with O(N) Complexity

Author:Murphy  |  View: 20179  |  Time: 2025-03-23 18:52:38

Building powerful graph Transformers has become a trending topic in the graph Machine Learning community, as a surge of recent efforts have shown that pure Transformer-based models can perform competitively or even superiorly on quite a few GNN benchmarks (see some typical works along this direction [1, 2, 3]).

The challenge, however, is that the key design of Transformers [4], i.e., attention mechanism, often requires quadratic complexity w.r.t. the input tokens. In the context of graph learning, the input tokens for Transformers are nodes in a graph, and the global attention aimed at capturing long-range interactions among nodes is hard to scale for graphs with arbitrary numbers of nodes. For example, on the common node classification dataset Pubmed (with ~10K nodes), running a one-layer single-head Transformer with all-pair attention in a GPU with 16GB memory is infeasible.

This tutorial will introduce two recent scalable graph Transformers [5, 6] that design special global attention mechanisms with linear complexity w.r.t. the number of tokens (nodes). The goal of this tutorial is to provide hands-on guidance on:

  1. how the linear complexity is achieved when preserving the all-pair attention;
  2. how the new attention functions are implemented using Pytorch codes.

These are complementary to the published scientific papers that focus on the high-level idea description.


Where does the O(N²) Comes From?

Transformers can be seen as a generalization of graph neural networks (GNNs) where the fixed adjacency matrix in GNNs is extended to a variable attention matrix in Transformers. From a graph view, GNNs operate message passing over a fixed observed graph (that often has sparse connection), while the message passing of Transformers is anchored on a densely connected latent graph (whose edge weights are generated by pairwise attention scores.)

Comparison of GNNs and Transformers in terms of message passing over different structures: GNNs propagate signals over a sparse observed graph, while Transformers can be seen as propagating signals over a densely connected graph with layer-wise edge weights. The latter requires estimation for the N*N attention matrix and feature propagation over such a dense matrix.

We next recapsulate the standard attention computation in original Transformer [4]. The embeddings at the current layer are first mapped into query, key and value vectors, and then calculate the all-pair attention for feature aggregation:

We use z to denote the node embeddings, and q, k and v to denote the query, key and value vectors, respectively. The W_Q, W_K and W_V are learnable weights at the k-th layer.

Since the computation of the above updating requires O(N), the total complexity for updating N nodes in one layer would require O(N²). A more intuitive way to see the O(N²) complexity can be from a matrix view which is practically considered for implementation using the deep learning tools (e.g., Pytorch, Tensorflow, etc.). In specific, we can illustrate the computation flow of one attention layer below.

The left part illustrates the global attention layer from a matrix view, and the right part presents the corresponding data flow where the matrix product marked by red color introduces O(N²) complexity.

The attention layer above can be easily implemented with PyTorch (here we use the "einsum" function that is a generalized matrix product implementation, see here for its detailed information):

# qs: [N, H, D], ks: [L, H, D], vs: [L, H, D]

attn = torch.einsum("nhd,lhd->nlh", qs, ks)  # [N, L, H]
attn = torch.softmax(attn, dim=1) # [N, L, H]
z_next = torch.einsum("nlh,lhd->nhd", attn, vs)  # [N, H, D]

While the quadratic complexity is cumbersome, we next introduce two effective ways that can strictly reduce the O(N²) to O(N) and more importantly, still preserve the expressivity for explicitly modeling the all-pair influence.


NodeFormer: Kernelized Softmax Message Passing

The recent work NodeFormer [5] (a scalable graph structure learning Transformer for node classification) leverages the random Fourier features [8, 9] to convert the dot-then-exponential operation into a map-then-dot alternative by kernel approximation:

The phi function here is a non-parametric random feature map where the feature dimension m controls the approximation power for the original exponential term. See reference [8] for more introduction on the random feature map and theoretical its properties.

In this way, the original Softmax attention can be transformed into an efficient version:

The derivation from the LHS to the RHS is according the basic association rule of matrix product that changes the order of matrix product

Notice that in the RHS, the two summation terms over N nodes are independent from node u, which means they can be re-used by all the nodes after once computation. Thereby, for updating N nodes in each layer, one can first spend O(N) to compute the two summation terms and based on this, the computation for the next-layer embeddings of all the N nodes would only require O(N) which induce the total complexity O(N).

To better understand how the linear complexity is achieved, we can write the matrix-form computation flow as follows.

The left part illustrates the global attention layer of NodeFormer from a matrix view, and the right part presents the corresponding data flow where the matrix products marked by red color are the computation bottleneck which requires O(Nmd).

Notice that the order of matrix product plays an important role in reducing the complexity. In the above computation, the N*N attention matrix is avoided though we successfully achieve all-pair attentive aggregation. The exact complexity of one layer is O(Nmd). Since for large graphs, N is often orders-of-magnitude larger than m and d, the computational efficiency can be significantly improved in practice. For example, NodeFormer with three Transformer layers only requires 4GB GPU memory for computing the all-pair attention among 0.1M nodes. Below is the Pytorch codes for implementing the above efficient all-pair attention. The complete open-source model implementation is publicly available at GitHub.

# qs: [N, H, D], ks: [L, H, D], vs: [L, H, D]

qs = softmax_kernel(qs) # [N, H, M]
ks = softmax_kernel(ks) # [L, H, M]

# numerator
kvs = torch.einsum("lhm,lhd->hmd", ks, vs)
attn_num = torch.einsum("nhm,hmd->nhd", qs, kvs) # [N, H, D]

# denominator
all_ones = torch.ones([ks.shape[0]])
ks_sum = torch.einsum("lhm,l->hm", ks, all_ones)
attn_den = torch.einsum("nhm,hm->nh", qs, ks_sum)  # [N, H]

# attentive aggregated results
z_next = attn_num / attn_den # [N, H, D]

DIFFormer: Simplified Attention Computation

The lesson we can learn from NodeFormer is that the crux of complexity reduction lies in the order of matrix product w.r.t. attentive aggregation. We can next leverage this idea to design another efficient attention function without any stochastic approximation, i.e., the simple attention in DIFFormer [6] (a. k. a. simple diffusivity model in the original paper motivated from the diffusion over latent structures).

Our observation is from the Taylor expansion of the exponential function that can be used to motivate a new attention function:

Notice that though the new attention function is motivated from the first-order Taylor expansion of e^x, it is not required to be a well-posed approximation for the original Softmax attention. That being said, we found it works stably well and in practice through extensive experiments.

This new attention layer can be efficiently computed using linear complexity due to that we can inherit the trick to re-order the matrix product:

Again, the two summation terms in the RHS are shared by all the nodes and thereby only need once computation. To clearly see the O(N) complexity, we can write down the computation flow with a matrix view.

The computation bottleneck of matrix products is marked by red color in the right part, inducing O(Nd²) complexity. Notice again that d is orders-of-magnitude smaller than N in practice: for example, d could range from 32 to 256, while N could be up to million or even billion.

The following shows the Pytorch implementation of one-layer DIFFormer's simple attention, and the complete model implementation is publicly avaialble at GitHub. In particular, when equipped with the simple attention, DIFFormer (a. k. a. DIFFormer-s in the original paper) can scale to large-scale graphs with millions of nodes.

# qs: [N, H, D], ks: [L, H, D], vs: [L, H, D]

qs = qs / torch.norm(qs, p=2) # [N, H, D]
ks = ks / torch.norm(ks, p=2) # [L, H, D]
N = qs.shape[0]

# numerator
kvs = torch.einsum("lhm,lhd->hmd", ks, vs)
attn_num = torch.einsum("nhm,hmd->nhd", qs, kvs) # [N, H, D]
all_ones = torch.ones([vs.shape[0]])
vs_sum = torch.einsum("l,lhd->hd", all_ones, vs) # [H, D]
attn_num += vs_sum.unsqueeze(0).repeat(vs.shape[0], 1, 1) # [N, H, D]

# denominator
all_ones = torch.ones([ks.shape[0]])
ks_sum = torch.einsum("lhm,l->hm", ks, all_ones)
attn_den = torch.einsum("nhm,hm->nh", qs, ks_sum)  # [N, H]

# attentive aggregated results
attn_den = torch.unsqueeze(attn_den, len(attn_den.shape))  # [N, H, 1]
attn_den += torch.ones_like(attn_den) * N
z_next = attn_num / attn_den # [N, H, D]

References

[1] Chengxuan Ying et al., Do Transformers Really Perform Bad for Graph Representation?, NeurIPS 2021.

[2] Ladislav Rampášek et al., Recipe for a General, Powerful, Scalable Graph Transformer, NeurIPS 2022.

[3] Jinwoo Kim et al., Pure Transformers are Powerful Graph Learners, NeurIPS 2022.

[4] Ashish Vaswani et al., Attention is All you Need, NeurIPS 2017.

[5] Qitian Wu et al., NodeFormer: A Scalable Graph Structure Learning Transformer for Node Classification, NeurIPS 2022. This paper proposes an efficient Transformer for large node classification graphs. The key design is the kernelized Softmax message passing that achieves linear complexity w.r.t. number of nodes, and furthermore, the authors extend the kernel trick to Gumbel-Softmax that can learn sparse latent structures from a potentially all-pair connected graph.

[6] Qitian Wu et al., DIFFormer: Scalable (Graph) Transformers Induced by Energy Constrained Diffusion, ICLR 2023. This work designs a scalable graph Transformer whose attention functions are motivated from diffusivity estimates for diffusion over latent structures. In terms of the model architecture, DIFFormer generalizes the key idea used in NodeFormer for achieving O(N) complexity, and therefore, can be seen as the 2.0 version of NodeFormer.

[7] A round-up of linear transformers This blog introduces several typical strategies in recent efficient Transformers that successfully reduce the attention complexity to O(N), e.g., low-rank approximation, local-global attention and using softmax as a kernel.

[8] _Ali Rahimi and Benjamin Recht. Random features for large-scale kernel machines, NeurIPS 2007._ This early work introduces random feature map as an effective approximation technique for computation over large numbers of data points, along with its theoretical properties.

[9] Fanghui Liu et al., Random Features for Kernel Approximation- A Survey on Algorithms, Theory, and Beyond, IEEE TPAMI 2022. This survey summarizes an exhaustive set of different random features for kernel approximation and discusses their different properties and applicability.

All images unless otherwise noted are by the author.

Tags: Geometric Deep Learning Machine Learning Math Thoughts And Theory Transformers

Comment