mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Subprocess compile (#146134)
Add a mode to `fx_codegen_and_compile()` to compile in a separate process. This is to prepare for async compile where we'll compile and run eager in parallel (and also be able to move the compile phase to a remote computer). Added a test based which runs the test_torchinductor tests with subprocess compiling turned on. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146134 Approved by: https://github.com/jamesjwu
This commit is contained in:
committed by
PyTorch MergeBot
parent
8f361c808b
commit
07f876e960
@ -25,13 +25,17 @@ from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inducto
|
||||
check_model,
|
||||
CommonTemplate,
|
||||
copy_tests,
|
||||
TestFailure,
|
||||
)
|
||||
|
||||
|
||||
importlib.import_module("filelock")
|
||||
|
||||
# xfail by default, set is_skip=True to skip
|
||||
test_failures = {}
|
||||
test_failures = {
|
||||
# TypeError: cannot pickle 'generator' object
|
||||
"test_layer_norm_graph_pickler": TestFailure(("cpu"), is_skip=True),
|
||||
}
|
||||
|
||||
|
||||
def make_test_cls(cls, xfail_prop="_expected_failure_graph_pickler"):
|
||||
|
95
test/inductor/test_compile_subprocess.py
Normal file
95
test/inductor/test_compile_subprocess.py
Normal file
@ -0,0 +1,95 @@
|
||||
# Owner(s): ["module: fx"]
|
||||
|
||||
#
|
||||
# Tests compiling the inductor tests in a subprocess.
|
||||
#
|
||||
|
||||
import contextlib
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.library
|
||||
from torch._inductor.compile_fx import _InProcessFxCompile, FxCompile, FxCompileMode
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing._internal.common_utils import TEST_WITH_ASAN
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
|
||||
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
import inductor.test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library
|
||||
from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library
|
||||
check_model,
|
||||
check_model_gpu,
|
||||
copy_tests,
|
||||
TestFailure,
|
||||
)
|
||||
|
||||
|
||||
importlib.import_module("filelock")
|
||||
|
||||
# xfail by default, set is_skip=True to skip
|
||||
test_failures = {
|
||||
# TypeError: cannot pickle 'generator' object
|
||||
"test_layer_norm": TestFailure(("cpu", "cuda"), is_skip=True),
|
||||
}
|
||||
|
||||
|
||||
class TestSubprocess(TestCase):
|
||||
def setUp(self):
|
||||
torch._dynamo.reset()
|
||||
FxCompile._reset_stats()
|
||||
|
||||
TestCase.setUp(self)
|
||||
|
||||
self._stack = contextlib.ExitStack()
|
||||
self._stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.compile_fx.fx_compile_mode",
|
||||
FxCompileMode.SUBPROCESS,
|
||||
)
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
# Check that the test didn't instigate an in-process compile - which
|
||||
# would mean that something about the fx graph failed to serialize. If
|
||||
# some tests are expected to fail then we should probably add a list of
|
||||
# expected failures here.
|
||||
self.assertEqual(
|
||||
FxCompile._compile_stats[type(_InProcessFxCompile)].codegen_and_compile, 0
|
||||
)
|
||||
self._stack.close()
|
||||
TestCase.tearDown(self)
|
||||
torch._dynamo.reset()
|
||||
|
||||
|
||||
if HAS_CPU:
|
||||
|
||||
class CpuTests(TestSubprocess):
|
||||
common = check_model
|
||||
device = "cpu"
|
||||
|
||||
copy_tests(
|
||||
inductor.test_torchinductor.CommonTemplate, CpuTests, "cpu", test_failures
|
||||
)
|
||||
|
||||
if HAS_GPU and not TEST_WITH_ASAN:
|
||||
|
||||
class GPUTests(TestSubprocess):
|
||||
common = check_model_gpu
|
||||
device = GPU_TYPE
|
||||
|
||||
copy_tests(
|
||||
inductor.test_torchinductor.CommonTemplate, GPUTests, GPU_TYPE, test_failures
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
if HAS_CPU or HAS_GPU:
|
||||
run_tests(needs="filelock")
|
@ -20,6 +20,8 @@ import unittest
|
||||
import unittest.mock
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
from typing import Callable, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
@ -133,6 +135,10 @@ from torch.testing._internal.inductor_utils import (
|
||||
from torch.testing._internal.triton_utils import requires_cuda
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
HAS_AVX2 = "fbgemm" in torch.backends.quantized.supported_engines
|
||||
|
||||
aten = torch.ops.aten
|
||||
@ -6395,16 +6401,16 @@ class CommonTemplate:
|
||||
return x.cos().sin().softmax(-1)
|
||||
|
||||
x = torch.randn(16, 256, device=self.device)
|
||||
_, (coda_a0,) = run_and_get_kernels(a, x)
|
||||
_, (coda_b0,) = run_and_get_kernels(b, x)
|
||||
_, (coda_c0,) = run_and_get_kernels(c, x)
|
||||
_, (coda_a0,) = _run_and_get_stripped_kernels(a, x)
|
||||
_, (coda_b0,) = _run_and_get_stripped_kernels(b, x)
|
||||
_, (coda_c0,) = _run_and_get_stripped_kernels(c, x)
|
||||
self.assertEqual(coda_a0, coda_c0)
|
||||
|
||||
# compile in a different order
|
||||
torch.compiler.reset()
|
||||
_, (coda_c1,) = run_and_get_kernels(c, x)
|
||||
_, (coda_a1,) = run_and_get_kernels(a, x)
|
||||
_, (coda_b1,) = run_and_get_kernels(b, x)
|
||||
_, (coda_c1,) = _run_and_get_stripped_kernels(c, x)
|
||||
_, (coda_a1,) = _run_and_get_stripped_kernels(a, x)
|
||||
_, (coda_b1,) = _run_and_get_stripped_kernels(b, x)
|
||||
self.assertEqual(coda_a0, coda_a1)
|
||||
self.assertEqual(coda_b0, coda_b1)
|
||||
self.assertEqual(coda_c0, coda_c1)
|
||||
@ -6417,9 +6423,9 @@ class CommonTemplate:
|
||||
"__init__",
|
||||
lambda self, _: CompileContext_init(self, CompileId(999, 999)),
|
||||
):
|
||||
_, (coda_a2,) = run_and_get_kernels(a, x)
|
||||
_, (coda_c2,) = run_and_get_kernels(c, x)
|
||||
_, (coda_b2,) = run_and_get_kernels(b, x)
|
||||
_, (coda_a2,) = _run_and_get_stripped_kernels(a, x)
|
||||
_, (coda_c2,) = _run_and_get_stripped_kernels(c, x)
|
||||
_, (coda_b2,) = _run_and_get_stripped_kernels(b, x)
|
||||
self.assertEqual(coda_a0, coda_a2)
|
||||
self.assertEqual(coda_b0, coda_b2)
|
||||
self.assertEqual(coda_c0, coda_c2)
|
||||
@ -6442,7 +6448,7 @@ class CommonTemplate:
|
||||
return x
|
||||
|
||||
x = torch.randn(16, 256, device=self.device)
|
||||
_, (code0, code1) = run_and_get_kernels(b, x)
|
||||
_, (code0, code1) = _run_and_get_stripped_kernels(b, x)
|
||||
self.assertEqual(code0, code1)
|
||||
|
||||
@patch.object(cpp_prefix_path, "cache_clear", lambda: None)
|
||||
@ -6464,8 +6470,8 @@ class CommonTemplate:
|
||||
|
||||
x = torch.randn(16, 256, device=self.device)
|
||||
y = torch.randn(256, 256, device=self.device)
|
||||
_, (code0,) = run_and_get_kernels(a, x)
|
||||
_, (code1,) = run_and_get_kernels(b, x, y)
|
||||
_, (code0,) = _run_and_get_stripped_kernels(a, x)
|
||||
_, (code1,) = _run_and_get_stripped_kernels(b, x, y)
|
||||
self.assertEqual(code0, code1)
|
||||
|
||||
def test_flip(self):
|
||||
@ -14270,6 +14276,20 @@ if HAS_CPU:
|
||||
self.assertEqual(ret_opt, fn(pytype, dtype))
|
||||
|
||||
|
||||
def _strip_tmp_path(code: str) -> str:
|
||||
"""
|
||||
Canonicalize things that look like a tmp path so they can be compared.
|
||||
"""
|
||||
return re.sub('#include ".*?"', '#include "<tmppath>"', code)
|
||||
|
||||
|
||||
def _run_and_get_stripped_kernels(
|
||||
fn: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
|
||||
) -> tuple[_T, list[str]]:
|
||||
result, codes = run_and_get_kernels(fn, *args, **kwargs)
|
||||
return result, [_strip_tmp_path(code) for code in codes]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
|
@ -1169,6 +1169,24 @@ class FxGraphCache:
|
||||
log.warning("fx graph unable to write to cache", exc_info=True)
|
||||
counters["inductor"]["fxgraph_cache_write_error"] += 1
|
||||
|
||||
@staticmethod
|
||||
def _check_for_hop(gm: torch.fx.GraphModule) -> None:
|
||||
for module in gm.modules():
|
||||
if not isinstance(module, torch.fx.GraphModule):
|
||||
continue
|
||||
for node in module.graph.nodes:
|
||||
if (
|
||||
isinstance(node.target, torch._ops.HigherOrderOperator)
|
||||
and not node.target.cacheable()
|
||||
):
|
||||
raise BypassFxGraphCache(
|
||||
f"Can't cache HigherOrderOperator: {node.target.name()}"
|
||||
)
|
||||
if node.op == "getattr" and isinstance(
|
||||
getattr(gm, node.target), torch._C.ScriptObject
|
||||
):
|
||||
raise BypassFxGraphCache("Can't cache torchbind objects")
|
||||
|
||||
@staticmethod
|
||||
def _check_can_cache(gm: torch.fx.GraphModule) -> None:
|
||||
"""
|
||||
@ -1205,22 +1223,8 @@ class FxGraphCache:
|
||||
log.debug("fx graph cache no shape env")
|
||||
raise BypassFxGraphCache("No shape env")
|
||||
|
||||
# We skip caching if there are any torchbind objects.
|
||||
for module in gm.modules():
|
||||
if not isinstance(module, torch.fx.GraphModule):
|
||||
continue
|
||||
for node in module.graph.nodes:
|
||||
if (
|
||||
isinstance(node.target, torch._ops.HigherOrderOperator)
|
||||
and not node.target.cacheable()
|
||||
):
|
||||
raise BypassFxGraphCache(
|
||||
f"Can't cache HigherOrderOperator: {node.target.name()}"
|
||||
)
|
||||
if node.op == "getattr" and isinstance(
|
||||
getattr(gm, node.target), torch._C.ScriptObject
|
||||
):
|
||||
raise BypassFxGraphCache("Can't cache torchbind objects")
|
||||
# We skip caching if there are any HOPs or torchbind objects.
|
||||
FxGraphCache._check_for_hop(gm)
|
||||
|
||||
@staticmethod
|
||||
def prepare_key(
|
||||
|
@ -12,8 +12,8 @@ import sys
|
||||
import time
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from contextlib import AbstractContextManager
|
||||
from dataclasses import dataclass
|
||||
from inspect import currentframe
|
||||
from itertools import count
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
@ -54,18 +54,12 @@ from torch._functorch.aot_autograd import (
|
||||
make_boxed_func,
|
||||
SerializableAOTDispatchCompiler,
|
||||
)
|
||||
from torch._inductor.codecache import (
|
||||
BypassFxGraphCache,
|
||||
code_hash,
|
||||
FxGraphCache,
|
||||
output_code_log,
|
||||
)
|
||||
from torch._inductor.codecache import code_hash, FxGraphCache, output_code_log
|
||||
from torch._inductor.cudagraph_utils import BoxedDeviceIndex, PlaceholderInfo
|
||||
from torch._inductor.debug import save_args_for_compile_fx_inner
|
||||
from torch._inductor.output_code import (
|
||||
CompiledAOTI,
|
||||
CompiledFxGraph,
|
||||
CompiledFxGraphConstants,
|
||||
CompiledFxGraphConstantsWithGm,
|
||||
get_expanded_dims,
|
||||
index_expanded_dims,
|
||||
@ -121,7 +115,7 @@ from .virtualized import V
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from collections.abc import Generator, Sequence
|
||||
|
||||
from torch._inductor.output_code import _StrideExprStr
|
||||
from torch._ops import OpOverload
|
||||
@ -151,21 +145,20 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
# For testing - use the serde FxCompile scheme to debug serialization and
|
||||
# deserialization of GraphMoule and CompiledFxGraph.
|
||||
class FxCompileMode(enum.Enum):
|
||||
NORMAL = 0
|
||||
# For testing - use the serde FxCompile scheme to debug serialization and
|
||||
# deserialization of GraphMoule and CompiledFxGraph.
|
||||
SERIALIZE = 1
|
||||
# Compile using a subprocess instead of in-process.
|
||||
SUBPROCESS = 2
|
||||
|
||||
|
||||
def _fx_compile_mode_default() -> FxCompileMode:
|
||||
name = "TORCHINDUCTOR_FX_COMPILE_MODE"
|
||||
value = os.environ.get(name)
|
||||
NORMAL = FxCompileMode.NORMAL
|
||||
if value is None:
|
||||
return NORMAL
|
||||
return FxCompileMode.NORMAL
|
||||
try:
|
||||
value = value.upper()
|
||||
return FxCompileMode[value]
|
||||
@ -186,7 +179,6 @@ def _fx_compile_mode_default() -> FxCompileMode:
|
||||
|
||||
fx_compile_mode = _fx_compile_mode_default()
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
||||
pre_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "pre_grad_graphs")
|
||||
@ -849,12 +841,23 @@ def _compile_fx_inner(
|
||||
return compiled_graph
|
||||
|
||||
|
||||
class _FxCompileStat:
|
||||
# Count of successful compiles of this type
|
||||
codegen_and_compile: int = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"codegen_and_compile: {self.codegen_and_compile}"
|
||||
|
||||
|
||||
class FxCompile(ABC):
|
||||
"""
|
||||
An FxCompile represents a mechanism that can turn a GraphModule into an
|
||||
OutputCode.
|
||||
"""
|
||||
|
||||
# Some stats for logging/debugging
|
||||
_compile_stats: dict[type[FxCompile], _FxCompileStat] = defaultdict(_FxCompileStat)
|
||||
|
||||
# TODO: We should probably eventually add some kind of async version of this
|
||||
# so we can kick off a compile and then go do other things - but we'll need
|
||||
# to know what kind of API we want for that first.
|
||||
@ -867,6 +870,10 @@ class FxCompile(ABC):
|
||||
graph_kwargs: _CompileFxKwargs,
|
||||
) -> OutputCode: ...
|
||||
|
||||
@classmethod
|
||||
def _reset_stats(cls) -> None:
|
||||
cls._compile_stats.clear()
|
||||
|
||||
|
||||
class _InProcessFxCompile(FxCompile):
|
||||
@override
|
||||
@ -881,6 +888,7 @@ class _InProcessFxCompile(FxCompile):
|
||||
# to propagate it further on
|
||||
# TODO: _CompileFxKwargs actually has stronger types than in the
|
||||
# signature, need to tighten it up
|
||||
|
||||
assert "cudagraphs" in graph_kwargs and graph_kwargs["cudagraphs"] is not None
|
||||
cudagraphs: BoxedBool = graph_kwargs["cudagraphs"]
|
||||
static_input_idxs: Sequence[int] = graph_kwargs.get("static_input_idxs", ())
|
||||
@ -1215,6 +1223,8 @@ class _InProcessFxCompile(FxCompile):
|
||||
)
|
||||
)
|
||||
|
||||
self._compile_stats[type(self)].codegen_and_compile += 1
|
||||
|
||||
return CompiledFxGraph(
|
||||
compiled_fn,
|
||||
graph,
|
||||
@ -1232,195 +1242,6 @@ class _InProcessFxCompile(FxCompile):
|
||||
)
|
||||
|
||||
|
||||
def _current_fake_mode() -> torch._subclasses.FakeTensorMode:
|
||||
fake_mode = None
|
||||
if context := torch._guards.TracingContext.try_get():
|
||||
fake_mode = context.fake_mode
|
||||
if fake_mode is not None:
|
||||
return fake_mode
|
||||
|
||||
shape_env = torch.fx.experimental.symbolic_shapes.ShapeEnv()
|
||||
return torch._subclasses.FakeTensorMode(shape_env=shape_env)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _WireProtocolInput:
|
||||
"""
|
||||
For _SerializedFxCompile - encapsulates all the data being transferred
|
||||
(sent) from the parent to the child.
|
||||
"""
|
||||
|
||||
gm: torch.fx.GraphModule
|
||||
example_inputs: Sequence[InputType]
|
||||
inputs_to_check: Sequence[int]
|
||||
graph_kwargs: _CompileFxKwargs
|
||||
# TODO: Add additional state to transfer to the child.
|
||||
|
||||
def serialize(self) -> _WireProtocolPickledInput:
|
||||
"""
|
||||
Turns this object into a _WireProtocolPickledInput which can be
|
||||
directly transferred across a stream.
|
||||
"""
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
|
||||
return _WireProtocolPickledInput(GraphPickler.dumps(self))
|
||||
|
||||
|
||||
@dataclass
|
||||
class _WireProtocolPickledInput:
|
||||
value: bytes
|
||||
|
||||
def deserialize(self) -> _WireProtocolInput:
|
||||
"""
|
||||
Turn this streamable object back into a _WireProtocolInput.
|
||||
"""
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
|
||||
fake_mode = _current_fake_mode()
|
||||
result = GraphPickler.loads(self.value, fake_mode)
|
||||
assert isinstance(result, _WireProtocolInput)
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class _WireProtocolOutput:
|
||||
"""
|
||||
For _SerializedFxCompile - encapsulates all the data being transferred
|
||||
(returned) back from the child to the parent.
|
||||
"""
|
||||
|
||||
graph: OutputCode
|
||||
|
||||
def serialize(self) -> _WireProtocolPickledOutput:
|
||||
"""
|
||||
Turns this object into a _WireProtocolPickledOutput which can be
|
||||
directly transferred across a stream.
|
||||
"""
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
|
||||
if isinstance(self.graph, CompiledFxGraph):
|
||||
self.graph.prepare_for_serialization()
|
||||
return _WireProtocolPickledOutput(GraphPickler.dumps(self))
|
||||
|
||||
|
||||
@dataclass
|
||||
class _WireProtocolPickledOutput:
|
||||
value: bytes
|
||||
|
||||
def deserialize(self, constants: CompiledFxGraphConstants) -> _WireProtocolOutput:
|
||||
"""
|
||||
Turn this streamable object back into a _WireProtocolOutput.
|
||||
"""
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
|
||||
fake_mode = _current_fake_mode()
|
||||
result = GraphPickler.loads(self.value, fake_mode)
|
||||
assert isinstance(result, _WireProtocolOutput)
|
||||
if isinstance(result.graph, CompiledFxGraph):
|
||||
result.graph.after_deserialization(constants)
|
||||
return result
|
||||
|
||||
|
||||
class _SerializedFxCompile(FxCompile):
|
||||
"""
|
||||
This is used to represent an FxCompile which occurs across a serialized
|
||||
boundary.
|
||||
"""
|
||||
|
||||
@override
|
||||
def codegen_and_compile(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
example_inputs: Sequence[InputType],
|
||||
inputs_to_check: Sequence[int],
|
||||
graph_kwargs: _CompileFxKwargs,
|
||||
) -> OutputCode:
|
||||
# _context = torch._guards.TracingContext.try_get()
|
||||
constants = CompiledFxGraphConstantsWithGm(gm)
|
||||
|
||||
try:
|
||||
input = _WireProtocolInput(
|
||||
gm,
|
||||
example_inputs,
|
||||
inputs_to_check,
|
||||
graph_kwargs,
|
||||
).serialize()
|
||||
except (AttributeError, BypassFxGraphCache):
|
||||
# For example: AttributeError: Can't pickle local object
|
||||
# 'make_opaque_unary_fn.<locals>.OpaqueUnaryFn'
|
||||
|
||||
# TODO: scuba record about not being able to do this?
|
||||
log.debug("Unable to pickle input graph or example inputs", exc_info=True)
|
||||
|
||||
# Fallback to in-process
|
||||
return _InProcessFxCompile().codegen_and_compile(
|
||||
gm, example_inputs, inputs_to_check, graph_kwargs
|
||||
)
|
||||
|
||||
output = self._send_to_child(input).deserialize(constants)
|
||||
|
||||
self._postprocess(output)
|
||||
|
||||
# TODO: Do we need to figure out what changed in TracingContext in the
|
||||
# child and plumb that back up to the parent?
|
||||
|
||||
return output.graph
|
||||
|
||||
@abstractmethod
|
||||
def _send_to_child(
|
||||
self, pickled_input: _WireProtocolPickledInput
|
||||
) -> _WireProtocolPickledOutput:
|
||||
# The implementation of this should transfer `input` to the child, call
|
||||
# `_run_in_child(input)` and transfer the result back.
|
||||
...
|
||||
|
||||
def _postprocess(self, output: _WireProtocolOutput) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _run_in_child(
|
||||
cls,
|
||||
pickled_input: _WireProtocolPickledInput,
|
||||
extra_env: Optional[Mapping[str, str]] = None,
|
||||
) -> _WireProtocolPickledOutput:
|
||||
with contextlib.ExitStack() as stack:
|
||||
if extra_env is not None:
|
||||
import unittest
|
||||
|
||||
stack.enter_context(unittest.mock.patch.dict("os.environ", extra_env))
|
||||
|
||||
# TODO: Should we split the input into multiple sections where each
|
||||
# section sets up state for the previous section? (i.e. a Config section
|
||||
# which we decode and apply, followed by a FakeTensorMode section which
|
||||
# we decode and apply, etc)
|
||||
input = pickled_input.deserialize()
|
||||
|
||||
stack.enter_context(DebugContext())
|
||||
|
||||
output_graph = _InProcessFxCompile().codegen_and_compile(
|
||||
input.gm,
|
||||
input.example_inputs,
|
||||
input.inputs_to_check,
|
||||
input.graph_kwargs,
|
||||
)
|
||||
|
||||
return _WireProtocolOutput(
|
||||
output_graph,
|
||||
).serialize()
|
||||
|
||||
|
||||
# This is a debugging/testing implementation of FxCompile which serializes the
|
||||
# input and output but still runs the FxCompile in-process.
|
||||
class _DebugSerdeFxCompile(_SerializedFxCompile):
|
||||
@override
|
||||
def _send_to_child(
|
||||
self, pickled_input: _WireProtocolPickledInput
|
||||
) -> _WireProtocolPickledOutput:
|
||||
# For debugging just serde the input and output but don't run in a
|
||||
# subprocess.
|
||||
return self._run_in_child(pickled_input)
|
||||
|
||||
|
||||
def fx_codegen_and_compile(
|
||||
gm: GraphModule,
|
||||
example_inputs: Sequence[InputType],
|
||||
@ -1430,12 +1251,17 @@ def fx_codegen_and_compile(
|
||||
**graph_kwargs: Unpack[_CompileFxKwargs],
|
||||
) -> OutputCode:
|
||||
scheme: FxCompile
|
||||
|
||||
if fx_compile_mode == FxCompileMode.NORMAL:
|
||||
scheme = _InProcessFxCompile()
|
||||
elif fx_compile_mode == FxCompileMode.SERIALIZE:
|
||||
from .compile_fx_ext import _DebugSerdeFxCompile
|
||||
|
||||
scheme = _DebugSerdeFxCompile()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
elif fx_compile_mode == FxCompileMode.SUBPROCESS:
|
||||
from .compile_fx_subproc import _SubprocessFxCompile
|
||||
|
||||
scheme = _SubprocessFxCompile()
|
||||
|
||||
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
|
||||
|
||||
|
604
torch/_inductor/compile_fx_ext.py
Normal file
604
torch/_inductor/compile_fx_ext.py
Normal file
@ -0,0 +1,604 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import sys
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import override, Self, TypeGuard
|
||||
|
||||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||
import torch.fx
|
||||
from torch._inductor.codecache import BypassFxGraphCache, FxGraphCache
|
||||
from torch._inductor.metrics import CachedMetricsDeltas, CachedMetricsHelper
|
||||
from torch._inductor.output_code import (
|
||||
CompiledFxGraph,
|
||||
CompiledFxGraphConstants,
|
||||
CompiledFxGraphConstantsWithGm,
|
||||
OutputCode,
|
||||
)
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from . import config
|
||||
from .compile_fx import _CompileFxKwargs, _InProcessFxCompile, FxCompile, log
|
||||
from .debug import DebugContext
|
||||
from .graph import GraphLowering
|
||||
from .output_code import complex_memory_overlap as complex_memory_overlap # noqa: F401
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import types
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from torch._inductor.utils import InputType
|
||||
from torch.fx import GraphModule
|
||||
|
||||
|
||||
@dataclass
|
||||
class _VirtualizedSerializer:
|
||||
"""
|
||||
This handles the data for serializing Virtualized.
|
||||
"""
|
||||
|
||||
# The values here get serialized. We don't grab everything because some of
|
||||
# the fields can't be serialized.
|
||||
aot_compilation: Any = None
|
||||
choices: Any = None
|
||||
local_buffer_context: Any = None
|
||||
ops: Any = None
|
||||
kernel: Any = None
|
||||
current_node: Any = None
|
||||
|
||||
@classmethod
|
||||
def serialize(cls) -> _VirtualizedSerializer:
|
||||
"""
|
||||
Turn the current state of torch._inductor.virtualized.V into a
|
||||
serializable structure.
|
||||
"""
|
||||
kwargs = {}
|
||||
for f in dataclasses.fields(cls):
|
||||
kwargs[f.name] = getattr(V, f.name)
|
||||
return _VirtualizedSerializer(**kwargs)
|
||||
|
||||
def patch(self) -> _VirtualizedSerializerContextManager:
|
||||
"""
|
||||
Returns a context manager which patches the saved values into the
|
||||
current environment. While patched, any value not listed above will be
|
||||
poisoned so that reads will raise an error.
|
||||
"""
|
||||
return _VirtualizedSerializerContextManager(self)
|
||||
|
||||
|
||||
class _VirtualizedSerializerContextManager(contextlib.ExitStack):
|
||||
"""
|
||||
Helper for _VirtualizedSerializer.patch()
|
||||
"""
|
||||
|
||||
def __init__(self, virtualized: _VirtualizedSerializer) -> None:
|
||||
super().__init__()
|
||||
self.virtualized = virtualized
|
||||
|
||||
@override
|
||||
def __enter__(self) -> Self:
|
||||
super().__enter__()
|
||||
|
||||
for set_name in dir(V):
|
||||
if not set_name.startswith("set_"):
|
||||
continue
|
||||
name = set_name[4:]
|
||||
name = name.removesuffix("_handler")
|
||||
set_handler = getattr(V, set_name)
|
||||
if hasattr(self.virtualized, name):
|
||||
value = getattr(self.virtualized, name)
|
||||
else:
|
||||
# poison any values that we don't serialize so that any
|
||||
# unset accesses are caught.
|
||||
value = torch._inductor.virtualized._PoisonedVirtual
|
||||
self.enter_context(set_handler(value))
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def _is_fallback_handler(op: object) -> bool:
|
||||
try:
|
||||
return op._is_fallback_handler # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
||||
class _LoweringSerializer:
|
||||
"""
|
||||
This handles the data for serializing lowering.lowering
|
||||
"""
|
||||
|
||||
# A full implementation would make sure that all lowerings are copied over
|
||||
# (or at least detected and raise a bypass when a non-standard lowering is
|
||||
# used). For now we just handle tests by looking for lowerings that were
|
||||
# overridden with a forced fallback.
|
||||
fallbacks: OrderedSet[str]
|
||||
|
||||
def __init__(self) -> None:
|
||||
from . import lowering
|
||||
|
||||
self.fallbacks = OrderedSet(
|
||||
str(k) for k, v in lowering.lowerings.items() if _is_fallback_handler(v)
|
||||
)
|
||||
|
||||
def patch(self) -> _LoweringSerializerContextManager:
|
||||
return _LoweringSerializerContextManager(self)
|
||||
|
||||
|
||||
class _LoweringSerializerContextManager(contextlib.ExitStack):
|
||||
"""
|
||||
Helper for _LoweringSerializer.patch()
|
||||
"""
|
||||
|
||||
def __init__(self, lowering: _LoweringSerializer) -> None:
|
||||
super().__init__()
|
||||
self.lowering = lowering
|
||||
|
||||
@override
|
||||
def __enter__(self) -> Self:
|
||||
super().__enter__()
|
||||
|
||||
from . import lowering
|
||||
|
||||
for k, v in lowering.lowerings.items():
|
||||
name = str(k)
|
||||
if name in self.lowering.fallbacks:
|
||||
if not _is_fallback_handler(v):
|
||||
self.enter_context(lowering.force_fallback(k)) # type: ignore[arg-type]
|
||||
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class _WireProtocolInput:
|
||||
"""
|
||||
For _SerializedFxCompile - encapsulates all the data being transferred
|
||||
(sent) from the parent to the child.
|
||||
"""
|
||||
|
||||
gm: torch.fx.GraphModule
|
||||
example_inputs: Sequence[InputType]
|
||||
inputs_to_check: Sequence[int]
|
||||
graph_kwargs: _CompileFxKwargs
|
||||
tracing_context: Optional[torch._guards.TracingContext]
|
||||
config: dict[str, object]
|
||||
virtualized: _VirtualizedSerializer
|
||||
deterministic_guard_for_testing: Optional[
|
||||
torch.testing._internal.common_utils.DeterministicGuard
|
||||
]
|
||||
logger_state: _LoggerState
|
||||
lowering: _LoweringSerializer
|
||||
|
||||
def serialize(self) -> _WireProtocolPickledInput:
|
||||
"""
|
||||
Turns this object into a _WireProtocolPickledInput which can be
|
||||
directly transferred across a stream.
|
||||
"""
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
|
||||
return _WireProtocolPickledInput(GraphPickler.dumps(self))
|
||||
|
||||
|
||||
def _current_fake_mode() -> torch._subclasses.FakeTensorMode:
|
||||
fake_mode = None
|
||||
if context := torch._guards.TracingContext.try_get():
|
||||
fake_mode = context.fake_mode
|
||||
if fake_mode is not None:
|
||||
return fake_mode
|
||||
|
||||
shape_env = torch.fx.experimental.symbolic_shapes.ShapeEnv()
|
||||
return torch._subclasses.FakeTensorMode(shape_env=shape_env)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _WireProtocolPickledInput:
|
||||
value: bytes
|
||||
|
||||
def deserialize(self) -> _WireProtocolInput:
|
||||
"""
|
||||
Turn this streamable object back into a _WireProtocolInput.
|
||||
"""
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
|
||||
fake_mode = _current_fake_mode()
|
||||
result = GraphPickler.loads(self.value, fake_mode)
|
||||
assert isinstance(result, _WireProtocolInput)
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class _WireProtocolOutput:
|
||||
"""
|
||||
For _SerializedFxCompile - encapsulates all the data being transferred
|
||||
(returned) back from the child to the parent.
|
||||
"""
|
||||
|
||||
graph: OutputCode
|
||||
metrics: CachedMetricsDeltas
|
||||
logs: list[logging.LogRecord]
|
||||
warning_replay: Optional[list[warnings.WarningMessage]]
|
||||
|
||||
def serialize(self) -> _WireProtocolPickledOutput:
|
||||
"""
|
||||
Turns this object into a _WireProtocolPickledOutput which can be
|
||||
directly transferred across a stream.
|
||||
"""
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
|
||||
if isinstance(self.graph, CompiledFxGraph):
|
||||
self.graph.prepare_for_serialization()
|
||||
return _WireProtocolPickledOutput(GraphPickler.dumps(self))
|
||||
|
||||
|
||||
@dataclass
|
||||
class _WireProtocolPickledOutput:
|
||||
value: bytes
|
||||
|
||||
def deserialize(self, constants: CompiledFxGraphConstants) -> _WireProtocolOutput:
|
||||
"""
|
||||
Turn this streamable object back into a _WireProtocolOutput.
|
||||
"""
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
|
||||
fake_mode = _current_fake_mode()
|
||||
result = GraphPickler.loads(self.value, fake_mode)
|
||||
assert isinstance(result, _WireProtocolOutput)
|
||||
if isinstance(result.graph, CompiledFxGraph):
|
||||
result.graph.after_deserialization(constants)
|
||||
return result
|
||||
|
||||
|
||||
class _LoggerState:
|
||||
"""
|
||||
This class is for tracking logging that happens during an out-of-process
|
||||
compile so we can "replay" those messages when the compile is done. Used as
|
||||
a context manager which returns the captured logs (object).
|
||||
"""
|
||||
|
||||
loggers: dict[str, int]
|
||||
# The actual log capturing mechanism - this should be None when we're not
|
||||
# actively capturing logs.
|
||||
captured_logs: Optional[_CapturedLogs] = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Mapping from logger name to level.
|
||||
self.loggers = {}
|
||||
|
||||
def filter(
|
||||
logger: Union[logging.Logger, logging.PlaceHolder],
|
||||
) -> TypeGuard[logging.Logger]:
|
||||
if not isinstance(logger, logging.Logger):
|
||||
# Assume that Placeholders propagate
|
||||
return False
|
||||
# We only want to track torch._inductor logging
|
||||
if not logger.name.startswith("torch._inductor"):
|
||||
return False
|
||||
# If this logger propagates then assume we'll track its parent
|
||||
if logger.propagate:
|
||||
return False
|
||||
return True
|
||||
|
||||
root = logging.getLogger("torch._inductor")
|
||||
if sys.version_info < (3, 12):
|
||||
# logging.getChildren() doesn't exist until 3.12
|
||||
logging._acquireLock() # type: ignore[attr-defined]
|
||||
try:
|
||||
for logger in root.manager.loggerDict.values():
|
||||
if filter(logger):
|
||||
self.loggers[logger.name] = logger.level
|
||||
finally:
|
||||
logging._releaseLock() # type: ignore[attr-defined]
|
||||
else:
|
||||
q = [root]
|
||||
while q:
|
||||
logger = q.pop()
|
||||
if filter(logger):
|
||||
self.loggers[logger.name] = logger.level
|
||||
q.extend(logger.getChildren())
|
||||
|
||||
def __enter__(self) -> _CapturedLogs:
|
||||
assert self.captured_logs is None
|
||||
self.captured_logs = _CapturedLogs(self)
|
||||
self.captured_logs.apply()
|
||||
return self.captured_logs
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[types.TracebackType],
|
||||
) -> None:
|
||||
assert self.captured_logs is not None
|
||||
self.captured_logs.remove()
|
||||
|
||||
|
||||
class _CapturedLogs:
|
||||
"""
|
||||
Helper for _LoggerState - this class actually attaches to the logger in
|
||||
the child process and grabs the log messages themselves.
|
||||
"""
|
||||
|
||||
state: _LoggerState
|
||||
queue: queue.Queue[logging.LogRecord]
|
||||
handlers: Optional[dict[str, logging.Handler]]
|
||||
|
||||
def __init__(self, state: _LoggerState) -> None:
|
||||
self.state = state
|
||||
# A queue of the log entries
|
||||
# TODO: For memory purposes should we log to a file and then respond with that?
|
||||
self.queue = queue.Queue(-1)
|
||||
# Mapping from name to handler (only valid when applied)
|
||||
self.handlers = None
|
||||
|
||||
def finish(self) -> list[logging.LogRecord]:
|
||||
assert self.handlers is None
|
||||
logs = []
|
||||
try:
|
||||
while True:
|
||||
logs.append(self.queue.get_nowait())
|
||||
except queue.Empty:
|
||||
pass
|
||||
return logs
|
||||
|
||||
def remove(self) -> None:
|
||||
assert self.handlers is not None
|
||||
handlers, self.handlers = self.handlers, None
|
||||
for name, handler in handlers.items():
|
||||
logger = logging.getLogger(name)
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def apply(self) -> None:
|
||||
from logging.handlers import QueueHandler
|
||||
|
||||
assert self.handlers is None
|
||||
self.handlers = {}
|
||||
for name, level in self.state.loggers.items():
|
||||
logger = logging.getLogger(name)
|
||||
handler = QueueHandler(self.queue)
|
||||
self.handlers[name] = handler
|
||||
logger.addHandler(handler)
|
||||
if level != logging.NOTSET:
|
||||
logger.setLevel(level)
|
||||
|
||||
|
||||
class _SerializedFxCompile(FxCompile):
|
||||
"""
|
||||
This is used to represent an FxCompile which occurs across a serialized
|
||||
boundary.
|
||||
"""
|
||||
|
||||
@override
|
||||
def codegen_and_compile(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
example_inputs: Sequence[InputType],
|
||||
inputs_to_check: Sequence[int],
|
||||
graph_kwargs: _CompileFxKwargs,
|
||||
) -> OutputCode:
|
||||
def fallback() -> OutputCode:
|
||||
return _InProcessFxCompile().codegen_and_compile(
|
||||
gm, example_inputs, inputs_to_check, graph_kwargs
|
||||
)
|
||||
|
||||
try:
|
||||
# _check_for_hop raises BypassFxGraphCache when it detects something
|
||||
# we can't cache (or serialize)
|
||||
FxGraphCache._check_for_hop(gm)
|
||||
except BypassFxGraphCache as e:
|
||||
log.debug("Skipping %s compile: %s", type(self), e)
|
||||
return fallback()
|
||||
|
||||
context = torch._guards.TracingContext.try_get()
|
||||
constants = CompiledFxGraphConstantsWithGm(gm)
|
||||
logger_state = _LoggerState()
|
||||
lowering = _LoweringSerializer()
|
||||
|
||||
# If we're running tests then grab the DeterministicGuard (don't want to
|
||||
# import this if it isn't already imported because it has side-effects)
|
||||
deterministic_guard_for_testing: Optional[
|
||||
torch.testing._internal.common_utils.DeterministicGuard
|
||||
] = None
|
||||
try:
|
||||
deterministic_guard_for_testing = (
|
||||
torch.testing._internal.common_utils.DeterministicGuard._current_state()
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
input = _WireProtocolInput(
|
||||
gm,
|
||||
example_inputs,
|
||||
inputs_to_check,
|
||||
graph_kwargs,
|
||||
context,
|
||||
config.save_config_portable(),
|
||||
_VirtualizedSerializer.serialize(),
|
||||
deterministic_guard_for_testing,
|
||||
logger_state,
|
||||
lowering,
|
||||
).serialize()
|
||||
except (AttributeError, BypassFxGraphCache):
|
||||
# For example: AttributeError: Can't pickle local object
|
||||
# 'make_opaque_unary_fn.<locals>.OpaqueUnaryFn'
|
||||
|
||||
# TODO: scuba record about not being able to do this?
|
||||
log.debug("Unable to pickle input graph or example inputs", exc_info=True)
|
||||
|
||||
return fallback()
|
||||
|
||||
output = self._send_to_child(input).deserialize(constants)
|
||||
|
||||
self._postprocess(output)
|
||||
self._compile_stats[type(self)].codegen_and_compile += 1
|
||||
|
||||
# TODO: Do we need to figure out what changed in TracingContext in the
|
||||
# child and plumb that back up to the parent?
|
||||
|
||||
return output.graph
|
||||
|
||||
@abstractmethod
|
||||
def _send_to_child(
|
||||
self, pickled_input: _WireProtocolPickledInput
|
||||
) -> _WireProtocolPickledOutput:
|
||||
# The implementation of this should transfer `input` to the child, call
|
||||
# `_run_in_child(input)` and transfer the result back.
|
||||
...
|
||||
|
||||
def _postprocess(self, output: _WireProtocolOutput) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _run_in_child(
|
||||
cls,
|
||||
pickled_input: _WireProtocolPickledInput,
|
||||
extra_env: Optional[Mapping[str, str]] = None,
|
||||
) -> _WireProtocolPickledOutput:
|
||||
metrics = CachedMetricsHelper()
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
if extra_env is not None:
|
||||
import unittest
|
||||
|
||||
stack.enter_context(unittest.mock.patch.dict("os.environ", extra_env))
|
||||
|
||||
# Save warnings to "replay" in the parent
|
||||
warning_replay = stack.enter_context(warnings.catch_warnings(record=True))
|
||||
|
||||
# TODO: Should we split the input into multiple sections where each
|
||||
# section sets up state for the previous section? (i.e. a Config section
|
||||
# which we decode and apply, followed by a FakeTensorMode section which
|
||||
# we decode and apply, etc)
|
||||
input = pickled_input.deserialize()
|
||||
|
||||
stack.enter_context(input.virtualized.patch())
|
||||
stack.enter_context(input.lowering.patch())
|
||||
stack.enter_context(config.patch(input.config))
|
||||
captured_logs = stack.enter_context(input.logger_state)
|
||||
if input.deterministic_guard_for_testing:
|
||||
stack.enter_context(input.deterministic_guard_for_testing)
|
||||
stack.enter_context(torch._guards.tracing(input.tracing_context))
|
||||
stack.enter_context(DebugContext())
|
||||
|
||||
output_graph = _InProcessFxCompile().codegen_and_compile(
|
||||
input.gm,
|
||||
input.example_inputs,
|
||||
input.inputs_to_check,
|
||||
input.graph_kwargs,
|
||||
)
|
||||
|
||||
logs = captured_logs.finish()
|
||||
|
||||
return _WireProtocolOutput(
|
||||
output_graph, metrics.get_deltas(), logs, warning_replay
|
||||
).serialize()
|
||||
|
||||
|
||||
# This is a debugging/testing implementation of FxCompile which serializes the
|
||||
# input and output but still runs the FxCompile in-process.
|
||||
class _DebugSerdeFxCompile(_SerializedFxCompile):
|
||||
@override
|
||||
def _send_to_child(
|
||||
self, pickled_input: _WireProtocolPickledInput
|
||||
) -> _WireProtocolPickledOutput:
|
||||
# For debugging just serde the input and output but don't run in a
|
||||
# subprocess.
|
||||
return self._run_in_child(pickled_input)
|
||||
|
||||
|
||||
class _OutOfProcessFxCompile(_SerializedFxCompile):
|
||||
"""
|
||||
Represents an FxCompile which is run outside the current process (in
|
||||
either a subprocess or possibly even a separate machine).
|
||||
"""
|
||||
|
||||
def _postprocess(self, output: _WireProtocolOutput) -> None:
|
||||
# Since our metrics were gathered in a subprocess make sure to add them
|
||||
# here.
|
||||
CachedMetricsHelper.apply_deltas(output.metrics)
|
||||
|
||||
# This is used by tests to check the output for specific details. For
|
||||
# remote things (subproc and RE) we need to do the `save_output_code`
|
||||
# here since it didn't happen earlier in-process. In the future if this
|
||||
# doesn't have "source_code" (it's a CompiledAOTI, for example) and we
|
||||
# need it we'll have to grab it and serialize it separately from the
|
||||
# child.
|
||||
if GraphLowering.save_output_code is not None:
|
||||
GraphLowering.save_output_code(output.graph.source_code) # type: ignore[attr-defined]
|
||||
|
||||
# And forward our collected logs. The cache is cleared when the outer
|
||||
# function exits.
|
||||
@functools.lru_cache(None)
|
||||
def getLogger(name: str) -> logging.Logger:
|
||||
return logging.getLogger(name)
|
||||
|
||||
if output.warning_replay:
|
||||
for w in output.warning_replay:
|
||||
warnings.warn_explicit(
|
||||
message=w.message,
|
||||
category=w.category,
|
||||
filename=w.filename,
|
||||
lineno=w.lineno,
|
||||
source=w.source,
|
||||
)
|
||||
|
||||
for record in output.logs:
|
||||
logger = getLogger(record.name)
|
||||
logger.handle(record)
|
||||
|
||||
|
||||
# For debugging - create a _FxCompile which writes the serialized data to a file
|
||||
# and then exits.
|
||||
#
|
||||
# TODO: make this a FxCompileMode value?
|
||||
#
|
||||
# The "child runner" should look something like this:
|
||||
#
|
||||
# import torch
|
||||
# from torch._inductor import compile_fx
|
||||
# idx = 0
|
||||
# with open(f"/tmp/pytorch_compile_fx_tmp_input_{idx}.bin", "rb") as f:
|
||||
# input = compile_fx._WireProtocolPickledInput(f.read())
|
||||
# result = compile_fx._SubprocessFxCompile._run_in_child(input)
|
||||
# with open(f"/tmp/pytorch_compile_fx_tmp_output_{idx}.bin", "wb") as f:
|
||||
# f.write(result.value)
|
||||
#
|
||||
class _DebugFileFxCompile(_OutOfProcessFxCompile):
|
||||
file_index = 0
|
||||
|
||||
@override
|
||||
def _send_to_child(
|
||||
self, pickled_input: _WireProtocolPickledInput
|
||||
) -> _WireProtocolPickledOutput:
|
||||
idx = _DebugFileFxCompile.file_index
|
||||
_DebugFileFxCompile.file_index += 1
|
||||
|
||||
name = f"/tmp/aorenste/pytorch_compile_fx_tmp_input_{idx}.bin"
|
||||
with open(name, "wb") as f:
|
||||
f.write(pickled_input.value)
|
||||
print(f"Wrote to {name}")
|
||||
|
||||
if False:
|
||||
name = f"/tmp/aorenste/pytorch_compile_fx_tmp_actual_{idx}.bin"
|
||||
actual = self._run_in_child(pickled_input)
|
||||
with open(name, "wb") as f:
|
||||
f.write(actual.value)
|
||||
return actual
|
||||
elif False:
|
||||
name = f"/tmp/aorenste/pytorch_compile_fx_tmp_output_{idx}.bin"
|
||||
with open(name, "rb") as f:
|
||||
result = _WireProtocolPickledOutput(f.read())
|
||||
print(f"Read from {name}")
|
||||
return result
|
||||
else:
|
||||
os._exit(-1)
|
102
torch/_inductor/compile_fx_subproc.py
Normal file
102
torch/_inductor/compile_fx_subproc.py
Normal file
@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import functools
|
||||
import os
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
|
||||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||
import torch.fx
|
||||
from torch._inductor.compile_worker.subproc_pool import (
|
||||
AnyPool,
|
||||
SubprocKind,
|
||||
SubprocPool,
|
||||
)
|
||||
from torch._inductor.utils import clear_inductor_caches
|
||||
|
||||
from .compile_fx_ext import (
|
||||
_OutOfProcessFxCompile,
|
||||
_WireProtocolPickledInput,
|
||||
_WireProtocolPickledOutput,
|
||||
)
|
||||
from .output_code import complex_memory_overlap as complex_memory_overlap # noqa: F401
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
|
||||
|
||||
class _SubprocessFxCompile(_OutOfProcessFxCompile):
|
||||
@override
|
||||
def _send_to_child(
|
||||
self, input: _WireProtocolPickledInput
|
||||
) -> _WireProtocolPickledOutput:
|
||||
# TODO: Do we need to copy across some kind of logging IDs? (ChromiumEventLogger)
|
||||
|
||||
pool = self.process_pool()
|
||||
|
||||
# TODO: This is probably the wrong thing to do long-term - but for now
|
||||
# let's share the cache so we can identify tests broken by this later.
|
||||
env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"]
|
||||
extra_env = {v: os.environ[v] for v in env_vars if v in os.environ}
|
||||
|
||||
f = pool.submit(_SubprocessFxCompile._run_in_child_subprocess, input, extra_env)
|
||||
|
||||
# For debugging: If we want to print status updates...
|
||||
# last = time.time()
|
||||
# while not f.done():
|
||||
# print("tick...")
|
||||
# time.sleep(0.125)
|
||||
# now = time.time()
|
||||
# if now - last > 1:
|
||||
# last = now
|
||||
|
||||
output = f.result()
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@functools.cache
|
||||
def process_pool() -> AnyPool:
|
||||
pool = SubprocPool(
|
||||
# TODO: Consider raising this limit if we start using async w/
|
||||
# subprocess and want to compile multiple graphs in parallel.
|
||||
1,
|
||||
kind=SubprocKind.SPAWN,
|
||||
)
|
||||
|
||||
atexit.register(pool.shutdown)
|
||||
|
||||
return pool
|
||||
|
||||
@classmethod
|
||||
def _run_in_child_subprocess(
|
||||
cls,
|
||||
pickled_input: _WireProtocolPickledInput,
|
||||
extra_env: Optional[Mapping[str, str]],
|
||||
) -> _WireProtocolPickledOutput:
|
||||
# TODO: In subprocess mode we need to clear the inductor caches.
|
||||
# The problem:
|
||||
# 1. We compile in worker A which fills stuff in tmpdir
|
||||
# 2. parent clears inductor caches which deletes tmpdirs and tells
|
||||
# cpp_prefix_path() to clear its LRU cache
|
||||
# 3. We compile a second time in subproc A - but since we never told
|
||||
# cpp_prefix_path() in worker A to clear its LRU it thinks the
|
||||
# tmpdir still exists and fails to compile.
|
||||
#
|
||||
# TODO: We probably should be using a separate tmpdir in the worker
|
||||
# anyway... but we should probably still respect clear_inductor_caches()
|
||||
# in the parent... maybe?
|
||||
#
|
||||
# TODO: We could be less aggressive by keeping a clock which gets
|
||||
# incremented when we clear the cache, send the clock to the worker and
|
||||
# only clear caches if the clock changed since last time.
|
||||
#
|
||||
clear_inductor_caches()
|
||||
torch._inductor.metrics.reset()
|
||||
|
||||
# TODO: turn off config.fx_graph_async_compile
|
||||
|
||||
result = cls._run_in_child(pickled_input, extra_env)
|
||||
return result
|
@ -1903,6 +1903,9 @@ def fallback_handler(kernel, add_to_fallback_set=True):
|
||||
wrap_tensors, ir.FallbackKernel.create(kernel, *args, **kwargs)
|
||||
)
|
||||
|
||||
# This lets us detect that a lowering is a fallback handler.
|
||||
handler._is_fallback_handler = True # type: ignore[attr-defined]
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
|
@ -96,6 +96,14 @@ class NullHandler:
|
||||
"""
|
||||
|
||||
|
||||
# If a virtualized value is set to _PoisonedVirtual then any attempt to get the
|
||||
# value will result an an exception being raised. This is useful if we want to
|
||||
# trap uninitialized reads of virtualized globals - for example when compiling
|
||||
# in a subprocess we don't want the child reading globals that weren't copied
|
||||
# from the parent.
|
||||
_PoisonedVirtual = object()
|
||||
|
||||
|
||||
class Virtualized(Generic[T]):
|
||||
"""
|
||||
Implements a global variable that redirects via thread local variable
|
||||
@ -110,11 +118,12 @@ class Virtualized(Generic[T]):
|
||||
"""
|
||||
|
||||
def __init__(self, vname: str, default: Union[Callable[[], T], type[NullHandler]]):
|
||||
self._vname = vname
|
||||
self._key: str = f"__torchinductor_{vname}"
|
||||
self._default = default
|
||||
|
||||
def _set_handler(self, value: T) -> AbstractContextManager[None]:
|
||||
prior = self._get_handler()
|
||||
prior = self._get_handler(False)
|
||||
setattr(threadlocal, self._key, value)
|
||||
|
||||
@contextmanager
|
||||
@ -126,9 +135,14 @@ class Virtualized(Generic[T]):
|
||||
|
||||
return ctx()
|
||||
|
||||
def _get_handler(self) -> T:
|
||||
def _get_handler(self, check_poisoned: bool = True) -> T:
|
||||
try:
|
||||
return getattr(threadlocal, self._key)
|
||||
value = getattr(threadlocal, self._key)
|
||||
if check_poisoned and value is _PoisonedVirtual:
|
||||
raise RuntimeError(
|
||||
f"Attempt to use poisoned virtualized value '{self._vname}'."
|
||||
)
|
||||
return value
|
||||
except AttributeError:
|
||||
# TODO: To be honest, I feel we probably should just error in this
|
||||
# case, instead of making a null handler that will probably error
|
||||
|
@ -2031,20 +2031,24 @@ class DeterministicGuard:
|
||||
self.warn_only = warn_only
|
||||
self.fill_uninitialized_memory = fill_uninitialized_memory
|
||||
|
||||
def __enter__(self):
|
||||
self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
|
||||
self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
|
||||
self.fill_uninitialized_memory_restore = torch.utils.deterministic.fill_uninitialized_memory # type: ignore[attr-defined]
|
||||
torch.use_deterministic_algorithms(
|
||||
self.deterministic,
|
||||
warn_only=self.warn_only)
|
||||
@classmethod
|
||||
def _current_state(cls):
|
||||
return cls(
|
||||
torch.are_deterministic_algorithms_enabled(),
|
||||
warn_only=torch.is_deterministic_algorithms_warn_only_enabled(),
|
||||
fill_uninitialized_memory=torch.utils.deterministic.fill_uninitialized_memory, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
def _update(self):
|
||||
torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only)
|
||||
torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory # type: ignore[attr-defined]
|
||||
|
||||
def __enter__(self):
|
||||
self._restore = self._current_state()
|
||||
self._update()
|
||||
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
torch.use_deterministic_algorithms(
|
||||
self.deterministic_restore,
|
||||
warn_only=self.warn_only_restore)
|
||||
torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory_restore # type: ignore[attr-defined]
|
||||
self._restore._update()
|
||||
|
||||
class AlwaysWarnTypedStorageRemoval:
|
||||
def __init__(self, always_warn):
|
||||
|
Reference in New Issue
Block a user