WangYu::Space

Study, think, create, and grow. Teach yourself and teach others.

How Continuous Batching Works

分类:机器学习标签: LLM创建时间:2025-10-15 21:10:00

Batch inference in LLMs

When deep learning models run inference, they often process a batch of input simultaneously to maximize hardware utilization and throughput. For example, in classification tasks, the input might be a vector, and if only one input is processed at a time, the computation would be many vector-matrix multiplications. For large models, performing a single vector-matrix multiplication need load the entire matrix into memory, which can be inefficient. If the model processes a batch of inputs, it can perform matrix-matrix multiplications, which can be more efficient.

In LLMs, the input data often consists of a batch of sequences, each sequence has different lengths. In order to process multiple sequences in parallel, we have to use padding and masking techniques.

When processing sequences of varying lengths, traditional batching methods require all inputs in a batch to have the same length. This is typically achieved by padding shorter sequences with special tokens to match the length of the longest sequence in the batch.

When conputing attention scores, the model uses a mask to ignore the padded tokens, ensuring that they do not affect the attention mechanism.

However, padding can lead to inefficiencies, especially when there is a significant variance in sequence lengths within a batch. To address this, large language models (LLMs) often employ a technique known as continuous batching.

What is Continuous Batching?

Continuous batching allows the model to group sequences of varying lengths into a single batch without excessive padding. The idea is to concatenate multiple sequences into a single long sequence, while keeping track of the start and end positions of each sequence within the batch.

Here’s a simplified illustration of continuous batching:

In Transformer-based LLMs models, they are constructed using Embedding layers, Multi-head Self-Attention layers, Feed-Forward layers, etc. Embedding layers and Feed-Forward layers can naturally handle continuous batching since they operate on each token independently. When it comes to the Multi-head Self-Attention layers, special care must be taken to ensure that tokens from different sequences do not attend to each other. Because we have keep track of the start and end positions of each sequence, we can only do attention within the boundaries of each sequence.

After processing the whole sequence through the model, we can extract the outputs corresponding to each original sequence based on their end positions.

If a sequence has generated all its tokens (i.e., an end-of-sequence token is produced or it reaches its maximum length), we can we simply remove it from the next batch to be processed.

I think this is the key idea behind continuous batching in LLMs. It allows the model to efficiently process multiple sequences of varying lengths in parallel, maximizing hardware utilization.

Other posts on the Internet may tell you that continuous batching is about filling the batch with new sequences as soon as some sequences finish generating tokens. But LLMs generate tokens one by one, so if a sequence has finished generating tokens, we can simply remove it from the batch for the next token generation step.

And someone may even say that traditional batching is static batching, it runs a fixed batch size throughout the inference process, until all sequences are finished, even if some sequences have already finished generating tokens. I think no one would implement such a naive static batching in production systems.

How to Implement Continuous Batching

To implement continuous batching, we need to manage the sequences and their positions carefully. We can maintain a list of the lengths of the active sequences, and in attention computation, we can split the overall sequence into segments corresponding to each original sequence, and do attention like the traditional way.

When you read the following code snippets, please refer to the full implementation in this notebook.

Attention Computation with Continuous Batching

We can maintain a list of cumulative lengths to keep track of the start and end positions of each sequence in the concatenated batch.

cu_lens = [0]
for seq in sequences:
    cu_lens.append(cu_lens[-1] + len(seq))

During attention computation, we iterate over each sequence segment, compute attention separately, and then concatenate the outputs.

o = []
for i in range(cu_lens.shape[0] - 1):
    start = cu_lens[i].item()
    end = cu_lens[i + 1].item()
    seqlen_i = end - start

    q_i = q[:,start:end,:]
    k_i = k[:,start:end,:]
    v_i = v[:,start:end,:]

    scores = q_i @ k_i.transpose(-2, -1)

    mask = torch.triu(torch.ones(seqlen_i, seqlen_i, device=x.device, dtype=torch.bool), diagonal=1)
    scores = scores.masked_fill(mask, -torch.inf)
    scores = scores * self.scale
    weights = F.softmax(scores, dim=-1)

    # [q_head, seqlen, head_dim]
    o_i = weights @ v_i

    # [seqlen, q_head, head_dim]
    o_i = o_i.transpose(0, 1)
    
    o.append(o_i.flatten(1))

o = torch.concat(o, dim=0)

Rotary Position Embedding with Continuous Batching

When applying rotary position embeddings, we also need to consider the positions of each token within its original sequence. We can maintain a positions tensor that keeps track of the position of each token in the concatenated batch.

positions = []

for seq in sequences:
    seq_len = len(seq)
    positions.append(torch.arange(seq_len))
q = self.rotary_embedding(q, positions)
k = self.rotary_embedding(k, positions)

We can use the positions tensor as indices to get the correct positional embeddings for each token in the concatenated batch.

cos = self.cos_cached[positions]
sin = self.sin_cached[positions]

Generating Next Tokens

In the output layer, we can extract the logits corresponding to the last token of each sequence based on their end positions.

class Qwen3ForCausalLM(nn.Module):
    def __init__(self, config: Qwen3Config):
        super().__init__()
        
        self.model = Qwen3Model(config)
        
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

    def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, cu_lens: torch.Tensor) -> torch.Tensor:

        # [seq_len, hidden_size]
        x = self.model(input_ids, positions, cu_lens)

        # [seq_len, vocab_size]
        # extract the last token of each sequence
        x = self.lm_head(x[cu_lens[1:]-1, :])

        return x

Batching Management

With continuous batching, we can now run inference on multiple sequences of varying lengths efficiently. We can maintain a list of active sequences, and after generating the next token for each sequence, we can update the list by removing sequences that have finished generating tokens.

Here is a simplified example of how to manage the batching process:

First, we define a Request class to hold the tokens of each sequence.

class Request:
    """
    A request contains the tokens of a sequence to be processed.
    """
    def __init__(self, tokens):
        self.tokens = tokens

Each time we get a batch of requests, we concatenate their tokens, and maintain their positions and cumulative lengths. Then run a step of generation.

def generate_one_step(model, requests: list[Request]):
    """
    Generate one token for each request in the batch.
    """

    tokens = []
    positions = []
    cu_lens = [0]

    for req in requests:
        tokens.extend(req.tokens)
        positions.extend(range(len(req.tokens)))
        cu_lens.append(cu_lens[-1] + len(req.tokens))

    tokens = torch.tensor(tokens, dtype=torch.long, device=device)
    positions = torch.tensor(positions, dtype=torch.long, device=device)
    cu_lens = torch.tensor(cu_lens, dtype=torch.long, device=device)
    
    # [len(reqs), vocob_size]
    logits = model(tokens, positions, cu_lens)
    next_tokens = torch.argmax(logits, dim=-1)

    return next_tokens.tolist()

In the main generation loop, we keep generating tokens for the batch of requests, and append the new tokens to each request. If a request generates an end-of-sequence token, we remove it from the batch.

def generate(model: Qwen3ForCausalLM, tokenizer, prompts: list[str], enable_think=True, max_new_tokens=64):
    """
    Generate tokens for a batch of prompts using continuous batching.
    If one request is finished, remove it from the batch.
    """

    requests = []
    for prompt in prompts:
        prompt = apply_chat_template(prompt, enable_think)
        tokens = qwen3_tokenizer.encode(prompt).ids
        req = Request(tokens)
        requests.append(req)
    eos_token = tokenizer.encode("<|im_end|>").ids[0]

    new_tokens = 0;
    while len(requests) and new_tokens < max_new_tokens:
        new_tokens += 1
        tokens = generate_one_step(model, requests)
        for req, token in zip(requests, tokens):
            req.tokens.append(token)
        
        # remove finished requests
        requests = [req for req in requests if req.tokens[-1] != eos_token]
        
    return requests

Conclusion

In this post, I explained how continuous batching works in large language models (LLMs). I implemented a simple version of continuous batching in this notebook. I hope this explanation and the code snippets can help you understand the concept of continuous batching in LLMs.

评论 (评论内容仅博主可见,不会公开显示)