mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See also #163972, which was intended to be this PR. Triton (release/3.5.x) by default ships CUDA12.8 ptxas. This PR tries to bundle a ptxas version for cuda13, so that it can help https://github.com/pytorch/pytorch/issues/163801 when users run on new devices like THOR and Spark. Fixes https://github.com/pytorch/pytorch/issues/163801 Test Plan: Check binary size increase against nightly or v2.9RC Install the binary from into a working THOR and GB200/GH100 machine (reproduce the original issue first on THOR), then install the binary built from this PR and we expect the issue to be gone without any additional user setting. Testing on GB200 is to ensure no regression. Reference: https://github.com/pytorch/pytorch/pull/119750 and5c814e2527
Note: with this PR, the pytorch world's torch.compile is supposed to find ptxas via "torch/_inductor/runtime/compile_tasks.py" and "_set_triton_ptxas_path". Use cases that do not go through "_set_triton_ptxas_path" may not be able to use the cuda13 ptxas binary. However, as is, the triton world does not know the existence of this new cuda13 ptxas. So IF a users thinks there is already pytorch/bin/ptxas and delete the ptxas from triton, thenc6ad34f7eb/python/triton/knobs.py (L216)
would still complain ptxas not found (if removed - it won't know this new one available) Pull Request resolved: https://github.com/pytorch/pytorch/pull/163988 Approved by: https://github.com/atalman
77 lines
2.2 KiB
Python
77 lines
2.2 KiB
Python
from __future__ import annotations
|
|
|
|
import functools
|
|
import linecache
|
|
import os
|
|
import sys
|
|
import time
|
|
import warnings
|
|
from pathlib import Path
|
|
from types import ModuleType
|
|
from typing import Any, Callable, TYPE_CHECKING
|
|
|
|
from torch._utils_internal import log_triton_builds
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
|
|
|
|
|
def _reload_python_module(
|
|
key: str, path: str, set_sys_modules: bool = True
|
|
) -> ModuleType:
|
|
with open(path) as f:
|
|
try:
|
|
code = compile(f.read(), path, "exec", dont_inherit=True)
|
|
except Exception as e:
|
|
raise RuntimeError(
|
|
f"Failed to import {path}\n{type(e).__name__}: {e}"
|
|
) from None
|
|
mod = ModuleType(f"{__name__}.{key}")
|
|
mod.__file__ = path
|
|
mod.key = key # type: ignore[attr-defined]
|
|
exec(code, mod.__dict__, mod.__dict__)
|
|
if set_sys_modules:
|
|
sys.modules[mod.__name__] = mod
|
|
return mod
|
|
|
|
|
|
@functools.cache
|
|
def _set_triton_ptxas_path() -> None:
|
|
if os.environ.get("TRITON_PTXAS_PATH") is not None:
|
|
return
|
|
ptxas = Path(__file__).absolute().parents[2] / "bin" / "ptxas"
|
|
if not ptxas.exists():
|
|
return
|
|
if ptxas.is_file() and os.access(ptxas, os.X_OK):
|
|
os.environ["TRITON_PTXAS_PATH"] = str(ptxas)
|
|
else:
|
|
warnings.warn(f"{ptxas} exists but is not an executable")
|
|
|
|
|
|
def _worker_compile_triton(
|
|
load_kernel: Callable[[], CachingAutotuner],
|
|
extra_env: dict[str, str],
|
|
extra_config: dict[str, Any],
|
|
) -> tuple[CachingAutotuner, int]:
|
|
_set_triton_ptxas_path()
|
|
os.environ.update(extra_env)
|
|
from torch._inductor import config
|
|
|
|
with config.patch(extra_config):
|
|
fail = None
|
|
try:
|
|
start_ns = time.time_ns()
|
|
kernel = load_kernel()
|
|
kernel.precompile(warm_cache_only=True)
|
|
elapsed_ns = time.time_ns() - start_ns
|
|
kernel.prepare_for_pickle()
|
|
# We can release this memory in the compile subprocesses:
|
|
linecache.clearcache()
|
|
return kernel, elapsed_ns // 1000
|
|
except Exception as e:
|
|
fail = str(e)
|
|
raise
|
|
finally:
|
|
log_triton_builds(fail=fail)
|