-
Notifications
You must be signed in to change notification settings - Fork 110
Static Kernel Code Checker #110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
PaliC
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a lot of work to do. One immediate point of feedback is that if you want to ban something I would just replace whatever that is with a useful error or use ast to capture it (as in you need to do with try and except) and emit a useful error if the nodes pop up.
uv.lock
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed remove this :P
More seriously just update .gitignore
| @@ -0,0 +1,270 @@ | |||
| """ | |||
| Tests for kernel_static_checker.py | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would recommend using fixtures for this (this is a common strat for web dev / js testing)
I can give you some examples if you want
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah i will clean up the test file with some adveserial examples and reusable features, it was a start
src/code_check_old.py
Outdated
| has_tilelang = "@T.prim_func" in kernel_code or "tvm.build" in kernel_code | ||
|
|
||
| has_custom_implementation = any([ | ||
| has_triton_kernel, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for the python dsl. An idea would be to patch the decorators such that they do some inspection op / timing to validate it's executed / it takes up most of the time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea, we can use that in the runtime-based validation later on.
src/code_check_old.py
Outdated
| # Kevin Rule 1: PyTorch wrapping detection | ||
| # Zero reward for kernels containing torch.nn or torch.nn.functional | ||
| pytorch_patterns = [ | ||
| "torch.nn.functional", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would just overwrite these libraries to emit a useful error if you want to ban torch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i see what you mean by overwriting now, let's keep this a static checker for now and we can think about smarter runtime checks later
| "torch.softmax", "torch.log_softmax", "torch.tanh", "torch.sigmoid", | ||
| "torch.hardsigmoid", "torch.silu", "torch.mish", | ||
| # Normalization | ||
| "torch.batch_norm", "torch.group_norm", "torch.layer_norm", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just overwrite these with an error I think it's better.
and dsl tests Co-authored-by: Simon Guo <[email protected]> Co-authored-by: Ethan <[email protected]>
|
These are cool, I didn't take a careful enough look to make precise comments but in general I like the idea of all these tests. Also make sure people have to option to use or not use them (especially for the L3+ kernels where it might be too hard to not use PyTorch ops at all). |
| Run with pytest: | ||
| pytest src/kernelbench/unit_tests/test_precision.py -v | ||
| or | ||
| uv run pytest src/kernelbench/unit_tests/test_precision.py -v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ethanboneh this is a bit redundant and too extensive, perhaps we can clean it up with the real preicison checks in result PR
|
Thanks to @alexzhang13 and @PaliC for the thoughtful reviews and discussions! The goal of this PR is to introduce an initial checker system. While I like @PaliC’s AST-based checker idea, this PR focuses on a regex-based approach, which already catches a large class of problematic patterns in practice. Many of these patterns are derived from observations during Kevin RL training, as well as issues highlighted in Jiwei’s recent blog post. One advantage of regex-based checking is that it allows us to efficiently scan and validate large numbers of already-generated kernels. Going forward, we can complement this with more dynamic checkers in the evaluation pipeline, and I also plan to explore adding an LM-as-a-judge checker. This work is part of our ongoing effort to mitigate reward hacking, as discussed in #74, and we hope it can serve as a useful resource for the community. The current set of patterns is not meant to be exhaustive, and contributions of additional checks are very welcome. Since this system is still experimental, we will merge it but it is gated behind a flag. Users can configure groups of patterns that are either disallowed or treated as warnings, depending on their use case. See the file for more details on our suggestions for use cases. particularly a note on degree of how much pytorch computational ops should be allowed |
Kernel Static Checker - Pattern-based validation for GPU kernel code.
See
src/kernel_static_checker.pyThe goal is flag reward hacking patterns (both strictly prohibited and possible ones).
through statically examining the code.
Warning: This list is by no means complete and nor this is not a replacement for runtime checks.
We welcome feedback and contributions as community find new ways of hacks.
Usage:
result = validate_kernel_static(code, backend="cuda")
will return a tuple (valid, errors, warnings)
Right now, all these checks are regex-matching-based.
In the future we can add AST-based checking (@PaliC) to make it more reliable, more runtime level checks, and LM-as-a-judge.
Source of checks: