Files
pytorch/torch/_inductor/triton_bundler.py
James Wu cbc0964636 Store statically launchable CachingAutotuners inside CompiledFXGraph.triton_bundle (#149054)
This PR adds CachingAutotuners that are statically launchable to FXGraphCache's cache entry.

Regular CachingAutotuners, with triton kernels attached to them, are not very good to cache: they are very large, and take huge amounts of space since they track all of the various binary files, along with various metadata. We could probably figure out what information we could delete from the kernel and have it still work, but with StaticCudaLauncher, we no longer have to. Instead, we can cache every compiled triton kernel that is statically launchable.

Because StaticTritonCompileResult is serializable, and designed to have a very small memory footprint, we can save it into FXGraphCache without increasing the cache size significantly. We store it as a part of CompiledFxGraph.triton_bundle.

Then, on load, we repopulate the CachingAutotuner into our CompiledTritonKernel cache.

The upsides of this are many:
- We no longer need to call into a separate process on cache hit
- We can *guarantee* that the triton kernel we got from our cache entry is the one we use to launch again, so no worries about triton's own caching logic
- Once we achieve feature parity and all torch.compiled triton kernels are statically launchable, we can clean up a bunch of TritonBundler code and simplify the cache hit logic.

Fixes #149449

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149054
Approved by: https://github.com/oulgen
2025-03-30 17:51:11 +00:00

394 lines
15 KiB
Python

import copy
import dataclasses
import logging
import os
import shutil
import uuid
from pathlib import Path
from typing import Optional
from torch._dynamo.utils import counters, dynamo_timed, set_feature_use
from torch._utils_internal import justknobs_check
from torch.utils._filelock import FileLock
from .runtime.runtime_utils import triton_cache_dir
from .utils import _IS_WINDOWS, GPU_KERNEL_BIN_EXTS
log = logging.getLogger(__name__)
@dataclasses.dataclass(frozen=True)
class TritonBundleEntry:
"""
When we have compiled a triton kernel, we take note of that kernel by
its triton generated hash, its device, and where this kernel is located.
This is the minimum information we can use to later retrieve this kernel
from file system.
"""
kernel_hash: str
device: int
directory: str
@dataclasses.dataclass(frozen=True)
class TritonKernelArtifact:
"""
Artifact for an individual kernel converted to bytes.
Bytes could be a cubin, json, ttir, or ttgir.
"""
filename: str
payload: bytes = dataclasses.field(repr=False) # Do not display binary
@dataclasses.dataclass(frozen=True)
class StaticallyLaunchedAutotuner:
"""
Represents a statically compiled CachingAutotuner object that we can
save directly in the cache. A CachingAutotuner is made up of a list of
StaticTritonCompileResults, each of which uses the cubin from a TritonKernelArtifact.
Statically saved here have their cubin files saved by a corresponding TritonBundleEntry.
"""
cache_key: str
kernel_name: str
kernel: "CachingAutotuner" # type: ignore[name-defined] # noqa: F821
@dataclasses.dataclass(frozen=True)
class TritonKernelArtifacts:
"""
Collection of artifacts for a particular kernel.
"""
kernel_hash: str
device: int
artifacts: list[TritonKernelArtifact]
@dataclasses.dataclass(frozen=True)
class TritonBundlerMetadata:
"""
Metadata used for instrumentation
"""
cached_kernel_names: list[str]
statically_launched_kernel_names: list[str]
@dataclasses.dataclass(frozen=True)
class TritonBundle:
"""
Serializable bundle to save into FXGraphCache
"""
kernel_artifacts: list[TritonKernelArtifacts]
static_autotuners: list[StaticallyLaunchedAutotuner]
class TritonBundler:
"""
Lightweight Triton Kernel bundler that notes each time we compile a triton
kernel. When collect is called, converts all the previously noted kernels and
their artifacts into a structured bytes blob, and later when write is called
it writes this structured blob back to file system.
Intended Life cycle:
- TritonBundler.begin_compile is called when we start compiling in Inductor
- TritonBundler.put is called each time a Triton Kernel is compiled
- TritonBundler.collect is called when a cache entry is being generated
- TritonBundler.end_compile is called to indicate bundling is completed,
collect will execute this function as well.
- TritonBundler.read_and_emit is called when a cache entry is read
"""
_entries: Optional[list[TritonBundleEntry]] = None
_static_autotuners: Optional[list[StaticallyLaunchedAutotuner]] = None
# __grp__kernel_name.json contains metadata with source code paths
# we use this as sentinal value for search and replace
_REPLACE_BYTES: bytes = b"[REPLACE]"
@staticmethod
def is_enabled() -> bool:
from torch._inductor import config
if config.force_disable_caches:
return False
if (b := config.bundle_triton_into_fx_graph_cache) is not None:
return b
if not config.is_fbcode():
return False
return justknobs_check(
"pytorch/remote_cache:bundle_triton_into_fx_graph_cache_v2"
)
@classmethod
def begin_compile(cls) -> None:
"""
Initializes the TritonBundler.
The current TritonBundler bundle is finalized by TritonBundler.collect.
"""
if not TritonBundler.is_enabled():
return
log.debug("TritonBundler.begin_compile is called")
assert cls._entries is None
cls._entries = []
cls._static_autotuners = []
@classmethod
def end_compile(cls) -> None:
"""
Finalizes the TritonBundler. If collect is not yet called, it
discards the current bundle.
"""
log.debug("TritonBundler.end_compile is called")
cls._entries = None
cls._static_autotuners = None
@classmethod
def put(cls, kernel_hash: str, device: int) -> None:
"""
Lazily observes that we have seen a Triton kernel compilation. Remembers
it for when collect is later called.
"""
if (entries := cls._entries) is not None:
entries.append(
TritonBundleEntry(kernel_hash, device, triton_cache_dir(device))
)
@classmethod
def put_static_autotuner(cls, key: str, kernel: "CachingAutotuner") -> None: # type: ignore[name-defined] # noqa: F821
from torch._inductor import config
assert config.use_static_cuda_launcher
if (entries := cls._static_autotuners) is not None:
# Clear a bunch of unpicklable values and make a copy to save
# for FXGraphCache
old_values = kernel.prepare_for_pickle()
new_kernel = copy.deepcopy(kernel)
new_kernel._reload_kernel = None
entries.append(
StaticallyLaunchedAutotuner(
key,
new_kernel.inductor_meta.get("kernel_name", "unknown_kernel"),
new_kernel,
)
)
# Put the values back since we need it to use now
(
kernel.fn.fn,
kernel.fn.__globals__,
kernel.fn.used_global_vals,
kernel.fn.repr,
kernel.launchers,
) = old_values
@classmethod
def collect_static_autotuners(
cls,
) -> tuple[list[StaticallyLaunchedAutotuner], list[str]]:
if not cls._static_autotuners:
return [], []
else:
log.info(
"Saving %d statically launchable CachingAutotuners",
len(cls._static_autotuners),
)
static_autotuner_names = [i.kernel_name for i in cls._static_autotuners]
counters["inductor"]["triton_bundler_save_static_autotuner"] += 1
return cls._static_autotuners, static_autotuner_names
@classmethod
def load_autotuners(
cls, static_autotuners: Optional[list[StaticallyLaunchedAutotuner]]
) -> list[str]:
"""
Load statically launchable CachingAutotuners into async_compile.CompiledTritonKernels
cache.
"""
if not static_autotuners:
return []
from torch._inductor.async_compile import CompiledTritonKernels
from torch._inductor.codecache import StaticAutotunerFuture
log.info("Loading %d statically launchable autotuners", len(static_autotuners))
kernel_names = []
with dynamo_timed("TritonBundler.load_cached_static_autotuners"):
for result in static_autotuners:
# We make a future instead of returning the kernel here so that
# kernels that are not statically launchable (i.e. cache miss)
# can launch a worker without waiting on the blocking step of
# StaticAutotunerFuture.result().
CompiledTritonKernels._cache[result.cache_key] = StaticAutotunerFuture(
result.kernel
)
counters["inductor"]["triton_bundler_load_static_autotuner"] += 1
kernel_names.append(result.kernel_name)
return kernel_names
@classmethod
def collect(
cls,
) -> tuple[TritonBundle, Optional[TritonBundlerMetadata]]:
"""
This is the main function called when a cache write happens. This function
converts all the previously remembered kernels into bundled format so that
it can be written into a cache entry.
This function also finalizes the current bundle.
"""
from torch._inductor import config
if not TritonBundler.is_enabled():
cls.end_compile()
set_feature_use("triton_bundling", False)
return TritonBundle([], []), None
set_feature_use("triton_bundling", True)
with dynamo_timed(key="TritonBundler.collect", log_pt2_compile_event=True):
entries = cls._entries
if entries is not None:
result: list[TritonKernelArtifacts] = []
kernel_names: list[str] = []
for entry in entries:
artifacts: list[TritonKernelArtifact] = []
path = os.path.join(entry.directory, entry.kernel_hash)
if not os.path.exists(path):
continue
for filename in os.listdir(path):
filepath = os.path.join(path, filename)
try:
assert os.path.isfile(filepath)
with open(filepath, "rb") as file:
payload = file.read()
if filepath.endswith(".json"):
# Make sure there's no sentinel value
if TritonBundler._REPLACE_BYTES in payload:
log.warning(
"Bundle contains illegal %s, payload: %s",
TritonBundler._REPLACE_BYTES,
payload,
)
raise AssertionError(
"Bundle contains illegal bytes"
)
# Remove the path from payload
payload = payload.replace(
str.encode(path), TritonBundler._REPLACE_BYTES
)
artifacts.append(
TritonKernelArtifact(filename, payload)
)
counters["inductor"]["triton_bundler_save_kernel"] += 1
except Exception:
log.debug("failed to collect triton kernel", exc_info=True)
extension = os.path.splitext(filename)[1]
if extension in GPU_KERNEL_BIN_EXTS.values():
# Each kernel has bunch of files like .cubin(for cuda), .spv(for xpu), .json, .ttir
# Just append one of them without the extension
kernel_names.append(Path(filename).stem)
if artifacts:
result.append(
TritonKernelArtifacts(
entry.kernel_hash,
entry.device,
artifacts,
)
)
if config.use_static_cuda_launcher:
static_autotuners, static_kernel_names = (
cls.collect_static_autotuners()
)
else:
static_autotuners = []
static_kernel_names = []
cls.end_compile()
return TritonBundle(result, static_autotuners), TritonBundlerMetadata(
kernel_names, static_kernel_names
)
return TritonBundle([], []), None
@staticmethod
def read_and_emit(bundle: TritonBundle) -> Optional[TritonBundlerMetadata]:
"""
This is the main function called when a cache read happens. This function
converts the bundled format back into individual files and writes them
to the filesystem.
NOTE: When we are writing to the filesystem, we assume exclusive access
to the target directory.
This means that if the target folder already exists and is non-empty,
we bail out.
Exclusive access means that no other process should be writing to
or reading from the target directory.
"""
from torch._inductor import config
if not TritonBundler.is_enabled():
return None
with dynamo_timed(
key="TritonBundler.read_and_emit", log_pt2_compile_event=True
):
kernel_names: list[str] = []
for artifacts in bundle.kernel_artifacts:
basedir = triton_cache_dir(artifacts.device)
directory = os.path.join(basedir, artifacts.kernel_hash)
if os.path.exists(directory) and len(os.listdir(directory)) != 0:
# If directory already exists, we bail out and leave
# local disk to take care of caching
log.debug(
"Bailing out TritonBundler.read_and_emit, %s is non empty",
directory,
)
continue
Path(basedir).mkdir(parents=True, exist_ok=True)
# Random ID to avoid any collisions
rnd_id = str(uuid.uuid4())
tmp_dir = os.path.join(basedir, f"tmp.{rnd_id}")
os.makedirs(tmp_dir)
for artifact in artifacts.artifacts:
filepath = os.path.join(tmp_dir, artifact.filename)
with open(filepath, "wb") as file:
payload = artifact.payload
if artifact.filename.endswith(".json"):
payload = payload.replace(
TritonBundler._REPLACE_BYTES, str.encode(directory)
)
file.write(payload)
counters["inductor"]["triton_bundler_read_and_emit_kernel"] += 1
extension = os.path.splitext(artifact.filename)[1]
if extension in GPU_KERNEL_BIN_EXTS.values():
# Each kernel has bunch of files like .cubin(for cuda), spv(for xpu), .json, .ttir
# Just append one of them without the extension
kernel_names.append(Path(artifact.filename).stem)
if _IS_WINDOWS:
with FileLock(directory + ".lock"):
if os.path.exists(directory):
shutil.rmtree(directory)
os.replace(tmp_dir, directory)
else:
# Atomic on POSIX systems
os.replace(tmp_dir, directory)
if config.use_static_cuda_launcher:
static_kernel_names = TritonBundler.load_autotuners(
bundle.static_autotuners
)
else:
static_kernel_names = []
return TritonBundlerMetadata(kernel_names, static_kernel_names)