Building a Production LLM Inference Engine from Scratch

I built an LLM inference engine from scratch. The same core ideas behind vLLM, implemented without using it. The result: 21.2x higher throughput than HuggingFace at 64 concurrent requests, with GPU memory holding flat at 3 GB regardless of load.

This post is about the systems decisions that made that number possible. Not a tutorial. An explanation of why each piece works the way it does.

The Problem with Naive LLM Serving

Before vLLM, the standard way to serve LLMs was simple: take a request, run a forward pass for each token until done, return the result. HuggingFace's pipeline() does exactly this. It works fine for one user. It falls apart completely at any real load.

Two things kill it:

1. Sequential execution

Each request blocks the GPU until it's done. If request A takes 3 seconds and request B arrives 1 second in, B waits 2 more seconds doing nothing, even though the GPU has capacity. At 64 concurrent users, you're essentially queuing everyone behind each other. HuggingFace stays flat at ~63 tok/sec regardless of how many requests arrive, because it can only process one at a time.

2. KV cache memory waste

During autoregressive generation, every transformer layer caches its key and value tensors for all previous tokens (the KV cache). This lets you avoid recomputing attention over the full context on every step. Instead you just compute attention for the new token against the cached history.

The problem: naive implementations pre-allocate KV cache for the maximum possible sequence length. If your max is 2048 tokens but the actual response is 150 tokens, you've wasted 93% of that allocation. With 10 concurrent requests, you're wasting 10x that. The GPU runs out of memory long before it runs out of compute.

Memory waste in practice: Naive allocation wastes ~60% of KV cache memory on padding and over-reservation. PagedAttention brings this to ~4%.

PagedAttention: OS Paging for KV Cache

The key insight in the vLLM paper is borrowed from operating systems: instead of contiguous pre-allocated memory per sequence, split the KV cache into fixed-size blocks (pages) and allocate them on demand.

Each sequence gets a block table, a mapping from logical block index to physical block index in the KV cache pool. Blocks are 16 tokens each. When a sequence needs more KV cache, the block manager allocates the next free physical block and adds it to that sequence's block table. When the sequence finishes, those blocks are returned to the free pool immediately.

class BlockManager:
    def __init__(self, num_layers, num_kv_heads, head_dim, num_blocks, device):
        # KV cache pool: [num_layers, 2, num_blocks, block_size, num_kv_heads, head_dim]
        self.kv_cache = torch.zeros(
            num_layers, 2, num_blocks, BLOCK_SIZE, num_kv_heads, head_dim,
            dtype=torch.float16, device=device
        )
        self.free_blocks = list(range(num_blocks))
        self.block_tables: dict[int, list[int]] = {}

    def allocate(self, seq_id: int) -> int:
        block = self.free_blocks.pop()
        self.block_tables[seq_id].append(block)
        return block

    def free(self, seq_id: int):
        self.free_blocks.extend(self.block_tables.pop(seq_id))

The waste now comes only from the last partial block per sequence, at most block_size - 1 unused slots. With block_size=16, worst case is 15 wasted slots per sequence, or roughly 4% waste across a typical workload. Compare that to 60% with naive pre-allocation.

The tradeoff: attention computation is now over non-contiguous memory. The physical blocks for a given sequence are scattered across the pool, so standard attention math no longer applies directly. This is where the custom CUDA kernel comes in.

The Custom CUDA Kernel

You can't just call torch.matmul over a sequence's KV cache anymore. Standard attention assumes contiguous memory, but with PagedAttention the blocks for a given sequence are scattered across the pool. You need a kernel that gathers KV from non-contiguous physical blocks via the block table, computes attention, and returns the result.

I wrote paged_attention_decode_batched in CUDA C++, bound to Python via pybind11. The grid layout is dim3(num_heads, num_seqs), one CUDA block per (head, sequence) pair. Each block:

__global__ void paged_attention_decode_batched(
    float* __restrict__ output,          // [num_seqs, num_heads, head_dim]
    const __half* __restrict__ q,        // [num_seqs, num_heads, head_dim]
    const __half* __restrict__ kv_cache, // [num_layers, 2, num_blocks, block_size, num_kv_heads, head_dim]
    const int* __restrict__ block_tables,// [num_seqs, max_blocks_per_seq]
    const int* __restrict__ seq_lens,    // [num_seqs]
    int layer_idx, int num_kv_heads, int head_dim, int block_size, int max_blocks
) {
    int head_idx = blockIdx.x;
    int seq_idx  = blockIdx.y;
    int kv_head  = head_idx / (gridDim.x / num_kv_heads); // GQA head mapping

    int seq_len = seq_lens[seq_idx];
    float scale = 1.0f / sqrtf((float)head_dim);

    // Numerically stable online softmax (log-sum-exp trick)
    float max_score = -1e9f, sum_exp = 0.0f;
    // ... gather K from non-contiguous blocks, compute QK scores
    // ... second pass: gather V, compute weighted sum
}

Key implementation details:

GQA support: TinyLlama uses Grouped Query Attention, 32 query heads but only 4 KV heads. The kernel maps query head index to KV head index via integer division, so multiple query heads share the same KV block table entries.

Numerically stable softmax: Computing softmax over the full context requires a two-pass approach. First pass finds the max score (for numerical stability), second pass computes the weighted value sum. Standard online softmax with the log-sum-exp trick.

Non-contiguous KV gather: For each token position, the kernel looks up the physical block index from the block table (block_tables[seq_idx][token / block_size]), then reads the KV tensor at the corresponding physical offset. This is the core of PagedAttention. Indirection through the block table at kernel level.

Correctness is verified against a PyTorch reference implementation. Max absolute difference in float16: <0.05 across all test cases.

Continuous Batching: The Scheduler

PagedAttention solves memory. Continuous batching solves throughput. Static batching (the HuggingFace approach) groups requests into a batch before generation starts and processes them together. The batch is fixed for the entire generation. This means if one sequence in a batch finishes early, its GPU slot sits idle until the longest sequence in the batch completes.

Continuous batching fixes this. The scheduler maintains a waiting queue and a running set. On every token step:

def step(self) -> dict[int, str]:
    # Prefill any new sequences that fit in memory
    while self.waiting and self._can_allocate(self.waiting[0]):
        seq = self.waiting.popleft()
        self._prefill(seq)
        self.running.append(seq)

    # Batch all running sequences into one forward pass
    batch = self._build_decode_batch(self.running)
    logits = self.model.decode_batch(batch)

    # Sample next token for each sequence
    outputs = {}
    finished = []
    for seq in self.running:
        token = self._sample(logits[seq.id], seq.config)
        seq.tokens.append(token)
        outputs[seq.id] = self.tokenizer.decode([token])
        if self._is_done(seq):
            finished.append(seq)

    # Free finished sequences immediately
    for seq in finished:
        self.running.remove(seq)
        self.block_manager.free(seq.id)

    return outputs

The critical property: when a sequence finishes mid-step, its blocks are freed and a waiting request can be prefilled in the next step. The GPU is never idle waiting for stragglers. At 64 concurrent sequences, all 64 are batched into a single forward pass on every token step. The GPU processes them in parallel.

This is why the throughput scales with concurrency while HuggingFace stays flat. At concurrency=64, we're doing 1340 tok/sec. HuggingFace is at 63 tok/sec, essentially the same as concurrency=1, because it's still processing one request at a time.

Benchmark Results

Benchmarked on Modal A10G (24GB VRAM), TinyLlama-1.1B, 150 tokens per request. The HuggingFace baseline runs requests sequentially, no continuous batching, no PagedAttention.

Concurrency This engine HuggingFace Speedup
4178.0 tok/sec57.3 tok/sec3.1x
16672.8 tok/sec49.3 tok/sec13.6x
321187.8 tok/sec64.7 tok/sec18.4x
641340.1 tok/sec63.1 tok/sec21.2x

Peak GPU memory: 2.96 to 3.02 GB across 1 to 64 concurrent sequences. PagedAttention allocates blocks per actual token length. Memory usage is flat regardless of concurrency because we're not pre-allocating for max sequence length per request.

HuggingFace throughput is effectively constant (~60 tok/sec) because it's sequential. Concurrency doesn't help it. This engine scales because all running sequences are batched into one GPU forward pass per token step.


The full source, block manager, scheduler, CUDA kernel, transformer, benchmarks, is on GitHub. The live demo runs on Modal at the link in the nav.