Introduction
FlashAttention is one of the most important systems papers in modern deep learning. It doesn't change the math of attention — the outputs are identical to standard attention. Instead, it makes attention fast by recognizing that the bottleneck isn't compute, it's memory bandwidth.
Understanding FlashAttention explains why large context windows are now economically viable and why it's a core component of virtually every modern LLM training and serving stack.
The Bottleneck: It's Not FLOPs, It's Memory Bandwidth
Modern GPUs have enormous compute capacity (hundreds of TFLOPS) but the memory hierarchy creates a fundamental bottleneck:
GPU Memory Hierarchy:
SRAM (on-chip registers/shared memory): ~20 MB, ~19 TB/s bandwidth
HBM (high-bandwidth memory, "GPU RAM"): ~80 GB, ~2 TB/s bandwidth
DRAM (host RAM): ~500 GB, ~50 GB/s bandwidth
SRAM is 10x faster than HBM, but tiny. HBM is 100x slower than SRAM but large. The key insight of FlashAttention: standard attention spends most of its time reading and writing the attention matrix to HBM, not doing compute.
Standard Attention: The IO Problem
Standard self-attention computation:
def standard_attention(Q, K, V):
S = Q @ K.T # Step 1: write S to HBM (N×N matrix)
S = S / sqrt(d_k) # Step 2: read/write S
P = softmax(S) # Step 3: read S, write P to HBM
O = P @ V # Step 4: read P, write O to HBM
return O
For sequence length N=4096, d=128:
S matrix size: 4096 × 4096 × 2 bytes = 32 MB
Written to HBM: 2× (S + P) + O = ~100 MB per layer per forward pass
HBM reads required: proportional to N² × d
At N=32K tokens, this becomes 8GB per layer just for the attention matrix — clearly impossible in standard HBM for large models.
FlashAttention: Tiling + Recomputation
FlashAttention's solution: never materialize the full attention matrix in HBM. Instead, process attention in tiles that fit in SRAM.
The Tiling Algorithm
Divide Q into blocks: Q_1, Q_2, ..., Q_{N/Br}
Divide K, V into blocks: K_1, V_1, ..., K_{N/Bc}, V_{N/Bc}
For each query block Q_i:
Initialize O_i = 0, l_i = 0, m_i = -∞
For each key/value block K_j, V_j:
1. Load Q_i, K_j into SRAM ← small tiles
2. Compute S_ij = Q_i @ K_j.T / sqrt(d_k) ← stays in SRAM
3. Compute running softmax update:
m_ij = max(m_i, rowmax(S_ij))
P_ij = exp(S_ij - m_ij) ← stays in SRAM
l_ij = exp(m_i - m_ij) · l_i + rowsum(P_ij)
4. Update output:
O_i = (l_i · exp(m_i - m_ij))⁻¹ · O_i + P_ij @ V_j
5. Update running stats: m_i = m_ij, l_i = l_ij
Write final O_i to HBM once
The key trick: the online softmax update allows computing the correctly normalized softmax without ever seeing the full row at once. The math works out because softmax can be decomposed into running maximum and sum statistics.
Memory Complexity
Standard attention HBM writes: O(N² + Nd)
FlashAttention HBM writes: O(N × d) ← N² term eliminated!
At N=32K, d=128:
Standard: 1G + 4M ≈ 1 GB per layer
FlashAttention: 4M per layer ← 250x less HBM traffic
Backward Pass: Recomputation
The backward pass of standard attention requires storing the N×N attention matrix P to compute gradients — a memory disaster for long sequences.
FlashAttention doesn't store P. Instead, during the backward pass, it recomputes the tiled attention using only the stored output O and softmax statistics (m, l). This trades compute for memory:
Extra FLOPs from recomputation: ~1.1-1.3x
Memory savings: 5-20x (no N×N matrix stored)
The extra compute is almost free because the compute is bottlenecked by memory bandwidth, not arithmetic units.
FlashAttention-2: Multi-Head Parallelism
FlashAttention-2 (2023) improved GPU utilization by:
- Minimizing non-matmul operations: Earlier version had overhead from softmax update steps; FA2 restructures to keep GPUs computing matmuls.
- Parallelizing across sequence dimension: FA1 parallelized across batch and heads; FA2 also parallelizes across the query sequence, enabling better GPU occupancy for long sequences.
- Better work partitioning: More careful assignment of work to thread blocks to avoid synchronization overhead.
Result: ~2x speedup over FlashAttention-1, ~5-9x over standard PyTorch attention.
FlashAttention-3: H100 Specialization
FlashAttention-3 (2024) targets NVIDIA H100's specific hardware features:
- WGMMA (Warp Group Matrix Multiply Accumulate): New H100 instruction that computes larger matrix multiplications with better utilization
- TMA (Tensor Memory Accelerator): Hardware unit for asynchronous memory loads — FA3 overlaps compute with memory loading
- FP8 support: H100's FP8 tensor cores give 2x throughput over FP16; FA3 supports FP8 attention
Combined, FA3 achieves 75% of H100's theoretical FP16 FLOP throughput — close to the hardware ceiling.
Impact on Modern LLM Training and Serving
FlashAttention is now standard in every major LLM training stack:
# HuggingFace Transformers — just set attn_implementation
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16
)
# vLLM uses FlashAttention automatically for supported GPUs
Training speedup: 2-4x wall clock time reduction for sequences > 2K tokens Context window enablement: 128K-1M context windows are only practical because FA eliminates the O(N²) memory requirement Training cost reduction: GPT-4 class training runs would cost ~3x more without FA
Alternatives and the Broader Landscape
FlashAttention solves the exact-attention memory problem. For cases where approximate attention is acceptable:
- Ring Attention: Distributes attention across multiple GPUs for multi-million token contexts
- Linear Attention (Retention, GLA): Replaces softmax with kernel approximations for O(N) complexity
- Sliding Window Attention (Mistral, Longformer): Only attend to local windows; linear complexity but limited global context
For production serving today, FlashAttention (FA2/FA3) + grouped-query attention (GQA) + paged KV cache is the standard stack for efficient transformer inference.
Conclusion
FlashAttention's insight — that the GPU memory hierarchy, not FLOP count, determines attention speed — led to a 2-4x training speedup and enabled context windows that would have been impossible otherwise. It's a masterclass in hardware-aware algorithm design: the math doesn't change, but understanding where the bottleneck actually is unlocks an order-of-magnitude improvement. Any serious LLM practitioner should understand why it works.
FlashAttention enables long contexts. Learn about the LLM serving challenges those contexts create in our guide on KV Cache Optimization.