Files
pytorch/torch/_inductor/standalone_compile.py
rzou 723c27ed78 [standalone_compile] binary format write should be atomic (#162432)
We update it to call write_atomic instead of file.write

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162432
Approved by: https://github.com/oulgen
2025-09-09 18:43:13 +00:00

267 lines
10 KiB
Python

from __future__ import annotations
import copy
import logging
import os
import pickle
import shutil
from contextlib import AbstractContextManager, nullcontext
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING
import torch.fx
from torch._dynamo.utils import dynamo_timed
from torch._inductor.cpp_builder import normalize_path_separator
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
from torch._inductor.runtime.cache_dir_utils import temporary_cache_dir
from torch._inductor.utils import BoxedBool, InputType
from torch._subclasses import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from . import config
if TYPE_CHECKING:
from collections.abc import Sequence
from torch.compiler._cache import CacheInfo
from torch.fx import GraphModule
log = logging.getLogger(__name__)
class CompiledArtifact:
"""
CompiledArtifact class represents the precompiled inductor artifact that
can be invoked in order to avoid repeated compilation.
CompiledArtifact can be obtained by calling standalone_compile(gm, example_inputs)
to create a fresh CompiledArtifact from a GraphModule and example inputs.
Later this CompiledArtifact can be saved to disk, either as a binary or unpacked
into the provided folder via the CompiledArtifact.save function.
CompiledArtifact.load provides a way to create a CompiledArtifact from the
binary or unpacked data.
Finally, the CompiledArtifact can be invoked via the __call__ method
to execute the precompiled artifact.
"""
_compiled_fn: Callable[..., Any]
_artifacts: Optional[tuple[bytes, CacheInfo]]
def __init__(
self,
compiled_fn: Callable[..., Any],
artifacts: Optional[tuple[bytes, CacheInfo]],
):
self._compiled_fn = compiled_fn
self._artifacts = artifacts
def __call__(self, *args: Any) -> Any:
return self._compiled_fn(*args)
def save(
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> None:
with dynamo_timed("CompiledArtifact.save"):
if self._artifacts is None:
raise RuntimeError(
"CompiledArtifact.save failed to save since there's no artifact to save"
)
artifact_bytes, cache_info = self._artifacts
assert len(cache_info.aot_autograd_artifacts) == 1, cache_info
key = cache_info.aot_autograd_artifacts[0]
if format == "binary":
# can't assert that it is a file since it might not exist yet
assert not os.path.isdir(path)
from torch.utils._appending_byte_serializer import BytesWriter
from .codecache import torch_key
writer = BytesWriter()
writer.write_bytes(torch_key())
writer.write_str(key)
writer.write_bytes(artifact_bytes)
from torch._inductor.codecache import write_atomic
write_atomic(path, writer.to_bytes())
else:
assert format == "unpacked"
if os.path.exists(path):
assert os.path.isdir(path)
shutil.rmtree(path, ignore_errors=True)
from .codecache import FxGraphCache
with temporary_cache_dir(path):
# This function unpacks the cache artifacts to disk
loaded_cache_info = torch.compiler.load_cache_artifacts(
artifact_bytes
)
assert loaded_cache_info is not None
# Now write all the output_code artifacts to disk so that
# they can be inspected and modified
for key in loaded_cache_info.inductor_artifacts:
subdir = FxGraphCache._get_tmp_dir_for_key(key)
assert os.path.exists(subdir)
for path in sorted(os.listdir(subdir)):
with open(os.path.join(subdir, path), "rb") as f:
graph = pickle.load(f)
output_file = graph.write_to_disk()
log.info("Output code written to: %s", output_file)
@staticmethod
def load(
*, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> CompiledArtifact:
path = normalize_path_separator(path)
with dynamo_timed("CompiledArtifact.load"):
if format == "binary":
# can't assert that it is a file since it might not exist yet
assert not os.path.isdir(path)
with open(path, "rb") as file:
artifacts = file.read()
from torch.utils._appending_byte_serializer import BytesReader
from .codecache import torch_key
reader = BytesReader(artifacts)
assert reader.read_bytes() == torch_key()
key = reader.read_str()
artifact_bytes = reader.read_bytes()
assert reader.is_finished()
torch.compiler.load_cache_artifacts(artifact_bytes)
cache_dir_ctx: AbstractContextManager[None] = nullcontext()
else:
assert format == "unpacked"
assert os.path.isdir(path)
autograd_cache_dir = os.path.join(path, "aotautograd")
assert os.path.isdir(autograd_cache_dir)
files = list(os.listdir(autograd_cache_dir))
assert len(files) == 1
key = files[0]
cache_dir_ctx = temporary_cache_dir(path)
with (
cache_dir_ctx,
config.patch(unsafe_skip_cache_dynamic_shape_guards=True),
):
with torch._functorch.config.patch(strict_autograd_cache=True):
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache,
)
entry = AOTAutogradCache._lookup(
key,
local=True,
remote=False,
args=[],
cache_info={},
aot_config=None,
)
assert entry is not None
from .compile_fx import _CompileFxKwargs
fx_config = _CompileFxKwargs(
cudagraphs=BoxedBool(False),
boxed_forward_device_index=BoxedDeviceIndex(0),
)
context = torch._guards.TracingContext(
FakeTensorMode(shape_env=ShapeEnv())
)
with torch._guards.tracing(context):
compiled_fn = entry.wrap_post_compile(
[], entry.sanitized_aot_config, fx_config
)
return CompiledArtifact(lambda *args: compiled_fn(list(args)), None)
def standalone_compile(
gm: GraphModule,
example_inputs: Sequence[InputType],
*,
dynamic_shapes: Any,
options: Any,
) -> CompiledArtifact:
from torch.compiler._cache import CacheArtifactManager
from .compile_fx import compile_fx
ignore_shape_env = False
if dynamic_shapes == "from_example_inputs":
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
# tells compile_fx to ignore the shape_envs on the ambient context
# and the graph_module.
ignore_shape_env = True
elif dynamic_shapes == "from_tracing_context":
# Reuse fake_mode from the TracingContext.
# NB: The TracingContext only exists if we're currently in a torch.compile backend.
context = torch._guards.TracingContext.get()
assert context.fake_mode is not None
fake_mode = context.fake_mode
elif dynamic_shapes == "from_graph":
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
# Strategy: find a FakeTensor in the graph output, grab its FakeTensorMode.
# The graph passed to standalone_compile must be an Inductor-approved graph,
# which means that there is at least one Tensor output and the output node
# contains a flat list of Tensors.
last_node = next(iter(reversed(gm.graph.nodes)))
assert last_node.op == "output"
assert len(last_node.args) == 1
def handle_node(node: torch.fx.Node) -> None:
nonlocal fake_mode
if "example_value" in node.meta:
maybe_tensor = node.meta["example_value"]
if isinstance(maybe_tensor, torch._subclasses.fake_tensor.FakeTensor):
fake_mode = maybe_tensor.fake_mode
# If gm came from Dynamo, then last_node.args[0] is always a list,
# even in single-Tensor returns.
#
# It's possible to get into a situation where last_node.args[0]
# is a Node (and not a list!). This happens if you call split_module
# on the graph. We allow for this case since it is common.
if isinstance(last_node.args[0], torch.fx.Node):
handle_node(last_node.args[0])
else:
for node in last_node.args[0]:
handle_node(node)
else:
raise ValueError(
f"standalone_compile got unsupported `dynamic_shapes` value: dynamic_shapes={dynamic_shapes}."
)
context = torch._guards.TracingContext(fake_mode)
with (
torch._guards.tracing(context),
CacheArtifactManager.with_fresh_cache(),
config.patch("triton.autotune_at_compile_time", True),
):
# compile_fx can mutate gm
gm = copy.deepcopy(gm)
compiled_fn = compile_fx(
gm, example_inputs, ignore_shape_env=ignore_shape_env, **options
)
assert callable(compiled_fn)
artifacts = torch.compiler.save_cache_artifacts()
if artifacts is None:
log.warning(
"standalone_compile artifact generation failed, cannot save. "
"Run with TORCH_LOGS=+torch._inductor.codecache to identify the problem"
)
return CompiledArtifact(compiled_fn, artifacts)