-
Notifications
You must be signed in to change notification settings - Fork 259
Enable FP6-LLM kernel build on Windows #305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/305
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b7d8ba1 with merge base 8a4e693 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @matthewdouglas! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
return decorator | ||
|
||
|
||
def benchmark_model(model, num_runs, input_tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these look like linting changes? can't quite see the difference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe change in end of line symbols? iirc, Windows use different symbols for end of line.
u_int32_t *Frag1_PTR = read_RPTR_Frag1; | ||
u_int32_t *Frag2_PTR = read_RPTR_Frag2; | ||
__device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t Reg[][4], | ||
uint32_t * __restrict__ read_RPTR_Frag1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TIL about uint32_t
vs u_int32_t
lol
"-O3" if not debug_mode else "-O0", | ||
] | ||
} | ||
if not IS_WINDOWS: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah feels like we should be testing this in CI, shouldn't be too hard to use windows machine for cpu but I'm not sure how abundant cuda enabled windows machines are in the github org
int slice_id) { | ||
__device__ __forceinline__ void B_FromSharedToReg(uint32_t Reg[][4], | ||
half (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], | ||
int slice_id) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing some __restrict__
here.
For Reg[][4]
, I don't know if we can add __restrict__
directly. Otherwise, maybe we need to change it to a pointer (so we can add back __restrict__
). From what I know, Reg[][4]
is still passed as pointer, but it allows us to do 2d-indexing (last dim is compile-time constant, so it translates to 4 * first_index + second_index
).
u_int32_t *OutputRegs = reinterpret_cast<u_int32_t*> (Reg); | ||
u_int32_t *Frag1_PTR = read_RPTR_Frag1; | ||
u_int32_t *Frag2_PTR = read_RPTR_Frag2; | ||
__device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t Reg[][4], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing __restrict__
here
Done via #396 |
Summary: It handles the case hhen params.json is explicitly given and the json doesn't contain mention about tokenizer config, but the command line does. Test Plan: python torchchat.py generate --device cpu --checkpoint-path /Users/mnachin/models/Meta-Llama-3-8B/original/consolidated.00.pth --params-path=/Users/mnachin/models/Meta-Llama-3-8B/original/params.json --temperature 0 --tiktoken
This PR includes a small set of changes to enable building the FP6-LLM kernels and the torch extension in general under Windows natively. Tested with MSVC 19.39 (VS2022 17.9) and NVCC 12.4.
I have not yet validated these changes with GCC, so keeping this in draft mode for now.