Skip to content

Draft: Proper chunked prefill bucketing #295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm_hpu_extension/bucketing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def generate_prompt_buckets(self):
max_num_prefill_seqs = self.max_num_prefill_seqs,
block_size = self.block_size,
max_num_batched_tokens = self.max_num_batched_tokens,
max_model_len = self.max_model_len)
max_model_len = self.max_model_len,
max_num_blocks = self.num_hpu_blocks)
self.log_generate_info(True)
else:
logger().info("Bucketing is off - skipping prompt buckets generation")
Expand Down
54 changes: 44 additions & 10 deletions vllm_hpu_extension/bucketing/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@ def check_for_user_flags(self, phase):


def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
max_num_batched_tokens, max_model_len):
max_num_batched_tokens, max_model_len, max_num_blocks):
self.check_for_user_flags('prompt')
use_merged_prefill = get_config().merged_prefill
use_merged_prefill = get_config().merged_prefill
prefix_caching = get_config().prefix_caching
max_prompt_seq = max_model_len
# NOTE(kzawora): v1 requires chunked prefill,
# and we assume it is not going to be supported in v0 hpu code
enable_chunked_prefill = get_config().engine_version == 'v1'
# NOTE(kzawora): Chunked prefill scenarios will never exceed upper boundary of max_num_batched_tokens, regardless of max_model_len
max_prompt_seq = max_model_len if not enable_chunked_prefill else max_num_batched_tokens

# cfgs shape: [min, step, max, limit]
prompt_bs_limit = math.ceil(math.log2(max_num_prefill_seqs)) + 1
Expand All @@ -54,8 +58,10 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
prompt_seq_bucket_cfg,
block_size,
prefix_caching,
enable_chunked_prefill,
max_num_batched_tokens,
max_model_len)
max_model_len,
max_num_blocks)

return sorted(prompt_buckets)

Expand Down Expand Up @@ -89,8 +95,10 @@ def generate_prompt_buckets(bs_bucket_config,
seq_bucket_config,
block_size,
prefix_caching,
enable_chunked_prefill,
max_num_batched_tokens=None,
max_model_len=None):
max_model_len=None,
max_num_blocks=None):
_, _, bmax, _ = seq_bucket_config
batch_size_buckets = warmup_range_with_limit(bs_bucket_config)
long_context = False
Expand All @@ -103,7 +111,7 @@ def generate_prompt_buckets(bs_bucket_config,
for bs in batch_size_buckets:
for b in seq_bucket_config:
buckets_3d.append((bs, b, 0))
max_blocks_range = (bmax - b) // block_size
max_blocks_range = (bmax - b) // block_size if not max_num_blocks else max_num_blocks
Copy link
Preview

Copilot AI Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The conditional logic is unclear. Consider using 'max_num_blocks if max_num_blocks is not None else (bmax - b) // block_size' to be more explicit about None checking.

Suggested change
max_blocks_range = (bmax - b) // block_size if not max_num_blocks else max_num_blocks
max_blocks_range = max_num_blocks if max_num_blocks is not None else (bmax - b) // block_size

Copilot uses AI. Check for mistakes.

if max_blocks_range == 0:
continue
else:
Expand Down Expand Up @@ -131,10 +139,36 @@ def generate_prompt_buckets(bs_bucket_config,
filtered_buckets = buckets
if max_num_batched_tokens is not None and max_model_len is not None:
# Remove buckets exceeding batch token budget
filtered_buckets = list(
filter(
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens \
and bucket[1] <= max_model_len, buckets))
if not enable_chunked_prefill:
filtered_buckets = list(
filter(
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens \
and bucket[1] <= max_model_len, buckets))
else:
def filter_fn(bucket):
# NOTE(kzawora): Chunked prefill scenarios will never exceed upper boundary of max_num_batched_tokens, regardless of max_model_len
_, seq, block = bucket
is_seq_in_bounds = seq <= max_num_batched_tokens
is_block_in_bounds = block <= max_num_blocks
# New logic: allow all buckets up to and including the first that exceeds max_model_len, then filter the rest
return is_seq_in_bounds and is_block_in_bounds
# Find the first bucket that exceeds max_model_len
# For each (bs, seq), keep all buckets that do not exceed model len, and the first that does
from collections import defaultdict
Copy link
Preview

Copilot AI Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import statements should be placed at the top of the file, not inside functions. Move this import to the module level.

Suggested change
from collections import defaultdict

Copilot uses AI. Check for mistakes.

first_exceed_seen = defaultdict(bool)
def keep_bucket(idx_bucket):
_, bucket = idx_bucket
bs, seq, block = bucket
exceeds = (seq + block * block_size) > max_model_len
key = (bs, seq)
if not exceeds:
return filter_fn(bucket)
elif not first_exceed_seen[key] and filter_fn(bucket):
first_exceed_seen[key] = True
return True
else:
return False
filtered_buckets = list(map(lambda x: x[1], filter(keep_bucket, enumerate(buckets))))
Copy link
Preview

Copilot AI Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This complex nested lambda expression reduces readability. Consider using a list comprehension or separating into multiple steps for better clarity.

Suggested change
filtered_buckets = list(map(lambda x: x[1], filter(keep_bucket, enumerate(buckets))))
filtered_buckets = [bucket for _, bucket in enumerate(buckets) if keep_bucket((_, bucket))]

Copilot uses AI. Check for mistakes.


if len(filtered_buckets) == 0:
# we can handle this if we ignore max_num_batched_tokens
Expand Down
48 changes: 41 additions & 7 deletions vllm_hpu_extension/bucketing/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@

class LinearBucketingStrategy:
def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
max_num_batched_tokens, max_model_len):
max_num_batched_tokens, max_model_len, max_num_blocks):
use_merged_prefill = get_config().merged_prefill
prefix_caching = get_config().prefix_caching
chunked_prefill = get_config().engine_version == 'v1'

max_prompt_seq = max_model_len

Expand Down Expand Up @@ -50,7 +51,10 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
prompt_seq_bucket_cfg,
block_size,
prefix_caching,
max_num_batched_tokens)
chunked_prefill,
max_num_batched_tokens,
max_model_len,
max_num_blocks)

return sorted(prompt_buckets)

Expand Down Expand Up @@ -129,7 +133,10 @@ def generate_prompt_buckets(bs_bucket_config,
seq_bucket_config,
block_size,
prefix_caching,
max_num_batched_tokens=None):
enable_chunked_prefill,
max_num_batched_tokens=None,
max_model_len=None,
max_num_blocks=None):
_, _, bmax = seq_bucket_config
batch_size_buckets = warmup_range(bs_bucket_config)
seq_bucket_config = warmup_range(seq_bucket_config)
Expand Down Expand Up @@ -157,10 +164,37 @@ def generate_prompt_buckets(bs_bucket_config,
filtered_buckets = buckets
if max_num_batched_tokens is not None:
# Remove buckets exceeding batch token budget
filtered_buckets = list(
filter(
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens,
buckets))
if not enable_chunked_prefill:
filtered_buckets = list(
filter(
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens,
buckets))
else:
def filter_fn(bucket):
# NOTE(kzawora): Chunked prefill scenarios will never exceed upper boundary of max_num_batched_tokens, regardless of max_model_len
_, seq, block = bucket
is_seq_in_bounds = seq <= max_num_batched_tokens
is_block_in_bounds = block <= max_num_blocks
# New logic: allow all buckets up to and including the first that exceeds max_model_len, then filter the rest
return is_seq_in_bounds and is_block_in_bounds
# Find the first bucket that exceeds max_model_len
# For each (bs, seq), keep all buckets that do not exceed model len, and the first that does
from collections import defaultdict
Copy link
Preview

Copilot AI Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import statements should be placed at the top of the file, not inside functions. Move this import to the module level.

Suggested change
from collections import defaultdict

Copilot uses AI. Check for mistakes.

first_exceed_seen = defaultdict(bool)
def keep_bucket(idx_bucket):
_, bucket = idx_bucket
bs, seq, block = bucket
exceeds = (seq + block * block_size) > max_model_len
key = (bs, seq)
if not exceeds:
return filter_fn(bucket)
elif not first_exceed_seen[key] and filter_fn(bucket):
first_exceed_seen[key] = True
return True
else:
return False
filtered_buckets = list(map(lambda x: x[1], filter(keep_bucket, enumerate(buckets))))
Copy link
Preview

Copilot AI Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This complex nested lambda expression reduces readability. Consider using a list comprehension or separating into multiple steps for better clarity.

Suggested change
filtered_buckets = list(map(lambda x: x[1], filter(keep_bucket, enumerate(buckets))))
filtered_buckets = [bucket for _, bucket in enumerate(buckets) if keep_bucket((_, bucket))]

Copilot uses AI. Check for mistakes.



if len(filtered_buckets) == 0:
# we can handle this if we ignore max_num_batched_tokens
Expand Down