mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Changes by apply order: 1. Replace all `".."` and `os.pardir` usage with `os.path.dirname(...)`. 2. Replace nested `os.path.dirname(os.path.dirname(...))` call with `str(Path(...).parent.parent)`. 3. Reorder `.absolute()` ~/ `.resolve()`~ and `.parent`: always resolve the path first. `.parent{...}.absolute()` -> `.absolute().parent{...}` 4. Replace chained `.parent x N` with `.parents[${N - 1}]`: the code is easier to read (see 5.) `.parent.parent.parent.parent` -> `.parents[3]` 5. ~Replace `.parents[${N - 1}]` with `.parents[${N} - 1]`: the code is easier to read and does not introduce any runtime overhead.~ ~`.parents[3]` -> `.parents[4 - 1]`~ 6. ~Replace `.parents[2 - 1]` with `.parent.parent`: because the code is shorter and easier to read.~ Pull Request resolved: https://github.com/pytorch/pytorch/pull/129374 Approved by: https://github.com/justinchuby, https://github.com/malfet
68 lines
2.0 KiB
Python
68 lines
2.0 KiB
Python
# mypy: allow-untyped-defs
|
|
from __future__ import annotations
|
|
|
|
import functools
|
|
import os
|
|
import sys
|
|
import warnings
|
|
from pathlib import Path
|
|
from types import ModuleType
|
|
from typing import Any, Callable, Dict
|
|
|
|
|
|
def _reload_triton_kernel_in_subproc(reload_module, kernel_name):
|
|
return _module_to_triton_kernel(reload_module(), kernel_name)
|
|
|
|
|
|
def _module_to_triton_kernel(mod, kernel_name):
|
|
kernel = getattr(mod, kernel_name)
|
|
kernel._reload_in_subproc = functools.partial(
|
|
_reload_triton_kernel_in_subproc,
|
|
mod._reload_in_subproc,
|
|
kernel_name,
|
|
)
|
|
return kernel
|
|
|
|
|
|
def _reload_python_module_in_subproc(key, path):
|
|
codecache = sys.modules.get("torch._inductor.codecache")
|
|
if codecache:
|
|
return codecache.PyCodeCache.load_by_key_path(key, path)
|
|
else:
|
|
return _reload_python_module(key, path)
|
|
|
|
|
|
def _reload_python_module(key, path):
|
|
with open(path) as f:
|
|
try:
|
|
code = compile(f.read(), path, "exec", dont_inherit=True)
|
|
except Exception as e:
|
|
raise RuntimeError(
|
|
f"Failed to import {path}\n{type(e).__name__}: {e}"
|
|
) from None
|
|
mod = ModuleType(f"{__name__}.{key}")
|
|
mod.__file__ = path
|
|
mod.key = key # type: ignore[attr-defined]
|
|
exec(code, mod.__dict__, mod.__dict__)
|
|
sys.modules[mod.__name__] = mod
|
|
return mod
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def _set_triton_ptxas_path() -> None:
|
|
if os.environ.get("TRITON_PTXAS_PATH") is not None:
|
|
return
|
|
ptxas = Path(__file__).absolute().parents[1] / "bin" / "ptxas"
|
|
if not ptxas.exists():
|
|
return
|
|
if ptxas.is_file() and os.access(ptxas, os.X_OK):
|
|
os.environ["TRITON_PTXAS_PATH"] = str(ptxas)
|
|
else:
|
|
warnings.warn(f"{ptxas} exists but is not an executable")
|
|
|
|
|
|
def _worker_compile_triton(load_kernel: Callable[[], Any], extra_env: Dict[str, str]):
|
|
_set_triton_ptxas_path()
|
|
os.environ.update(extra_env)
|
|
load_kernel().precompile(warm_cache_only=True)
|