-
Notifications
You must be signed in to change notification settings - Fork 110
Precision Support + TileLang Integration #80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ghts (during forward pass) use same precision for both Co-authored-by: Simon Guo <[email protected]> Co-authored-by: Sahan Paliskara <[email protected]>
|
For the tolerances we took inspiration from torchbench. You can look at all of the references / tolerance decisions here The specific tolerances we used are inspired from here + backendbench where we use 1e-02 for everything at fp16 |
nathanjpaek
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, added tilelang
|
Thanks to @nathanjpaek for the comprehensive testing! We checked after specifying the precision argument, inside the eval function, and it uses the target precision. We also checked the timing of varying precision, and tested again the main branch version without explicit precision specification (which is With |
* initial implementation for various precision support on input and weights (during forward pass) use same precision for both Co-authored-by: Simon Guo <[email protected]> Co-authored-by: Sahan Paliskara <[email protected]> * add tilelang * update requirements for tilelang * add precision to other files * tested and updated readme --------- Co-authored-by: Sahan Paliskara <[email protected]> Co-authored-by: Nathan Paek <[email protected]> Co-authored-by: nathanjp <[email protected]>
KernelBench by default uses the PyTorch Tensor precisions which is
fp32.All results reported on KernelBench so far are
fp32.However, as more inference and training techniques go towards lower precision, it is important we support a variety of precisions to understand performance comprehensively. We address the issue in raised in #79.
Specifically
fp32,fp16,bf16) to cast inputs and weights into the target precisionsfp16andbf16We will also add this info in the model generation prompt in another PR.
Now for KernelBench run you can specify the desired precision as an argument.