mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155613 Approved by: https://github.com/ezyang ghstack dependencies: #155612
682 lines
23 KiB
Python
682 lines
23 KiB
Python
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import dataclasses
|
|
import functools
|
|
import logging
|
|
import os
|
|
import queue
|
|
import sys
|
|
import warnings
|
|
from abc import abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import Any, Optional, TYPE_CHECKING, Union
|
|
from typing_extensions import final, override, Self, TypeGuard
|
|
|
|
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
|
import torch.fx
|
|
from torch._inductor.codecache import BypassFxGraphCache, FxGraphCache
|
|
from torch._inductor.metrics import CachedMetricsDeltas, CachedMetricsHelper
|
|
from torch._inductor.output_code import (
|
|
CompiledFxGraph,
|
|
CompiledFxGraphConstants,
|
|
CompiledFxGraphConstantsWithGm,
|
|
OutputCode,
|
|
)
|
|
from torch._subclasses import FakeTensorMode
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
from . import config
|
|
from .compile_fx import _CompileFxKwargs, _InProcessFxCompile, FxCompile, log
|
|
from .debug import DebugContext
|
|
from .graph import GraphLowering
|
|
from .output_code import complex_memory_overlap as complex_memory_overlap # noqa: F401
|
|
from .virtualized import V
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
import types
|
|
from collections.abc import Generator, Mapping, Sequence
|
|
from concurrent.futures import Future
|
|
|
|
from torch._inductor.utils import InputType
|
|
from torch.fx import GraphModule
|
|
|
|
|
|
@dataclass
|
|
class _VirtualizedSerializer:
|
|
"""
|
|
This handles the data for serializing Virtualized.
|
|
"""
|
|
|
|
# The values here get serialized. We don't grab everything because some of
|
|
# the fields can't be serialized.
|
|
aot_compilation: Any = None
|
|
choices: Any = None
|
|
local_buffer_context: Any = None
|
|
ops: Any = None
|
|
kernel: Any = None
|
|
current_node: Any = None
|
|
|
|
@classmethod
|
|
def serialize(cls) -> _VirtualizedSerializer:
|
|
"""
|
|
Turn the current state of torch._inductor.virtualized.V into a
|
|
serializable structure.
|
|
"""
|
|
kwargs = {}
|
|
for f in dataclasses.fields(cls):
|
|
kwargs[f.name] = getattr(V, f.name)
|
|
return _VirtualizedSerializer(**kwargs)
|
|
|
|
def patch(self) -> _VirtualizedSerializerContextManager:
|
|
"""
|
|
Returns a context manager which patches the saved values into the
|
|
current environment. While patched, any value not listed above will be
|
|
poisoned so that reads will raise an error.
|
|
"""
|
|
return _VirtualizedSerializerContextManager(self)
|
|
|
|
|
|
class _VirtualizedSerializerContextManager(contextlib.ExitStack):
|
|
"""
|
|
Helper for _VirtualizedSerializer.patch()
|
|
"""
|
|
|
|
def __init__(self, virtualized: _VirtualizedSerializer) -> None:
|
|
super().__init__()
|
|
self.virtualized = virtualized
|
|
|
|
@override
|
|
def __enter__(self) -> Self:
|
|
super().__enter__()
|
|
|
|
for set_name in dir(V):
|
|
if not set_name.startswith("set_"):
|
|
continue
|
|
name = set_name[4:]
|
|
name = name.removesuffix("_handler")
|
|
set_handler = getattr(V, set_name)
|
|
if hasattr(self.virtualized, name):
|
|
value = getattr(self.virtualized, name)
|
|
else:
|
|
# poison any values that we don't serialize so that any
|
|
# unset accesses are caught.
|
|
value = torch._inductor.virtualized._PoisonedVirtual
|
|
self.enter_context(set_handler(value))
|
|
|
|
return self
|
|
|
|
|
|
def _is_fallback_handler(op: object) -> bool:
|
|
try:
|
|
return op._is_fallback_handler # type: ignore[attr-defined]
|
|
except AttributeError:
|
|
return False
|
|
|
|
|
|
class _LoweringSerializer:
|
|
"""
|
|
This handles the data for serializing lowering.lowering
|
|
"""
|
|
|
|
# A full implementation would make sure that all lowerings are copied over
|
|
# (or at least detected and raise a bypass when a non-standard lowering is
|
|
# used). For now we just handle tests by looking for lowerings that were
|
|
# overridden with a forced fallback.
|
|
fallbacks: OrderedSet[str]
|
|
|
|
def __init__(self) -> None:
|
|
from . import lowering
|
|
|
|
self.fallbacks = OrderedSet(
|
|
str(k) for k, v in lowering.lowerings.items() if _is_fallback_handler(v)
|
|
)
|
|
|
|
def patch(self) -> _LoweringSerializerContextManager:
|
|
return _LoweringSerializerContextManager(self)
|
|
|
|
|
|
class _LoweringSerializerContextManager(contextlib.ExitStack):
|
|
"""
|
|
Helper for _LoweringSerializer.patch()
|
|
"""
|
|
|
|
def __init__(self, lowering: _LoweringSerializer) -> None:
|
|
super().__init__()
|
|
self.lowering = lowering
|
|
|
|
@override
|
|
def __enter__(self) -> Self:
|
|
super().__enter__()
|
|
|
|
from . import lowering
|
|
|
|
for k, v in lowering.lowerings.items():
|
|
name = str(k)
|
|
if name in self.lowering.fallbacks:
|
|
if not _is_fallback_handler(v):
|
|
self.enter_context(lowering.force_fallback(k)) # type: ignore[arg-type]
|
|
|
|
return self
|
|
|
|
|
|
@dataclass
|
|
class _FakeTensorModeSerializer:
|
|
allow_non_fake_inputs: bool
|
|
|
|
def __init__(self, fake_mode: FakeTensorMode) -> None:
|
|
self.allow_non_fake_inputs = fake_mode.allow_non_fake_inputs
|
|
self.shape_env = fake_mode.shape_env
|
|
|
|
@contextlib.contextmanager
|
|
def patch(self, fake_mode: FakeTensorMode) -> Generator[None, None, None]:
|
|
saved_allow_non_fake_inputs = fake_mode.allow_non_fake_inputs
|
|
fake_mode.allow_non_fake_inputs = self.allow_non_fake_inputs
|
|
|
|
yield
|
|
|
|
fake_mode.allow_non_fake_inputs = saved_allow_non_fake_inputs
|
|
|
|
|
|
@dataclass
|
|
class _WireProtocolInput:
|
|
"""
|
|
For _SerializedFxCompile - encapsulates all the data being transferred
|
|
(sent) from the parent to the child.
|
|
"""
|
|
|
|
gm: torch.fx.GraphModule
|
|
example_inputs: Sequence[InputType]
|
|
inputs_to_check: Sequence[int]
|
|
graph_kwargs: _CompileFxKwargs
|
|
tracing_context: Optional[torch._guards.TracingContext]
|
|
config: dict[str, object]
|
|
virtualized: _VirtualizedSerializer
|
|
deterministic_guard_for_testing: Optional[ # type: ignore[name-defined] # mypy bug
|
|
torch.testing._internal.common_utils.DeterministicGuard
|
|
]
|
|
logger_state: _LoggerState
|
|
lowering: _LoweringSerializer
|
|
fake_tensor_mode: _FakeTensorModeSerializer
|
|
|
|
def serialize(self) -> _WireProtocolPickledInput:
|
|
"""
|
|
Turns this object into a _WireProtocolPickledInput which can be
|
|
directly transferred across a stream.
|
|
"""
|
|
from torch.fx._graph_pickler import GraphPickler
|
|
|
|
return _WireProtocolPickledInput(GraphPickler.dumps(self))
|
|
|
|
|
|
def _current_fake_mode() -> FakeTensorMode:
|
|
fake_mode = None
|
|
if context := torch._guards.TracingContext.try_get():
|
|
fake_mode = context.fake_mode
|
|
if fake_mode is not None:
|
|
return fake_mode
|
|
|
|
shape_env = torch.fx.experimental.symbolic_shapes.ShapeEnv()
|
|
return FakeTensorMode(shape_env=shape_env)
|
|
|
|
|
|
@dataclass
|
|
class _WireProtocolPickledInput:
|
|
value: bytes
|
|
|
|
def deserialize(self) -> _WireProtocolInput:
|
|
"""
|
|
Turn this streamable object back into a _WireProtocolInput.
|
|
"""
|
|
from torch.fx._graph_pickler import GraphPickler
|
|
|
|
fake_mode = _current_fake_mode()
|
|
result = GraphPickler.loads(self.value, fake_mode)
|
|
assert isinstance(result, _WireProtocolInput)
|
|
return result
|
|
|
|
|
|
@dataclass
|
|
class _WireProtocolOutput:
|
|
"""
|
|
For _SerializedFxCompile - encapsulates all the data being transferred
|
|
(returned) back from the child to the parent.
|
|
"""
|
|
|
|
graph: OutputCode
|
|
metrics: CachedMetricsDeltas
|
|
logs: list[logging.LogRecord]
|
|
warning_replay: Optional[list[warnings.WarningMessage]]
|
|
shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv]
|
|
|
|
def serialize(self) -> _WireProtocolPickledOutput:
|
|
"""
|
|
Turns this object into a _WireProtocolPickledOutput which can be
|
|
directly transferred across a stream.
|
|
"""
|
|
from torch.fx._graph_pickler import GraphPickler
|
|
|
|
if isinstance(self.graph, CompiledFxGraph):
|
|
self.graph.prepare_for_serialization()
|
|
return _WireProtocolPickledOutput(GraphPickler.dumps(self))
|
|
|
|
|
|
@dataclass
|
|
class _WireProtocolPickledOutput:
|
|
value: bytes
|
|
|
|
def deserialize(self, constants: CompiledFxGraphConstants) -> _WireProtocolOutput:
|
|
"""
|
|
Turn this streamable object back into a _WireProtocolOutput.
|
|
"""
|
|
from torch.fx._graph_pickler import GraphPickler
|
|
|
|
fake_mode = _current_fake_mode()
|
|
result = GraphPickler.loads(self.value, fake_mode)
|
|
assert isinstance(result, _WireProtocolOutput)
|
|
if isinstance(result.graph, CompiledFxGraph):
|
|
result.graph.after_deserialization(constants)
|
|
return result
|
|
|
|
|
|
class _LoggerState:
|
|
"""
|
|
This class is for tracking logging that happens during an out-of-process
|
|
compile so we can "replay" those messages when the compile is done. Used as
|
|
a context manager which returns the captured logs (object).
|
|
"""
|
|
|
|
loggers: dict[str, int]
|
|
# The actual log capturing mechanism - this should be None when we're not
|
|
# actively capturing logs.
|
|
captured_logs: Optional[_CapturedLogs] = None
|
|
|
|
def __init__(self) -> None:
|
|
# Mapping from logger name to level.
|
|
self.loggers = {}
|
|
|
|
def filter(
|
|
logger: Union[logging.Logger, logging.PlaceHolder],
|
|
) -> TypeGuard[logging.Logger]:
|
|
if not isinstance(logger, logging.Logger):
|
|
# Assume that Placeholders propagate
|
|
return False
|
|
# We only want to track torch._inductor logging
|
|
if not logger.name.startswith("torch._inductor"):
|
|
return False
|
|
# If this logger propagates then assume we'll track its parent
|
|
if logger.propagate:
|
|
return False
|
|
return True
|
|
|
|
root = logging.getLogger("torch._inductor")
|
|
if sys.version_info < (3, 12):
|
|
# logging.getChildren() doesn't exist until 3.12
|
|
logging._acquireLock() # type: ignore[attr-defined]
|
|
try:
|
|
for logger in root.manager.loggerDict.values():
|
|
if filter(logger):
|
|
self.loggers[logger.name] = logger.level
|
|
finally:
|
|
logging._releaseLock() # type: ignore[attr-defined]
|
|
else:
|
|
q = [root]
|
|
while q:
|
|
logger = q.pop()
|
|
if filter(logger):
|
|
self.loggers[logger.name] = logger.level
|
|
q.extend(logger.getChildren())
|
|
|
|
def __enter__(self) -> _CapturedLogs:
|
|
assert self.captured_logs is None
|
|
self.captured_logs = _CapturedLogs(self)
|
|
self.captured_logs.apply()
|
|
return self.captured_logs
|
|
|
|
def __exit__(
|
|
self,
|
|
exc_type: Optional[type[BaseException]],
|
|
exc_value: Optional[BaseException],
|
|
traceback: Optional[types.TracebackType],
|
|
) -> None:
|
|
assert self.captured_logs is not None
|
|
self.captured_logs.remove()
|
|
|
|
|
|
class _CapturedLogs:
|
|
"""
|
|
Helper for _LoggerState - this class actually attaches to the logger in
|
|
the child process and grabs the log messages themselves.
|
|
"""
|
|
|
|
state: _LoggerState
|
|
queue: queue.Queue[logging.LogRecord]
|
|
handlers: Optional[dict[str, logging.Handler]]
|
|
|
|
def __init__(self, state: _LoggerState) -> None:
|
|
self.state = state
|
|
# A queue of the log entries
|
|
# TODO: For memory purposes should we log to a file and then respond with that?
|
|
self.queue = queue.Queue(-1)
|
|
# Mapping from name to handler (only valid when applied)
|
|
self.handlers = None
|
|
|
|
def finish(self) -> list[logging.LogRecord]:
|
|
assert self.handlers is None
|
|
logs = []
|
|
try:
|
|
while True:
|
|
logs.append(self.queue.get_nowait())
|
|
except queue.Empty:
|
|
pass
|
|
return logs
|
|
|
|
def remove(self) -> None:
|
|
assert self.handlers is not None
|
|
handlers, self.handlers = self.handlers, None
|
|
for name, handler in handlers.items():
|
|
logger = logging.getLogger(name)
|
|
logger.removeHandler(handler)
|
|
|
|
def apply(self) -> None:
|
|
from logging.handlers import QueueHandler
|
|
|
|
assert self.handlers is None
|
|
self.handlers = {}
|
|
for name, level in self.state.loggers.items():
|
|
logger = logging.getLogger(name)
|
|
handler = QueueHandler(self.queue)
|
|
self.handlers[name] = handler
|
|
logger.addHandler(handler)
|
|
if level != logging.NOTSET:
|
|
logger.setLevel(level)
|
|
|
|
|
|
class _SerializedFxCompile(FxCompile):
|
|
"""
|
|
This is used to represent an FxCompile which occurs across a serialized
|
|
boundary.
|
|
"""
|
|
|
|
@override
|
|
def codegen_and_compile(
|
|
self,
|
|
gm: GraphModule,
|
|
example_inputs: Sequence[InputType],
|
|
inputs_to_check: Sequence[int],
|
|
graph_kwargs: _CompileFxKwargs,
|
|
) -> OutputCode:
|
|
# If this code changes it's likely _AsyncFxCompile.codegen_and_compile()
|
|
# will also need to match.
|
|
|
|
serialized = self.serialize_compile(
|
|
gm, example_inputs, inputs_to_check, graph_kwargs
|
|
)
|
|
if not serialized:
|
|
return _InProcessFxCompile().codegen_and_compile(
|
|
gm, example_inputs, inputs_to_check, graph_kwargs
|
|
)
|
|
|
|
inputs, constants = serialized
|
|
output = self._send_to_child(inputs).deserialize(constants)
|
|
|
|
self._postprocess(output)
|
|
self._compile_stats[type(self)].codegen_and_compile += 1
|
|
|
|
# TODO: Do we need to figure out what changed in TracingContext in the
|
|
# child and plumb that back up to the parent?
|
|
|
|
return output.graph
|
|
|
|
def serialize_compile(
|
|
self,
|
|
gm: GraphModule,
|
|
example_inputs: Sequence[InputType],
|
|
inputs_to_check: Sequence[int],
|
|
graph_kwargs: _CompileFxKwargs,
|
|
) -> Optional[tuple[_WireProtocolPickledInput, CompiledFxGraphConstantsWithGm]]:
|
|
"""
|
|
Prepare a _WireProtocolInput to compile. If None is returned then it
|
|
wasn't possible to serialize and we should fallback to in-process.
|
|
"""
|
|
try:
|
|
# _check_for_hop raises BypassFxGraphCache when it detects something
|
|
# we can't cache (or serialize)
|
|
FxGraphCache._check_for_hop(gm)
|
|
except BypassFxGraphCache as e:
|
|
log.debug("Skipping %s compile: %s", type(self), e)
|
|
return None
|
|
|
|
context = torch._guards.TracingContext.try_get()
|
|
constants = CompiledFxGraphConstantsWithGm(gm)
|
|
logger_state = _LoggerState()
|
|
lowering = _LoweringSerializer()
|
|
|
|
# If we're running tests then grab the DeterministicGuard (don't want to
|
|
# import this if it isn't already imported because it has side-effects)
|
|
deterministic_guard_for_testing: Optional[ # type: ignore[name-defined] # mypy bug
|
|
torch.testing._internal.common_utils.DeterministicGuard
|
|
] = None
|
|
try:
|
|
deterministic_guard_for_testing = (
|
|
torch.testing._internal.common_utils.DeterministicGuard._current_state() # type: ignore[attr-defined] # mypy bug
|
|
)
|
|
except AttributeError:
|
|
pass
|
|
|
|
fake_mode = _current_fake_mode()
|
|
fake_tensor_mode = _FakeTensorModeSerializer(fake_mode)
|
|
|
|
try:
|
|
input = _WireProtocolInput(
|
|
gm,
|
|
example_inputs,
|
|
inputs_to_check,
|
|
graph_kwargs,
|
|
context,
|
|
config.save_config_portable(),
|
|
_VirtualizedSerializer.serialize(),
|
|
deterministic_guard_for_testing,
|
|
logger_state,
|
|
lowering,
|
|
fake_tensor_mode,
|
|
).serialize()
|
|
return (input, constants)
|
|
except (AttributeError, BypassFxGraphCache):
|
|
# For example: AttributeError: Can't pickle local object
|
|
# 'make_opaque_unary_fn.<locals>.OpaqueUnaryFn'
|
|
|
|
# TODO: scuba record about not being able to do this?
|
|
log.warning("Unable to pickle input graph or example inputs", exc_info=True)
|
|
|
|
return None
|
|
|
|
@abstractmethod
|
|
def _send_to_child(
|
|
self, pickled_input: _WireProtocolPickledInput
|
|
) -> _WireProtocolPickledOutput:
|
|
# The implementation of this should transfer `input` to the child, call
|
|
# `_run_in_child(input)` and transfer the result back.
|
|
...
|
|
|
|
def _postprocess(self, output: _WireProtocolOutput) -> None:
|
|
pass
|
|
|
|
@classmethod
|
|
def _run_in_child(
|
|
cls,
|
|
pickled_input: _WireProtocolPickledInput,
|
|
extra_env: Optional[Mapping[str, str]] = None,
|
|
) -> _WireProtocolPickledOutput:
|
|
metrics = CachedMetricsHelper()
|
|
|
|
with contextlib.ExitStack() as stack:
|
|
if extra_env is not None:
|
|
import unittest
|
|
|
|
stack.enter_context(unittest.mock.patch.dict("os.environ", extra_env))
|
|
|
|
# Save warnings to "replay" in the parent
|
|
warning_replay = stack.enter_context(warnings.catch_warnings(record=True))
|
|
|
|
# TODO: Should we split the input into multiple sections where each
|
|
# section sets up state for the previous section? (i.e. a Config section
|
|
# which we decode and apply, followed by a FakeTensorMode section which
|
|
# we decode and apply, etc)
|
|
input = pickled_input.deserialize()
|
|
|
|
stack.enter_context(input.virtualized.patch())
|
|
stack.enter_context(input.lowering.patch())
|
|
stack.enter_context(config.patch(input.config))
|
|
captured_logs = stack.enter_context(input.logger_state)
|
|
if input.deterministic_guard_for_testing:
|
|
stack.enter_context(input.deterministic_guard_for_testing)
|
|
stack.enter_context(torch._guards.tracing(input.tracing_context))
|
|
stack.enter_context(DebugContext())
|
|
|
|
fake_mode = _current_fake_mode()
|
|
stack.enter_context(input.fake_tensor_mode.patch(fake_mode))
|
|
|
|
output_graph = _InProcessFxCompile().codegen_and_compile(
|
|
input.gm,
|
|
input.example_inputs,
|
|
input.inputs_to_check,
|
|
input.graph_kwargs,
|
|
)
|
|
|
|
logs = captured_logs.finish()
|
|
|
|
return _WireProtocolOutput(
|
|
output_graph,
|
|
metrics.get_deltas(),
|
|
logs,
|
|
warning_replay,
|
|
fake_mode.shape_env,
|
|
).serialize()
|
|
|
|
|
|
# This is a debugging/testing implementation of FxCompile which serializes the
|
|
# input and output but still runs the FxCompile in-process.
|
|
@final
|
|
class _DebugSerdeFxCompile(_SerializedFxCompile):
|
|
@override
|
|
def _send_to_child(
|
|
self, pickled_input: _WireProtocolPickledInput
|
|
) -> _WireProtocolPickledOutput:
|
|
# For debugging just serde the input and output but don't run in a
|
|
# subprocess.
|
|
return self._run_in_child(pickled_input)
|
|
|
|
|
|
class _OutOfProcessFxCompile(_SerializedFxCompile):
|
|
"""
|
|
Represents an FxCompile which is run outside the current process (in
|
|
either a subprocess or possibly even a separate machine).
|
|
"""
|
|
|
|
@override
|
|
@final
|
|
def _send_to_child(
|
|
self, pickled_input: _WireProtocolPickledInput
|
|
) -> _WireProtocolPickledOutput:
|
|
f = self._send_to_child_async(pickled_input)
|
|
|
|
# For debugging: If we want to print status updates...
|
|
# last = time.time()
|
|
# while not f.done():
|
|
# print("tick...")
|
|
# time.sleep(0.125)
|
|
# now = time.time()
|
|
# if now - last > 1:
|
|
# last = now
|
|
|
|
return f.result()
|
|
|
|
@abstractmethod
|
|
def _send_to_child_async(
|
|
self, pickled_input: _WireProtocolPickledInput
|
|
) -> Future[_WireProtocolPickledOutput]: ...
|
|
|
|
def _postprocess(self, output: _WireProtocolOutput) -> None:
|
|
# Since our metrics were gathered in a subprocess make sure to add them
|
|
# here.
|
|
CachedMetricsHelper.apply_deltas(output.metrics)
|
|
|
|
# This is used by tests to check the output for specific details. For
|
|
# remote things (subproc and RE) we need to do the `save_output_code`
|
|
# here since it didn't happen earlier in-process. In the future if this
|
|
# doesn't have "source_code" (it's a CompiledAOTI, for example) and we
|
|
# need it we'll have to grab it and serialize it separately from the
|
|
# child.
|
|
if GraphLowering.save_output_code is not None:
|
|
GraphLowering.save_output_code(output.graph.source_code) # type: ignore[attr-defined]
|
|
|
|
# And forward our collected logs. The cache is cleared when the outer
|
|
# function exits.
|
|
@functools.cache
|
|
def getLogger(name: str) -> logging.Logger:
|
|
return logging.getLogger(name)
|
|
|
|
if output.warning_replay:
|
|
for w in output.warning_replay:
|
|
warnings.warn_explicit(
|
|
message=w.message,
|
|
category=w.category,
|
|
filename=w.filename,
|
|
lineno=w.lineno,
|
|
source=w.source,
|
|
)
|
|
|
|
for record in output.logs:
|
|
logger = getLogger(record.name)
|
|
logger.handle(record)
|
|
|
|
|
|
# For debugging - create a _FxCompile which writes the serialized data to a file
|
|
# and then exits.
|
|
#
|
|
# TODO: make this a FxCompileMode value?
|
|
#
|
|
# The "child runner" should look something like this:
|
|
#
|
|
# import torch
|
|
# from torch._inductor import compile_fx
|
|
# idx = 0
|
|
# with open(f"/tmp/pytorch_compile_fx_tmp_input_{idx}.bin", "rb") as f:
|
|
# input = compile_fx._WireProtocolPickledInput(f.read())
|
|
# result = compile_fx._SubprocessFxCompile._run_in_child(input)
|
|
# with open(f"/tmp/pytorch_compile_fx_tmp_output_{idx}.bin", "wb") as f:
|
|
# f.write(result.value)
|
|
#
|
|
@final
|
|
class _DebugFileFxCompile(_SerializedFxCompile):
|
|
file_index = 0
|
|
|
|
@override
|
|
def _send_to_child(
|
|
self, pickled_input: _WireProtocolPickledInput
|
|
) -> _WireProtocolPickledOutput:
|
|
idx = _DebugFileFxCompile.file_index
|
|
_DebugFileFxCompile.file_index += 1
|
|
|
|
name = f"/tmp/aorenste/pytorch_compile_fx_tmp_input_{idx}.bin"
|
|
with open(name, "wb") as f:
|
|
f.write(pickled_input.value)
|
|
print(f"Wrote to {name}")
|
|
|
|
if False:
|
|
name = f"/tmp/aorenste/pytorch_compile_fx_tmp_actual_{idx}.bin"
|
|
actual = self._run_in_child(pickled_input)
|
|
with open(name, "wb") as f:
|
|
f.write(actual.value)
|
|
return actual
|
|
elif False:
|
|
name = f"/tmp/aorenste/pytorch_compile_fx_tmp_output_{idx}.bin"
|
|
with open(name, "rb") as f:
|
|
result = _WireProtocolPickledOutput(f.read())
|
|
print(f"Read from {name}")
|
|
return result
|
|
else:
|
|
os._exit(-1)
|