mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Wire in pt2_triton_builds (#159897)
Summary: This allows us to start seeing the failure rate on these models (and potentially alert on it). Test Plan: ``` FORCE_LOG_TRITON_BUILDS_TO_PROD=1 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 buck2 run @//mode/opt :compile 2>&1 | tee out ``` P1889607054 Waiting for scuba table to generate, but manual logging show it should show up at https://fburl.com/scuba/pt2_triton_builds_inc_archive/7852kt8h soon. Rollback Plan: Reviewed By: masnesral Differential Revision: D79308333 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159897 Approved by: https://github.com/masnesral
This commit is contained in:
committed by
PyTorch MergeBot
parent
abfe403981
commit
0495cab545
@ -49,6 +49,7 @@ from torch._inductor.runtime.compile_tasks import (
|
||||
)
|
||||
from torch._inductor.utils import clear_on_fresh_cache
|
||||
from torch._inductor.virtualized import V
|
||||
from torch._utils_internal import log_triton_builds
|
||||
from torch.hub import _Faketqdm, tqdm
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._triton import has_triton_package
|
||||
@ -479,22 +480,29 @@ class AsyncCompile:
|
||||
log_waitcounter=True,
|
||||
waitcounter_name_override="compile_triton",
|
||||
):
|
||||
start_ns = time_ns()
|
||||
_set_triton_ptxas_path()
|
||||
kernel = load_kernel()
|
||||
kernel.set_compile_info(compile_id, is_backward)
|
||||
kernel.precompile(
|
||||
warm_cache_only=False,
|
||||
static_triton_bundle_key=CompiledTritonKernels.key(source_code),
|
||||
)
|
||||
elapsed_us = (time_ns() - start_ns) // 1000
|
||||
get_metrics_context().add_top_n(
|
||||
"triton_kernel_compile_times_us", kernel_name, elapsed_us
|
||||
)
|
||||
info = kernel.autotune_cache_info or {}
|
||||
info["compile_time_us"] = elapsed_us
|
||||
_add_triton_kernel_info(kernel_name, info)
|
||||
return kernel
|
||||
fail = None
|
||||
try:
|
||||
start_ns = time_ns()
|
||||
_set_triton_ptxas_path()
|
||||
kernel = load_kernel()
|
||||
kernel.set_compile_info(compile_id, is_backward)
|
||||
kernel.precompile(
|
||||
warm_cache_only=False,
|
||||
static_triton_bundle_key=CompiledTritonKernels.key(source_code),
|
||||
)
|
||||
elapsed_us = (time_ns() - start_ns) // 1000
|
||||
get_metrics_context().add_top_n(
|
||||
"triton_kernel_compile_times_us", kernel_name, elapsed_us
|
||||
)
|
||||
info = kernel.autotune_cache_info or {}
|
||||
info["compile_time_us"] = elapsed_us
|
||||
_add_triton_kernel_info(kernel_name, info)
|
||||
return kernel
|
||||
except Exception as e:
|
||||
fail = str(e)
|
||||
raise
|
||||
finally:
|
||||
log_triton_builds(fail=fail)
|
||||
|
||||
def multi_kernel(self, *args, **kwargs) -> Any:
|
||||
from torch._inductor.codegen.multi_kernel import MultiKernelCall
|
||||
|
@ -10,6 +10,8 @@ from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
|
||||
from torch._utils_internal import log_triton_builds
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
||||
@ -57,11 +59,18 @@ def _worker_compile_triton(
|
||||
from torch._inductor import config
|
||||
|
||||
with config.patch(extra_config):
|
||||
start_ns = time.time_ns()
|
||||
kernel = load_kernel()
|
||||
kernel.precompile(warm_cache_only=True)
|
||||
elapsed_ns = time.time_ns() - start_ns
|
||||
kernel.prepare_for_pickle()
|
||||
# We can release this memory in the compile subprocesses:
|
||||
linecache.clearcache()
|
||||
return kernel, elapsed_ns // 1000
|
||||
fail = None
|
||||
try:
|
||||
start_ns = time.time_ns()
|
||||
kernel = load_kernel()
|
||||
kernel.precompile(warm_cache_only=True)
|
||||
elapsed_ns = time.time_ns() - start_ns
|
||||
kernel.prepare_for_pickle()
|
||||
# We can release this memory in the compile subprocesses:
|
||||
linecache.clearcache()
|
||||
return kernel, elapsed_ns // 1000
|
||||
except Exception as e:
|
||||
fail = str(e)
|
||||
raise
|
||||
finally:
|
||||
log_triton_builds(fail=fail)
|
||||
|
@ -354,3 +354,7 @@ def get_default_numa_options():
|
||||
Must return None or NumaOptions, but not specifying to avoid circular import.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
def log_triton_builds(fail: Optional[str]):
|
||||
pass
|
||||
|
Reference in New Issue
Block a user