-
Notifications
You must be signed in to change notification settings - Fork 36
Draft: Proper chunked prefill bucketing #295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements proper chunked prefill bucketing by modifying bucketing strategies to handle the new chunked prefill mode in engine version v1. The changes add support for chunked prefill scenarios with different filtering logic and bucket generation parameters.
- Adds
max_num_blocks
parameter to bucket generation methods - Implements chunked prefill detection based on engine version
- Introduces specialized filtering logic for chunked prefill scenarios
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
File | Description |
---|---|
vllm_hpu_extension/bucketing/linear.py | Updates LinearBucketingStrategy to support chunked prefill with new filtering logic and parameter handling |
vllm_hpu_extension/bucketing/exponential.py | Updates ExponentialBucketingStrategy with similar chunked prefill support and bucket range calculations |
vllm_hpu_extension/bucketing/common.py | Passes the new max_num_blocks parameter to bucket generation methods |
Comments suppressed due to low confidence (1)
vllm_hpu_extension/bucketing/linear.py:136
- The parameter name is inconsistent between function signature and usage. The function parameter is named 'enable_chunked_prefill' but the passed variable is 'chunked_prefill'.
enable_chunked_prefill,
return is_seq_in_bounds and is_block_in_bounds | ||
# Find the first bucket that exceeds max_model_len | ||
# For each (bs, seq), keep all buckets that do not exceed model len, and the first that does | ||
from collections import defaultdict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import statements should be placed at the top of the file, not inside functions. Move this import to the module level.
from collections import defaultdict |
Copilot uses AI. Check for mistakes.
return is_seq_in_bounds and is_block_in_bounds | ||
# Find the first bucket that exceeds max_model_len | ||
# For each (bs, seq), keep all buckets that do not exceed model len, and the first that does | ||
from collections import defaultdict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import statements should be placed at the top of the file, not inside functions. Move this import to the module level.
from collections import defaultdict |
Copilot uses AI. Check for mistakes.
@@ -103,7 +111,7 @@ def generate_prompt_buckets(bs_bucket_config, | |||
for bs in batch_size_buckets: | |||
for b in seq_bucket_config: | |||
buckets_3d.append((bs, b, 0)) | |||
max_blocks_range = (bmax - b) // block_size | |||
max_blocks_range = (bmax - b) // block_size if not max_num_blocks else max_num_blocks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The conditional logic is unclear. Consider using 'max_num_blocks if max_num_blocks is not None else (bmax - b) // block_size' to be more explicit about None checking.
max_blocks_range = (bmax - b) // block_size if not max_num_blocks else max_num_blocks | |
max_blocks_range = max_num_blocks if max_num_blocks is not None else (bmax - b) // block_size |
Copilot uses AI. Check for mistakes.
return True | ||
else: | ||
return False | ||
filtered_buckets = list(map(lambda x: x[1], filter(keep_bucket, enumerate(buckets)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] This complex nested lambda expression reduces readability. Consider using a list comprehension or separating into multiple steps for better clarity.
filtered_buckets = list(map(lambda x: x[1], filter(keep_bucket, enumerate(buckets)))) | |
filtered_buckets = [bucket for _, bucket in enumerate(buckets) if keep_bucket((_, bucket))] |
Copilot uses AI. Check for mistakes.
return True | ||
else: | ||
return False | ||
filtered_buckets = list(map(lambda x: x[1], filter(keep_bucket, enumerate(buckets)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] This complex nested lambda expression reduces readability. Consider using a list comprehension or separating into multiple steps for better clarity.
filtered_buckets = list(map(lambda x: x[1], filter(keep_bucket, enumerate(buckets)))) | |
filtered_buckets = [bucket for _, bucket in enumerate(buckets) if keep_bucket((_, bucket))] |
Copilot uses AI. Check for mistakes.
No description provided.