The Ultra-Scale Playbook: Training LLMs on GPU Clusters

FineWeb

Fueled by the scaling laws, the trend of training ever larger language models on vaster amounts of data has been driving progress in AI for the past couple years. Initially, the development of the largest models happened exclusively behind closed doors of a handful of research labs but recently opened up more with the release of models such as Llama 3.1 405B and DeepSeek R1. While these models have openly shared weights and their training recipes are described in technical reports, the challenging engineering to involved to train at the necessary infrastructure scale is still hidden between the lines of a handful of papers and complex training frameworks. This ~~long blog post~~ open-source book is here to open this black box!

In this book we invite you to follow us in the wonderful world of scaling training of Large Language Models to tens, hundreds, thousands of GPUs. It assumes you know the basics on LLM architecture and training, but are new to distributed training. This writing can be seen as a second part of a trilogy following our first blog on processing data for pre-training, the so-called “FineWeb blog post”. Having read both blog posts, you should have almost all the core knowledge needed to deeply understand how LLMs are being built nowadays, just missing a bit the final spices like data mixing or architecture choices to complete the recipe (stay tuned…).

Pre-training LLMs from scratch now requires amounts of compute which exceed in almost every case the use of a single GPU or machine. The clusters used to train these models range from hundreds to thousands of nodes each usually equipped with 4 to 8 GPUs. To make the best use of such an expensive hardware as well as to train in a reasonable time, a range of distributed training methods have been developed with the goal of ensuring that GPUs are highly utilized at all times. Efficiently scaling LLM training is also not confined to pretraining anymore, as fine-tuning larger models on more domain specific data is becoming the standard practice to achieve the best results.

In this post we’ll cover these scaling methods exhaustively while keeping a single story-line to understand where each technique comes from. We’ll cover data, tensor, pipeline and context parallelism as well as ZeRO and kernel fusion. The post is built on the following three foundations:

Quick intros on theory and concepts: before diving into code and experiments, we want to understand how each method works at a high level and what it’s advantages and limits are. You’ll learn about which parts of a language model eat away your memory and when during training it happens. You’ll learn how we can solve memory constraints by parallelizing the models and increase the throughput by scaling up GPUs. As a result you'll understand how the following widget to compute the memory breakdown of a transformer model works:

While this widget gives a theoretical breakdown the following tool can be used to predict the memory usage:

image.png

Clear code implementations: theory is one thing, but we discover all kinds of edge cases and important details when we implement something. That’s why we link to implementation references where possible. Depending on the case, we’ll use two code references: the picotron repository is built for education, thus it implements concepts usually in single, self-contained short files. On the other hand, to look at production ready code, we’ll refer to the nanotron implementations which is a production training codebase used at Hugging Face.

Picotron implements each key concept in a self-contained way, such that the method can be studied separately and in isolation.

Real training efficiency benchmarks: Finally, how to actually scale your LLM training depends on your infrastructure, such as the kind of chips, interconnect etc., and we can’t give a single unified recipe. What we will give though is a way to benchmark several setups and it is what we have done on our cluster! We ran over 4100 distributed experiments with up to 512 GPUs to scan many possible distributed training layouts and model sizes. TODO: link to dataset too

An overview of the over 4000 experiments across all Llama architectures where each data point corresponds to an experiment launch.

As you can see, there’s a lot of ground to be covered. Before getting into the trenches of distributed training let’s take a quick high level look on we’ll cover in the post.

TL;DR

This book is very extensive so we decide to start with a very general overview of how you can think about distributed training. At a high level, the key challenge in scaling LLM training is to make a training step (forward/backward/optimizer step) with a large batch size the fastest possible.

When scaling up models and input batches, we quickly end up in situations where either our target batch size won't fit in memory, or/and the model itself is too large to fit in a single GPU's memory.

To solve this scaling issue we’ll need to carefully evaluate different parallelization strategies and find the optimal balance between three main factors:

  1. Memory Usage
    • Hard limitation - if a training step doesn't fit in memory, training cannot proceed
    • Sometimes we can increase compute (e.g. recomputation) or increase communication (e.g. ZeRO) to reduce memory usage
  2. Compute Efficiency
    • Memory transfer can also decrease compute efficiency.
    • We want our hardware to spend most time computing, so we need to reduce time spent on data transfers or unoptimized kernels.
    • GPUs need sufficient workload (large enough matrices/batch sizes) to maintain high utilization (compute-bound) otherwise they become memory-bound (limited by memory bandwidth).
  3. Communication overhead
    • Two main types. For GPUs: intra-node (NVLink TODO: bandwidth) and inter-node (network TODO: bandwidth)
    • Two main attributes: base latency and peak bandwidth. Base latency is a constant overhead that makes us want to do the least number of comms possible, and peak bandwidth controls the how fast we can move data between gpus
    • We prioritize using the fastest communication channels (like NVLink) for operations that occur frequently and/or block computation (e.g. tensor parallelism)
    • We want to minimize communication overhead as it keeps GPUs idle, so we try to overlap communication with compute as much as possible

But let’s not get too much ahead of our self and scale progressively. To guide you along the journey and as a practical reference we summarized the key concepts in a cheatsheet:

[TODO: ADD CHEATSHEET]

Now that we nailed a few key concept and terms let’s get started by revisiting the basic training steps of an LLM!

First Steps: Training on one GPU

Let’s start by quickly reviewing the very basics of model training before we start to scale to many GPUs. When a model is trained on a single GPU, the training typically consists of three steps:

  1. a forward pass which passes inputs through the model to yield its outputs,
  2. a backward pass to compute the gradients, and
  3. an optimization step using the gradients to update the parameters

It looks generally like this:

image.png

In this figure, the boxes on the top line can be seen as successive layers inside a model (same for the last line). The red boxes are the associated gradients for each of these layers, computed during the backward pass.

The batch size (bs) is one of the important hyper-parameters for model training and affects both model convergence and throughput.

If the batch size is too small, gradients will tend to be noisy and the model may not be able to converge to the most optimal performance, on the contrary it can be useful in early training to navigate quickly in the training landscape. On the other hand, a batch size too large will make less use of each training token rendering convergence slower and wasting compute. You can find a nice discussion of this topic in OpenAI’s paper on large batch training or Section 4.2 of MiniMax-01 technical report.

Batch size also affects the time it takes to train on a given text dataset: a small batch size will require more optimizer steps to train on the same amount of samples. Optimizer steps are costly (in compute time) and the total time to train will thus increase compared to a larger batch size. This being said, note that the batch size can often be adjusted quite largely around the optimal batch size without major impact to the performance of the model, i.e. the sensitivity of final model performances to the exact batch size value is usually rather low around the optimal batch size.

In the LLM pretraining community, batch sizes are commonly reported in terms of tokens rather than in number of samples (bst = Batch Size Tokens), this makes training numbers generally independent of the exact input sequence length used during the training.

In the simplest case, training on a single machine, the bs (in samples) and bst can be computed from the model input sequence length (seq) as follows :

bst=bs *seq

A sweet spot for recent LLM training is typically on the order of 4-60 million tokens per batch. However, a typical issue when scaling the training of our model to these large batch sizes is out-of-memory issues, ie. our GPU doesn’t have enough memory.

It’s time to tackle our first scaling problem: what if our model starts exploding GPU memory before we’ve reached our target batch size (maybe in some case even when using the lowest possible batch size, bs=1)?

Let’s start by quickly understanding what led to our out-of-memory issue in the first place. This will help us gain some useful intuitions for later.

Memory usage in Transformers

When training a neural network model, one store several items in memory:

These items are stored as tensors which come in different shapes and precisions. The shapes are determined by hyper-parameters such as batch size, sequence length, model hidden dimensions, attention heads, vocabulary size, and potential model sharding as we’ll see later. Precision refers to formats like FP32, BF16, or FP8, which respectively require 4, 2, or 1 byte to store each single value in the tensor.

So how can I quickly determine memory usage from these variable? One simple way is to do this empirically and just measure it.

Memory profiling a training step

Using this snippet [TODO: link to appendix A5], we can understand how memory is allocated throughout training. We can see that memory utilization is not a static thing but varies a lot during training and during a training step:

llama-1b-memory.png

Clearly the first step looks very different from the subsequent ones, but let’s first have a look at the general anatomy of a step: first the activations increase quickly as we do the forward pass, then during the backward pass the gradients build up and as the backward pass propagates, the stored activations used to compute the gradients are progressively cleared. Finally, we perform the optimization step during which we need all the gradients and then update the optimizer states before we start the next forward pass.

Why does the first step looks different: the activations increase quickly and then plateau for a while. In this first step the torch cache allocator does a lot of preparation preparing memory allocations to speed up the subsequent steps so that they don’t require searching for free memory blocks afterwards (see Zach’s blog). After the first step we also see the optimizer states appearing which generally offset the memory usage for further training steps.

Now that we’ve a first view of memory, let’s see how scaling up training is often a question of maximizing compute efficiency while keeping the memory requirements of these various items (activations, parameters, gradients, optimizer states) within the memory constraints of the GPUs.

Weights/grads/optimizer states memory

We can actually pretty easily estimate the memory needed for the model’s weights, gradients and optimizer states.

For a simple transformer LLM the number of parameters is given by the following formula:

N = h * v + L * (12 * h^2 + 13 * h) + 2*h

In that equation, h is the hidden dimension, v the vocabulary size, and L the number of layers in the model. Note that looking at the equation we can see that the term that will dominate at large hidden dimensions is the h^2 term since it’s the only one growing quadratically as we scale the parameters.

Memory requirements for the parameters and gradients are simply the number of parameters multiplied by the number of bytes per parameter. In good old-fashioned full precision (FP32) training both parameters and gradients require 4 bytes while the optimizer, if we use Adam, requires the momentum and variance to be stored, which adds another two 4 bytes per parameter. In summary:

\begin{aligned} & m_{params} = 4 * N \\ & m_{grad} = 4 * N \\ & m_{opt} = (4+4) * N \end{aligned}

Now let’s have look how things change if we train with mixed precision. The default nowadays is for mixed precision training is BF16, requires 2 bytes per parameter and gradient as well as an additional copy of the model weights and gradients in FP32, thus 12 bytes per parameter in total. In addition to the parameters and gradient, we need to store the optimizer states: for the Adam optimizer, this requires the momentum and the variance usually stored in FP32 for numerical stability, each using 4 bytes.

Here’s the summary:

\begin{aligned} & m_{params} = 2 * N \\ & m_{grad} = 2 * N \\ & m_{params_fp32} = 4 * N \\ & m_{opt} = (4+4) * N \end{aligned}

Interestingly, mixed precision itself doesn’t save overall memory as it just distributes the memory differently across the three components, and in fact adds another 4 bytes over full precision training if we accumulate gradients in FP32. It’s still advantageous as having the model which does the forward/backward in half precision it allows us to (1) use optimized lower precision operations on the GPU which are faster and (2) reduces the activation memory requirements during the forward pass.

Let’s get a sense of how much general memory we need for a model (full and mixed precision giving the same overall value):

Model parameters FP32 or BF16 w/o FP32 grad acc BF16 w/ FP32 grad acc
1B 16 GB 20 GB
7B 112 GB 140 GB
70B 1120 GB 1400 GB
405B 6480 GB 8100 GB

As we can see, as soon as we reach 7B (!), weights and optimizer requirements already starts to add up significantly and exceed the size of a typical GPU memory, e.g. 80GB for a H100 GPU.

But for now, let’s start with models which still fits in a single GPU, take a look at the other big contributor to our memory budget: the activation memory.

Activations memory

Activation memory is a bit more complex to compute than the weights, gradients and optimizer states, in part because it depends on the inputs of the model. If you’re unsure why we even need to store activations for the backward pass, this reference is a good quick refresh. After a careful inspection of how backward pass is computed we can estimate the total memory required for the activations in mixed precision and we arrive at the following equation:

m_{act} = L seq * bs * h * (34 + \frac{5n_{heads}*seq}{h})

Here L is the number of layers, seq the sequence length, bs the batch size in samples, h the hidden dimension of the model and n_{heads} the number of heads.

For the exact derivation of the numbers, you can follow this original NVIDIA paper on recomputation , it essentially requires you to do some accounting of all the sizes of intermediate activations between each operation in a transformer layer.

An interesting observation here is how the memory is not static for a given model but it scales linearly with both the sequence length and batch size. This means the activation memory is the part which will blow up when we increase our batch size or train with longer sequences. We can use this equation to look at how memory usage changes for various sequence lengths for example for Llama models (bs=1):

llama-memory-bars-no-recomp.png

This graph tells a striking story: for short sequences (or similar for small batch-sizes), activations are almost negligible, but starting at around 2-4k tokens they come to take a significant amount of memory while parameter, gradient and optimizer states usage (that we’ll discuss later) stays roughly independent of the sequence length and batch size.

For large input tokens (a.k.a large batch-sizes/sequences), activations become by far the largest memory burden.

Is there a way to tame this “activation explosion”? Good question, reader!

It’s time to explain our first technique – called activation recomputation **which will help us cap activation memory footprint. An essential tool in today’s large model training toolbox.

Activation recomputation

The general idea behind activation recomputation – also called gradient checkpointing or rematerialization – is to discard some activations during the forward pass to save memory and spend some extra compute to recompute these on the fly during the backward pass. Without recomputation, we store every hidden state between two learnable operations (e.g. FF, LayerNorm etc.), such that we can use them during the backward pass to compute gradients. When we use recomputation we typically will only store activations at a few key points along the model architecture, discard the rest of activations and recompute them on the fly during the backward pass from the nearest saved activations, basically performing again a sub-part of the forward pass to trade of memory for compute. It generally looks like this:

image.png

There are several strategies to select key activations to store:

Let’s see how drastically recomputation strategies can in practice reduce the memory footprint and how selective recomputation strikes a nice balance between memory saving and recomputation cost:

llama-8b-memory-bars--recomp.png

Most training frameworks these days use FlashAttention (which we’ll cover a bit later) which integrate natively activation recomputation in its optimization strategy by recomputing attention scores and matrices in the backward pass instead of storing them. Thus most people using Flash Attention are already making use of selective recomputation.

As you’ve now understood, activation recomputation increases the number of FLOPs slightly due to recomputation, while it significantly reduces memory access overhead.

This trade-off is particularly advantageous on hardware with small high-speed memory, like GPUs, as accessing memory is typically slower than performing computations. Despite the additional operations involves, the overall effect is thus often faster computation as well, in addition to the much lower memory footprint.

Now that we’ve learned about recomputation, we can tame the activations memory usage as we saw in the above graphs!

However, activations still bears a linear dependance on the batch size and all our profiles in the barplots above were using bs=1 so as we move to larger batch sizes it might become an issue again. Do not despair as we have a second tool in our box - gradient accumulation to the rescue!

Gradient accumulation

Now that we’ve used activation recomputation to fit our model with a small batch size on a single GPU, we still need to reach our target batch size, let’s say 1M tokens (see our earlier discussion on optimal batch size). Gradient accumulation is a very straightforward method to avoid memory explosion when doing this.

With gradient accumulation we split our batch into micro-batches, do forward and backward passes repeatedly on each micro-batch, compute the gradients, and, as the name suggests, sum the gradients for each micro-batch before doing a final optimizer step. In practice, we perform the optimization step not on the sum but on the average of the gradients, so the result is independent of the number of gradient accumulation steps.

Let’s call the batch size for each forward pass the micro batch size (mbs). We’ll refer to the overall batch size between each optimizer step as the global batch size (gbs). If we do one optimizer step for each 8 forward/backward passes, the global batch size will be 8 times the micro batch size.

What we now call global batch size thus corresponds to what we’ve called up to now just batch size for simplicity (we now make our terms more precise to avoid ambiguity).

With gradient accumulation the global batch size can be simply computed as follows:

bs = gbs = mbs \times grad\_acc

Gradient accumulation allows us to effectively increase our batch size up to infinity (and beyond!) while the memory footprint stays constant. Gradient accumulation is also compatible with activation recomputation for further memory reduction. One drawback however, is that gradient accumulation requires multiple consecutive forward/backward passes per optimization step thereby increasing the compute overhead and slowing down training. No free lunch!

image.png

Gradient accumulation allows us to reduce memory of activations which grow linearly with batch size by computing only only partial, micro-batches. There is a small overhead caused by the additional forward and backward passes.

But if you’ve carefully followed, you probably noticed that the forward/backward passes for each micro-batch can actually be run in parallel. Forward/backward passes are independent from each other, with independent input samples being the only difference. Seems like it’s time to start extending our training to more than one GPU!

Let’s get a larger workstation 🖥️ with a couple of GPUs and start investigating our first scaling technique called data parallelism which is just a parallel version of gradient accumulation.

TODO: add profiling here or not?

Data Parallelism

First optimization: Overlap gradient synchronization with backward pass

Second optimization: Bucketing gradients

Third optimization: Interplay with gradient accumulation

Revisit global batch size

Our journey up to now

ZeRO (Zero Redundancy Optimizer)

Memory usage revisited

ZeRO-1: Partitioning Optimizer States

ZeRO-2: Adding Gradient Partitioning

ZeRO-3: Adding Parameter Partitioning

Tensor Parallelism

Tensor Parallelism in a Transformer Block

Sequence Parallelism

Context Parallelism

Introducing Context Parallelism

Discovering Ring Attention

Zig-Zag Ring Attention – A Balanced Compute Implementation

Pipeline Parallelism

Splitting layers on various nodes - All forward, all backward

One-forward-one-backward and LLama 3.1 schemes

Interleaving stages

Zero Bubble and DualPipe

Expert parallelism

5D parallelism in a nutshell

How to Find the Best Training Configuration

Diving in the GPUs – fusing, threading, mixing

A primer on GPU

How to improve performance with Kernels ?

Memory Coalescing

Tiling

Thread Coarsening

Minimizing Control Divergence

Flash Attention 1-3

Fused Kernels

Mixed Precision Training

FP16 and BF16 training

FP8 pretraining

Conclusion

What you learned

What we learned

What’s next?

References

Landmark LLM Scaling Papers

Training Frameworks

Debugging

Distribution Techniques

CUDA Kernels

Hardware

Others

Appendix

Citation

For attribution in academic contexts, please cite this work as

XXX, et al., "The Ultra-Scale Playbook: Training LLMs on GPU Clusterse", 2025.

BibTeX citation

@misc{TODO,
      title={The Ultra-Scale Playbook: Training LLMs on GPU Clusters},
      author={TODO},
      year={2025},
}