Automatic Differentiation #6913
IaroslavElistratov
started this conversation in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
One kernel-authoring related pain point I noticed is it can be challenging to write backward kernels.
Implemented a PoC that auto-diffs Triton-IR, then wraps a pair of forward/backward IRs into torch.autograd.Function. Tested on Flash-Attention-v2 (removes 300 lines of user code), Layer-Norm (removes 120 lines of user code), other triton tutorials, and 10 other simpler kernels. Validated numerical correctness against pytorch implementations.
Repo: https://github.com/IaroslavElistratov/triton-autodiff
Would love to hear your thoughts. Please note these are very early results.
Beta Was this translation helpful? Give feedback.
All reactions