Skip to content

Commit 22e39df

Browse files
authored
docs: add continuous batching page (#41847)
* docs: add continuous batching page * docs(cb): add `generate_batch` example * docs(cb): add `opentelemtry` and `serving` section * feat: add `TODO` note about opentelemetry dependency * docs(cb): add supported features * docs(cb): add unsupported features * docs(cb): add `ContinuousBatchingManager` example * docs(cb): x reference CB in optimizing inference
1 parent 63fbd50 commit 22e39df

File tree

5 files changed

+204
-0
lines changed

5 files changed

+204
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@
119119
title: Tools
120120
- local: transformers_as_backend
121121
title: Inference server backends
122+
- local: continuous_batching
123+
title: Continuous Batching
122124
title: Inference
123125
- isExpanded: false
124126
sections:
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# Continuous Batching
18+
19+
Continuous Batching (CB) is an advanced technique to optimize the inference of transformer models by dynamically grouping multiple requests into batches. This approach maximizes GPU utilization and throughput, specifically for workloads with many variable-length inputs.
20+
21+
We are particularly interested in having Continuous Batching in transformers for the following use cases:
22+
- Evaluation of models on large datasets with variable-length inputs
23+
- Generating outputs for multiple sequences for GRPO policies
24+
25+
CB is what makes inference engines like vLLM or SGLang efficient. That being said, transformers does not aim to be a production-ready inference engine, but a complete framework for model development. For this reason, CB is available in `transformers serve`.
26+
27+
If you are not familiar with some of the core concepts CB is built upon, we invite you to read the associated blog post: [Continuous Batching: Efficient Inference for Large Language Models](https://huggingface.co/blog/continuous-batching). _broken link for now_
28+
29+
## API Reference
30+
31+
## Usage Examples
32+
33+
The main way to use CB in transformers is via the `generate_batch` method.
34+
35+
Unlike `generate`, CB takes already tokenized inputs, known as input IDs. Each sequence of input IDs is represented as a list of integers, in python: `list[int]`. Since
36+
37+
For a more detailed example, please refer to: [examples/continuous_batching](./path/to/example)
38+
39+
### `generate_batch` example
40+
41+
We have created a `ContinuousMixin` that is inherited by the `GenerationMixin` so that all auto regressive text models support CB.
42+
43+
This adds the `generate_batch` method to all models that inherit from `GenerationMixin`.
44+
45+
You can use it as follows:
46+
47+
```py
48+
import datasets
49+
import torch
50+
51+
from transformers import AutoModelForCausalLM, AutoTokenizer
52+
from transformers.generation import GenerationConfig
53+
54+
model = AutoModelForCausalLM.from_pretrained(
55+
"Qwen/Qwen3-4B-Instruct-2507",
56+
attn_implementation="spda_paged",
57+
device_map="cuda", # if you need cuda
58+
dtype=torch.bfloat16,
59+
)
60+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
61+
62+
# prepare a batch of inputs
63+
dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
64+
dataset = dataset.select(range(args.samples))
65+
tokenized_datasets = dataset.map(lambda x: tokenizer(x["question"]), batched=True)
66+
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
67+
68+
generation_config = GenerationConfig(
69+
max_new_tokens=32,
70+
use_cuda_graph=False, # Not supported for simple version
71+
eos_token_id=tokenizer.eos_token_id,
72+
pad_token_id=tokenizer.pad_token_id,
73+
do_sample=False,
74+
max_batch_tokens=512, # max number of tokens in a batch, this is just a default value you should tune based on your hardware
75+
)
76+
77+
batch_outputs = model.generate_batch(
78+
inputs=simple_batch_inputs,
79+
generation_config=generation_config,
80+
)
81+
82+
for request_id, output in batch_outputs.items():
83+
generated_text = tokenizer.decode(output.generated_tokens, skip_special_tokens=True)
84+
print(f"Request {request_id} output: {generated_text}")
85+
```
86+
87+
### `ContinuousBatchingManager` example
88+
89+
If you want more control w.r.t. how you want to schedule requests using CB, you can use the `ContinuousBatchingManager` class directly.
90+
91+
This is what we use in `transformers serve` because requests arrive asynchronously and we can leverage the asynchronous nature of the CB process to make things more efficient.
92+
93+
Under the hood, the `ContinuousBatchingManager` creates a background thread that receives inputs from a python `queue.Queue` which it uses to get requests to batch in each forward pass.
94+
95+
Note that the manager is thread safe!
96+
97+
```py
98+
import datasets
99+
import torch
100+
101+
from transformers import AutoModelForCausalLM, AutoTokenizer
102+
from transformers.generation import GenerationConfig
103+
from transformers.generation.continuous_batching import RequestStatus
104+
105+
model = AutoModelForCausalLM.from_pretrained(
106+
"Qwen/Qwen3-4B-Instruct-2507",
107+
attn_implementation="spda_paged",
108+
device_map="cuda", # if you need cuda
109+
dtype=torch.bfloat16,
110+
)
111+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
112+
113+
# prepare a batch of inputs
114+
dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
115+
dataset = dataset.select(range(args.samples))
116+
tokenized_datasets = dataset.map(lambda x: tokenizer(x["question"]), batched=True)
117+
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
118+
119+
# initialize the manager, available method thanks to the `ContinuousMixin`
120+
manager = model.init_continuous_batching(generation_config=generation_config)
121+
122+
# start the background thread
123+
manager.start()
124+
125+
# this is for demonstration purposes only, in practice this is most useful to do concurrently
126+
for i, input in enumerate(simple_batch_inputs):
127+
request_id = manager.add_request(input_ids=input, request_id=f"request_{i}") # if you do not specify a request_id, one will be generated for you
128+
129+
# Can be done in an other thread
130+
for id, request in manager.get_result():
131+
generated_text = tokenizer.decode(request.generated_tokens, skip_special_tokens=True)
132+
print(f"Request {id} output: {generated_text}")
133+
134+
# you can also get results for a specific request id
135+
result = manager.get_result(request_id="request_5") # this is blocking and will wait for the result to be ready
136+
137+
# or get results for a request that is streaming
138+
manager.add_request(
139+
input_ids=input,
140+
request_id="streaming_request",
141+
stream=True,
142+
)
143+
for chunk in manager.request_id_iter(request_id="streaming_request"):
144+
generated_text = tokenizer.decode(chunk.generated_tokens, skip_special_tokens=True)
145+
print(generated_text)
146+
# FIXME: stop iteration in `request_id_iter` when finished instead of doing it externally
147+
if chunk.status == RequestStatus.FINISHED:
148+
break
149+
150+
# stop the background thread before exiting the process
151+
manager.stop()
152+
```
153+
154+
## Supported & Unsupported Features
155+
156+
### Supported Features
157+
158+
- Dynamic scheduling of variable-length requests
159+
- Chunked prefill
160+
- Paged Attention Cache
161+
- Sliding window attention
162+
- Chat templates
163+
164+
### Unsupported Features
165+
166+
At the moment, the following features are not supported with CB. We plan to add support to the following:
167+
168+
- Prefix caching
169+
- Beam search
170+
- tool calling
171+
172+
The others are unplanned, but depending on community requests we might consider adding them:
173+
174+
- MTP (multi token prediction)
175+
- Medusa
176+
177+
## Performance Considerations
178+
179+
180+
## Integration with Serving
181+
182+
You can use CB in `transformers serve` by passing the `--continuous-batching` flag when starting the server.
183+
184+
## Monitoring
185+
186+
We have added `opentelemetry` support to Continuous Batching to help you monitor its performance in production. To enable it, you need to install the `opentelemetry` extra when installing `transformers`:
187+
188+
```sh
189+
# this installs `opentelemetry-api`, `opentelemetry-sdk` and `opentelemetry-exporter-otlp`
190+
pip install transformers[open-telemetry]
191+
```
192+
193+
This will enable traces and metrics collection in CB. You will then have to setup the backend to collect and visualize the traces and metrics.
194+

docs/source/en/llm_optims.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,3 +393,9 @@ model = AutoModelForCausalLM.from_pretrained(
393393
"mistralai/Mistral-7B-v0.1", quantization_config=quant_config, device_map="auto"
394394
)
395395
```
396+
397+
## Continuous Batching
398+
399+
When serving LLMs for inference, you may have multiple requests arriving at different times. Continuous Batching (CB) is a technique that groups incoming requests into batches to maximize GPU utilization and throughput.
400+
401+
See the [Continuous Batching](./continuous_batching) guide for more details on how to use CB in transformers.

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ def run(self):
392392
extras["benchmark"] = deps_list("optimum-benchmark")
393393

394394
# OpenTelemetry dependencies for metrics collection in continuous batching
395+
# TODO: refactor this to split API and SDK; SDK and exporter should only be needed to run code that collects metrics whereas API is what people will need to instrument their code and handle exporter themselves
395396
extras["open-telemetry"] = deps_list("opentelemetry-api") + ["opentelemetry-exporter-otlp", "opentelemetry-sdk"]
396397

397398
# when modifying the following list, make sure to update src/transformers/dependency_versions_check.py

src/transformers/generation/continuous_batching/continuous_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,7 @@ def __iter__(self):
919919
if result is not None:
920920
yield result
921921

922+
# FIXME: stop iteration when request status is finished?
922923
def request_id_iter(self, request_id: str) -> Generator[GenerationOutput]:
923924
"""Iterate over results matching a specific request id as they become available."""
924925
request_cancelled = False

0 commit comments

Comments
 (0)