-
Notifications
You must be signed in to change notification settings - Fork 110
Add Triton Backend #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@simonguozirui Any update on when this will be merged to main? |
|
@PaliC is this still up to date, or are there big changes in KernelBench since this PR was drafted? |
|
@AffectionateCurry and I are back working to merge this. For jit-compile language it is quite easy to do so, for frameworks that require building and linking that is much more complicated. |
|
@simonguozirui and @AffectionateCurry - amazing work. I am currently actively looking into KernelBench + extension myself. Let me know if you guys are open to collaboration. |
simonguozirui
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @AffectionateCurry @nathanjpaek
Look through my comments and see what you all think. I like many of the abstractions you all put in (along with the changes @PaliC did).
Can we ensure the new changes don't break existing CUDA pipeline. Test CUDA / Triton / CuTE on both local L40S Lab machine and modal cloud execution
| anthropic | ||
| modal | ||
| numpy | ||
| openai |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we remove requirements.txt? we should keep it? we can think about using uv later but not get rid of it here.
@nathanjpaek don't we also need to add tilelang here
|
|
||
| with open(eval_file_path, "w") as f: | ||
| json.dump(eval_results, f) | ||
| json.dump(eval_results, f, indent=4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great check
| elif config.dataset_src == "local": | ||
| problem_idx_in_dataset = config.problem_id - 1 # due to dataset list being 0-indexed locally | ||
| problem_idx_in_dataset = ( | ||
| config.problem_id - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pythonomar22 this is something we will get rid of with your new benchmark data class so we dont' have to deal with all these nasty off-by-one issue
|
|
||
| # Use appropriate prompt constructor based on backend | ||
| if config.backend == "cuda": | ||
| custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AffectionateCurry i see what you mean here now. we can refactor this later with a better prompt template!
| custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend) | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch here, we shall update read me in the pre-GPU mode hackathon to list these are available options
| deleted manually be the caller. | ||
| This is a hack that is needed for triton code as compile / exec do not play well | ||
| with the @triton.jit decorator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we did this for some of the multi-turn kernelbench experiments too. so we might need to support this for cuda code path as well.
Right now @AffectionateCurry you should state this is only used invoked upon for alternative backends
| return ModelNew, temp_file | ||
|
|
||
|
|
||
| # def load_tilelang_model( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can do this in a later PR
| # Create a new module based on that spec | ||
| temp_module = importlib.util.module_from_spec(spec) | ||
| # Execute the code in the module's namespace | ||
| spec.loader.exec_module(temp_module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how safe is this haha, we shall really understand [in case any reward hacking]
| import os | ||
| from .utils import read_file | ||
|
|
||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great first step, we will replace this with something a bit more modular later on!
| return tensor.to(device=device) | ||
|
|
||
| # Apply backend-specific dtype casting for float tensors | ||
| # if backend.lower() == "tilelang": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did we write this function just for tile lang?
In general i actually quite like this abstraction we can do some checks etc
|
@nathanjpaek thanks for adding an one-shot example for @AffectionateCurry have checked the current state of PR work for We will make prompt construction and eval logic more clean across backends in future PRs. Great job @nathanjpaek @AffectionateCurry for your first PR! Thank you to @PaliC @msaroufim @Zacharias030 and the PyTorch team for your help! |
…ce#35) * triton_backend_v2 * fix eval bugs * fix issues * revert eval * remove traceback * remove cot * improve eval * looked over pr and added future support for other languages * updated requirements * added back requirements.txt * add cute one shot addition example * remove unncessary files and redo requirements * let's see if that fixes it * fix config in file suggested soksoerey * move natalia's old file into change log --------- Co-authored-by: AffectionateCurry <[email protected]> Co-authored-by: nathanjpaek <[email protected]> Co-authored-by: Simon Guo <[email protected]>
This PR adds a triton backend to kernel bench. To invoke it simply add backend="triton" to the following 4 scripts (use them as normal otherwise)
This PR also adds a
{error_type}_nameinto the eval json. The reason for this is that it makes classifying errors (especially for triton) much easier. For example, from the error log it isn't obvious what an error is (ie. you might getat 37:15:\n h_start = pooled_row * stride - padding\n w_start = pooled_col * stride - padding\n\n # Initialize the max value\n max_val = tl.full((1,), float('-inf'), tl.float32)\n\n # Itera...). But if the error name istriton.compiler.errors.UnsupportedLanguageConstructit's a lot more obvious.Testing: I've tested the 4 scripts in both the triton and cuda variants and they seem to work normally. (outside of
scripts/generate_and_eval_single_sample_modal.pywhich should be equivalent toscripts/generate_and_eval_single_sample.py)Todo:
Below is the github copilot generated summary which is honestly pretty useful for navigating large PRs.
==========================================================================================
This pull request includes several changes to improve code readability and add new functionality to the
scripts/eval_from_generations.pyandscripts/generate_and_eval_single_sample.pyfiles. The most notable changes include reformatting code for better readability, adding a new backend configuration option, and enhancing error logging.Code readability improvements:
scripts/eval_from_generations.py: Reorganized import statements and reformatted multiple lines of code to follow PEP 8 guidelines.scripts/generate_and_eval_single_sample.py: Reorganized import statements and reformatted multiple lines of code to follow PEP 8 guidelines.New functionality:
scripts/eval_from_generations.py: Added a new configuration optionbackendto specify the backend for kernel implementation (cuda or triton).Error logging enhancements:
scripts/eval_from_generations.py: Enhanced error logging by adding more detailed error messages and including error names in the metadata. [1] [2]Miscellaneous:
scripts/eval_from_generations.py: Added indentation to JSON output inadd_to_eval_results_filefor better readability.scripts/generate_and_eval_single_sample.py: Added thebackendconfiguration option to theEvalConfigclass.