kvcache.cobanov.dev

Built by Mert Cobanov

KV Cache &
Flash Attention.

LLM inference optimizations are one of the parts of this field I keep coming back to. New techniques land at a steady clip, a smarter cache, a faster kernel, a better scheduler, and each is a specific, careful answer to a specific bottleneck. I built this page to work through the most important of them, as much for my own understanding as for whoever else reads it, visualized as clearly as I could manage. Eleven sections, top to bottom, each pairing a short explanation with Python and an animated diagram.

These started as my own notes while working through Rohit Ghumare’s excellent ai-engineering-from-scratch repo. I kept rewriting them with my own visualizations until they made sense, and this page is the result.

The baseline

01

Naive autoregressive decoding

A decoder-only transformer produces text one token at a time. During training, the entire sequence is fed in parallel and attention is computed once over the whole input. At inference time this is no longer possible: token cannot be sampled until token has been emitted.

The simplest implementation handles this serial constraint by running the model from scratch at every step. Each new token triggers a fresh forward pass over the entire sequence so far: every position’s Q, K, and V vectors are recomputed, the full attention matrix is materialized, and only the last row of the output is actually used. The work done for positions 0 through is identical to the work done in the previous step. It is thrown away and recomputed.

The repeated per-position work scales as the triangular sum: emitting N tokens reprocesses token positions, i.e. in the generated length. A 100-token completion reprojects 5,050 positions; a 4,096-token completion reprojects more than eight million. A literal full-matrix implementation also computes attention-score cells, which is even more wasteful.

naive_decode.pypython
def naive_decode(prompt_tokens, n_new):
    tokens = list(prompt_tokens)
    projection_rows = 0
    attention_cells = 0
    for step in range(n_new):
        # recompute Q, K, V for the entire sequence every step
        Q = project_q(tokens)            # (N, d)
        K = project_k(tokens)            # (N, d)
        V = project_v(tokens)            # (N, d)

        attn = softmax(Q @ K.T) @ V      # full N x N matrix
        next_tok = sample(attn[-1])      # only the last row is used

        tokens.append(next_tok)
        projection_rows += len(tokens)   # repeated per-position projection work
        attention_cells += len(tokens) ** 2
    return tokens, projection_rows, attention_cells
naive decoder, per-step work
tokens generated: 8recomputed positions: 36

Each row is one decode step. Filled cells along the row are the token positions whose Q/K/V projections are recomputed for that step. The highlighted row is the current step; the dimmer rows were already processed in earlier steps and then thrown away. Total filled cells = N(N+1)/2.

Drag the slider to grow N. The triangle counts token positions whose Q/K/V projections are recomputed from scratch.

Two observations follow from this. First, the K and V vectors for any prefix position are fixed once that position has been computed from its causal prefix and the model weights; future tokens cannot change them. Recomputing them is wasted compute. Second, of the full N×N attention output, only one row (attn[-1]) is consumed at each step. The other rows are computed and discarded.

These two observations motivate the KV cache, which is the subject of the next section.

First optimization

02

The KV cache

The and projections for any prefix position are functions of that position’s hidden state, its causal prefix, and the model weights. Once a prefix position has been computed, future tokens are masked from changing it. If we keep those and vectors around, later decode steps can read them instead of reprojecting the prefix.

The fix is straightforward: keep a per-layer, per-head buffer and append to it every time we produce a new token. Each decode step projects a single new token’s , appends and to the cache, and runs attention as a row against the cached keys. The repeated projection/state work drops from old positions per step to one new position. Across a generation of tokens, that part falls from to . The new query still compares against all cached keys, so attention-score work over the generation remains length-dependent.

For that means the naive decoder reprojects token positions while the cached decoder projects only newly generated positions. The cache removes duplicated prefix work; it does not make the attention scan over a long prefix disappear.

kv_cache.pypython
class KVCache:
    def __init__(self, n_layers, n_heads, d_head):
        self.K = [[[] for _ in range(n_heads)] for _ in range(n_layers)]
        self.V = [[[] for _ in range(n_heads)] for _ in range(n_layers)]

    def append(self, layer, head, k, v):
        self.K[layer][head].append(k)
        self.V[layer][head].append(v)

    def read(self, layer, head):
        return self.K[layer][head], self.V[layer][head]


def cached_decode(prompt, n_new, cache):
    # prefill: run the prompt once and fill the cache
    for tok in prompt:
        h = causal_hidden_state(tok, cache)   # depends on the prefix, not tok alone
        k, v = project_kv(h)
        cache.append(layer=0, head=0, k=k, v=v)

    # decode: only one current position is projected each step
    out = []
    for _ in range(n_new):
        h_cur = causal_hidden_state(out[-1] if out else prompt[-1], cache)
        q = project_q(h_cur)
        K, V = cache.read(0, 0)

        attn = softmax(q @ stack(K).T) @ stack(V)   # (1, N)
        next_tok = sample(attn)

        h = causal_hidden_state(next_tok, cache)
        k, v = project_kv(h)
        cache.append(0, 0, k, v)
        out.append(next_tok)
    return out
cached decoder, per-step work
tokens
8
projected
8
vs naive proj.
22%

Orange cells are newly projected K/V entries. Green cells are cached prefix entries read by the current query. They are not reprojected, but the query still attends over them.

Orange = newly computed this step. Green = served from the cache.

Two practical consequences. First, generation now has two distinct phases. The prefill phase processes the entire prompt in one parallel forward pass to populate the cache. The decode phase runs one token at a time against the cache. These have very different arithmetic intensity profiles and production engines schedule them differently.

Second, the savings in compute are paid for in memory. Every cached K and V vector lives in GPU HBM until the sequence finishes. For short sequences this is irrelevant. For 32K- or 128K-token contexts on large models it becomes the dominant constraint, which is what we look at next.

The tradeoff

03

The cost of keeping K and V around

The KV cache turns compute savings into a memory tax. Per layer, per token, the cache stores one K vector and one V vector, each of width d_head times the number of KV heads. With fp16, that works out to:

For Llama 3.1 8B (32 layers, 8 KV heads under GQA, d_head 128, fp16) the cache costs 128 KiB per token, which is about 4 GiB for a 32K context. Llama 3 70B has the same per-head width but 80 layers, so each token costs roughly 320 KiB and a 32K context needs about 10 GiB. At 128K context, the cache alone for 70B is around 40 GiB, the majority of an A100 before any model weights are loaded.

Two architectural decisions make this manageable. Grouped Query Attention (GQA) decouples the number of K/V heads from the number of Q heads: Llama 3.1 70B has 64 query heads sharing 8 KV heads, cutting the cache by 8x relative to 64-head multi-head attention. Multi-head Latent Attention (MLA), used in DeepSeek V2/V3, goes further by projecting K and V into a smaller shared latent space and decompressing on demand.

kv_cache_size.pypython
def gqa_kv_bytes(n_tokens, n_layers, n_kv_heads, d_head, dtype_b=2, batch=1):
    """Cache size for GQA or MHA: store one K and one V per kv head per token."""
    return 2 * batch * n_tokens * n_layers * n_kv_heads * d_head * dtype_b


def mla_kv_bytes(n_tokens, n_layers, kv_lora_rank, qk_rope_dim, dtype_b=2, batch=1):
    """Cache size for MLA: store the compressed latent + RoPE channel only."""
    return batch * n_tokens * n_layers * (kv_lora_rank + qk_rope_dim) * dtype_b


gqa = {
    "Llama 3.1 8B":        dict(n_layers=32, n_kv_heads=8,  d_head=128),
    "Llama 3.1 70B":       dict(n_layers=80, n_kv_heads=8,  d_head=128),
    "Llama 3.1 405B":      dict(n_layers=126,n_kv_heads=8,  d_head=128),
    "Qwen 3 32B":          dict(n_layers=64, n_kv_heads=8,  d_head=128),
    "Qwen 2.5 72B":        dict(n_layers=80, n_kv_heads=8,  d_head=128),
    "Mistral Large 2 123B":dict(n_layers=88, n_kv_heads=8,  d_head=128),
}
mla = {
    "DeepSeek V3 671B":    dict(n_layers=61, kv_lora_rank=512, qk_rope_dim=64),
    "Kimi K2 1T":          dict(n_layers=61, kv_lora_rank=512, qk_rope_dim=64),
}

for name, cfg in gqa.items():
    gb = gqa_kv_bytes(32_000, **cfg) / 1024**3
    print(f"{name:24s} {gb:6.2f} GiB @ 32K context")

for name, cfg in mla.items():
    gb = mla_kv_bytes(32_000, **cfg) / 1024**3
    print(f"{name:24s} {gb:6.2f} GiB @ 32K context")

# Llama 3.1 8B               3.91 GiB
# Llama 3.1 70B              9.77 GiB
# Llama 3.1 405B            15.38 GiB
# Qwen 3 32B                 7.81 GiB
# Qwen 2.5 72B               9.77 GiB
# Mistral Large 2 123B      10.74 GiB
# DeepSeek V3 671B           2.10 GiB   <- MLA wins, despite being bigger
# Kimi K2 1T                 2.10 GiB
KV cache vs HBM ceiling
scopeThis calculator shows the KV cache only, per replica, for one batch. Model weights are separate; for large models most of the GPU is already taken by weights before any cache is allocated. Each preset includes the typical deployment shape.
KV cache only9.77 GiB
0H100 80 GB HBM (cache only)80 GiB
bytes / token / layer
4096 B
bytes / token
320.0 KiB
cache / GPU HBM
12.2%
weightsWeights ~140 GB at bf16; typical deployment is 2-4 H100s with tensor parallel.
Try Llama 3 70B at 128K context, then switch to DeepSeek V3. MLA is the reason the 671B model has a smaller cache than the 70B one.

The KV cache is also why throughput-oriented servers quantize the cache (fp8, int4) before they quantize anything else: a single byte cut per element is multiplied by every layer, every head, every token, every concurrent sequence. At 70B and 32K context, switching from fp16 to fp8 halves the cache from roughly 10 GiB to 5 GiB, which is the difference between fitting two sequences on a single H100 and fitting one.

Why kernels matter

04

The memory bandwidth bottleneck

The KV cache solved a compute problem. There is still a data-movement problem underneath it. At every decode step a GPU runs roughly : a query against a length-N key tensor, a softmax, and a length-N value tensor. The arithmetic is modest. The data movement is not.

Modern GPUs have two memory tiers. HBM is the large off-chip pool: 80 GB on an H100, 192 GB on a B200, around 3 TB/s of bandwidth. SRAM is the on-chip scratchpad inside each streaming multiprocessor: roughly 256 KB per SM on H100, but around 30 TB/s, ten times faster than HBM. Every time a computation goes out to HBM and back, the kernel pays that 10x factor.

The naive attention kernel computes , writes the score matrix to HBM, reads it back to compute softmax, writes back, then reads it once more to multiply by V. For N = 4096 in fp16, each score or probability matrix is 32 MiB, so those four intermediate HBM legs alone move about 128 MiB per layer, per head. From onward, attention becomes memory-bound; the matrix multiply units run at a small fraction of their peak FLOPs because the kernel is waiting on bytes.

standard_attention.pypython
def standard_attention(Q, K, V):
    # Q, K, V live in HBM.
    S = Q @ K.T            # materialize N x N in HBM
    P = softmax(S)         # read S from HBM, write P to HBM
    O = P @ V              # read P from HBM, read V, write O
    return O

# For N = 4096, fp16:
#   S, P each occupy 2 * 4096 * 4096 = 32 MiB
#   Write S, read S, write P, read P: ~128 MiB of
#   N x N intermediate traffic per layer, per head.
standard attention through the memory hierarchy
sequence length N
HBM80 GB3 TB/sSRAM256 KB · 30 TB/s1.0 MiB
stage 1 of 111.0 MiB · HBM → SRAM
Load Q from HBM
Q is N×d. A few hundred KB to a few MB.
cumulative HBM traffic so far1.0 MiB

The full kernel pushes 132.0 MiB through HBM at N = 4,096. Four of those legs move an N×N intermediate, each one 32.0 MiB. The actual answer (the O tensor) is only 1.0 MiB.

load Q,K,Vread/write S, Pwrite O
Press run. Watch the eleven stages: each red bar is an N×N tensor being shuffled between HBM and SRAM. Try changing N.

The fix is not to add more FLOPs but to rearrange where they happen. If the full matrix never has to live in HBM, if we can do softmax and the value multiply while the relevant slice still lives in SRAM, the bandwidth bill collapses. That is the observation Flash Attention turns into an algorithm, which is the subject of the next section.

Rearranging the work

05

Flash Attention: tiling the N×N matrix out of HBM

The full attention output never depends on having the entire matrix in one place. Each output row only needs the corresponding row of and all of , . We can split Q into row blocks and K, V into column blocks, then loop: load one Q block and one K-V block into SRAM, compute the partial scores, partially accumulate the output, repeat. The intermediate scores never leave the chip.

The catch is softmax. It is a global operation: the normalization constant depends on the max and sum across all N scores. If we process K in pieces, the max and sum we see in any one tile are partial. The trick (usually attributed to Milakov and Gimelshein, and put to work in this kernel by Dao et al.) is to keep a running pair: the largest score seen so far and the rescaled sum so far. When a new tile arrives we adjust both, rescale the partial output accordingly, and add the new tile’s contribution. In real arithmetic the final division gives the same softmax as the naive formulation; in floating point, the exact bits can differ because the additions and rescaling happen in a different order.

The end result is that the score and probability matrices are never written to HBM. The algorithm keeps-sized state for the output, row maxima, and row sums instead of materializing intermediates. The exact HBM traffic still depends on tile sizes and how often K/V tiles are reread, but the expensive S/P read-write path from the previous section disappears. The actual wall-clock speedup depends on the kernel and GPU, but is commonly several times faster than unfused attention on A100/H100-class systems.

flash_attention.pypython
def flash_attention(Q, K, V, tile=64):
    N, d = Q.shape
    O = zeros_like(Q)

    for i in range(0, N, tile):                   # outer loop over Q tiles
        Q_i = Q[i:i + tile]                       # load to SRAM

        m_i = full(tile, -inf)                    # running max
        l_i = zeros(tile)                         # running sum
        O_i = zeros((tile, d))                    # running output

        for j in range(0, N, tile):               # inner loop over K, V tiles
            K_j = K[j:j + tile]
            V_j = V[j:j + tile]

            S_ij  = Q_i @ K_j.T                   # SRAM only, never to HBM
            m_new = maximum(m_i, S_ij.max(-1))
            P_ij  = exp(S_ij - m_new[:, None])
            l_i   = exp(m_i - m_new) * l_i + P_ij.sum(-1)
            O_i   = exp(m_i - m_new)[:, None] * O_i + P_ij @ V_j
            m_i   = m_new

        O[i:i + tile] = O_i / l_i[:, None]        # one HBM write per Q tile

    return O
tiled traversal of the attention matrix
tile
K columns →Q rows ↓
SRAM working set
Q tile 0
K tile 0
V tile 0
running (m, ℓ) × 4
partial O × 4
iteration
outer i = 0 / 3
inner j =
ready
accumulators for Q tile 0
rowm
0−∞0.00
1−∞0.00
2−∞0.00
3−∞0.00
HBM traffic so far
flash0 B
naive0 B
Toy values (N=16, d=8, fp16).
active tileaccumulated into current Qfinalized Q tile

Outer loop walks the Q row blocks. Inside, an inner loop walks every K and V column block. Each (Qi, Kj, Vj) triple is loaded into SRAM, used to refresh the running (m, ℓ) and partial O for the rows in Qi, and discarded. When the inner loop ends, the finalized O tile is written out and we move to the next Q. The score matrix never exists as a single object; it is consumed tile by tile in place.

Each block is loaded into SRAM, used, and discarded. The full N×N matrix is never materialized. Step through manually or press run to watch the inner loop scan K for each Q row block.

Three things matter for the rest of this page. First, Flash Attention is exact in the algorithmic sense, not approximate. Output and gradients match the unfused kernel up to floating-point rounding. Second, the algorithmic core is the online softmax; the kernel-level work is choosing tile sizes and scheduling to fit the GPU’s SRAM and warp structure. Third, every major version bump (v1 to v2 to v3 to v4) keeps the algorithm and rewrites the scheduling to match a new GPU architecture. We look at the online softmax in isolation next, then at the version history.

The numerical heart of Flash Attention

06

Online softmax

Softmax is defined globally: each element of the output depends on the sum of exponentials over the full input. In a streaming setting we only see one tile of scores at a time, so the obvious approach (compute , then sum , then divide) would require two full passes over the scores. The online variant collapses this into a single pass by maintaining two running quantities: , the largest score seen so far, and , the sum of over everything seen so far.

When a new tile arrives, two adjustments happen. The running max may have to grow, in which case the running sum needs to be rescaled by so it expresses everything relative to the new offset. The new tile’s contribution is then added in. After all tiles have been seen, dividing each by the final gives the exact softmax.

Flash Attention fuses this with the matrix multiply: the running rescale factor is applied not just to the running sum but also to the running output tile, which is itself an accumulator over contributions. So the algorithm never needs to revisit prior tiles. One linear sweep, one final division, the same real-valued result as ordinary softmax.

online_softmax.pypython
import math

def online_softmax(scores, tile=4):
    """Numerically stable softmax computed in a single pass over tiles.
    Same real-valued result as naive softmax; fp rounding may differ.
    """
    m = -math.inf      # running max
    l = 0.0            # running sum of exponentials

    # one pass to learn m and l
    for s in range(0, len(scores), tile):
        block = scores[s:s + tile]
        m_new = max(m, *block)
        scale = math.exp(m - m_new) if m != -math.inf else 0.0
        l     = l * scale + sum(math.exp(x - m_new) for x in block)
        m     = m_new

    # second pass to emit normalized probabilities
    return [math.exp(x - m) / l for x in scores]


# In Flash Attention the second pass is fused into the same tile loop:
# each tile's output contribution is rescaled by exp(m_old - m_new) when
# the running max advances, so no scores need to be revisited.
tile-by-tile softmax over 12 scores
tile
1 / 3
running m
2.100
running ℓ
1.761
partial softmax · exp(xi − m) / ℓ for each position
currentabsorbednaive final
tile 0current
1.2
0
-0.4
1
0.8
2
2.1
3
tile 1pending
0.3
4
-1.5
5
1.8
6
0.7
7
tile 2pending
-0.2
8
2.4
9
1.1
10
0.5
11
tile 0 · how (m, ℓ) updated
block maxmax(1.2, -0.4, 0.8, 2.1)= 2.100
m_newmax(m_old, block_max) = max(−∞, 2.100)= 2.100
scale(initial — ℓ_old = 0)=
ℓ_newΣ exp(x − m_new) over tile= 1.761

First tile sets the baseline. No previous sum to rescale, so scale is undefined and ℓ is just the sum of this tile's exponents.

Step through each tile to watch m and ℓ evolve. The key move is the rescale when a new max arrives. The dashed line on the bar chart is the final naive softmax for reference.

Two things to internalize. First, this is not a compression or approximation. Every probability is the same as the unfused version in real arithmetic, with the usual floating-point rounding differences. Second, the trick is general: any operation of the form “normalize across N then aggregate” can be made streaming the same way. Layer norm, RMS norm, and the backward pass through softmax all admit similar treatments, and Flash Attention’s newer kernels exploit this on the gradient path too.

The kernel keeps moving

07

Four versions, four GPU generations

Flash Attention has been rewritten three times since 2022. The algorithm itself (tiling plus online softmax) is unchanged. What changes is the schedule: how the tiles map to warps, how memory copies overlap with math, which numeric formats are used, which special instructions on which generation of tensor core the kernel relies on.

v1 (2022) was the first end-to-end fused implementation on A100. v2 (2023) reworked the parallelization to keep the warps busy on causal workloads, where the upper-right of the matrix is masked out. v3 (2024) was designed around H100 features: warp specialization, asynchronous TMA copies, FP8 with per-block scaling. v4 (2026) targets Hopper and Blackwell through a CuTe-DSL implementation; on Blackwell it exploits fully asynchronous MMA, tensor memory, larger tiles, and software exp2/rescaling paths that fit the new balance of tensor-core throughput versus scalar units.

From a deployment perspective, the choice is determined by the GPU and by feature coverage in the serving or training framework. FA4 is no longer just a forward-only idea: the paper and public CuTe tree include backward kernels, and the PyPI beta exposes regular and variable-length APIs. Mature deployments may still fall back to FA3 or cuDNN for unsupported masks, head dimensions, or integration paths.

pick_flash.pypython
def pick_flash(gpu: str, training: bool) -> str:
    """Pick the right FlashAttention kernel for the deployment."""
    if gpu in ("B200", "GB200"):
        return "FlashAttention-4"
    if gpu in ("H100", "H200"):
        return "FlashAttention-3"
    if gpu in ("A100", "A6000", "L40S"):
        return "FlashAttention-2"
    return "FlashAttention-2 (safe default)"


# Practical defaults (May 2026):
#   - Blackwell:              FA4 when the needed mask/head-dim path exists.
#   - Hopper:                 FA3 is the stable default; FA4 CuTe also targets it.
#   - Anything older:         FA2.
version timeline

Hopper async, FP8

H100 · 2024

Warp specialization with async memory copies. FP8 path with per-block scaling. Designed around Hopper's WGMMA.

utilization (approx, fp16 forward)
~740 TFLOPs/s fp16, ~75% of H100 SOL; 1.2 PFLOPs/s FP8.
Click a node. Each version is matched to its GPU generation; performance figures are approximate fp16 utilization.

The reason this matters in practice: the same model, with the same weights, running the same logical attention, can vary by a factor of three to five in wall-clock latency depending on which kernel is wired in. The difference between a workload feeling responsive and feeling sluggish is almost never the model. It is the kernel below the model.

Allocation, not arithmetic

08

PagedAttention: KV cache as virtual memory

The default way to allocate KV cache is to give each sequence a contiguous slab of HBM. That works for one sequence. For a server handling many sequences at once it breaks down. Sequence lengths vary wildly. A one-word reply and a thousand-token explanation can arrive in the same batch. None of the obvious allocation strategies does well across that spread.

Reserve the maximum context length for everyone and most of the cache sits idle. Allocate the current length and grow it as the sequence extends, and each grow event needs a reallocation and copy under HBM bandwidth pressure. Guess the expected length up front and you waste memory when you guess high or stall when you guess low. Each choice is bad in a different way.

The harder problem arrives when sequences finish. A sequence that ends early leaves a hole in HBM. The next request, of a different size, generally cannot use that hole because contiguous memory has to be, well, contiguous. Production traces from vLLM in 2023 showed 60-80% of the KV pool sitting idle inside such fragments.

The PagedAttention proposal, introduced by the vLLM team, ports an idea from operating-system virtual memory to KV cache. Carve HBM into fixed-size physical blocks (16 tokens by default). Give each sequence a small page table mapping its logical token positions to the physical blocks that hold them. New tokens grow the page table by appending another block. A finished sequence returns all of its blocks to a free list. The attention kernel takes the page table as an extra input and gathers K, V from the physical blocks at lookup time.

Three consequences follow. External fragmentation mostly disappears because any free block can serve any sequence; the remaining internal fragmentation is bounded by the final partially filled block of each live sequence. The same HBM serves substantially more concurrent sequences, roughly 4x throughput at fixed memory in the original paper. And the page-table indirection makes prefix sharing practical: two sequences that share full prefix blocks can point to the same physical blocks with refcounts, no copy, no recomputation. Partial-block sharing needs a copy or copy-on-write tail.

paged_kv_cache.pypython
class PagedKVCache:
    """KV cache as virtual memory: fixed blocks + per-sequence page table."""
    BLOCK = 16  # tokens per physical block (vLLM default)

    def __init__(self, total_blocks: int):
        self.free: list[int] = list(range(total_blocks))
        self.page_table: dict[int, list[int]] = {}
        self.refcount: dict[int, int] = {}

    def _alloc_block(self) -> int:
        block = self.free.pop()
        self.refcount[block] = 1
        return block

    def _incref(self, block: int) -> None:
        self.refcount[block] += 1

    def _decref(self, block: int) -> None:
        self.refcount[block] -= 1
        if self.refcount[block] == 0:
            del self.refcount[block]
            self.free.append(block)

    def allocate(self, seq_id: int, n_tokens: int) -> None:
        n_blocks = -(-n_tokens // self.BLOCK)        # ceil-div
        self.page_table[seq_id] = [self._alloc_block() for _ in range(n_blocks)]

    def append_token(self, seq_id: int, current_len: int) -> None:
        # Only grow when crossing a block boundary.
        if current_len % self.BLOCK == 0:
            self.page_table[seq_id].append(self._alloc_block())

    def release(self, seq_id: int) -> None:
        for block in self.page_table.pop(seq_id):
            self._decref(block)

    def physical(self, seq_id: int, logical_pos: int) -> tuple[int, int]:
        block_idx = self.page_table[seq_id][logical_pos // self.BLOCK]
        return block_idx, logical_pos % self.BLOCK

    def share_prefix(self, src_id: int, dst_id: int, prefix_len: int) -> None:
        """Share full prefix blocks; partial tail tokens need copy or COW."""
        if prefix_len % self.BLOCK != 0:
            raise ValueError("share only full blocks; copy the partial tail")
        n = prefix_len // self.BLOCK
        shared = self.page_table[src_id][:n].copy()
        for block in shared:
            self._incref(block)
        self.page_table[dst_id] = shared
contiguous vs paged, walked through step by step
step 1 / 6
Empty pool

Both allocators start empty. The pool is 64 token slots, drawn here as 16 physical blocks of 4 tokens each. vLLM uses 16-token blocks in production; the smaller value here is only to keep the diagram readable.

contiguous allocator
bump pointer + holes
live0 / 64
wasted0
paged allocator
4-token blocks + page table
live0 / 64
tail slack0
block 0block 1block 2block 3block 4block 5block 6block 7block 8block 9block 10block 11block 12block 13block 14block 15
free list[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
Press auto or step through. Watch how the contiguous side accumulates a hole it cannot reuse, while the paged side stays usable through the same sequence of events.

What makes PagedAttention interesting is that it is not a new attention. The math is identical. What changed is how the kernel addresses memory. That is also why every modern inference engine ships some variant of it: vLLM’s original 16-token blocks, SGLang’s RadixAttention with a trie of shared prefixes, TensorRT-LLM’s paged KV. The layouts differ; the underlying principle is the same one operating systems have used for sixty years.

Scheduler over kernel

09

Continuous batching

Single-sequence speedups stop mattering past a point. Production servers run many sequences in parallel, and the dominant inefficiency is no longer kernel time but slot idleness. Static batching builds a batch of N requests, runs it to completion, and only then starts the next. Because sequences have wildly different lengths (a one-word answer and a thousand-token explanation can arrive in the same batch), everyone waits for the slowest. Slot utilization on chat-style workloads under static batching is often below 30%.

Continuous batching, first shipped in Orca and now standard in vLLM, TensorRT-LLM, and SGLang, treats the batch as a fluid pool. At every decode step, any sequence that finished is removed and a queued sequence is admitted into its slot. The kernel still runs one token per active sequence per step, but the slots stay full. Five to ten times the throughput on the same hardware for typical chat traffic, with no change to the model.

The trick that makes this practical is that the kernel does not need sequences to be aligned in length. The KV cache for each sequence is addressed through its own page table; the attention kernel reads them with a gather. A new sequence’s prefill can be folded into the same iteration as ongoing decodes (“chunked prefill”), so admitting a request does not stall the batch.

continuous_batching.pypython
def continuous_batch_step(active, waiting, max_batch):
    """One iteration of a continuous-batching scheduler."""
    # 1) drop sequences that finished during the previous step
    active = [s for s in active if not s.is_done()]

    # 2) admit new sequences into any free slots
    while waiting and len(active) < max_batch:
        new_seq = waiting.pop(0)
        new_seq.prefill()           # chunked-prefill is fine here too
        active.append(new_seq)

    # 3) run one decode step for every active sequence in parallel
    for seq in active:
        seq.advance_one_token()

    return active, waiting
static vs continuous, same workload
step
0 / 60
static tokens
0
continuous tokens
0
static batching
continuous batching

Each row is a GPU batch slot. Each column is a decode step. Coloured cells emit a token; empty cells in the static version are wasted capacity while the batch waits for the longest sequence to finish.

Press play. The gaps in the static panel are the cost of waiting for the longest sequence to finish.

Continuous batching pairs naturally with paged KV. The scheduler can admit and evict sequences at block granularity without any data copy. Together with prefix sharing, the combination is what makes modern high-QPS LLM serving possible at all: most of the 2024–2026 throughput improvements at API providers come from scheduler work, not from new kernels.

More tokens per forward pass

10

Speculative decoding

Until this point every optimization has been about making one decode step cheaper. Speculative decoding asks a different question: can a single decode step produce more than one token? The serial constraint is real (token i conditions on token i-1), but it can be amortized by guessing.

The setup uses two models. A cheap draft model proposes k tokens by sampling autoregressively from the current prefix. The capable target model then evaluates the prefix followed by those k proposals in a single forward pass and checks each proposal in turn. If a proposal is accepted under the rejection-sampling rule, the chain advances. The first rejected token is replaced by a fresh sample from the positive residual distribution, and the chain stops there. If all k draft tokens are accepted, the target distribution at the last verified position gives one additional token. The critical property of the acceptance rule is that the output distribution is exactly the target’s. Quality is unchanged.

The expected number of tokens emitted per target forward pass depends on how often the draft and target agree. With a well-matched draft, acceptance rates of 60-80% are common on code and structured prose, giving 3 to 5 tokens per target call. That is an equivalent reduction in wall-clock latency since target forward passes dominate the cost. EAGLE-2 and Medusa go further by folding the draft into the target model itself, reusing its hidden states so the draft has almost no overhead and stays well-calibrated to the target.

speculative_decode.pypython
def speculative_step(draft, target, prefix, k=5):
    """One round of draft-and-verify. Returns the tokens to append."""
    # 1) cheap model proposes k tokens autoregressively and keeps q_i
    proposals, q_dists = [], []
    ctx = list(prefix)
    for _ in range(k):
        q = draft.probs(ctx)
        tok = sample(q)
        proposals.append(tok)
        q_dists.append(q)
        ctx.append(tok)

    # 2) target model scores prefix + proposals in a SINGLE forward pass
    logits = target.forward(prefix + proposals)

    # 3) walk left to right, accept by rejection-sampling rule
    accepted = []
    for i, tok in enumerate(proposals):
        p_target = softmax(logits[len(prefix) + i - 1])
        q_draft  = q_dists[i]
        if random() < min(1.0, p_target[tok] / q_draft[tok]):
            accepted.append(tok)
        else:
            # resample one token from the adjusted residual distribution
            residual = (p_target - q_draft).clip(min=0)
            accepted.append(sample(residual / residual.sum()))
            return accepted

    # if every draft token is accepted, sample one extra target token
    p_next = softmax(logits[len(prefix) + k - 1])
    accepted.append(sample(p_next))
    return accepted
# Each call to speculative_step does ONE target forward pass and returns
# between 1 and k+1 tokens depending on the acceptance rate.
draft proposes, target verifies
rounds
0
tokens emitted
0
target passes
0
tokens / call
0.00
accumulated output
prefixacceptedresampledextra target
Thedecoderwrotethe
DRAFT
small · ~1B
0 fwd
idle
TARGET
large · ~70B
0 fwd
idle
this round’s proposals
…1
…2
…3
…4
…5
waiting for next round
target forward passes: speculative vs vanilla
speculative0 calls → 0 tokens
vanilla0 calls → 0 tokens

Target passes dominate wall clock. Draft cost is treated as negligible (EAGLE/Medusa reuse target hidden states).

Watch the draft run k forward passes, then the target verify all k tokens in one pass. If all pass, the target emits one extra token from the same pass.

Speculative decoding shines where latency dominates: chat completions, code assistants, anywhere the user is waiting on the first few hundred tokens. It composes cleanly with everything above. The draft and target share the KV cache for the verified prefix. The target’s forward pass uses Flash Attention. Continuous batching admits new sequences alongside speculative ones. The technique does not save FLOPs in absolute terms (the target still has to evaluate every position), but it cuts wall-clock by trading compute for serial steps.

Closing

11

Where this leaves us

Across these eleven sections we walked one thread. A naive decoder does too much work, and a sequence of optimizations each remove a different kind of waste. The KV cache cut compute but cost memory. GQA and MLA brought it back down. Flash Attention rearranged HBM traffic so the kernel could run close to peak. PagedAttention turned the cache into something a multi-sequence server could actually manage. Continuous batching kept the slots full. Speculative decoding amortized the serial constraint. None of these is the main optimization, which is why vLLM, TensorRT-LLM, and SGLang ship them all at once. The large gap between a naive prototype and a tuned deployment comes from the composition, not from any one trick.

Writing this page was, honestly, an adventure for me. I built it as a way of arranging my own notes so they would still make sense to me six months from now, and I shared it because the visual form happens to be easier to follow than the way I keep these ideas in my head. LLM inference optimization is one of the corners of this field that genuinely excites me. The pace is wild, and the work is being done by people who think carefully about hardware, numerics, and scheduling all at once. Reading their papers, then trying to redraw what they did in a form I could explain to a friend, that is the part I enjoy most.

I do not have a strong prediction about where this goes next. Kernels will keep approaching peak. Attention will keep being rewritten for each new GPU generation. The boundary between “model” and “serving system” will keep blurring as drafts get folded into targets, attention gets fused with norms, and KV management moves into the model itself. Long contexts will make the cache problem worse before it gets better. I will keep updating this page as new pieces land, because the pace does not seem to be slowing.