The Anatomy of a Prompt : A journey from Python to Silicon - Part 4 of 5 - Flash Attention: The Trick That Made GPUs 4x Faster (And Why Python Can't Do It)
GPUs were bored, idle and the economics of serving billions of requests was poor. How did we get out of this?
Chapter 6: The Solutions — Fighting Physics with Software
The Economic Problem
By 2021, a year before ChatGPT and Dall-e went wild, a pattern had become painfully clear to ML engineers, GPUs were expensive, but most of the time, they were bored, idle.
Run a profiler on standard attention and you’d see it, the Tensor cores, the expensive, fast math units, were idle 80% of the time. They weren’t computing, they were waiting for memory.
The GPU cores could finish their math in microseconds, then they’d wait for hundreds of microseconds while data crawled from HBM.
The question became: We cannot make HBMs faster, we cannot make SRAM bigger, so the only thing we can ask is, can we reorganize work so that the GPU stops waiting?
Standard Attention Problem
In order for us to understand if we can reorganize work, we need to traceback and see what work actually looks like.
Let’s trace what happens to your prompt with 2,000 tokens of context.
Standard attention in PyTorch:
# Three separate operations = three separate HBM round-trips
scores = torch.matmul(Q, K.transpose(-2, -1)) # 2000×2000 matrix → write to HBM
scores = F.softmax(scores, dim=-1) # Read from HBM → write to HBM
output = torch.matmul(scores, V) # Read from HBM → write to HBMThe scores matrix is 2000 x 2000 = 4 million elements. If FP16, that is 8MB, written to HBM, then read back and then written again.
Each line is a separate CUDA kernel launch. Between kernels, all intermediate data goes back to HBM. The GPU cores finish their math in microseconds, then sit idle for hundreds of microseconds waiting for the next memory load.
The Softmax Problem
The immediate knee-jerk question then is, Why can’t we fuse these operations and stay in fast SRAM?
Look at softmax. For token 5 attending to all 2,000 tokens:
softmax(score_i) = exp(score_i) / Σ exp(score_j) for ALL jSoftmax is what we need to predict the next token. To compute the softmax for any single element, you need the sum of exponentials across all 2,000 scores. You can’t normalize tokens 0-500 without knowing what tokens 501-2000 contain.
This is why engineers assumed you had to materialize the full matrix. You need to see everything before you can normalize anything.
The Insight, the Flash Attention Insight
In 2022, Tri Dao asked a very different question, what if we never store the full matrix at all? He understood the problems that ML engineers were facing and knew there was no way to overcome the architecture of the model and the physical limitations of the hardware.
The insight unlock was this: you can compute softmax incrementally by tracking two running statistics — a running maximum and a running sum.
So instead of keeping the whole matrix in memory, you just needed to keep 2 numbers as you processed batches of tokens.
Lo, and behold, the Flash Attention algorithm
Here’s how it works for token 5 attending to 2,000 keys, processing in tiles of 500:
Tile 1 (keys 0-499)
Load Q[5], K[0:500], V[0:500] into SRAM (fits easily)
Compute local attention scores: 500 numbers
Find local max: m₁ = 2.1
Compute exponentials and sum: sum₁ = 847.3
Compute partial output
Store in registers: m₁, sum₁, partial_output
Discard the 500 scores (never written to HBM)Tile 2 (keys 500-999)
Load next tile into SRAM
Compute local scores
Find local max: m₂ = 3.4 ← bigger than m₁!
THE KEY TRICK: Rescale the old work
───────────────────────────────────
The old sum used max=2.1
The new max is 3.4
Correction factor: exp(2.1 - 3.4) ≈ 0.27
Multiply old sum and output by 0.27
Add this tile's contribution
───────────────────────────────────Tiles 3-4: Repeat, rescaling whenever a new max appears.
Final step: Divide by the final sum.
The magic: we process tiles sequentially, correcting for “wrong” earlier estimates by rescaling. The full attention matrix never exists — not in SRAM, not in HBM, nowhere.
Suddenly, we can now run faster, because GPUs are not waiting anymore.
What Stays in SRAM?
The Comparison
Same GPU. Same math (actually more FLOPs due to rescaling). 2-4x faster because the cores stop waiting.
So is Flash attention, a package in Python for pytorch?
If you have read everything until here, you immediately realize that Flash attention is not a pytorch module. It is a hand-written CUDA kernel in C++, because as we stated above, GPU does not understand Python. The careful orchestration of cores and memory is managed via the CUDA kernel.
Now, you might ask, well can we not have some way to do this in Python?
# This CANNOT implement Flash Attention:
for tile in tiles:
scores = torch.matmul(Q, K_tile.T) # ← Kernel launch, result goes to HBM
# Next line is ANOTHER kernel launch
# Data round-trips through HBM between every operationPytorch operations are opaque kernels. Each torch.matmul() launches a kernel: load from HBM, compute, write to HBM, return to Python. You cannot keep intermediate results in SRAM across PyTorch operations.
The CUDA Kernel Requirement
To fuse everything into a single kernel that never leaves SRAM, you must write raw CUDA:
__global__ void flash_attention_kernel(float* Q, float* K, float* V, float* Out) {
__shared__ float Q_tile[TILE_SIZE][HEAD_DIM]; // Lives in SRAM
__shared__ float K_tile[TILE_SIZE][HEAD_DIM];
float running_max = -INFINITY;
float running_sum = 0.0f;
for (int tile = 0; tile < num_tiles; tile++) {
// Load tile into SRAM
// Compute scores IN REGISTERS
// Update running statistics
// Rescale accumulator
// Accumulate this tile's contribution
// Scores are DISCARDED here — never touch HBM
}
// Single write to HBM at the end
store_output(Out);
}This entire loop executes as one kernel launch. The intermediate scores exist only in registers and shared memory.
Approach Kernel Launches HBM Round-Trips GPU-Utilization
Standard PyTorch. 3. 6. ~20%
Flash Attention 1. 2. ~70%NVIDIA’s True Moat Continues
Now I understand why my GPU struggles with AI but runs games fine.
Games have Decades of standardized APIs. AMD has invested years optimizing their DirectX and Vulkan kernels. When a game ships, it uses these standard paths.
For Machine learning, CUDA is proprietary. Flash Attention was written for NVIDIA hardware. The kernel assumes specific shared memory sizes, specific warp behavior, specific memory coalescing patterns.
To run Flash Attention on AMD, someone has to rewrite it from scratch for ROCm, tuned for AMD’s different memory hierarchy. This work is happening, but it’s behind.
NVIDIA’s moat isn’t just the silicon. It’s 15 years of CUDA work — the thousands of person-years of kernel optimization that every ML framework depends on.
Next Week: Scale & Economics (The Finale)
Flash Attention solved single-GPU inference. Same hardware, 2-4x faster — just by reorganizing memory access.
But here’s the thing:
A 70B model in FP16 = 140GB. An H100 has 80GB. The model doesn’t fit.
A week-long Claude conversation might have 50,000+ tokens. At 1.6MB per token, that’s 80GB of KV cache — for one user.
How do you split a brain across 8 GPUs? How do you serve a million users? And who pays for $30,000 chips that cost $10K/year just in electricity?
Part 5: Tensor Parallelism, NVLink, and why output tokens cost 6x more than input.
The finale drops next week.












