Wire in pt2_triton_builds (#159897)

Summary:
This allows us to start seeing the failure rate on these models (and
potentially alert on it).

Test Plan:
```
FORCE_LOG_TRITON_BUILDS_TO_PROD=1 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 buck2 run @//mode/opt :compile 2>&1 | tee out
```
P1889607054

Waiting for scuba table to generate, but manual logging show it should show up at https://fburl.com/scuba/pt2_triton_builds_inc_archive/7852kt8h soon.

Rollback Plan:

Reviewed By: masnesral

Differential Revision: D79308333

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159897
Approved by: https://github.com/masnesral
This commit is contained in:
Colin L Reliability Rice
2025-08-06 07:39:47 +00:00
committed by PyTorch MergeBot
parent abfe403981
commit 0495cab545
3 changed files with 45 additions and 24 deletions

View File

@ -49,6 +49,7 @@ from torch._inductor.runtime.compile_tasks import (
)
from torch._inductor.utils import clear_on_fresh_cache
from torch._inductor.virtualized import V
from torch._utils_internal import log_triton_builds
from torch.hub import _Faketqdm, tqdm
from torch.utils._ordered_set import OrderedSet
from torch.utils._triton import has_triton_package
@ -479,22 +480,29 @@ class AsyncCompile:
log_waitcounter=True,
waitcounter_name_override="compile_triton",
):
start_ns = time_ns()
_set_triton_ptxas_path()
kernel = load_kernel()
kernel.set_compile_info(compile_id, is_backward)
kernel.precompile(
warm_cache_only=False,
static_triton_bundle_key=CompiledTritonKernels.key(source_code),
)
elapsed_us = (time_ns() - start_ns) // 1000
get_metrics_context().add_top_n(
"triton_kernel_compile_times_us", kernel_name, elapsed_us
)
info = kernel.autotune_cache_info or {}
info["compile_time_us"] = elapsed_us
_add_triton_kernel_info(kernel_name, info)
return kernel
fail = None
try:
start_ns = time_ns()
_set_triton_ptxas_path()
kernel = load_kernel()
kernel.set_compile_info(compile_id, is_backward)
kernel.precompile(
warm_cache_only=False,
static_triton_bundle_key=CompiledTritonKernels.key(source_code),
)
elapsed_us = (time_ns() - start_ns) // 1000
get_metrics_context().add_top_n(
"triton_kernel_compile_times_us", kernel_name, elapsed_us
)
info = kernel.autotune_cache_info or {}
info["compile_time_us"] = elapsed_us
_add_triton_kernel_info(kernel_name, info)
return kernel
except Exception as e:
fail = str(e)
raise
finally:
log_triton_builds(fail=fail)
def multi_kernel(self, *args, **kwargs) -> Any:
from torch._inductor.codegen.multi_kernel import MultiKernelCall

View File

@ -10,6 +10,8 @@ from pathlib import Path
from types import ModuleType
from typing import Any, Callable, TYPE_CHECKING
from torch._utils_internal import log_triton_builds
if TYPE_CHECKING:
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
@ -57,11 +59,18 @@ def _worker_compile_triton(
from torch._inductor import config
with config.patch(extra_config):
start_ns = time.time_ns()
kernel = load_kernel()
kernel.precompile(warm_cache_only=True)
elapsed_ns = time.time_ns() - start_ns
kernel.prepare_for_pickle()
# We can release this memory in the compile subprocesses:
linecache.clearcache()
return kernel, elapsed_ns // 1000
fail = None
try:
start_ns = time.time_ns()
kernel = load_kernel()
kernel.precompile(warm_cache_only=True)
elapsed_ns = time.time_ns() - start_ns
kernel.prepare_for_pickle()
# We can release this memory in the compile subprocesses:
linecache.clearcache()
return kernel, elapsed_ns // 1000
except Exception as e:
fail = str(e)
raise
finally:
log_triton_builds(fail=fail)

View File

@ -354,3 +354,7 @@ def get_default_numa_options():
Must return None or NumaOptions, but not specifying to avoid circular import.
"""
return None
def log_triton_builds(fail: Optional[str]):
pass