Skip to content

Conversation

@wychi
Copy link
Contributor

@wychi wychi commented Aug 1, 2025

Summary:

Exceptions during autotune kernel precompilation are now systematically captured and reported via the chromium_event_logger, enabling better debugging and analysis of autotune failures.

Currently, exceptions are dumped to the console in the following format::

[0/0] RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help.
[0/0] Runtime error during autotuning: 
[0/0] No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[0/0] Ignoring this choice.

The exception tracebacks:

# inner exception
traceback:
  File "/torch/_inductor/runtime/triton_heuristics.py", line 603, in _make_launchers
    launchers.append(result.make_launcher())
                     ^^^^^^^^^^^^^^^^^^^^^^
  File "/torch/_inductor/runtime/triton_heuristics.py", line 1503, in make_launcher
    self.kernel.load_kernel(device)
  File "/torch/_inductor/runtime/static_cuda_launcher.py", line 113, in load_kernel
    (self.function, self.n_regs, self.n_spills) = _StaticCudaLauncher._load_kernel(

# wrapped exception
traceback:
  File "/usr/local/fbcode/platform010/lib/python3.12/concurrent/futures/thread.py", line 59, in run
    result = self.fn(*self.args, **self.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<trimmed>#link-tree/torch/_inductor/select_algorithm.py", line 2596, in precompile_with_captured_stdout
    choice.precompile()
  File "<trimmed>#link-tree/torch/_inductor/select_algorithm.py", line 1881, in precompile
    self.bmreq.precompile()
  File "<trimmed>#link-tree/torch/_inductor/autotune_process.py", line 660, in precompile
    getattr(mod, self.kernel_name).precompile()
  File "<trimmed>#link-tree/torch/_inductor/runtime/triton_heuristics.py", line 440, in precompile
    self._make_launchers()
  File "<trimmed>#link-tree/torch/_inductor/runtime/triton_heuristics.py", line 608, in _make_launchers
    raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")

With this change, the exception details will also be logged in the metadata of the {name}_template_precompiling event.

The format:

{
  "exceptions": [
    {
      "choice_type": "triton",
      "choice": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0",
      "exception_message": "No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help.",
      "exception": "OutOfMemoryError",
      "required_memory": "262144",
      "hardware_limit": "232448"
    }
  ]
}

Test Plan:
buck2 run //scripts/wychi:test_autotune_mm 2>&1 > /tmp/mylog.txt

Rollback Plan:

Differential Revision: D79420953

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159688

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit 7f1f847 with merge base eb25a95 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

  • pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, lf.linux.12xlarge, unstable) (gh) (#158876)
    /var/lib/jenkins/workspace/xla/torch_xla/csrc/runtime/BUILD:476:14: Compiling torch_xla/csrc/runtime/xla_util_test.cpp failed: (Exit 1): gcc failed: error executing CppCompile command (from target //torch_xla/csrc/runtime:xla_util_test) /usr/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections ... (remaining 229 arguments skipped)

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D79420953

@pytorch pytorch deleted a comment from pytorch-bot bot Aug 1, 2025
@wychi
Copy link
Contributor Author

wychi commented Aug 1, 2025

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Aug 1, 2025
@wychi wychi requested review from masnesral and stashuk-olek August 1, 2025 23:07
get_chromium_event_logger().add_event_data(
event_name, autotune_choices_stats=payload
)
sys.stderr.write(f"Autotune Choices Stats:\n{payload}\n")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how we log stuff here? No logging API?

Copy link
Contributor Author

@wychi wychi Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to know more about other logging APIs.

the logging is mainly done by calling get_chromium_event_logger().add_event_data(). So the data will be part of _template_autotuning event.
The sys.stderr.write is just for local debug. I can remove that for sure.

Btw, this part of code is actually another PR (#159496) and was already landed. I'll make necessary changes with another PR if needed.

Copy link
Contributor

@stashuk-olek stashuk-olek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 5, 2025
@wychi wychi force-pushed the export-D79420953 branch from df87084 to e6b6017 Compare August 5, 2025 20:40
pytorch-bot bot pushed a commit that referenced this pull request Aug 5, 2025
Summary:


Exceptions during autotune kernel precompilation are now systematically captured and reported via the chromium_event_logger, enabling better debugging and analysis of autotune failures.


Currently, exceptions are dumped to the console in the following format::
```
[0/0] RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help.
[0/0] Runtime error during autotuning: 
[0/0] No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[0/0] Ignoring this choice.
```

With this change, the exception details will also be logged in the metadata of the `{name}_template_precompiling` event.


The format:
```
{
  "exceptions": [
    {
      "choice_type": "triton",
      "choice": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0",
      "exception_message": "No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help.",
      "exception": "OutOfMemoryError",
      "required_memory": "262144",
      "hardware_limit": "232448"
    }
  ]
}
```

Test Plan:
buck2 run //scripts/wychi:test_autotune_mm 2>&1 > /tmp/mylog.txt

Rollback Plan:

Reviewed By: stashuk-olek

Differential Revision: D79420953
Summary:
Pull Request resolved: pytorch#159688

Pull Request resolved: pytorch#159687

Exceptions during autotune kernel precompilation are now systematically captured and reported via the chromium_event_logger, enabling better debugging and analysis of autotune failures.

Currently, exceptions are dumped to the console in the following format::
```
[0/0] RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help.
[0/0] Runtime error during autotuning:
[0/0] No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help..
[0/0] Ignoring this choice.
```

With this change, the exception details will also be logged in the metadata of the `{name}_template_precompiling` event.

The format:
```
{
  "exceptions": [
    {
      "choice_type": "triton",
      "choice": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0",
      "exception_message": "No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help.",
      "exception": "OutOfMemoryError",
      "required_memory": "262144",
      "hardware_limit": "232448"
    }
  ]
}
```

Test Plan:
buck2 run //scripts/wychi:test_autotune_mm 2>&1 > /tmp/mylog.txt

Rollback Plan:

Reviewed By: stashuk-olek

Differential Revision: D79420953
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D79420953

@wychi wychi force-pushed the export-D79420953 branch from e6b6017 to 7f1f847 Compare August 5, 2025 20:44
if not pt2_compile_substack:
return

current_event = pt2_compile_substack[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This smells like an anti-pattern. @jamesjwu wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not ideal, but it involves minimal changes. The assumption here is that the log occurs within the _template_precompiling event, which holds true with the current implementation.

Another approach would be to pass the exceptions by altering the return type of make_precompile_fn() and wait_on_futures().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the "anti-pattern" is not very harmful, I want to proceed with the change to gather some preliminary data.

Based on whether the data proves useful or valuable, I will either revert the PR if it’s not helpful or find a better location to log the errors if further action is needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
Summary:

Exceptions during autotune kernel precompilation are now systematically captured and reported via the chromium_event_logger, enabling better debugging and analysis of autotune failures.

Currently, exceptions are dumped to the console in the following format::
```
[0/0] RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help.
[0/0] Runtime error during autotuning:
[0/0] No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help..
[0/0] Ignoring this choice.
```

The exception tracebacks:
```
# inner exception
traceback:
  File "/torch/_inductor/runtime/triton_heuristics.py", line 603, in _make_launchers
    launchers.append(result.make_launcher())
                     ^^^^^^^^^^^^^^^^^^^^^^
  File "/torch/_inductor/runtime/triton_heuristics.py", line 1503, in make_launcher
    self.kernel.load_kernel(device)
  File "/torch/_inductor/runtime/static_cuda_launcher.py", line 113, in load_kernel
    (self.function, self.n_regs, self.n_spills) = _StaticCudaLauncher._load_kernel(

# wrapped exception
traceback:
  File "/usr/local/fbcode/platform010/lib/python3.12/concurrent/futures/thread.py", line 59, in run
    result = self.fn(*self.args, **self.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<trimmed>#link-tree/torch/_inductor/select_algorithm.py", line 2596, in precompile_with_captured_stdout
    choice.precompile()
  File "<trimmed>#link-tree/torch/_inductor/select_algorithm.py", line 1881, in precompile
    self.bmreq.precompile()
  File "<trimmed>#link-tree/torch/_inductor/autotune_process.py", line 660, in precompile
    getattr(mod, self.kernel_name).precompile()
  File "<trimmed>#link-tree/torch/_inductor/runtime/triton_heuristics.py", line 440, in precompile
    self._make_launchers()
  File "<trimmed>#link-tree/torch/_inductor/runtime/triton_heuristics.py", line 608, in _make_launchers
    raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")
```

With this change, the exception details will also be logged in the metadata of the `{name}_template_precompiling` event.

The format:
```
{
  "exceptions": [
    {
      "choice_type": "triton",
      "choice": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0",
      "exception_message": "No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help.",
      "exception": "OutOfMemoryError",
      "required_memory": "262144",
      "hardware_limit": "232448"
    }
  ]
}
```

Test Plan:
buck2 run //scripts/wychi:test_autotune_mm 2>&1 > /tmp/mylog.txt

Rollback Plan:

Differential Revision: D79420953

Pull Request resolved: pytorch#159688
Approved by: https://github.com/stashuk-olek
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Summary:

Exceptions during autotune kernel precompilation are now systematically captured and reported via the chromium_event_logger, enabling better debugging and analysis of autotune failures.

Currently, exceptions are dumped to the console in the following format::
```
[0/0] RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help.
[0/0] Runtime error during autotuning:
[0/0] No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help..
[0/0] Ignoring this choice.
```

The exception tracebacks:
```
# inner exception
traceback:
  File "/torch/_inductor/runtime/triton_heuristics.py", line 603, in _make_launchers
    launchers.append(result.make_launcher())
                     ^^^^^^^^^^^^^^^^^^^^^^
  File "/torch/_inductor/runtime/triton_heuristics.py", line 1503, in make_launcher
    self.kernel.load_kernel(device)
  File "/torch/_inductor/runtime/static_cuda_launcher.py", line 113, in load_kernel
    (self.function, self.n_regs, self.n_spills) = _StaticCudaLauncher._load_kernel(

# wrapped exception
traceback:
  File "/usr/local/fbcode/platform010/lib/python3.12/concurrent/futures/thread.py", line 59, in run
    result = self.fn(*self.args, **self.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<trimmed>#link-tree/torch/_inductor/select_algorithm.py", line 2596, in precompile_with_captured_stdout
    choice.precompile()
  File "<trimmed>#link-tree/torch/_inductor/select_algorithm.py", line 1881, in precompile
    self.bmreq.precompile()
  File "<trimmed>#link-tree/torch/_inductor/autotune_process.py", line 660, in precompile
    getattr(mod, self.kernel_name).precompile()
  File "<trimmed>#link-tree/torch/_inductor/runtime/triton_heuristics.py", line 440, in precompile
    self._make_launchers()
  File "<trimmed>#link-tree/torch/_inductor/runtime/triton_heuristics.py", line 608, in _make_launchers
    raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")
```

With this change, the exception details will also be logged in the metadata of the `{name}_template_precompiling` event.

The format:
```
{
  "exceptions": [
    {
      "choice_type": "triton",
      "choice": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0",
      "exception_message": "No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help.",
      "exception": "OutOfMemoryError",
      "required_memory": "262144",
      "hardware_limit": "232448"
    }
  ]
}
```

Test Plan:
buck2 run //scripts/wychi:test_autotune_mm 2>&1 > /tmp/mylog.txt

Rollback Plan:

Differential Revision: D79420953

Pull Request resolved: pytorch#159688
Approved by: https://github.com/stashuk-olek
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants