Distributed Training
FSDP, DeepSpeed, Megatron-LM, and choosing the right strategy for your scale
Once a model or its training batch outgrows a single GPU, you're doing distributed training whether you planned for it or not. The good news: the ecosystem has converged around a few battle-tested frameworks. The bad news: picking the wrong one for your scale wastes weeks.
The Three Frameworks That Matter
- FSDP (Fully Sharded Data Parallel) — PyTorch-native. Shards model parameters, gradients, and optimizer states across GPUs. The default choice if you're already in PyTorch and your model fits the FSDP mental model. Works well up to a few hundred GPUs.
- DeepSpeed — Microsoft's library. Famous for its ZeRO optimizer stages. More knobs, more complexity, more mature for very large runs. Often the go-to for 100B+ parameter training.
- Megatron-LM — NVIDIA's framework for training massive transformers. Combines tensor, pipeline, and data parallelism in a tightly optimized stack. Hardest to set up, fastest when configured correctly. Used for frontier-scale pre-training.
DeepSpeed ZeRO Stages
DeepSpeed's ZeRO optimizer eliminates redundancy progressively:
- ZeRO Stage 1 — partition optimizer states across GPUs. Each GPU only updates its shard. Memory savings: ~4x on optimizer states.
- ZeRO Stage 2 — also partition gradients. Gradients are reduce-scattered instead of all-reduced. Additional memory savings with minimal communication overhead.
- ZeRO Stage 3 — also partition parameters. Each GPU only holds 1/N of the model at rest. Parameters are gathered on demand for forward and backward passes. Maximum memory savings, highest communication cost.
The practical rule: start with Stage 2. Move to Stage 3 only when your model won't fit with Stage 2. Stage 1 is rarely worth the setup if you're already configuring DeepSpeed.
Data Parallelism vs Model Parallelism
These are two fundamentally different answers to the question "how do I use more GPUs?"
Data parallelism — every GPU has a copy of the whole model and processes a different mini-batch. Gradients are synchronized at the end of each step. Simple, efficient, scales well when the model fits on one GPU.
Model parallelism — the model itself is split across GPUs. Required when a single GPU can't hold the model. Comes in multiple flavors:
- Tensor parallelism — split individual layers (e.g., split a matrix multiply across GPUs). Low latency but requires fast interconnects (NVLink).
- Pipeline parallelism — split the model into sequential stages. Each GPU handles a chunk of layers. Introduces "pipeline bubbles" (idle time) but works over slower interconnects.
When to use what:
- Model fits on one GPU → pure data parallelism (DDP or FSDP Stage 2)
- Model doesn't fit on one GPU → FSDP Stage 3 or combine data + tensor parallelism
- Hundreds of GPUs, frontier-scale → combine all three (data + tensor + pipeline), likely with Megatron-LM
FSDP in Practice
FSDP wraps your model and handles sharding automatically. Key decisions:
- Sharding strategy —
FULL_SHARD(like ZeRO-3),SHARD_GRAD_OP(like ZeRO-2), orNO_SHARD(plain DDP). - Auto-wrapping — FSDP needs to know which sub-modules to wrap as individual FSDP units. Transformer blocks are the obvious choice. Get this wrong and you'll either OOM or serialize communication.
- Mixed precision — FSDP has built-in mixed-precision support. Use BF16 compute with FP32 reduce for stability.
- Activation checkpointing — trade compute for memory by recomputing activations in the backward pass instead of storing them. Essential for large models.
Communication Overhead
Distributed training is bottlenecked by communication, not compute. The key operations:
- All-reduce — sum gradients across all GPUs. O(model_size) communication per step. Used in DDP.
- All-gather — reconstruct full parameters from shards. Used in FSDP/ZeRO-3 forward pass.
- Reduce-scatter — reduce and distribute gradient shards. Used in FSDP/ZeRO-2+ backward pass.
The single most impactful optimization: overlap communication with computation. FSDP and DeepSpeed both do this automatically when configured correctly — prefetching the next layer's parameters while computing the current layer.
Choosing Your Strategy
A decision tree for most teams:
- Model fits on one GPU with batch size you want? → DDP. Done.
- Model fits but batch is too small? → DDP with gradient accumulation. Done.
- Model doesn't fit on one GPU? → FSDP with
FULL_SHARD. If you need more control, DeepSpeed ZeRO Stage 3. - Training 70B+ parameters on 100+ GPUs? → Combine FSDP/DeepSpeed with tensor parallelism. Consider Megatron-LM.
- Frontier scale (hundreds of billions, thousands of GPUs)? → Megatron-LM or a custom 3D parallelism setup. You probably already have infra engineers for this.
Don't over-engineer. Most fine-tuning jobs run beautifully on FSDP with 4–8 GPUs. You only need the heavy machinery for pre-training or training very large models.
Common Pitfalls
- Forgetting to scale the learning rate — when you scale batch size across GPUs, you typically need to scale the learning rate too (linear scaling rule, with warmup).
- Uneven data loading — every GPU must see different data. Use
DistributedSampleror equivalent. - Checkpointing with sharded state — saving and loading FSDP/DeepSpeed checkpoints is more complex than single-GPU. Use the framework's native checkpoint utilities.
- Network bottlenecks — tensor parallelism over Ethernet instead of NVLink will destroy your throughput. Profile before committing to a parallelism strategy.
- Not profiling at all — PyTorch Profiler and NVIDIA Nsight Systems are your friends. A 10-minute profiling session can reveal a 2x speedup.