Steven's Knowledge

Mixed-Precision Training

FP16, BF16, FP8 training — loss scaling, stability, memory math, and getting every FLOP out of your hardware

Training in full FP32 is leaving half your hardware on the table. Mixed-precision training runs most operations in lower precision (FP16, BF16, or FP8) while keeping a master copy of weights in FP32 for numerical stability. The result: roughly 2x memory savings, 2x faster compute on modern GPUs, and — when done right — identical training quality.

The Precision Landscape

  • FP32 (32-bit) — full precision. Safe, slow, memory-hungry. Still used for the master weight copy and certain reductions.
  • FP16 (16-bit float) — 5 exponent bits, 10 mantissa bits. Small dynamic range means you need loss scaling to avoid underflow. The original mixed-precision format.
  • BF16 (Brain Float 16) — 8 exponent bits, 7 mantissa bits. Same dynamic range as FP32 but less precision. No loss scaling needed. The default for modern training on Ampere+ GPUs.
  • FP8 (8-bit float) — two variants: E4M3 (more precision, less range) and E5M2 (more range, less precision). Available on Hopper (H100) and Blackwell. Cuts memory and compute roughly in half again versus FP16/BF16.

The practical default today is BF16. If you have Hopper or newer hardware, FP8 is worth exploring for pre-training at scale.

How Mixed Precision Works

The standard recipe (AMP — Automatic Mixed Precision):

  1. Master weights in FP32. The optimizer holds the full-precision copy.
  2. Forward pass in lower precision (BF16 or FP16). Activations and most computation happen in reduced precision.
  3. Loss computation — can be in FP32 for stability or in reduced precision.
  4. Backward pass in lower precision. Gradients are computed in BF16/FP16.
  5. Gradient update in FP32. Gradients are cast back to FP32 and applied to the master weights.

The critical insight: individual operations tolerate low precision fine, but accumulating small gradient updates into large weights requires FP32 to avoid drift. That's why you keep master weights at full precision.

Loss Scaling (FP16 Only)

FP16 has a narrow dynamic range. Small gradients — common in later training stages — underflow to zero, stalling training. Loss scaling fixes this:

  1. Multiply the loss by a large constant (the "scale factor") before the backward pass.
  2. All gradients are scaled up proportionally, pushing them out of the underflow zone.
  3. Before the optimizer step, divide gradients by the same scale factor.

Dynamic loss scaling starts with a large scale and halves it whenever an inf/NaN is detected in gradients (a sign of overflow). It doubles the scale every N steps when gradients are clean. This is fully automatic in PyTorch AMP.

BF16 doesn't need loss scaling because its 8 exponent bits give it the same range as FP32. This is the single biggest reason BF16 has replaced FP16 as the default training precision.

BF16 vs FP16 — Which to Choose

FP16BF16
Dynamic rangeNarrow (needs loss scaling)Wide (same as FP32)
PrecisionHigher (10 mantissa bits)Lower (7 mantissa bits)
Loss scalingRequiredNot needed
HardwareVolta+ (V100, A100, H100)Ampere+ (A100, H100)
StabilityTricky at large scaleStable by default

Pick BF16 unless you're on pre-Ampere hardware. The reduced precision of BF16 rarely matters in practice — modern optimizers and training recipes are designed around it.

FP8 Training

FP8 is the new frontier. NVIDIA's Transformer Engine handles FP8 training on Hopper and Blackwell GPUs:

  • Forward pass uses E4M3 (more precision for activations and weights).
  • Backward pass uses E5M2 (more range for gradients).
  • Per-tensor scaling — each tensor gets its own scale factor, computed from the tensor's statistics. This is more complex than FP16 loss scaling but handled automatically by Transformer Engine.

Memory savings math:

  • FP32 model: 4 bytes/param → FP16/BF16: 2 bytes/param → FP8: 1 byte/param
  • A 70B model's weights alone: 280GB (FP32) → 140GB (BF16) → 70GB (FP8)
  • Optimizer states (Adam, FP32): add 8 bytes/param regardless. This is often the real memory bottleneck.

FP8 training is production-ready for standard transformer architectures with Transformer Engine. For custom architectures, test carefully — not all operations are FP8-safe.

AMP in PyTorch

PyTorch's torch.amp makes mixed precision straightforward:

The key components:

  • autocast context manager — automatically casts operations to the chosen precision. Knows which ops are safe in low precision and which should stay in FP32 (e.g., softmax, layer norm, loss functions).
  • GradScaler — handles dynamic loss scaling for FP16. Not needed for BF16.

What autocast keeps in FP32:

  • Softmax and log-softmax
  • Loss functions (cross-entropy, etc.)
  • Layer normalization
  • Small reductions that are numerically sensitive

What runs in lower precision:

  • Linear layers (matmuls) — these get the biggest speedup
  • Convolutions
  • Attention score computation
  • Most element-wise operations

Training Stability with Lower Precision

Lower precision can destabilize training. Watch for:

  • Loss spikes — sudden jumps in loss, often caused by overflow or underflow. With BF16, these are rare. With FP16, check your loss scaling.
  • Gradient norm explosion — gradients blow up due to precision loss in accumulation. Fix with gradient clipping (max norm = 1.0 is a solid default).
  • Slow divergence — training runs that look fine for thousands of steps then drift. Usually caused by FP32 master weights not being maintained correctly.
  • Numerical differences in specific layers — embedding layers and the final output projection can be precision-sensitive. Force these to FP32 if you see instability.

Stability checklist:

  1. Master weights in FP32 — always.
  2. Gradient clipping — max norm 1.0.
  3. BF16 over FP16 if hardware supports it.
  4. Monitor loss curves closely for the first 1000 steps.
  5. If using FP8, start with Transformer Engine defaults before tuning.

Memory Savings Math

For a model with P parameters, using Adam optimizer:

ComponentFP32BF16 MixedFP8 Mixed
Weights (master)4P4P4P
Weights (compute copy)2P1P
Gradients4P2P2P
Optimizer (m + v)8P8P8P
Total16P16P15P

Wait — the totals look similar. So where's the savings?

The big wins are:

  1. Activations — stored in lower precision, often 2–4x smaller. This is the dominant memory cost for large batch sizes and long sequences.
  2. Communication buffers — half the size means faster data transfer.
  3. Compute throughput — Tensor Cores run 2x faster on BF16 vs FP32, and 2x faster again on FP8 vs BF16. The primary win is speed, not just memory.

For a 7B model with sequence length 4096 and batch size 4, activations in BF16 can save 30–50GB compared to FP32. That's often the difference between fitting on your hardware and not.

Practical Recommendations

  1. Default to BF16 mixed precision. It's stable, fast, and requires zero tuning.
  2. Use torch.amp.autocast('cuda', dtype=torch.bfloat16) with no GradScaler for BF16.
  3. Keep gradient clipping on. It's cheap insurance against rare instability.
  4. Profile your actual memory breakdown. Don't assume — use torch.cuda.memory_summary() to see where bytes go.
  5. For FP8, use NVIDIA Transformer Engine and start with their recommended recipes before customizing.
  6. Don't mix precision strategies with quantization. QLoRA already handles precision for you. Adding AMP on top of QLoRA is redundant and can cause bugs.

On this page