Files
pytorch/torch/_inductor/runtime/compile_tasks.py
Colin L Reliability Rice 0495cab545 Wire in pt2_triton_builds (#159897)
Summary:
This allows us to start seeing the failure rate on these models (and
potentially alert on it).

Test Plan:
```
FORCE_LOG_TRITON_BUILDS_TO_PROD=1 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 buck2 run @//mode/opt :compile 2>&1 | tee out
```
P1889607054

Waiting for scuba table to generate, but manual logging show it should show up at https://fburl.com/scuba/pt2_triton_builds_inc_archive/7852kt8h soon.

Rollback Plan:

Reviewed By: masnesral

Differential Revision: D79308333

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159897
Approved by: https://github.com/masnesral
2025-08-06 07:39:51 +00:00

77 lines
2.2 KiB
Python

from __future__ import annotations
import functools
import linecache
import os
import sys
import time
import warnings
from pathlib import Path
from types import ModuleType
from typing import Any, Callable, TYPE_CHECKING
from torch._utils_internal import log_triton_builds
if TYPE_CHECKING:
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
def _reload_python_module(
key: str, path: str, set_sys_modules: bool = True
) -> ModuleType:
with open(path) as f:
try:
code = compile(f.read(), path, "exec", dont_inherit=True)
except Exception as e:
raise RuntimeError(
f"Failed to import {path}\n{type(e).__name__}: {e}"
) from None
mod = ModuleType(f"{__name__}.{key}")
mod.__file__ = path
mod.key = key # type: ignore[attr-defined]
exec(code, mod.__dict__, mod.__dict__)
if set_sys_modules:
sys.modules[mod.__name__] = mod
return mod
@functools.cache
def _set_triton_ptxas_path() -> None:
if os.environ.get("TRITON_PTXAS_PATH") is not None:
return
ptxas = Path(__file__).absolute().parents[1] / "bin" / "ptxas"
if not ptxas.exists():
return
if ptxas.is_file() and os.access(ptxas, os.X_OK):
os.environ["TRITON_PTXAS_PATH"] = str(ptxas)
else:
warnings.warn(f"{ptxas} exists but is not an executable")
def _worker_compile_triton(
load_kernel: Callable[[], CachingAutotuner],
extra_env: dict[str, str],
extra_config: dict[str, Any],
) -> tuple[CachingAutotuner, int]:
_set_triton_ptxas_path()
os.environ.update(extra_env)
from torch._inductor import config
with config.patch(extra_config):
fail = None
try:
start_ns = time.time_ns()
kernel = load_kernel()
kernel.precompile(warm_cache_only=True)
elapsed_ns = time.time_ns() - start_ns
kernel.prepare_for_pickle()
# We can release this memory in the compile subprocesses:
linecache.clearcache()
return kernel, elapsed_ns // 1000
except Exception as e:
fail = str(e)
raise
finally:
log_triton_builds(fail=fail)