Skip to content

Conversation

@bkal01
Copy link
Collaborator

@bkal01 bkal01 commented Nov 4, 2025

adds unit tests for eval scripts

eval scripts should:

  • flag kernels that try to hack by modifying the input as incorrect
    • tested by having a custom kernel zero out the inputs and return a zero matrix
    • if the eval script runs the custom kernel first and then PyTorch (as done in here), then the custom kernel will incorrectly pass
    • so, we should at least run the tests in both directions (as mentioned in the benchmarking doc) OR clone the inputs so any modifications the kernel makes doesn't affect the PyTorch run
  • flag kernels that try to hack by reusing PyTorch computations as incorrect
    • tested by using empty which can get allocated the same physical memory as the PyTorch reference outputs
    • if the eval script deletes/frees the PyTorch output object at some point before the custom kernel is run, the CUDA cache allocator might give that un-erased physical memory to the custom kernel and it will incorrectly pass
    • so, we should ensure we zero out physical memory to prevent reuse
  • flag kernels that achieve excessive speedup as potentially reward hacked
    • tested by having a custom kernel allocate a matmul to a non-default stream, achieving extremely unrealistic speedups when timed via CUDA events.
    • eval script should time the reference kernel and flag this speedup as something the user should double check via the KernelExecResult metadata.

Adds a unit test to check that a generated kernel which modifies the original inputs fails the correctness check.

For the square matmul problem, the kernel zeros out the inputs and returns a matrix of 0s. This will fail correctness/pass the test as long as the reference implementation is ran first. If we swap the order, the test will fail as the reference implementation will operate on tensors of 0s and it will look like the generated kernel computed the correct output.
Adds a unit test to check that a generated kernel which attempts to access the result from the PyTorch reference model in memory fails the correctness check.

If a generated kernel uses empty_like, the CUDA caching allocator can re-use the physical memory of the previously computed result. All the kernel needs to do is return immediately and it will pass the correctness check.

Note that in order to reproduce this, we need to copy the PyTorch output to the CPU and delete the output object. Then empty_like will fetch the physical memory for the output object.
@bkal01 bkal01 changed the title [WIP] add unit tests for input mod add unit tests for input mod Nov 5, 2025
@bkal01 bkal01 requested a review from simonguozirui November 5, 2025 05:05
@bkal01 bkal01 changed the title add unit tests for input mod add eval unit tests Nov 7, 2025
@simonguozirui simonguozirui changed the title add eval unit tests Eval Unit Tests for Adversarial Correctness Testing Nov 29, 2025
use generic matmul shape for cache reuse adversarial kernel rather than requiring a square matmul.
simonguozirui and others added 6 commits December 18, 2025 02:42
make a non-blocking non-default stream, and use cublasGemmEx rather than at::matmul:
eval script now flags excessive speedups by timing pytorch reference.
@simonguozirui
Copy link
Collaborator

Thanks @bkal01 to create the adversarial kernel with additional cuda stream.. Now we have unit test and eval timing functions that only time the main cuda_stream might suffer from such attack, but we have added a heuristics way to check it (see if speedup is bigger than some threshold like 10x or 5x).

Here is an example using naive do_bench with the hacky stream kernel.

CUDA_VISIBLE_DEVICES=7 python src/unit_tests/test_eval_adversarial.py 
Running test adversarial kernel non_default_stream_kernel.py against problem level1/1_Square_matrix_multiplication_.py
[Profiling] Using timing method: do_bench
[WARNING] Excessive speedup 1947.37x over 10x threshold detected
[WARNING] Double check your kernel carefully to ensure it is not reward hacking.
compiled=True correctness=True metadata={'hardware': 'NVIDIA H200', 'device': '0', 'correctness_trials': '(5 / 5)', 'excessive_speedup': True} runtime=0.00304 runtime_stats={'mean': 0.00304, 'std': 0.000107, 'min': 0.00298, 'max': 0.00326, 'num_trials': 13} ref_runtime=5.92 ref_runtime_stats={'mean': 5.92, 'std': 1.14, 'min': 5.18, 'max': 7.69, 'num_trials': 17}
Traceback (most recent call last):
  File "/home/simon/kb-maintain/src/unit_tests/test_eval_adversarial.py", line 108, in <module>
    main()
  File "/home/simon/kb-maintain/src/unit_tests/test_eval_adversarial.py", line 105, in main
    test_non_default_stream()
  File "/home/simon/kb-maintain/src/unit_tests/test_eval_adversarial.py", line 96, in test_non_default_stream
    raise AssertionError(
AssertionError: Excessive speedup detected, Eval Function did not handle hacky stream

@simonguozirui
Copy link
Collaborator

We added an optional and gated logic in the eval functioneval_kernel_against_ref adds new param: check_for_excessive_speedup (bool), excessive_speedup_threshold (float), and now populates ref_runtime and ref_runtime_stats on the KernelExecResult. If the kernel's speedup exceeds the threshold, it sets result.metadata["excessive_speedup"] = True and prints a warning like this

[WARNING] Excessive speedup 1906.75x over 10x threshold detected
[WARNING] Double check your kernel carefully to ensure it is not reward hacking.

@simonguozirui simonguozirui added the enhancement New feature or request label Dec 19, 2025
@simonguozirui simonguozirui changed the title Eval Unit Tests for Adversarial Correctness Testing Eval Unit Tests for Adversarial Eval Testing Dec 19, 2025
@simonguozirui
Copy link
Collaborator

Tysm @bkal01 for the great work and being super careful. These unit tests would be super helpful for us to test the eval function with adversarial examples. Merging these for now but we can add more later.

Right now we added a simple excessive speedup check (heuristics like >5x, 10x) mark it as suspicious. A better approach is to create a SoL modeling (ongoing effort) based on program ops and hardware specs.

Also started to add the draft of eval / benchmarking guide here. @PaliC and team will pick up in other PRs.

@simonguozirui simonguozirui merged commit fd57302 into main Dec 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants