-
Notifications
You must be signed in to change notification settings - Fork 36
Draft: Proper chunked prefill bucketing #295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -28,11 +28,15 @@ def check_for_user_flags(self, phase): | |||||
|
||||||
|
||||||
def get_prompt_buckets(self, max_num_prefill_seqs, block_size, | ||||||
max_num_batched_tokens, max_model_len): | ||||||
max_num_batched_tokens, max_model_len, max_num_blocks): | ||||||
self.check_for_user_flags('prompt') | ||||||
use_merged_prefill = get_config().merged_prefill | ||||||
use_merged_prefill = get_config().merged_prefill | ||||||
prefix_caching = get_config().prefix_caching | ||||||
max_prompt_seq = max_model_len | ||||||
# NOTE(kzawora): v1 requires chunked prefill, | ||||||
# and we assume it is not going to be supported in v0 hpu code | ||||||
enable_chunked_prefill = get_config().engine_version == 'v1' | ||||||
# NOTE(kzawora): Chunked prefill scenarios will never exceed upper boundary of max_num_batched_tokens, regardless of max_model_len | ||||||
max_prompt_seq = max_model_len if not enable_chunked_prefill else max_num_batched_tokens | ||||||
|
||||||
# cfgs shape: [min, step, max, limit] | ||||||
prompt_bs_limit = math.ceil(math.log2(max_num_prefill_seqs)) + 1 | ||||||
|
@@ -54,8 +58,10 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size, | |||||
prompt_seq_bucket_cfg, | ||||||
block_size, | ||||||
prefix_caching, | ||||||
enable_chunked_prefill, | ||||||
max_num_batched_tokens, | ||||||
max_model_len) | ||||||
max_model_len, | ||||||
max_num_blocks) | ||||||
|
||||||
return sorted(prompt_buckets) | ||||||
|
||||||
|
@@ -89,8 +95,10 @@ def generate_prompt_buckets(bs_bucket_config, | |||||
seq_bucket_config, | ||||||
block_size, | ||||||
prefix_caching, | ||||||
enable_chunked_prefill, | ||||||
max_num_batched_tokens=None, | ||||||
max_model_len=None): | ||||||
max_model_len=None, | ||||||
max_num_blocks=None): | ||||||
_, _, bmax, _ = seq_bucket_config | ||||||
batch_size_buckets = warmup_range_with_limit(bs_bucket_config) | ||||||
long_context = False | ||||||
|
@@ -103,7 +111,7 @@ def generate_prompt_buckets(bs_bucket_config, | |||||
for bs in batch_size_buckets: | ||||||
for b in seq_bucket_config: | ||||||
buckets_3d.append((bs, b, 0)) | ||||||
max_blocks_range = (bmax - b) // block_size | ||||||
max_blocks_range = (bmax - b) // block_size if not max_num_blocks else max_num_blocks | ||||||
if max_blocks_range == 0: | ||||||
continue | ||||||
else: | ||||||
|
@@ -131,10 +139,36 @@ def generate_prompt_buckets(bs_bucket_config, | |||||
filtered_buckets = buckets | ||||||
if max_num_batched_tokens is not None and max_model_len is not None: | ||||||
# Remove buckets exceeding batch token budget | ||||||
filtered_buckets = list( | ||||||
filter( | ||||||
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens \ | ||||||
and bucket[1] <= max_model_len, buckets)) | ||||||
if not enable_chunked_prefill: | ||||||
filtered_buckets = list( | ||||||
filter( | ||||||
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens \ | ||||||
and bucket[1] <= max_model_len, buckets)) | ||||||
else: | ||||||
def filter_fn(bucket): | ||||||
# NOTE(kzawora): Chunked prefill scenarios will never exceed upper boundary of max_num_batched_tokens, regardless of max_model_len | ||||||
_, seq, block = bucket | ||||||
is_seq_in_bounds = seq <= max_num_batched_tokens | ||||||
is_block_in_bounds = block <= max_num_blocks | ||||||
# New logic: allow all buckets up to and including the first that exceeds max_model_len, then filter the rest | ||||||
return is_seq_in_bounds and is_block_in_bounds | ||||||
# Find the first bucket that exceeds max_model_len | ||||||
# For each (bs, seq), keep all buckets that do not exceed model len, and the first that does | ||||||
from collections import defaultdict | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Import statements should be placed at the top of the file, not inside functions. Move this import to the module level.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||
first_exceed_seen = defaultdict(bool) | ||||||
def keep_bucket(idx_bucket): | ||||||
_, bucket = idx_bucket | ||||||
bs, seq, block = bucket | ||||||
exceeds = (seq + block * block_size) > max_model_len | ||||||
key = (bs, seq) | ||||||
if not exceeds: | ||||||
return filter_fn(bucket) | ||||||
elif not first_exceed_seen[key] and filter_fn(bucket): | ||||||
first_exceed_seen[key] = True | ||||||
return True | ||||||
else: | ||||||
return False | ||||||
filtered_buckets = list(map(lambda x: x[1], filter(keep_bucket, enumerate(buckets)))) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] This complex nested lambda expression reduces readability. Consider using a list comprehension or separating into multiple steps for better clarity.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||
|
||||||
if len(filtered_buckets) == 0: | ||||||
# we can handle this if we ignore max_num_batched_tokens | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -11,9 +11,10 @@ | |||||
|
||||||
class LinearBucketingStrategy: | ||||||
def get_prompt_buckets(self, max_num_prefill_seqs, block_size, | ||||||
max_num_batched_tokens, max_model_len): | ||||||
max_num_batched_tokens, max_model_len, max_num_blocks): | ||||||
use_merged_prefill = get_config().merged_prefill | ||||||
prefix_caching = get_config().prefix_caching | ||||||
chunked_prefill = get_config().engine_version == 'v1' | ||||||
|
||||||
max_prompt_seq = max_model_len | ||||||
|
||||||
|
@@ -50,7 +51,10 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size, | |||||
prompt_seq_bucket_cfg, | ||||||
block_size, | ||||||
prefix_caching, | ||||||
max_num_batched_tokens) | ||||||
chunked_prefill, | ||||||
max_num_batched_tokens, | ||||||
max_model_len, | ||||||
max_num_blocks) | ||||||
|
||||||
return sorted(prompt_buckets) | ||||||
|
||||||
|
@@ -129,7 +133,10 @@ def generate_prompt_buckets(bs_bucket_config, | |||||
seq_bucket_config, | ||||||
block_size, | ||||||
prefix_caching, | ||||||
max_num_batched_tokens=None): | ||||||
enable_chunked_prefill, | ||||||
max_num_batched_tokens=None, | ||||||
max_model_len=None, | ||||||
max_num_blocks=None): | ||||||
_, _, bmax = seq_bucket_config | ||||||
batch_size_buckets = warmup_range(bs_bucket_config) | ||||||
seq_bucket_config = warmup_range(seq_bucket_config) | ||||||
|
@@ -157,10 +164,37 @@ def generate_prompt_buckets(bs_bucket_config, | |||||
filtered_buckets = buckets | ||||||
if max_num_batched_tokens is not None: | ||||||
# Remove buckets exceeding batch token budget | ||||||
filtered_buckets = list( | ||||||
filter( | ||||||
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens, | ||||||
buckets)) | ||||||
if not enable_chunked_prefill: | ||||||
filtered_buckets = list( | ||||||
filter( | ||||||
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens, | ||||||
buckets)) | ||||||
else: | ||||||
def filter_fn(bucket): | ||||||
# NOTE(kzawora): Chunked prefill scenarios will never exceed upper boundary of max_num_batched_tokens, regardless of max_model_len | ||||||
_, seq, block = bucket | ||||||
is_seq_in_bounds = seq <= max_num_batched_tokens | ||||||
is_block_in_bounds = block <= max_num_blocks | ||||||
# New logic: allow all buckets up to and including the first that exceeds max_model_len, then filter the rest | ||||||
return is_seq_in_bounds and is_block_in_bounds | ||||||
# Find the first bucket that exceeds max_model_len | ||||||
# For each (bs, seq), keep all buckets that do not exceed model len, and the first that does | ||||||
from collections import defaultdict | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Import statements should be placed at the top of the file, not inside functions. Move this import to the module level.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||
first_exceed_seen = defaultdict(bool) | ||||||
def keep_bucket(idx_bucket): | ||||||
_, bucket = idx_bucket | ||||||
bs, seq, block = bucket | ||||||
exceeds = (seq + block * block_size) > max_model_len | ||||||
key = (bs, seq) | ||||||
if not exceeds: | ||||||
return filter_fn(bucket) | ||||||
elif not first_exceed_seen[key] and filter_fn(bucket): | ||||||
first_exceed_seen[key] = True | ||||||
return True | ||||||
else: | ||||||
return False | ||||||
filtered_buckets = list(map(lambda x: x[1], filter(keep_bucket, enumerate(buckets)))) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] This complex nested lambda expression reduces readability. Consider using a list comprehension or separating into multiple steps for better clarity.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||
|
||||||
|
||||||
if len(filtered_buckets) == 0: | ||||||
# we can handle this if we ignore max_num_batched_tokens | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The conditional logic is unclear. Consider using 'max_num_blocks if max_num_blocks is not None else (bmax - b) // block_size' to be more explicit about None checking.
Copilot uses AI. Check for mistakes.