Subprocess compile (#146134)

Add a mode to `fx_codegen_and_compile()` to compile in a separate process. This is to prepare for async compile where we'll compile and run eager in parallel (and also be able to move the compile phase to a remote computer).

Added a test based which runs the test_torchinductor tests with subprocess compiling turned on.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146134
Approved by: https://github.com/jamesjwu
This commit is contained in:
Aaron Orenstein
2025-03-03 08:32:35 -08:00
committed by PyTorch MergeBot
parent 8f361c808b
commit 07f876e960
10 changed files with 924 additions and 248 deletions

View File

@ -25,13 +25,17 @@ from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inducto
check_model,
CommonTemplate,
copy_tests,
TestFailure,
)
importlib.import_module("filelock")
# xfail by default, set is_skip=True to skip
test_failures = {}
test_failures = {
# TypeError: cannot pickle 'generator' object
"test_layer_norm_graph_pickler": TestFailure(("cpu"), is_skip=True),
}
def make_test_cls(cls, xfail_prop="_expected_failure_graph_pickler"):

View File

@ -0,0 +1,95 @@
# Owner(s): ["module: fx"]
#
# Tests compiling the inductor tests in a subprocess.
#
import contextlib
import importlib
import os
import sys
from unittest.mock import patch
import torch
import torch.library
from torch._inductor.compile_fx import _InProcessFxCompile, FxCompile, FxCompileMode
from torch._inductor.test_case import TestCase
from torch.testing._internal.common_utils import TEST_WITH_ASAN
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
import inductor.test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library
from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library
check_model,
check_model_gpu,
copy_tests,
TestFailure,
)
importlib.import_module("filelock")
# xfail by default, set is_skip=True to skip
test_failures = {
# TypeError: cannot pickle 'generator' object
"test_layer_norm": TestFailure(("cpu", "cuda"), is_skip=True),
}
class TestSubprocess(TestCase):
def setUp(self):
torch._dynamo.reset()
FxCompile._reset_stats()
TestCase.setUp(self)
self._stack = contextlib.ExitStack()
self._stack.enter_context(
patch(
"torch._inductor.compile_fx.fx_compile_mode",
FxCompileMode.SUBPROCESS,
)
)
def tearDown(self):
# Check that the test didn't instigate an in-process compile - which
# would mean that something about the fx graph failed to serialize. If
# some tests are expected to fail then we should probably add a list of
# expected failures here.
self.assertEqual(
FxCompile._compile_stats[type(_InProcessFxCompile)].codegen_and_compile, 0
)
self._stack.close()
TestCase.tearDown(self)
torch._dynamo.reset()
if HAS_CPU:
class CpuTests(TestSubprocess):
common = check_model
device = "cpu"
copy_tests(
inductor.test_torchinductor.CommonTemplate, CpuTests, "cpu", test_failures
)
if HAS_GPU and not TEST_WITH_ASAN:
class GPUTests(TestSubprocess):
common = check_model_gpu
device = GPU_TYPE
copy_tests(
inductor.test_torchinductor.CommonTemplate, GPUTests, GPU_TYPE, test_failures
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_CPU or HAS_GPU:
run_tests(needs="filelock")

View File

@ -20,6 +20,8 @@ import unittest
import unittest.mock
import weakref
from pathlib import Path
from typing import Callable, TypeVar
from typing_extensions import ParamSpec
from unittest.mock import patch
import numpy as np
@ -133,6 +135,10 @@ from torch.testing._internal.inductor_utils import (
from torch.testing._internal.triton_utils import requires_cuda
_T = TypeVar("_T")
_P = ParamSpec("_P")
HAS_AVX2 = "fbgemm" in torch.backends.quantized.supported_engines
aten = torch.ops.aten
@ -6395,16 +6401,16 @@ class CommonTemplate:
return x.cos().sin().softmax(-1)
x = torch.randn(16, 256, device=self.device)
_, (coda_a0,) = run_and_get_kernels(a, x)
_, (coda_b0,) = run_and_get_kernels(b, x)
_, (coda_c0,) = run_and_get_kernels(c, x)
_, (coda_a0,) = _run_and_get_stripped_kernels(a, x)
_, (coda_b0,) = _run_and_get_stripped_kernels(b, x)
_, (coda_c0,) = _run_and_get_stripped_kernels(c, x)
self.assertEqual(coda_a0, coda_c0)
# compile in a different order
torch.compiler.reset()
_, (coda_c1,) = run_and_get_kernels(c, x)
_, (coda_a1,) = run_and_get_kernels(a, x)
_, (coda_b1,) = run_and_get_kernels(b, x)
_, (coda_c1,) = _run_and_get_stripped_kernels(c, x)
_, (coda_a1,) = _run_and_get_stripped_kernels(a, x)
_, (coda_b1,) = _run_and_get_stripped_kernels(b, x)
self.assertEqual(coda_a0, coda_a1)
self.assertEqual(coda_b0, coda_b1)
self.assertEqual(coda_c0, coda_c1)
@ -6417,9 +6423,9 @@ class CommonTemplate:
"__init__",
lambda self, _: CompileContext_init(self, CompileId(999, 999)),
):
_, (coda_a2,) = run_and_get_kernels(a, x)
_, (coda_c2,) = run_and_get_kernels(c, x)
_, (coda_b2,) = run_and_get_kernels(b, x)
_, (coda_a2,) = _run_and_get_stripped_kernels(a, x)
_, (coda_c2,) = _run_and_get_stripped_kernels(c, x)
_, (coda_b2,) = _run_and_get_stripped_kernels(b, x)
self.assertEqual(coda_a0, coda_a2)
self.assertEqual(coda_b0, coda_b2)
self.assertEqual(coda_c0, coda_c2)
@ -6442,7 +6448,7 @@ class CommonTemplate:
return x
x = torch.randn(16, 256, device=self.device)
_, (code0, code1) = run_and_get_kernels(b, x)
_, (code0, code1) = _run_and_get_stripped_kernels(b, x)
self.assertEqual(code0, code1)
@patch.object(cpp_prefix_path, "cache_clear", lambda: None)
@ -6464,8 +6470,8 @@ class CommonTemplate:
x = torch.randn(16, 256, device=self.device)
y = torch.randn(256, 256, device=self.device)
_, (code0,) = run_and_get_kernels(a, x)
_, (code1,) = run_and_get_kernels(b, x, y)
_, (code0,) = _run_and_get_stripped_kernels(a, x)
_, (code1,) = _run_and_get_stripped_kernels(b, x, y)
self.assertEqual(code0, code1)
def test_flip(self):
@ -14270,6 +14276,20 @@ if HAS_CPU:
self.assertEqual(ret_opt, fn(pytype, dtype))
def _strip_tmp_path(code: str) -> str:
"""
Canonicalize things that look like a tmp path so they can be compared.
"""
return re.sub('#include ".*?"', '#include "<tmppath>"', code)
def _run_and_get_stripped_kernels(
fn: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> tuple[_T, list[str]]:
result, codes = run_and_get_kernels(fn, *args, **kwargs)
return result, [_strip_tmp_path(code) for code in codes]
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -1169,6 +1169,24 @@ class FxGraphCache:
log.warning("fx graph unable to write to cache", exc_info=True)
counters["inductor"]["fxgraph_cache_write_error"] += 1
@staticmethod
def _check_for_hop(gm: torch.fx.GraphModule) -> None:
for module in gm.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if (
isinstance(node.target, torch._ops.HigherOrderOperator)
and not node.target.cacheable()
):
raise BypassFxGraphCache(
f"Can't cache HigherOrderOperator: {node.target.name()}"
)
if node.op == "getattr" and isinstance(
getattr(gm, node.target), torch._C.ScriptObject
):
raise BypassFxGraphCache("Can't cache torchbind objects")
@staticmethod
def _check_can_cache(gm: torch.fx.GraphModule) -> None:
"""
@ -1205,22 +1223,8 @@ class FxGraphCache:
log.debug("fx graph cache no shape env")
raise BypassFxGraphCache("No shape env")
# We skip caching if there are any torchbind objects.
for module in gm.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if (
isinstance(node.target, torch._ops.HigherOrderOperator)
and not node.target.cacheable()
):
raise BypassFxGraphCache(
f"Can't cache HigherOrderOperator: {node.target.name()}"
)
if node.op == "getattr" and isinstance(
getattr(gm, node.target), torch._C.ScriptObject
):
raise BypassFxGraphCache("Can't cache torchbind objects")
# We skip caching if there are any HOPs or torchbind objects.
FxGraphCache._check_for_hop(gm)
@staticmethod
def prepare_key(

View File

@ -12,8 +12,8 @@ import sys
import time
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from contextlib import AbstractContextManager
from dataclasses import dataclass
from inspect import currentframe
from itertools import count
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
@ -54,18 +54,12 @@ from torch._functorch.aot_autograd import (
make_boxed_func,
SerializableAOTDispatchCompiler,
)
from torch._inductor.codecache import (
BypassFxGraphCache,
code_hash,
FxGraphCache,
output_code_log,
)
from torch._inductor.codecache import code_hash, FxGraphCache, output_code_log
from torch._inductor.cudagraph_utils import BoxedDeviceIndex, PlaceholderInfo
from torch._inductor.debug import save_args_for_compile_fx_inner
from torch._inductor.output_code import (
CompiledAOTI,
CompiledFxGraph,
CompiledFxGraphConstants,
CompiledFxGraphConstantsWithGm,
get_expanded_dims,
index_expanded_dims,
@ -121,7 +115,7 @@ from .virtualized import V
if TYPE_CHECKING:
from collections.abc import Generator, Mapping, Sequence
from collections.abc import Generator, Sequence
from torch._inductor.output_code import _StrideExprStr
from torch._ops import OpOverload
@ -151,21 +145,20 @@ if TYPE_CHECKING:
)
# For testing - use the serde FxCompile scheme to debug serialization and
# deserialization of GraphMoule and CompiledFxGraph.
class FxCompileMode(enum.Enum):
NORMAL = 0
# For testing - use the serde FxCompile scheme to debug serialization and
# deserialization of GraphMoule and CompiledFxGraph.
SERIALIZE = 1
# Compile using a subprocess instead of in-process.
SUBPROCESS = 2
def _fx_compile_mode_default() -> FxCompileMode:
name = "TORCHINDUCTOR_FX_COMPILE_MODE"
value = os.environ.get(name)
NORMAL = FxCompileMode.NORMAL
if value is None:
return NORMAL
return FxCompileMode.NORMAL
try:
value = value.upper()
return FxCompileMode[value]
@ -186,7 +179,6 @@ def _fx_compile_mode_default() -> FxCompileMode:
fx_compile_mode = _fx_compile_mode_default()
log = logging.getLogger(__name__)
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
pre_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "pre_grad_graphs")
@ -849,12 +841,23 @@ def _compile_fx_inner(
return compiled_graph
class _FxCompileStat:
# Count of successful compiles of this type
codegen_and_compile: int = 0
def __repr__(self) -> str:
return f"codegen_and_compile: {self.codegen_and_compile}"
class FxCompile(ABC):
"""
An FxCompile represents a mechanism that can turn a GraphModule into an
OutputCode.
"""
# Some stats for logging/debugging
_compile_stats: dict[type[FxCompile], _FxCompileStat] = defaultdict(_FxCompileStat)
# TODO: We should probably eventually add some kind of async version of this
# so we can kick off a compile and then go do other things - but we'll need
# to know what kind of API we want for that first.
@ -867,6 +870,10 @@ class FxCompile(ABC):
graph_kwargs: _CompileFxKwargs,
) -> OutputCode: ...
@classmethod
def _reset_stats(cls) -> None:
cls._compile_stats.clear()
class _InProcessFxCompile(FxCompile):
@override
@ -881,6 +888,7 @@ class _InProcessFxCompile(FxCompile):
# to propagate it further on
# TODO: _CompileFxKwargs actually has stronger types than in the
# signature, need to tighten it up
assert "cudagraphs" in graph_kwargs and graph_kwargs["cudagraphs"] is not None
cudagraphs: BoxedBool = graph_kwargs["cudagraphs"]
static_input_idxs: Sequence[int] = graph_kwargs.get("static_input_idxs", ())
@ -1215,6 +1223,8 @@ class _InProcessFxCompile(FxCompile):
)
)
self._compile_stats[type(self)].codegen_and_compile += 1
return CompiledFxGraph(
compiled_fn,
graph,
@ -1232,195 +1242,6 @@ class _InProcessFxCompile(FxCompile):
)
def _current_fake_mode() -> torch._subclasses.FakeTensorMode:
fake_mode = None
if context := torch._guards.TracingContext.try_get():
fake_mode = context.fake_mode
if fake_mode is not None:
return fake_mode
shape_env = torch.fx.experimental.symbolic_shapes.ShapeEnv()
return torch._subclasses.FakeTensorMode(shape_env=shape_env)
@dataclass
class _WireProtocolInput:
"""
For _SerializedFxCompile - encapsulates all the data being transferred
(sent) from the parent to the child.
"""
gm: torch.fx.GraphModule
example_inputs: Sequence[InputType]
inputs_to_check: Sequence[int]
graph_kwargs: _CompileFxKwargs
# TODO: Add additional state to transfer to the child.
def serialize(self) -> _WireProtocolPickledInput:
"""
Turns this object into a _WireProtocolPickledInput which can be
directly transferred across a stream.
"""
from torch.fx._graph_pickler import GraphPickler
return _WireProtocolPickledInput(GraphPickler.dumps(self))
@dataclass
class _WireProtocolPickledInput:
value: bytes
def deserialize(self) -> _WireProtocolInput:
"""
Turn this streamable object back into a _WireProtocolInput.
"""
from torch.fx._graph_pickler import GraphPickler
fake_mode = _current_fake_mode()
result = GraphPickler.loads(self.value, fake_mode)
assert isinstance(result, _WireProtocolInput)
return result
@dataclass
class _WireProtocolOutput:
"""
For _SerializedFxCompile - encapsulates all the data being transferred
(returned) back from the child to the parent.
"""
graph: OutputCode
def serialize(self) -> _WireProtocolPickledOutput:
"""
Turns this object into a _WireProtocolPickledOutput which can be
directly transferred across a stream.
"""
from torch.fx._graph_pickler import GraphPickler
if isinstance(self.graph, CompiledFxGraph):
self.graph.prepare_for_serialization()
return _WireProtocolPickledOutput(GraphPickler.dumps(self))
@dataclass
class _WireProtocolPickledOutput:
value: bytes
def deserialize(self, constants: CompiledFxGraphConstants) -> _WireProtocolOutput:
"""
Turn this streamable object back into a _WireProtocolOutput.
"""
from torch.fx._graph_pickler import GraphPickler
fake_mode = _current_fake_mode()
result = GraphPickler.loads(self.value, fake_mode)
assert isinstance(result, _WireProtocolOutput)
if isinstance(result.graph, CompiledFxGraph):
result.graph.after_deserialization(constants)
return result
class _SerializedFxCompile(FxCompile):
"""
This is used to represent an FxCompile which occurs across a serialized
boundary.
"""
@override
def codegen_and_compile(
self,
gm: GraphModule,
example_inputs: Sequence[InputType],
inputs_to_check: Sequence[int],
graph_kwargs: _CompileFxKwargs,
) -> OutputCode:
# _context = torch._guards.TracingContext.try_get()
constants = CompiledFxGraphConstantsWithGm(gm)
try:
input = _WireProtocolInput(
gm,
example_inputs,
inputs_to_check,
graph_kwargs,
).serialize()
except (AttributeError, BypassFxGraphCache):
# For example: AttributeError: Can't pickle local object
# 'make_opaque_unary_fn.<locals>.OpaqueUnaryFn'
# TODO: scuba record about not being able to do this?
log.debug("Unable to pickle input graph or example inputs", exc_info=True)
# Fallback to in-process
return _InProcessFxCompile().codegen_and_compile(
gm, example_inputs, inputs_to_check, graph_kwargs
)
output = self._send_to_child(input).deserialize(constants)
self._postprocess(output)
# TODO: Do we need to figure out what changed in TracingContext in the
# child and plumb that back up to the parent?
return output.graph
@abstractmethod
def _send_to_child(
self, pickled_input: _WireProtocolPickledInput
) -> _WireProtocolPickledOutput:
# The implementation of this should transfer `input` to the child, call
# `_run_in_child(input)` and transfer the result back.
...
def _postprocess(self, output: _WireProtocolOutput) -> None:
pass
@classmethod
def _run_in_child(
cls,
pickled_input: _WireProtocolPickledInput,
extra_env: Optional[Mapping[str, str]] = None,
) -> _WireProtocolPickledOutput:
with contextlib.ExitStack() as stack:
if extra_env is not None:
import unittest
stack.enter_context(unittest.mock.patch.dict("os.environ", extra_env))
# TODO: Should we split the input into multiple sections where each
# section sets up state for the previous section? (i.e. a Config section
# which we decode and apply, followed by a FakeTensorMode section which
# we decode and apply, etc)
input = pickled_input.deserialize()
stack.enter_context(DebugContext())
output_graph = _InProcessFxCompile().codegen_and_compile(
input.gm,
input.example_inputs,
input.inputs_to_check,
input.graph_kwargs,
)
return _WireProtocolOutput(
output_graph,
).serialize()
# This is a debugging/testing implementation of FxCompile which serializes the
# input and output but still runs the FxCompile in-process.
class _DebugSerdeFxCompile(_SerializedFxCompile):
@override
def _send_to_child(
self, pickled_input: _WireProtocolPickledInput
) -> _WireProtocolPickledOutput:
# For debugging just serde the input and output but don't run in a
# subprocess.
return self._run_in_child(pickled_input)
def fx_codegen_and_compile(
gm: GraphModule,
example_inputs: Sequence[InputType],
@ -1430,12 +1251,17 @@ def fx_codegen_and_compile(
**graph_kwargs: Unpack[_CompileFxKwargs],
) -> OutputCode:
scheme: FxCompile
if fx_compile_mode == FxCompileMode.NORMAL:
scheme = _InProcessFxCompile()
elif fx_compile_mode == FxCompileMode.SERIALIZE:
from .compile_fx_ext import _DebugSerdeFxCompile
scheme = _DebugSerdeFxCompile()
else:
raise NotImplementedError
elif fx_compile_mode == FxCompileMode.SUBPROCESS:
from .compile_fx_subproc import _SubprocessFxCompile
scheme = _SubprocessFxCompile()
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)

View File

@ -0,0 +1,604 @@
from __future__ import annotations
import contextlib
import dataclasses
import functools
import logging
import os
import queue
import sys
import warnings
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Optional, TYPE_CHECKING, Union
from typing_extensions import override, Self, TypeGuard
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
import torch.fx
from torch._inductor.codecache import BypassFxGraphCache, FxGraphCache
from torch._inductor.metrics import CachedMetricsDeltas, CachedMetricsHelper
from torch._inductor.output_code import (
CompiledFxGraph,
CompiledFxGraphConstants,
CompiledFxGraphConstantsWithGm,
OutputCode,
)
from torch.utils._ordered_set import OrderedSet
from . import config
from .compile_fx import _CompileFxKwargs, _InProcessFxCompile, FxCompile, log
from .debug import DebugContext
from .graph import GraphLowering
from .output_code import complex_memory_overlap as complex_memory_overlap # noqa: F401
from .virtualized import V
if TYPE_CHECKING:
import types
from collections.abc import Mapping, Sequence
from torch._inductor.utils import InputType
from torch.fx import GraphModule
@dataclass
class _VirtualizedSerializer:
"""
This handles the data for serializing Virtualized.
"""
# The values here get serialized. We don't grab everything because some of
# the fields can't be serialized.
aot_compilation: Any = None
choices: Any = None
local_buffer_context: Any = None
ops: Any = None
kernel: Any = None
current_node: Any = None
@classmethod
def serialize(cls) -> _VirtualizedSerializer:
"""
Turn the current state of torch._inductor.virtualized.V into a
serializable structure.
"""
kwargs = {}
for f in dataclasses.fields(cls):
kwargs[f.name] = getattr(V, f.name)
return _VirtualizedSerializer(**kwargs)
def patch(self) -> _VirtualizedSerializerContextManager:
"""
Returns a context manager which patches the saved values into the
current environment. While patched, any value not listed above will be
poisoned so that reads will raise an error.
"""
return _VirtualizedSerializerContextManager(self)
class _VirtualizedSerializerContextManager(contextlib.ExitStack):
"""
Helper for _VirtualizedSerializer.patch()
"""
def __init__(self, virtualized: _VirtualizedSerializer) -> None:
super().__init__()
self.virtualized = virtualized
@override
def __enter__(self) -> Self:
super().__enter__()
for set_name in dir(V):
if not set_name.startswith("set_"):
continue
name = set_name[4:]
name = name.removesuffix("_handler")
set_handler = getattr(V, set_name)
if hasattr(self.virtualized, name):
value = getattr(self.virtualized, name)
else:
# poison any values that we don't serialize so that any
# unset accesses are caught.
value = torch._inductor.virtualized._PoisonedVirtual
self.enter_context(set_handler(value))
return self
def _is_fallback_handler(op: object) -> bool:
try:
return op._is_fallback_handler # type: ignore[attr-defined]
except AttributeError:
return False
class _LoweringSerializer:
"""
This handles the data for serializing lowering.lowering
"""
# A full implementation would make sure that all lowerings are copied over
# (or at least detected and raise a bypass when a non-standard lowering is
# used). For now we just handle tests by looking for lowerings that were
# overridden with a forced fallback.
fallbacks: OrderedSet[str]
def __init__(self) -> None:
from . import lowering
self.fallbacks = OrderedSet(
str(k) for k, v in lowering.lowerings.items() if _is_fallback_handler(v)
)
def patch(self) -> _LoweringSerializerContextManager:
return _LoweringSerializerContextManager(self)
class _LoweringSerializerContextManager(contextlib.ExitStack):
"""
Helper for _LoweringSerializer.patch()
"""
def __init__(self, lowering: _LoweringSerializer) -> None:
super().__init__()
self.lowering = lowering
@override
def __enter__(self) -> Self:
super().__enter__()
from . import lowering
for k, v in lowering.lowerings.items():
name = str(k)
if name in self.lowering.fallbacks:
if not _is_fallback_handler(v):
self.enter_context(lowering.force_fallback(k)) # type: ignore[arg-type]
return self
@dataclass
class _WireProtocolInput:
"""
For _SerializedFxCompile - encapsulates all the data being transferred
(sent) from the parent to the child.
"""
gm: torch.fx.GraphModule
example_inputs: Sequence[InputType]
inputs_to_check: Sequence[int]
graph_kwargs: _CompileFxKwargs
tracing_context: Optional[torch._guards.TracingContext]
config: dict[str, object]
virtualized: _VirtualizedSerializer
deterministic_guard_for_testing: Optional[
torch.testing._internal.common_utils.DeterministicGuard
]
logger_state: _LoggerState
lowering: _LoweringSerializer
def serialize(self) -> _WireProtocolPickledInput:
"""
Turns this object into a _WireProtocolPickledInput which can be
directly transferred across a stream.
"""
from torch.fx._graph_pickler import GraphPickler
return _WireProtocolPickledInput(GraphPickler.dumps(self))
def _current_fake_mode() -> torch._subclasses.FakeTensorMode:
fake_mode = None
if context := torch._guards.TracingContext.try_get():
fake_mode = context.fake_mode
if fake_mode is not None:
return fake_mode
shape_env = torch.fx.experimental.symbolic_shapes.ShapeEnv()
return torch._subclasses.FakeTensorMode(shape_env=shape_env)
@dataclass
class _WireProtocolPickledInput:
value: bytes
def deserialize(self) -> _WireProtocolInput:
"""
Turn this streamable object back into a _WireProtocolInput.
"""
from torch.fx._graph_pickler import GraphPickler
fake_mode = _current_fake_mode()
result = GraphPickler.loads(self.value, fake_mode)
assert isinstance(result, _WireProtocolInput)
return result
@dataclass
class _WireProtocolOutput:
"""
For _SerializedFxCompile - encapsulates all the data being transferred
(returned) back from the child to the parent.
"""
graph: OutputCode
metrics: CachedMetricsDeltas
logs: list[logging.LogRecord]
warning_replay: Optional[list[warnings.WarningMessage]]
def serialize(self) -> _WireProtocolPickledOutput:
"""
Turns this object into a _WireProtocolPickledOutput which can be
directly transferred across a stream.
"""
from torch.fx._graph_pickler import GraphPickler
if isinstance(self.graph, CompiledFxGraph):
self.graph.prepare_for_serialization()
return _WireProtocolPickledOutput(GraphPickler.dumps(self))
@dataclass
class _WireProtocolPickledOutput:
value: bytes
def deserialize(self, constants: CompiledFxGraphConstants) -> _WireProtocolOutput:
"""
Turn this streamable object back into a _WireProtocolOutput.
"""
from torch.fx._graph_pickler import GraphPickler
fake_mode = _current_fake_mode()
result = GraphPickler.loads(self.value, fake_mode)
assert isinstance(result, _WireProtocolOutput)
if isinstance(result.graph, CompiledFxGraph):
result.graph.after_deserialization(constants)
return result
class _LoggerState:
"""
This class is for tracking logging that happens during an out-of-process
compile so we can "replay" those messages when the compile is done. Used as
a context manager which returns the captured logs (object).
"""
loggers: dict[str, int]
# The actual log capturing mechanism - this should be None when we're not
# actively capturing logs.
captured_logs: Optional[_CapturedLogs] = None
def __init__(self) -> None:
# Mapping from logger name to level.
self.loggers = {}
def filter(
logger: Union[logging.Logger, logging.PlaceHolder],
) -> TypeGuard[logging.Logger]:
if not isinstance(logger, logging.Logger):
# Assume that Placeholders propagate
return False
# We only want to track torch._inductor logging
if not logger.name.startswith("torch._inductor"):
return False
# If this logger propagates then assume we'll track its parent
if logger.propagate:
return False
return True
root = logging.getLogger("torch._inductor")
if sys.version_info < (3, 12):
# logging.getChildren() doesn't exist until 3.12
logging._acquireLock() # type: ignore[attr-defined]
try:
for logger in root.manager.loggerDict.values():
if filter(logger):
self.loggers[logger.name] = logger.level
finally:
logging._releaseLock() # type: ignore[attr-defined]
else:
q = [root]
while q:
logger = q.pop()
if filter(logger):
self.loggers[logger.name] = logger.level
q.extend(logger.getChildren())
def __enter__(self) -> _CapturedLogs:
assert self.captured_logs is None
self.captured_logs = _CapturedLogs(self)
self.captured_logs.apply()
return self.captured_logs
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[types.TracebackType],
) -> None:
assert self.captured_logs is not None
self.captured_logs.remove()
class _CapturedLogs:
"""
Helper for _LoggerState - this class actually attaches to the logger in
the child process and grabs the log messages themselves.
"""
state: _LoggerState
queue: queue.Queue[logging.LogRecord]
handlers: Optional[dict[str, logging.Handler]]
def __init__(self, state: _LoggerState) -> None:
self.state = state
# A queue of the log entries
# TODO: For memory purposes should we log to a file and then respond with that?
self.queue = queue.Queue(-1)
# Mapping from name to handler (only valid when applied)
self.handlers = None
def finish(self) -> list[logging.LogRecord]:
assert self.handlers is None
logs = []
try:
while True:
logs.append(self.queue.get_nowait())
except queue.Empty:
pass
return logs
def remove(self) -> None:
assert self.handlers is not None
handlers, self.handlers = self.handlers, None
for name, handler in handlers.items():
logger = logging.getLogger(name)
logger.removeHandler(handler)
def apply(self) -> None:
from logging.handlers import QueueHandler
assert self.handlers is None
self.handlers = {}
for name, level in self.state.loggers.items():
logger = logging.getLogger(name)
handler = QueueHandler(self.queue)
self.handlers[name] = handler
logger.addHandler(handler)
if level != logging.NOTSET:
logger.setLevel(level)
class _SerializedFxCompile(FxCompile):
"""
This is used to represent an FxCompile which occurs across a serialized
boundary.
"""
@override
def codegen_and_compile(
self,
gm: GraphModule,
example_inputs: Sequence[InputType],
inputs_to_check: Sequence[int],
graph_kwargs: _CompileFxKwargs,
) -> OutputCode:
def fallback() -> OutputCode:
return _InProcessFxCompile().codegen_and_compile(
gm, example_inputs, inputs_to_check, graph_kwargs
)
try:
# _check_for_hop raises BypassFxGraphCache when it detects something
# we can't cache (or serialize)
FxGraphCache._check_for_hop(gm)
except BypassFxGraphCache as e:
log.debug("Skipping %s compile: %s", type(self), e)
return fallback()
context = torch._guards.TracingContext.try_get()
constants = CompiledFxGraphConstantsWithGm(gm)
logger_state = _LoggerState()
lowering = _LoweringSerializer()
# If we're running tests then grab the DeterministicGuard (don't want to
# import this if it isn't already imported because it has side-effects)
deterministic_guard_for_testing: Optional[
torch.testing._internal.common_utils.DeterministicGuard
] = None
try:
deterministic_guard_for_testing = (
torch.testing._internal.common_utils.DeterministicGuard._current_state()
)
except AttributeError:
pass
try:
input = _WireProtocolInput(
gm,
example_inputs,
inputs_to_check,
graph_kwargs,
context,
config.save_config_portable(),
_VirtualizedSerializer.serialize(),
deterministic_guard_for_testing,
logger_state,
lowering,
).serialize()
except (AttributeError, BypassFxGraphCache):
# For example: AttributeError: Can't pickle local object
# 'make_opaque_unary_fn.<locals>.OpaqueUnaryFn'
# TODO: scuba record about not being able to do this?
log.debug("Unable to pickle input graph or example inputs", exc_info=True)
return fallback()
output = self._send_to_child(input).deserialize(constants)
self._postprocess(output)
self._compile_stats[type(self)].codegen_and_compile += 1
# TODO: Do we need to figure out what changed in TracingContext in the
# child and plumb that back up to the parent?
return output.graph
@abstractmethod
def _send_to_child(
self, pickled_input: _WireProtocolPickledInput
) -> _WireProtocolPickledOutput:
# The implementation of this should transfer `input` to the child, call
# `_run_in_child(input)` and transfer the result back.
...
def _postprocess(self, output: _WireProtocolOutput) -> None:
pass
@classmethod
def _run_in_child(
cls,
pickled_input: _WireProtocolPickledInput,
extra_env: Optional[Mapping[str, str]] = None,
) -> _WireProtocolPickledOutput:
metrics = CachedMetricsHelper()
with contextlib.ExitStack() as stack:
if extra_env is not None:
import unittest
stack.enter_context(unittest.mock.patch.dict("os.environ", extra_env))
# Save warnings to "replay" in the parent
warning_replay = stack.enter_context(warnings.catch_warnings(record=True))
# TODO: Should we split the input into multiple sections where each
# section sets up state for the previous section? (i.e. a Config section
# which we decode and apply, followed by a FakeTensorMode section which
# we decode and apply, etc)
input = pickled_input.deserialize()
stack.enter_context(input.virtualized.patch())
stack.enter_context(input.lowering.patch())
stack.enter_context(config.patch(input.config))
captured_logs = stack.enter_context(input.logger_state)
if input.deterministic_guard_for_testing:
stack.enter_context(input.deterministic_guard_for_testing)
stack.enter_context(torch._guards.tracing(input.tracing_context))
stack.enter_context(DebugContext())
output_graph = _InProcessFxCompile().codegen_and_compile(
input.gm,
input.example_inputs,
input.inputs_to_check,
input.graph_kwargs,
)
logs = captured_logs.finish()
return _WireProtocolOutput(
output_graph, metrics.get_deltas(), logs, warning_replay
).serialize()
# This is a debugging/testing implementation of FxCompile which serializes the
# input and output but still runs the FxCompile in-process.
class _DebugSerdeFxCompile(_SerializedFxCompile):
@override
def _send_to_child(
self, pickled_input: _WireProtocolPickledInput
) -> _WireProtocolPickledOutput:
# For debugging just serde the input and output but don't run in a
# subprocess.
return self._run_in_child(pickled_input)
class _OutOfProcessFxCompile(_SerializedFxCompile):
"""
Represents an FxCompile which is run outside the current process (in
either a subprocess or possibly even a separate machine).
"""
def _postprocess(self, output: _WireProtocolOutput) -> None:
# Since our metrics were gathered in a subprocess make sure to add them
# here.
CachedMetricsHelper.apply_deltas(output.metrics)
# This is used by tests to check the output for specific details. For
# remote things (subproc and RE) we need to do the `save_output_code`
# here since it didn't happen earlier in-process. In the future if this
# doesn't have "source_code" (it's a CompiledAOTI, for example) and we
# need it we'll have to grab it and serialize it separately from the
# child.
if GraphLowering.save_output_code is not None:
GraphLowering.save_output_code(output.graph.source_code) # type: ignore[attr-defined]
# And forward our collected logs. The cache is cleared when the outer
# function exits.
@functools.lru_cache(None)
def getLogger(name: str) -> logging.Logger:
return logging.getLogger(name)
if output.warning_replay:
for w in output.warning_replay:
warnings.warn_explicit(
message=w.message,
category=w.category,
filename=w.filename,
lineno=w.lineno,
source=w.source,
)
for record in output.logs:
logger = getLogger(record.name)
logger.handle(record)
# For debugging - create a _FxCompile which writes the serialized data to a file
# and then exits.
#
# TODO: make this a FxCompileMode value?
#
# The "child runner" should look something like this:
#
# import torch
# from torch._inductor import compile_fx
# idx = 0
# with open(f"/tmp/pytorch_compile_fx_tmp_input_{idx}.bin", "rb") as f:
# input = compile_fx._WireProtocolPickledInput(f.read())
# result = compile_fx._SubprocessFxCompile._run_in_child(input)
# with open(f"/tmp/pytorch_compile_fx_tmp_output_{idx}.bin", "wb") as f:
# f.write(result.value)
#
class _DebugFileFxCompile(_OutOfProcessFxCompile):
file_index = 0
@override
def _send_to_child(
self, pickled_input: _WireProtocolPickledInput
) -> _WireProtocolPickledOutput:
idx = _DebugFileFxCompile.file_index
_DebugFileFxCompile.file_index += 1
name = f"/tmp/aorenste/pytorch_compile_fx_tmp_input_{idx}.bin"
with open(name, "wb") as f:
f.write(pickled_input.value)
print(f"Wrote to {name}")
if False:
name = f"/tmp/aorenste/pytorch_compile_fx_tmp_actual_{idx}.bin"
actual = self._run_in_child(pickled_input)
with open(name, "wb") as f:
f.write(actual.value)
return actual
elif False:
name = f"/tmp/aorenste/pytorch_compile_fx_tmp_output_{idx}.bin"
with open(name, "rb") as f:
result = _WireProtocolPickledOutput(f.read())
print(f"Read from {name}")
return result
else:
os._exit(-1)

View File

@ -0,0 +1,102 @@
from __future__ import annotations
import atexit
import functools
import os
from typing import Optional, TYPE_CHECKING
from typing_extensions import override
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
import torch.fx
from torch._inductor.compile_worker.subproc_pool import (
AnyPool,
SubprocKind,
SubprocPool,
)
from torch._inductor.utils import clear_inductor_caches
from .compile_fx_ext import (
_OutOfProcessFxCompile,
_WireProtocolPickledInput,
_WireProtocolPickledOutput,
)
from .output_code import complex_memory_overlap as complex_memory_overlap # noqa: F401
if TYPE_CHECKING:
from collections.abc import Mapping
class _SubprocessFxCompile(_OutOfProcessFxCompile):
@override
def _send_to_child(
self, input: _WireProtocolPickledInput
) -> _WireProtocolPickledOutput:
# TODO: Do we need to copy across some kind of logging IDs? (ChromiumEventLogger)
pool = self.process_pool()
# TODO: This is probably the wrong thing to do long-term - but for now
# let's share the cache so we can identify tests broken by this later.
env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"]
extra_env = {v: os.environ[v] for v in env_vars if v in os.environ}
f = pool.submit(_SubprocessFxCompile._run_in_child_subprocess, input, extra_env)
# For debugging: If we want to print status updates...
# last = time.time()
# while not f.done():
# print("tick...")
# time.sleep(0.125)
# now = time.time()
# if now - last > 1:
# last = now
output = f.result()
return output
@staticmethod
@functools.cache
def process_pool() -> AnyPool:
pool = SubprocPool(
# TODO: Consider raising this limit if we start using async w/
# subprocess and want to compile multiple graphs in parallel.
1,
kind=SubprocKind.SPAWN,
)
atexit.register(pool.shutdown)
return pool
@classmethod
def _run_in_child_subprocess(
cls,
pickled_input: _WireProtocolPickledInput,
extra_env: Optional[Mapping[str, str]],
) -> _WireProtocolPickledOutput:
# TODO: In subprocess mode we need to clear the inductor caches.
# The problem:
# 1. We compile in worker A which fills stuff in tmpdir
# 2. parent clears inductor caches which deletes tmpdirs and tells
# cpp_prefix_path() to clear its LRU cache
# 3. We compile a second time in subproc A - but since we never told
# cpp_prefix_path() in worker A to clear its LRU it thinks the
# tmpdir still exists and fails to compile.
#
# TODO: We probably should be using a separate tmpdir in the worker
# anyway... but we should probably still respect clear_inductor_caches()
# in the parent... maybe?
#
# TODO: We could be less aggressive by keeping a clock which gets
# incremented when we clear the cache, send the clock to the worker and
# only clear caches if the clock changed since last time.
#
clear_inductor_caches()
torch._inductor.metrics.reset()
# TODO: turn off config.fx_graph_async_compile
result = cls._run_in_child(pickled_input, extra_env)
return result

View File

@ -1903,6 +1903,9 @@ def fallback_handler(kernel, add_to_fallback_set=True):
wrap_tensors, ir.FallbackKernel.create(kernel, *args, **kwargs)
)
# This lets us detect that a lowering is a fallback handler.
handler._is_fallback_handler = True # type: ignore[attr-defined]
return handler

View File

@ -96,6 +96,14 @@ class NullHandler:
"""
# If a virtualized value is set to _PoisonedVirtual then any attempt to get the
# value will result an an exception being raised. This is useful if we want to
# trap uninitialized reads of virtualized globals - for example when compiling
# in a subprocess we don't want the child reading globals that weren't copied
# from the parent.
_PoisonedVirtual = object()
class Virtualized(Generic[T]):
"""
Implements a global variable that redirects via thread local variable
@ -110,11 +118,12 @@ class Virtualized(Generic[T]):
"""
def __init__(self, vname: str, default: Union[Callable[[], T], type[NullHandler]]):
self._vname = vname
self._key: str = f"__torchinductor_{vname}"
self._default = default
def _set_handler(self, value: T) -> AbstractContextManager[None]:
prior = self._get_handler()
prior = self._get_handler(False)
setattr(threadlocal, self._key, value)
@contextmanager
@ -126,9 +135,14 @@ class Virtualized(Generic[T]):
return ctx()
def _get_handler(self) -> T:
def _get_handler(self, check_poisoned: bool = True) -> T:
try:
return getattr(threadlocal, self._key)
value = getattr(threadlocal, self._key)
if check_poisoned and value is _PoisonedVirtual:
raise RuntimeError(
f"Attempt to use poisoned virtualized value '{self._vname}'."
)
return value
except AttributeError:
# TODO: To be honest, I feel we probably should just error in this
# case, instead of making a null handler that will probably error

View File

@ -2031,20 +2031,24 @@ class DeterministicGuard:
self.warn_only = warn_only
self.fill_uninitialized_memory = fill_uninitialized_memory
def __enter__(self):
self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
self.fill_uninitialized_memory_restore = torch.utils.deterministic.fill_uninitialized_memory # type: ignore[attr-defined]
torch.use_deterministic_algorithms(
self.deterministic,
warn_only=self.warn_only)
@classmethod
def _current_state(cls):
return cls(
torch.are_deterministic_algorithms_enabled(),
warn_only=torch.is_deterministic_algorithms_warn_only_enabled(),
fill_uninitialized_memory=torch.utils.deterministic.fill_uninitialized_memory, # type: ignore[attr-defined]
)
def _update(self):
torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only)
torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory # type: ignore[attr-defined]
def __enter__(self):
self._restore = self._current_state()
self._update()
def __exit__(self, exception_type, exception_value, traceback):
torch.use_deterministic_algorithms(
self.deterministic_restore,
warn_only=self.warn_only_restore)
torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory_restore # type: ignore[attr-defined]
self._restore._update()
class AlwaysWarnTypedStorageRemoval:
def __init__(self, always_warn):