Paper notes: Efficiently Scaling Transformer Inference


Efficient inference of massive transformers is difficult and important. The paper (Efficiently Scaling Transformer Inference) from Google discusses their approach to inference of their massive transformer model in generous (perhaps surprising) detail. This post will explain the paper with some of my own notes. All data, tables, and images are from the paper! I’ll assume some familiarity with the decoder-only transformer architecture and MPI collective communication routines (intro here) as well as the basics of Transformer inference (e.g. KV cache).

Why is autoregressive transformer inference nontrivial and qualitatively different from training? At training time the transformer is ran once for each sequence of length S to produce S logits, from which you compute gradients and update your model. But at inference time, you need to sample one token at a time to generate your sequence, which is also variable in length. This sequential sampling reduces opportunities for parallelism, making high chip utilization difficult. When serving an application, latency is also a concern that is not relevant for training or batch processing.

Most of the literature currently discusses parallelism and huge engineering scale in the context of training. Inference performance work is often related to quantization and various tricks applied to mid-size models. But its very important! And underrated by academics who generally do not have users. Inference is a privileged problem, but once you find customers / product-market-fit / AGI, inference compute may well be the bottleneck on your company/public benefit corporation/non-profit’s growth. Various chat services already seem to feel this problem! So this paper is a welcome rigorous analysis of (dense AR decoder only) massive transformer inference (perhaps it is one of the last).

The goal of the paper is to find the best way to partition huge models across large numbers of chips to optimize throughput or latency in various scenarios. In order to do so, the authors describe communication costs analytically and give a general notation and set of principles to minimize the communication cost in the different settings of transformer inference, and they appear successful in doing so.

Some background:

To optimize inference performance, authors develop an analytical model of performance under different partitioning strategies, and solve for the best strategy for a given workload. They also describe various other optimizations, but that is not the focus of their work.

Since the authors work(ed) at Google, the paper describes optimizations for PALM <= 540B, a dense transformer, on a 3D torus of up to 256 TPUv4 chips with 270GB/s network bandwidth. They achieve 29ms/tok latency @ int8 on 64 chips for a latency-sensitive applicaiton, and 76% MFU (Model FLOPs Utilization) for large batch size processing @ 2048 sequence length.

Understanding Inference Cost

As the authors describe, there are three interesting metrics at inference time:

Inference latency can be broken down into two stages, prefill and decoding. During prefill, the entire input prompt is passed through the model to produce logits for the first sampled token and initiate the kv cache. The sampling of subsequent tokens is called decoding, and during decoding a single query token is passed through the network and the kv cache is used to compute attention outputs.

Memory Costs

Memory costs are basically weights and KV cache. There are other tensors that take memory or need to be loaded into the compute cores (activations, communication buffers) but they are much smaller. Every element from both of these needs to be moved from HBM to the compute cores once per forward pass (either prefill or decode, although the kv cache is empty @ prefill–in its place the k/v vectors are large activations)

At what batch and sequence length does the memory cost to load the KV cache approach the cost to load the parameters? Assuming PaLM-1 with MQA and two bytes per float, for one forward pass:

$$ \begin{align*} \text{Parameter Mem} &= 2 \cdot N = 1080 \text{GB}\\\ \text{KV Cache Mem} &= 2 \cdot 2 \cdot B \cdot S \cdot L \cdot d_h \text{ bytes} \end{align*} $$

$$ \begin{align*} 4 \cdot B \cdot S \cdot L \cdot d_{head} \text{ bytes}&= 1080 \text{GB} \\\ B \cdot S &= \frac{1080 GB}{4 \cdot 118 \cdot 384 \text{ B}}\\\ B \cdot S &\approx 59500 \\\ \end{align*} $$

The memory cost to load the kv cache equals the cost to load the model parameters at a total token count of around $59500$ for PaLM-1. For a sequence length of 2048, this corresponds to a batch size of around 290.

Compute Costs

Each matmul parameter participates in two flops per token per forward pass. The matrix multiplications to compute QK.T@V given Q, K, and V are much smaller than the matmul parameters so they are generally excluded.

Partitioning Notation

The authors re-introduce a partitioning notation for describing the sharding of tensors across a multidimensional grid of chips. It is quite powerful and extends the concept of the “shape” of a tensor. A tensor $BLE$ has shape $(B, L, E)$. Chips are arranged in a 3d torus $X \times Y \times Z$. A tensor $BLE_{x}$ has dimension $E$ partitioned over the $x$ axis of chips, e.g. every tensor on a chip has shape $(B, L, \frac{E}{X})$. Note that the $B$ and $L$ dimensions are the same on every chip: no subscript $\implies$ no partitioning. Additionally, the authors indicate that a tensor which needs to be summed over a dimension $x$ with the suffix “partialsum-$x$”. For example, a tensor of shape $BE \text{ partialsum-x}$, means that it is a tensor with shape $BE$ such that it equals $\Sigma_{\text{chip} \in x} BE_{\text{chip}}$.

How do these shapes change with communication and compute? Let’s examine how MPI collective communication routines look like in this notation:

How does this notation describe compute, e.g. matrix multiplications? Recall a tensor of shape $HL$ can be multiplied by any matrix $BH$ to produce a tensor $BL$ where $B$ and $L$ stand in for arbitrarily dimensions. However, if the inner dimension can also be multiplied if they are sharded in the same way, e.g. $H_{xy}L$ can be multiplied with $BH_{xy}$, however the result is not $BL$ but instead $BL \text{ partialsum-xy}$. To see this, note that matrix multiplies are a series of dot products of vectors of size $H$, partitioning the inner vectors to $H_{xy}$ results in partial results on each chip. After summation, you get the final result – this is the same because of the linearity of the dot product. This reasoning can be extended to einsums in general.

Armed with terse notation, we are ready to analyze partitioning the transformer model in general. The transformer (mostly) consists of MLP blocks which apply pointwise to tokens and attention blocks which combine information between tokens. Their behavior is quite different so we will analyze their partitioning separately.

Partitioning The MLP

Due to the size of the models, remember that whatever strategy we pursue must partition the weights of the model across at least N / chip_memory chips (in int8). To fit a 540B param model onto 32GB cards, we need to shard the weights across at least 17 chips to run the model at all. So let’s begin by sharding the weights.

1D Weight Stationary

The baseline way to partition the Transformer MLP comes from the original Megatron Paper. We can avoid communicating between the two matmuls if we shard the first along its rows, and the second along its columns. The explanation in the original paper is great, but at a glance assume we have an input vector $\vec{x}$ and would like to compute the MLP output $y = \mathbf{B}\text{GeLU}(\mathbf{A}\vec{x})$ across two chips. We can write

$$ \mathbf{A} = \begin{bmatrix} \mathbf{A}_0 \\\ \mathbf{A}_1 \end{bmatrix} $$

where we partition A across its rows on two chips. Lets multiply it with x and apply an activation.

$$\text{GeLU}(\mathbf{A}\vec{x}) = \text{GeLU}( \begin{bmatrix} \mathbf{A}_0 \\\ \mathbf{A}_1 \end{bmatrix} \vec{x}) \Rightarrow \begin{bmatrix} \text{GeLU}(\mathbf{A}_0\vec{x}) \\\ \text{GeLU}(\mathbf{A}_1\vec{x}) \end{bmatrix} $$

Now to multiply the row-sharded post-GeLU activations by the column sharded $\mathbf{B} = [\mathbf{B}_0 \mathbf{B}_1]$ we don’t have to do any cross-chip communication!

$$ \begin{align*} y &= [\mathbf{B}_0 \mathbf{B}_1] \begin{bmatrix} \text{GeLU}(\mathbf{A}_0\vec{x}) \\\ \text{GeLU}(\mathbf{A}_1\vec{x}) \end{bmatrix} \\\ y &= \mathbf{B}_0 \text{GeLU}(\mathbf{A}_0\vec{x}) + \mathbf{B}_1 \text{GeLU}(\mathbf{A}_1\vec{x}) \end{align*} $$

So the second matmul can be done on each chip locally, then we need a reduce-scatter to distribute the partitioned full output to each chip (An all reduce would produce the full $BLE$ activations on each chip, but this scheme partitions activations between layers). Note that we need the exact same shape between the input and the output in order to add the residual.

Let’s examine the situation above with the notation described earlier:

The sharding notation allows for pleasing and terse demonstrations!

This partitioning strategy is very similar to the one called 1D weight-stationary in the paper. The difference is that in the Megatron partitioning strategy described above, each chip has the full activations at the beginning of each MLP, simplifying layernorm and residual calculations. 1D weight-stationary, however, partitions activations on the last dimension (since they may be too large to fit on one chip), and the embeddings/LM head will need to be adjusted accordingly and communication will need to be introduced into LayerNorm.

The name 1D weight-stationary comes from the weights being 1. stationary 2. partitioned on one tensor dimension. According to the paper, the communication cost per layer is $\frac{2BLE}{\text{network bandwidth}}$, as the all-gather and the reduce-scatter each communicate about $BLE$ elements. NB: I believe this accounting assumes network bandwith in units of elements rather than bytes.

2D Weight Stationary

We can partition the MLP weight matrices on both axes. This introduces communication before and after the activation, but as we will see can be more efficient. I’ll ignore the $z$ axis of the torus for clarity:

Our input activations are $BLE_{xy}$. Our weights are $E_xF_y$ and $F_yE_x$ (2d sharding, as promised). We then:

In the paper, $yz$ takes the place of $y$ above, i.e. they are always together.

Why can this be more efficient? Each communication primitive communicates less volume proportional to the square root of the number of chips. For example, in 1d weight stationary we would have $BLE_xyz \rightarrow BLE$ via an all-gather. This costs $BLE$, since all gather communication scales linearly to the size of the output of the all gather. But in the 2D case we have $BLE_xy \rightarrow BLE_x$ (again, ignoring $z$ for now), and the cost is then $\frac{BLE}{X}$. A similar logic applies to all of the other communications, which all apply over one dimension of (X, Y, Z). Assuming the standard $d_{ff} = 4 \cdot d_{model}$ the communication cost is $\frac{8BLE}{\sqrt{n_{chips}} \cdot \text{network bandwidth}}$, so the communication cost is lower than the 1D case when $n_{chips} > 16$.

Analysis of this strategy is less clean when using NVIDIA GPU’s, since typically you cannot pack more than 8 chips per node, and inter-node network bandwidth is much slower than intra-node bandwidth, though the DGX GH200 appears to address this with 900GB/s of interconnect between 256 H200s.

Weight Gathered

The previous approaches communicated activations, with the implicit assumption that weights are much larger than activations. But as batch size and sequence length grow large, this may not be the case. Then it will be cheaper to communicate weights rather than activations. The analogy to weight-gathering in training is ZeRO stage 3 or FSDP, where, in the forward pass weights are all-gathered in the forward before use and then discarded.

At very large batch sizes the authors use a pure weight gathered approach (termed “x-weight-gathered”), but at moderate batch sizes it is optimal to do both 2D activation partitioning and 1D weight partitioning. Since this is the MLP, sequence length is only long during prefill. It is 1 during decoding. And decoding occurs after prefill, so the same weight layout should be used in prefill then decoding as it is for inference. This strategy is called “XY-weight-gathered”. Let’s go through this one in the notation, as before:

Our input activations are $B_{xy}LE_{z}$. Our weights are $E_xF_{yz}$ and $F_{yz}E_x$. Fret not, for they soon will be gathered:

Other options for weight gathering include all gathering over $x$ or $xyz$ with $B_xLE_{yz}$ and $B_{xyz}LE$ respectively, the configuration is analogous to the $xy$ configuration described above. The $xyz$ weight gathered setup has a constant communication cost–try to derive this!

Here is a visual summary from the paper on three of the configurations described:

Sharding MLP

The optimal configuration, the one with the minimal communication cost for the MLP, changes as the tokens per batch increases, resulting in this elegant pareto frontier:

MLP communication pareto frontier

Partitioning The Attention Layer

Multihead attention can be similarly parallelized due to the similar structure to the MLP (matmul, independent computations over heads, then matmul again). However, the KV cache will dominate costs with full multihead attention, most transformers these days use multiquery or grouped query attention, where there are n_heads for the query tensor, but the key and value heads are replicated. In multiquery attention, every query head shares the same key/value head, whereas grouped query attention relaxes this to produce group_size kv heads. This architectural improvement is designed to reduce the memory impact of the KV cache, as it reduces its size by a factor of n_heads / group_size where group size = 1 for multiquery attention.

Even with multiquery attention, loading the KV cache dominates memory time during decoding.Unfortunately, this performance improvement eliminates an opportunity for parallelism over the head dimension. To avoid k/v replication, we can only really shard the computation over the batch dimension. This does incur an additional all-to-all communication operation, but otherwise the $W_Q$ and $W_O$ weights behave as they do in the MLP and in the attention mechanism everything is sharded as $B_{xyz}…$ to minimize time to load the KV cache.

Sharding attention

During prefill, however, this all-to-all communication is more expensive and there is no KV cache to load. So the attention computation is sharded over heads, and the K/V vectors are replicated for multihead attention.

Additional Optimizations

The authors further improve performance via:

They reserve 30% of HBM (chip memory) for the KV cache, but do not address how it is allocated or whether it is used efficiently a la VLLM.

And the final results, in graph form:

Prefill and Decoding cost vs latency curves

And in table form:

Performance table for 62B model Performance table for 540B model

The left table shows performance numbers for the 62B model, and the right for the 540B model. Curiously, the MFU for latency sensitive batched decoding is much lower than training or throughput-oriented inference–at around 10% on TPU’s with a lot of engineering effort! Memory and communication are still the bottleneck. There is certainly performance on the table, perhaps with quantized KV cache / activations + speculative decoding tricks. Additionally, I suspect the most interesting applications will produce more than 64 tokens of output, and most of the time will be spent in the decoding phase.

At least for google, the lowest latencies are at pretty high chip counts (>8 at least). Adapting this work to a two-tier communications model and reproducing the pareto frontier with NVIDIA GPU’s would certainly be interesting, as would combining this with Pipeline Parallelism as the work is mostly orthogonal.


Inference is important. How does one best run large transformer for inference? Communication/memory is the bottleneck, rigorously analyze the communication + memory costs of various partitioning strategies for MLP/attention. Optimal partitioning strategy and batch/chip count changes with tokens per forward pass, prefill or decoding, and optimizing for latency or throughput. Specific results are mostly only applicable to TPU’s but one can improve on basic tensor parallelism for inference.