The bottleneck nobody wrote into the math — and the IO-aware fix that quietly broke open the long-context era.
01 · Hardware ≠ Math
The textbook formula looks innocent.
Three matrices — query (Q), key (K), value (V) — and one line of algebra. Roughly 4N²d FLOPs. On an A100 capable of 312 trillion FLOPs per second, even a long sequence should finish in milliseconds.
Attention(Q, K, V) =softmax(Q · Kᵀ / √d) · V
And yet — it doesn't. Training a transformer with long context is famously slow, famously memory-hungry, and famously OOMs the moment N pushes past a few thousand. Something is wrong with this picture.
01 · Diagnose first
Where does the time actually go?
When standard attention runs on a real GPU, what consumes the most wall-clock time?
Correct intuition
The matrix multiplications are fast — GPUs are built for them. The softmax is small. The actual time-sink is the N×N attention matrix being written to main memory and then read back, twice. On an A100 at sequence length 1K, attention spends more wall-clock time on these reads and writes than on the matrix multiplications themselves.
01 · The constraint
FLOP-rich, bandwidth-poor.
Modern GPUs can do 312 TFLOPs/s but only stream 1.5 TB/s from main memory. An algorithm's wall-clock time is governed by whichever side runs out first. Arithmetic intensity — FLOPs per byte moved — decides the regime.
Arithmetic intensity (FLOPs / byte)2.0
Compute time2%
Memory time (HBM)98%
Drag the slider. Standard attention sits around AI ≈ 2 — deep in memory-bound territory.
The result, in two numbers
2–4×
end-to-end training speedup — exact attention, same outputs
10–20×
memory reduction — sequence length 2K to 4K
Two things to hold
(1) Attention is memory-bound, not compute-bound. (2) Flash attention is exact — same numbers, fewer trips to HBM.
Chapter 02
02
The hardware view
A GPU has two memories.
When you say "GPU memory" you mean HBM — the 40 or 80 GB that fills up. But each streaming multiprocessor also has its own SRAM. The numbers are jarring.
02 · The hierarchy
The memory hierarchy nobody mentions.
HBMMain GPU memory
40–80 GB · 1.5–2.0 TB/s
L2 cacheShared on-die
~40 MB · ~5 TB/s
SRAMPer-SM, programmer-controlled
192 KB × 108 SMs · ~19 TB/s
Why SRAM specifically
The L2 also sits on-die and is also fast — but you can't tell it what to keep. SRAM is the lever flash attention pulls because the programmer orchestrates it explicitly.
SRAM, relative to HBM
10×
faster
SRAM, relative to HBM
200,000×
smaller
That's the entire game. If your data lives in SRAM when you compute on it, you fly. If you round-trip through HBM, you crawl.
02 · The size problem
How fast it leaves SRAM behind.
For sequence length N and one attention head, the score matrix S = QKᵀ has N² entries. Try a few values of N.
Pick a sequence length
N = 128N = 1K
Sliders track 2⁷ = 128 through 2¹⁴ = 16K. Single head, fp32. Multiply by heads × layers × batch for the real damage.
One attention matrix (fp32)
4.0 MB
20× larger than SRAM
20× too large for one SM's SRAM. Must live in HBM.
Chapter 03
03
Trace the bytes
What naive attention does to your GPU.
The formula looks like one operation. On the GPU it's four, and each one round-trips through HBM. Let's count.
03 · HBM traffic, byte by byte
Standard attention, traced.
Q, K, V each have shape (N, d). Watch how often the N×N matrix crosses the wire.
Step01
S = Q · KᵀRead Q, K from HBM → write the N×N score matrix S back to HBM.
HBM elements2Nd + N²
Step02
P = softmax(S)Read S from HBM, normalize row-wise, write the N×N probability matrix P back to HBM.
HBM elementsN² + N²
Step03
O = P · VRead P and V from HBM, write the output O back to HBM.
HBM elementsN² + 2Nd
Total HBM traffic →
≈ 3N² + 4Nd
The N×N matrix is read or written three times — and discarded after step 3.
03 · Locate the cost
The quadratic culprit.
At long sequence length, which single quantity dominates HBM traffic?
Yes — and that's the lever
Q, K, V, and O are all (N, d) — they scale linearly in N. S and P are (N, N) — they scale quadratically. For N=4096 and d=64, that's a 64× difference per matrix. The plan writes itself: never materialize S and P in HBM.
Chapter 04
04
Why the obvious fix fails
The softmax that won't tile.
SRAM is megabytes. S is megabytes too. So tile it — compute a chunk at a time. Except: the softmax wants to see the whole row before it can normalize.
04 · Numerical reminder
The whole row is in the denominator.
In practice we always subtract the row maximum first — exp(100) overflows in float16. So the actual formula used everywhere is:
softmax(x)i = exp(xi − m) / ∑j exp(xj − m), m = max(x)
Both the numerator and the denominator depend on every element of the row. Split the row into two halves and compute each independently — and the two halves disagree on what "max" means.
04 · The trap
What goes wrong with tiling?
You compute softmax of the first half using its own max. The second half arrives. What breaks?
That's the trap
Tile 1's max might be 3; tile 2's max might be 7. Now tile 1's exponentials were all computed against the wrong reference point, and its partial sum is wrong too. You can't just sum partial softmaxes — you'd be summing apples and pears.
This is why approximate methods — Linformer, Reformer, Performer, sparse attention — were the only thing on the table for years. None produced consistent wall-clock speedups, because GPU memory wasn't really their target.
Chapter 05
05
The mathematical core
The online softmax trick.
Two scalars per row of the attention matrix. That's the entire price you pay to avoid materializing the N×N matrix in HBM.
05 · Three lines
The recurrence.
Running state (m, ℓ, O). A new tile arrives with values x_new and v_new. Update in three lines:
The trick is exp(m − m′). Whenever the max grows, every term you accumulated before was computed against the old reference. Multiplying by this factor shifts them to the new reference. Algebraically identical to having seen all values at once — but computed one tile at a time.
05 · Live trace
Watching it run.
Row of 4 values, in two tiles. Watch the running state evolve.
m · running max
−∞
ℓ · running denominator
0
O · weighted output
0
Tile 1
x = [1, 3] · v = [1, 2]
Tile 2
x = [2, 5] · v = [3, 4]
Initial state. No tile processed yet. The running statistics m, ℓ, O are at their identity values.
Chapter 06
06
Tiling × online softmax = flash
Assembling the algorithm.
Tile Q, K, V to fit in SRAM. Use the online softmax to fuse the whole pipeline. The N×N matrix never touches HBM.
06 · Flash attention, forward
The full algorithm.
Hover any highlighted line to see what it's doing on the chip.
# Inputs Q, K, V in HBM with shape (N, d). M = SRAM size.Bc = ⌈M / 4d⌉, Br = min(Bc, d) # pick block sizes that fit in SRAMSplit Q into Tr blocks Qi of size (Br, d)Split K, V into Tc blocks Kj, Vj of size (Bc, d)for i = 1 .. Tr: # outer loop over output rowsLoad Qi from HBM to SRAMInit Oi ← 0, ℓi ← 0, mi ← −∞ (in SRAM)for j = 1 .. Tc: # inner loop over key/value blocksLoad Kj, Vj from HBM to SRAMSij = Qi · Kjᵀ(in SRAM)m̃ = max(mi, rowmax(Sij)), P̃ij = exp(Sij − m̃)ℓ̃ = exp(mi − m̃)·ℓi + rowsum(P̃ij)Oi = exp(mi − m̃)·Oi + P̃ij · Vjmi ← m̃, ℓi ← ℓ̃Write Oi ← Oi / ℓi to HBMWrite mi, ℓi to HBM (needed for backward pass)
Hover any highlighted line above to see what's happening on the chip.
06 · Watching the kernel breathe
Tile by tile, on the chip.
The N×N matrix exists conceptually but is never assembled in HBM. Each tile is computed in SRAM, the result is folded into the running state, the tile is discarded.
HBM · main memory
stays cold
stream tile
→
→
discard
SRAM · on-chip, fused
all four ops fuse here
PhaseReady
Tile0 / 16
HBM read0
06 · The accounting
How much HBM did we save?
Standard attention does Θ(N² + Nd) HBM accesses. Flash attention does Θ(N²d² / M) — and the ratio depends on N and the SRAM budget. Move the sliders.
Configure the run
N4K
SRAM (KB)100
Head dimension d = 64 throughout. Dao et al. prove a matching lower bound — no exact attention algorithm can asymptotically beat this for a range of SRAM sizes.
Standard ÷ Flash · HBM accesses
12.5×
Chapter 07
07
Trading FLOPs for memory wins
The recomputation paradox.
Standard attention's backward pass needs the attention matrix P. Flash attention never stored it. So how does it get the gradients?
07 · The strange trick
Where do the gradients come from?
Flash attention saves only (mi, ℓi) per output row — O(N) extra memory, not O(N²). How does the backward pass reconstruct what it needs?
Right — and it's faster anyway
The backward kernel re-tiles Q and K, recomputes each block Sij = Qi·Kjᵀ, applies exp using the saved mi, divides by ℓi. That gives back exactly the same Pij the forward pass used. Extra FLOPs — zero extra HBM bandwidth. On modern GPUs that trade is a massive win.
07 · Synthesize
Rank the four ideas.
Drag in order from most central to the speedup down to least. Then check.
⋮⋮01Tiling Q, K, V into blocks that fit in SRAM
⋮⋮02Kernel fusion — all four ops in one CUDA kernel
⋮⋮03Online softmax (the exp(m − m′) rescaling)
⋮⋮04Recomputing P in the backward pass instead of storing it
Attention kernel on GPT-2
7.6×
faster than PyTorch's stock attention
End-to-end training on GPT-2
3×
wall-clock training speedup
The downstream effect is bigger than the benchmark. Attention's O(N²) memory cost is what kept transformer context windows at 1K–2K tokens for years. Flash attention made the memory cost O(N). The long-context era runs on this algorithm.
07 · Generalize the lesson
The lesson, generalized.
Someone claims a new attention variant: less compute, same HBM bandwidth. Compared to flash attention on an A100 — what should you expect?
That's the lesson
This is why Linformer, Performer, Reformer kept claiming theoretical speedups but failing to beat standard attention in wall-clock time. They reduced FLOPs in a regime where FLOPs weren't the bottleneck. Flash attention attacked the actual constraint — exact, not approximate. Hardware is a part of the algorithm.
You can now defend it
Four facts to walk out with.
1. Attention on modern GPUs is memory-bound, not compute-bound.
2. HBM and SRAM differ by 10× in speed and 200,000× in size — and only SRAM is programmer-controlled.
3. The online softmax recurrence needs only two scalars per row to be exact.
4. Recomputation beats storage when FLOPs are cheaper than bandwidth.
Worth following up on: FlashAttention-2's parallelization fix, FlashAttention-3's H100 asynchrony, and PagedAttention for KV cache.
Dao · Fu · Ermon · Rudra · Ré · NeurIPS 2022 · arXiv 2205.14135