Log max_autotune exceptions (#159687) (#159688)

Summary:

Exceptions during autotune kernel precompilation are now systematically captured and reported via the chromium_event_logger, enabling better debugging and analysis of autotune failures.

Currently, exceptions are dumped to the console in the following format::
```
[0/0] RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help.
[0/0] Runtime error during autotuning:
[0/0] No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help..
[0/0] Ignoring this choice.
```

The exception tracebacks:
```
# inner exception
traceback:
  File "/torch/_inductor/runtime/triton_heuristics.py", line 603, in _make_launchers
    launchers.append(result.make_launcher())
                     ^^^^^^^^^^^^^^^^^^^^^^
  File "/torch/_inductor/runtime/triton_heuristics.py", line 1503, in make_launcher
    self.kernel.load_kernel(device)
  File "/torch/_inductor/runtime/static_cuda_launcher.py", line 113, in load_kernel
    (self.function, self.n_regs, self.n_spills) = _StaticCudaLauncher._load_kernel(

# wrapped exception
traceback:
  File "/usr/local/fbcode/platform010/lib/python3.12/concurrent/futures/thread.py", line 59, in run
    result = self.fn(*self.args, **self.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<trimmed>#link-tree/torch/_inductor/select_algorithm.py", line 2596, in precompile_with_captured_stdout
    choice.precompile()
  File "<trimmed>#link-tree/torch/_inductor/select_algorithm.py", line 1881, in precompile
    self.bmreq.precompile()
  File "<trimmed>#link-tree/torch/_inductor/autotune_process.py", line 660, in precompile
    getattr(mod, self.kernel_name).precompile()
  File "<trimmed>#link-tree/torch/_inductor/runtime/triton_heuristics.py", line 440, in precompile
    self._make_launchers()
  File "<trimmed>#link-tree/torch/_inductor/runtime/triton_heuristics.py", line 608, in _make_launchers
    raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")
```

With this change, the exception details will also be logged in the metadata of the `{name}_template_precompiling` event.

The format:
```
{
  "exceptions": [
    {
      "choice_type": "triton",
      "choice": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0",
      "exception_message": "No valid triton configs. OutOfMemoryError: out of resource: triton_mm Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help.",
      "exception": "OutOfMemoryError",
      "required_memory": "262144",
      "hardware_limit": "232448"
    }
  ]
}
```

Test Plan:
buck2 run //scripts/wychi:test_autotune_mm 2>&1 > /tmp/mylog.txt

Rollback Plan:

Differential Revision: D79420953

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159688
Approved by: https://github.com/stashuk-olek
This commit is contained in:
Wenyuan Chi
2025-08-08 01:30:08 +00:00
committed by PyTorch MergeBot
parent 03b254e49f
commit d68c323692

View File

@ -2650,11 +2650,13 @@ class AlgorithmSelectorCache(PersistentCache):
def wait_on_futures():
log.debug("Waiting on futures")
counters["inductor"]["select_algorithm_precompile"] += 1
exceptions: list[tuple[ChoiceCaller, BaseException]] = []
for future in as_completed(
futures,
timeout=precompilation_timeout_seconds,
):
if e := future.exception():
exceptions.append((futures[future], e))
from torch._inductor.codegen.cuda.cuda_kernel import (
CUDATemplateCaller,
)
@ -2682,6 +2684,8 @@ class AlgorithmSelectorCache(PersistentCache):
futures.get(future),
elapsed_times.get(future),
)
if exceptions:
_log_autotune_exceptions(exceptions)
executor.shutdown(wait=True)
@ -3452,5 +3456,61 @@ def _log_autotune_choices_stats(
sys.stderr.write(f"Autotune Choices Stats:\n{payload}\n")
def _log_autotune_exceptions(
exceptions: list[tuple[ChoiceCaller, BaseException]],
) -> None:
"""Log autotune exceptions to chromium event logger."""
if not exceptions:
return
try:
pt2_compile_substack = get_chromium_event_logger().get_pt2_compile_substack()
if not pt2_compile_substack:
return
current_event = pt2_compile_substack[-1]
if not current_event.endswith("_template_precompiling"):
return
exception_details = []
for choice, exc in exceptions:
try:
choice_type = (
"triton" if isinstance(choice, TritonTemplateCaller) else "other"
)
data = {
"choice_type": choice_type,
"choice": choice.description,
"exception_message": str(exc),
}
exc_type_match = re.search(r"(\w+):", str(exc))
if exc_type_match:
data["exception"] = exc_type_match.group(1)
if "OutOfMemoryError" in str(exc):
required_match = re.search(r"Required: (\d+)", str(exc))
if required_match:
data["required_memory"] = required_match.group(1)
limit_match = re.search(r"Hardware limit:\s*(\d+)", str(exc))
if limit_match:
data["hardware_limit"] = limit_match.group(1)
exception_details.append(data)
except Exception:
# Don't let logging errors break the main flow
continue
if exception_details:
metadata = json.dumps({"exceptions": exception_details})
get_chromium_event_logger().try_add_event_data(
current_event, metadata=metadata
)
except Exception:
# Silently ignore logging errors to avoid breaking autotune
pass
# ensure lowering is imported so that `extern_kernels.*` is populated
from . import lowering # noqa: F401