FEL ČVUT · Lecture · Attention Systems

Flash
Attention.

The bottleneck nobody wrote into the math —
and the IO-aware fix that quietly broke open the long-context era.
01 · Hardware ≠ Math

The textbook formula looks innocent.

Three matrices — query (Q), key (K), value (V) — and one line of algebra. Roughly 4N²d FLOPs. On an A100 capable of 312 trillion FLOPs per second, even a long sequence should finish in milliseconds.

Attention(Q, K, V) = softmax(Q · Kᵀ / √d) · V

And yet — it doesn't. Training a transformer with long context is famously slow, famously memory-hungry, and famously OOMs the moment N pushes past a few thousand. Something is wrong with this picture.

01 · Diagnose first

Where does the time actually go?

When standard attention runs on a real GPU, what consumes the most wall-clock time?

01 · The constraint

FLOP-rich, bandwidth-poor.

Modern GPUs can do 312 TFLOPs/s but only stream 1.5 TB/s from main memory. An algorithm's wall-clock time is governed by whichever side runs out first. Arithmetic intensity — FLOPs per byte moved — decides the regime.

Arithmetic intensity  (FLOPs / byte) 2.0
Compute time2%
Memory time  (HBM)98%

Drag the slider. Standard attention sits around AI ≈ 2 — deep in memory-bound territory.

The result, in two numbers
2–4×
end-to-end training speedup
— exact attention, same outputs
10–20×
memory reduction
— sequence length 2K to 4K
Two things to hold
(1) Attention is memory-bound, not compute-bound.  (2) Flash attention is exact — same numbers, fewer trips to HBM.
Chapter 02
02
The hardware view

A GPU has
two memories.

When you say "GPU memory" you mean HBM — the 40 or 80 GB that fills up. But each streaming multiprocessor also has its own SRAM. The numbers are jarring.

02 · The hierarchy

The memory hierarchy nobody mentions.

HBMMain GPU memory
40–80 GB  ·  1.5–2.0 TB/s
L2 cacheShared on-die
~40 MB  ·  ~5 TB/s
SRAMPer-SM, programmer-controlled
192 KB × 108 SMs  ·  ~19 TB/s
Why SRAM specifically
The L2 also sits on-die and is also fast — but you can't tell it what to keep. SRAM is the lever flash attention pulls because the programmer orchestrates it explicitly.
SRAM, relative to HBM
10×
faster
SRAM, relative to HBM
200,000×
smaller

That's the entire game. If your data lives in SRAM when you compute on it, you fly. If you round-trip through HBM, you crawl.

02 · The size problem

How fast it leaves SRAM behind.

For sequence length N and one attention head, the score matrix S = QKᵀ has N² entries. Try a few values of N.

Pick a sequence length
N = 128 N = 1K

Sliders track 2⁷ = 128 through 2¹⁴ = 16K. Single head, fp32. Multiply by heads × layers × batch for the real damage.

One attention matrix (fp32)
4.0 MB
20× larger than SRAM
20× too large for one SM's SRAM. Must live in HBM.
Chapter 03
03
Trace the bytes

What naive attention
does to your GPU.

The formula looks like one operation. On the GPU it's four, and each one round-trips through HBM. Let's count.

03 · HBM traffic, byte by byte

Standard attention, traced.

Q, K, V each have shape (N, d). Watch how often the N×N matrix crosses the wire.

Step01
S = Q · KᵀRead Q, K from HBM → write the N×N score matrix S back to HBM.
HBM elements2Nd  + 
Step02
P = softmax(S)Read S from HBM, normalize row-wise, write the N×N probability matrix P back to HBM.
HBM elementsN² + N²
Step03
O = P · VRead P and V from HBM, write the output O back to HBM.
HBM elements  +  2Nd
Total HBM traffic
≈ 3N² + 4Nd
The N×N matrix is read or written three times — and discarded after step 3.
03 · Locate the cost

The quadratic culprit.

At long sequence length, which single quantity dominates HBM traffic?

Chapter 04
04
Why the obvious fix fails

The softmax
that won't tile.

SRAM is megabytes. S is megabytes too. So tile it — compute a chunk at a time. Except: the softmax wants to see the whole row before it can normalize.

04 · Numerical reminder

The whole row is in the denominator.

In practice we always subtract the row maximum first — exp(100) overflows in float16. So the actual formula used everywhere is:

softmax(x)i  =  exp(xi − m) / j exp(xj − m),    m = max(x)

Both the numerator and the denominator depend on every element of the row. Split the row into two halves and compute each independently — and the two halves disagree on what "max" means.

04 · The trap

What goes wrong with tiling?

You compute softmax of the first half using its own max. The second half arrives. What breaks?

This is why approximate methods — Linformer, Reformer, Performer, sparse attention — were the only thing on the table for years. None produced consistent wall-clock speedups, because GPU memory wasn't really their target.

Chapter 05
05
The mathematical core

The online
softmax trick.

Two scalars per row of the attention matrix. That's the entire price you pay to avoid materializing the N×N matrix in HBM.

05 · Three lines

The recurrence.

Running state (m, ℓ, O). A new tile arrives with values x_new and v_new. Update in three lines:

m′  =  max(m, max(x_new))
ℓ′  =  exp(m − m′) · ℓ   +   ∑ exp(x_new − m′)
O′  =  exp(m − m′) · O   +   ∑ exp(x_new − m′) · v_new

The trick is exp(m − m′). Whenever the max grows, every term you accumulated before was computed against the old reference. Multiplying by this factor shifts them to the new reference. Algebraically identical to having seen all values at once — but computed one tile at a time.

05 · Live trace

Watching it run.

Row of 4 values, in two tiles. Watch the running state evolve.

m  ·  running max
−∞
ℓ  ·  running denominator
0
O  ·  weighted output
0
Tile 1
x = [1, 3]  ·  v = [1, 2]
Tile 2
x = [2, 5]  ·  v = [3, 4]
Initial state. No tile processed yet. The running statistics m, ℓ, O are at their identity values.
Chapter 06
06
Tiling × online softmax = flash

Assembling
the algorithm.

Tile Q, K, V to fit in SRAM. Use the online softmax to fuse the whole pipeline. The N×N matrix never touches HBM.

06 · Flash attention, forward

The full algorithm.

Hover any highlighted line to see what it's doing on the chip.

# Inputs Q, K, V in HBM with shape (N, d). M = SRAM size. Bc = ⌈M / 4d⌉,  Br = min(Bc, d)  # pick block sizes that fit in SRAM Split Q into Tr blocks Qi of size (Br, d) Split K, V into Tc blocks Kj, Vj of size (Bc, d) for i = 1 .. Tr:  # outer loop over output rows Load Qi from HBM to SRAM Init Oi ← 0, ℓi ← 0, mi ← −∞  (in SRAM) for j = 1 .. Tc:  # inner loop over key/value blocks Load Kj, Vj from HBM to SRAM Sij = Qi · Kj  (in SRAM) m̃ = max(mi, rowmax(Sij)),  P̃ij = exp(Sij − m̃) ℓ̃ = exp(mi − m̃)·ℓi + rowsum(P̃ij) Oi = exp(mi − m̃)·Oi + P̃ij · Vj mi ← m̃,  ℓi ← ℓ̃ Write Oi ← Oi / ℓi to HBM Write mi, ℓi to HBM  (needed for backward pass)
Hover any highlighted line above to see what's happening on the chip.
06 · Watching the kernel breathe

Tile by tile, on the chip.

The N×N matrix exists conceptually but is never assembled in HBM. Each tile is computed in SRAM, the result is folded into the running state, the tile is discarded.

HBM  ·  main memory
stays cold
stream tile
discard
SRAM  ·  on-chip, fused
all four ops fuse here
PhaseReady
Tile0 / 16
HBM read0
06 · The accounting

How much HBM did we save?

Standard attention does Θ(N² + Nd) HBM accesses. Flash attention does Θ(N²d² / M) — and the ratio depends on N and the SRAM budget. Move the sliders.

Configure the run
N 4K
SRAM (KB) 100

Head dimension d = 64 throughout. Dao et al. prove a matching lower bound — no exact attention algorithm can asymptotically beat this for a range of SRAM sizes.

Standard ÷ Flash  ·  HBM accesses
12.5×
Chapter 07
07
Trading FLOPs for memory wins

The recomputation
paradox.

Standard attention's backward pass needs the attention matrix P. Flash attention never stored it. So how does it get the gradients?

07 · The strange trick

Where do the gradients come from?

Flash attention saves only (mi, ℓi) per output row — O(N) extra memory, not O(N²). How does the backward pass reconstruct what it needs?

07 · Synthesize

Rank the four ideas.

Drag in order from most central to the speedup down to least. Then check.

Attention kernel on GPT-2
7.6×
faster than PyTorch's stock attention
End-to-end training on GPT-2
wall-clock training speedup

The downstream effect is bigger than the benchmark. Attention's O(N²) memory cost is what kept transformer context windows at 1K–2K tokens for years. Flash attention made the memory cost O(N). The long-context era runs on this algorithm.

07 · Generalize the lesson

The lesson, generalized.

Someone claims a new attention variant: less compute, same HBM bandwidth. Compared to flash attention on an A100 — what should you expect?

You can now defend it

Four facts to walk out with.

1.  Attention on modern GPUs is memory-bound, not compute-bound.

2.  HBM and SRAM differ by 10× in speed and 200,000× in size — and only SRAM is programmer-controlled.

3.  The online softmax recurrence needs only two scalars per row to be exact.

4.  Recomputation beats storage when FLOPs are cheaper than bandwidth.

Worth following up on: FlashAttention-2's parallelization fix, FlashAttention-3's H100 asynchrony, and PagedAttention for KV cache.

Dao · Fu · Ermon · Rudra · Ré  ·  NeurIPS 2022  ·  arXiv 2205.14135