tutorial 2025-04-07 14 min read

Distributed Training Explained: From One GPU to Many

Understand data parallelism, model parallelism, and gradient accumulation. Learn how PyTorch DDP and FSDP work, and when to use which distributed training strategy.

distributed training PyTorch DDP FSDP data parallelism model parallelism

Why Distributed Training?

Two reasons to distribute training across multiple GPUs:

  1. Speed: A model that takes 2 weeks on one GPU takes days on 8
  2. Memory: A 70B model requires ~140GB in bf16 — doesn't fit on a single 80GB A100

These require different strategies. Speed → data parallelism. Memory → model or tensor parallelism.

Data Parallelism: The Simplest Case

Each GPU trains on a different batch. Gradients are averaged across GPUs and weights are updated identically on all GPUs.

GPU 0: batch_0 → forward → loss_0 → gradients_0 ─┐
GPU 1: batch_1 → forward → loss_1 → gradients_1 ─┼──> average gradients → update all weights
GPU 2: batch_2 → forward → loss_2 → gradients_2 ─┘

This works because all GPUs hold a full copy of the model. Effective batch size = single-GPU batch × number of GPUs.

PyTorch DistributedDataParallel (DDP)

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

def setup(rank, world_size):
    dist.init_process_group(
        backend="nccl",  # NCCL for GPU-GPU communication
        rank=rank,
        world_size=world_size,
    )
    torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size, args):
    setup(rank, world_size)

    # Load model and move to this GPU
    model = MyModel().to(rank)

    # Wrap with DDP — handles gradient synchronization
    model = DDP(model, device_ids=[rank])

    # Distributed sampler: each GPU sees different data
    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True,
    )
    loader = DataLoader(dataset, batch_size=32, sampler=sampler)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(args.epochs):
        sampler.set_epoch(epoch)  # Reshuffle each epoch

        for batch in loader:
            batch = {k: v.to(rank) for k, v in batch.items()}

            optimizer.zero_grad()
            loss = model(batch)

            loss.backward()
            # DDP automatically averages gradients across GPUs here

            optimizer.step()

        # Only log/save from rank 0
        if rank == 0:
            print(f"Epoch {epoch}: loss={loss.item():.4f}")
            torch.save(model.module.state_dict(), "checkpoint.pt")  # .module unwraps DDP

    cleanup()

# Launch: torchrun --nproc_per_node=4 train.py

Launching with torchrun

# Single node, 4 GPUs
torchrun --nproc_per_node=4 train.py --epochs 10

# Multi-node: 2 nodes, 4 GPUs each (8 GPUs total)
# On node 0:
torchrun --nnodes=2 --nproc_per_node=4 --node_rank=0 --master_addr=192.168.1.1 train.py
# On node 1:
torchrun --nnodes=2 --nproc_per_node=4 --node_rank=1 --master_addr=192.168.1.1 train.py

Gradient Accumulation: Simulating Larger Batches

When you can't fit a large batch on one GPU:

accumulation_steps = 4  # Simulate 4x larger batch size
optimizer.zero_grad()

for i, batch in enumerate(loader):
    loss = model(batch) / accumulation_steps  # Scale loss

    loss.backward()  # Accumulate gradients

    if (i + 1) % accumulation_steps == 0:
        optimizer.step()   # Update weights every 4 steps
        optimizer.zero_grad()

This is equivalent to a batch size of batch_size × accumulation_steps without requiring the memory.

FSDP: When the Model Doesn't Fit on One GPU

DDP requires each GPU to hold a full model copy. For large models (7B+), this is often impossible.

Fully Sharded Data Parallel (FSDP) shards the model across GPUs:

DDP:   GPU 0: full model + batch_0
       GPU 1: full model + batch_1

FSDP:  GPU 0: shard_0 of model + batch_0
       GPU 1: shard_1 of model + batch_1
       (parameters are all-gathered when needed, discarded after)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

# Auto-wrap policy: shard each transformer block separately
auto_wrap_policy = partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={LlamaDecoderLayer},
)

model = FSDP(
    model,
    auto_wrap_policy=auto_wrap_policy,
    mixed_precision=MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.bfloat16,
        buffer_dtype=torch.bfloat16,
    ),
    sharding_strategy=ShardingStrategy.FULL_SHARD,  # Maximum memory savings
    device_id=rank,
)

FSDP vs DDP: When to Use Which

Strategy When to use
Single GPU Model fits on one GPU, speed is acceptable
DDP Model fits on one GPU, need more speed
FSDP Model doesn't fit on one GPU
Tensor Parallelism Model so large that individual layers don't fit

For models up to ~7B with 4× 80GB A100s, DDP works. For larger models, use FSDP.

Model Parallelism (Pipeline Parallelism)

Split the model's layers across GPUs. GPU 0 runs layers 1-12, GPU 1 runs layers 13-24:

# Naive model parallelism (pipeline bubbles are a problem)
class PipelineModel(nn.Module):
    def __init__(self):
        super().__init__()
        # First half on GPU 0
        self.embedding = nn.Embedding(vocab_size, d_model).to(0)
        self.layers_0_11 = nn.ModuleList([
            TransformerBlock(d_model) for _ in range(12)
        ]).to(0)

        # Second half on GPU 1
        self.layers_12_23 = nn.ModuleList([
            TransformerBlock(d_model) for _ in range(12)
        ]).to(1)
        self.lm_head = nn.Linear(d_model, vocab_size).to(1)

    def forward(self, x):
        x = self.embedding(x)  # GPU 0

        for layer in self.layers_0_11:
            x = layer(x)  # GPU 0

        x = x.to(1)  # Transfer to GPU 1

        for layer in self.layers_12_23:
            x = layer(x)  # GPU 1

        return self.lm_head(x)  # GPU 1

Naive pipeline parallelism suffers from GPU idle time (pipeline bubbles). Production systems use microbatching to overlap computation across stages.

Practical Recommendations

For fine-tuning a 7B model:

# Option 1: QLoRA on single GPU (8B VRAM needed)
model = load_in_4bit(model_name)
model = apply_lora(model, r=16)
# Train normally

# Option 2: DDP with 4 GPUs (24GB each)
model = load_model(model_name, dtype=torch.bfloat16)
model = DDP(model, device_ids=[rank])
# 4x speedup

# Option 3: FSDP with 2 GPUs (40GB each)
model = load_model(model_name, dtype=torch.bfloat16)
model = FSDP(model, sharding_strategy=ShardingStrategy.FULL_SHARD)
# Fits on 2× A100 40GB

Debugging Distributed Training

# Check gradient synchronization is working
if rank == 0:
    for name, param in model.named_parameters():
        if param.grad is not None:
            dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
            print(f"{name}: grad_norm={param.grad.norm():.4f}")

# Print from specific rank only
def print_rank(msg, rank=0, current_rank=None):
    if current_rank == rank:
        print(f"[Rank {current_rank}] {msg}")

# Check NCCL is working
dist.all_reduce(torch.tensor(1.0).to(rank))  # Should not hang

For serving distributed models in production, see our LLM inference and optimization guide.

Want to Go Deeper?

This article is part of our comprehensive curriculum on building ML systems at scale. Explore our full courses for hands-on learning.