-
Notifications
You must be signed in to change notification settings - Fork 250
Add missing Block size + Update Configs to not hardcode rope_scaling #1128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1128
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7cdb226 with merge base 5986ed2 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchchat/model.py
Outdated
old_context_len = 8192 # original llama3 length | ||
def apply_scaling(freqs: torch.Tensor, rope_scaling: Dict[str, Any]): | ||
# Check for the presence of the required keys | ||
assert set(rope_scaling.keys()) >= {"factor", "low_freq_factor", "high_freq_factor", "original_max_position_embeddings"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could be clearer with issubset and a raiseValueError for more informative error vs assert?
required_keys = {"factor", "low_freq_factor", "high_freq_factor", "original_max_position_embeddings"}
if not required_keys.issubset(rope_scaling.keys()):
raise ValueError(f"Missing required keys in apply_scaling. Expected: {required_keys}")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
1 - I know I wasn't asked to review but since we did a lot of work getting rope embeddings happy for distributed wanted to check this.
2 - verified no issue with distributed run
3 - left minor option for using issubset to potentially read a bit more clearly.
Thanks for the review, always welcome!! Good idea with the Error raise |
Previously the block_size was not being included in the model_params.json of llama3/3.1 models so it just used a hard coded 2048. This PR adds them into the config.
For rope_scaling (llama 3.1), we were hardcoding the parameters inside of
apply_scaling()
instead of taking them in as TransformerArgs built from model_params. These happen to line up for 3.1, but this PR makes them an explicit read