mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 00:54:56 +08:00
Opt model save and load (#126374)
## save&load support for OptimizedModule [Issue Description](https://github.com/pytorch/pytorch/pull/101651) English is not my native language; please excuse typing errors. This pr is based on commit b9588101c4d3411b107fdc860acfa8a72c642f91\ I'll do something with the merge conflicts later ### test result for test/dynamo Conclusion:\ It performs the same as before as far as I can see. ENV(CPU only):\ platform linux -- Python 3.10.14, pytest-7.3.2, pluggy-1.5.0\ configfile: pytest.ini\ plugins: anyio-3.7.1, cpp-2.3.0, flakefinder-1.1.0, xdist-3.3.1, xdoctest-1.1.0, metadata-3.1.1, html-4.1.1, hypothesis-5.35.1, rerunfailures-14.0 #### before this pr: [before](https://github.com/pytorch/pytorch/files/15329370/before.md) #### after this pr: [after](https://github.com/pytorch/pytorch/files/15329376/after.md) ### some changes 1. add test_save_and_load to test/dynamo/test_modules.py with & without "backend='inductor'" 2. add \_\_reduce\_\_ function to OptimizedModule and derived classes of _TorchDynamoContext for pickling & unpickling 3. change the wrappers into wrapper classes ( including convert_frame_assert, convert_frame, catch_errors_wrapper in torch/_dynamo/convert_frame.py & wrap_backend_debug in torch/_dynamo/repro/after_dynamo.py ) 4. change self.output.compiler_fn into innermost_fn(self.output.compiler_fn) in torch/_dynamo/symbolic_convert.py to get the origin compiler_fn and to avoid the "compiler_fn is not eager" condition Pull Request resolved: https://github.com/pytorch/pytorch/pull/126374 Approved by: https://github.com/msaroufim, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
9a8ab778d3
commit
c3949b20a1
@ -3,6 +3,8 @@
|
|||||||
import collections
|
import collections
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
import traceback
|
import traceback
|
||||||
import types
|
import types
|
||||||
import unittest
|
import unittest
|
||||||
@ -16,6 +18,7 @@ import torch
|
|||||||
import torch._dynamo.test_case
|
import torch._dynamo.test_case
|
||||||
import torch._dynamo.testing
|
import torch._dynamo.testing
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch._dynamo.debug_utils import same_two_models
|
||||||
from torch._dynamo.eval_frame import unsupported
|
from torch._dynamo.eval_frame import unsupported
|
||||||
from torch._dynamo.mutation_guard import GenerationTracker
|
from torch._dynamo.mutation_guard import GenerationTracker
|
||||||
from torch._dynamo.testing import expectedFailureDynamic, same
|
from torch._dynamo.testing import expectedFailureDynamic, same
|
||||||
@ -2739,6 +2742,49 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|||||||
self.assertEqual(test_functions._variable, 1)
|
self.assertEqual(test_functions._variable, 1)
|
||||||
self.assertEqual(res, 3 * torch.ones(10))
|
self.assertEqual(res, 3 * torch.ones(10))
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
"inductor" not in torch._dynamo.list_backends(),
|
||||||
|
"inductor backend is not available",
|
||||||
|
)
|
||||||
|
def test_save_and_load_inductor(self):
|
||||||
|
mod = MockModule()
|
||||||
|
opt_mod = torch.compile(mod, backend="inductor")
|
||||||
|
inp = torch.randn(10, 10)
|
||||||
|
opt_mod(inp)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
|
||||||
|
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
|
||||||
|
loaded_model(inp)
|
||||||
|
self.assertTrue(same_two_models(loaded_model, mod, [inp]))
|
||||||
|
self.assertTrue(same_two_models(loaded_model, opt_mod, [inp]))
|
||||||
|
|
||||||
|
torch._dynamo.reset() # force recompiles
|
||||||
|
torch._inductor.metrics.generated_kernel_count = 0
|
||||||
|
loaded_model(inp)
|
||||||
|
self.assertGreater(torch._inductor.metrics.generated_kernel_count, 0)
|
||||||
|
|
||||||
|
def test_save_and_load_all_backends(self):
|
||||||
|
mod = MockModule()
|
||||||
|
inp = torch.randn(10, 10)
|
||||||
|
for backend in torch._dynamo.list_backends():
|
||||||
|
try:
|
||||||
|
opt_mod = torch.compile(mod, backend=backend)
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
|
||||||
|
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
|
||||||
|
torch._dynamo.reset() # force recompiles
|
||||||
|
torch._inductor.metrics.generated_kernel_count = 0
|
||||||
|
opt_mod(inp)
|
||||||
|
opt_success = torch._inductor.metrics.generated_kernel_count == 0
|
||||||
|
torch._dynamo.reset() # force recompiles
|
||||||
|
torch._inductor.metrics.generated_kernel_count = 0
|
||||||
|
loaded_model(inp)
|
||||||
|
loaded_success = torch._inductor.metrics.generated_kernel_count == 0
|
||||||
|
self.assertEqual(opt_success, loaded_success)
|
||||||
|
except torch._dynamo.exc.BackendCompilerFailed:
|
||||||
|
pass
|
||||||
|
|
||||||
def test_monkeypatching_forward(self):
|
def test_monkeypatching_forward(self):
|
||||||
class FakeModule(torch.nn.Module):
|
class FakeModule(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|||||||
@ -14,18 +14,22 @@ from torch.utils._python_dispatch import _disable_current_modes
|
|||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def aot_autograd(**kwargs):
|
class AotAutograd:
|
||||||
def compiler_fn(gm: torch.fx.GraphModule, example_inputs):
|
def __init__(self, **kwargs):
|
||||||
|
self.__name__ = "compiler_fn"
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
|
||||||
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs):
|
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs):
|
||||||
return flatten_graph_inputs(
|
return flatten_graph_inputs(
|
||||||
gm,
|
gm,
|
||||||
example_inputs,
|
example_inputs,
|
||||||
compiler_fn,
|
self,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Hack to get around circular import problems with aot_eager_decomp_partition
|
# Hack to get around circular import problems with aot_eager_decomp_partition
|
||||||
if callable(kwargs.get("decompositions")):
|
if callable(self.kwargs.get("decompositions")):
|
||||||
kwargs["decompositions"] = kwargs["decompositions"]()
|
self.kwargs["decompositions"] = self.kwargs["decompositions"]()
|
||||||
|
|
||||||
# NB: dont delete counter increment
|
# NB: dont delete counter increment
|
||||||
counters["aot_autograd"]["total"] += 1
|
counters["aot_autograd"]["total"] += 1
|
||||||
@ -42,10 +46,10 @@ def aot_autograd(**kwargs):
|
|||||||
# stop TorchDynamo from trying to compile our generated backwards pass
|
# stop TorchDynamo from trying to compile our generated backwards pass
|
||||||
return disable(disable(bw_compiler)(*args, **kwargs))
|
return disable(disable(bw_compiler)(*args, **kwargs))
|
||||||
|
|
||||||
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
|
bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"]
|
||||||
kwargs["bw_compiler"] = _wrapped_bw_compiler
|
self.kwargs["bw_compiler"] = _wrapped_bw_compiler
|
||||||
kwargs["inference_compiler"] = (
|
self.kwargs["inference_compiler"] = (
|
||||||
kwargs.get("inference_compiler") or kwargs["fw_compiler"]
|
self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"]
|
||||||
)
|
)
|
||||||
|
|
||||||
from functorch.compile import nop
|
from functorch.compile import nop
|
||||||
@ -54,7 +58,7 @@ def aot_autograd(**kwargs):
|
|||||||
|
|
||||||
# debug asserts slow down compile time noticeably,
|
# debug asserts slow down compile time noticeably,
|
||||||
# So only default them on when the aot_eager backend is used.
|
# So only default them on when the aot_eager backend is used.
|
||||||
if kwargs.get("fw_compiler", None) == nop:
|
if self.kwargs.get("fw_compiler", None) == nop:
|
||||||
patch_config = patch("functorch.compile.config.debug_assert", True)
|
patch_config = patch("functorch.compile.config.debug_assert", True)
|
||||||
else:
|
else:
|
||||||
patch_config = contextlib.nullcontext()
|
patch_config = contextlib.nullcontext()
|
||||||
@ -62,14 +66,16 @@ def aot_autograd(**kwargs):
|
|||||||
try:
|
try:
|
||||||
# NB: NOT cloned!
|
# NB: NOT cloned!
|
||||||
with enable_aot_logging(), patch_config:
|
with enable_aot_logging(), patch_config:
|
||||||
cg = aot_module_simplified(gm, example_inputs, **kwargs)
|
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
|
||||||
counters["aot_autograd"]["ok"] += 1
|
counters["aot_autograd"]["ok"] += 1
|
||||||
return disable(cg)
|
return disable(cg)
|
||||||
except Exception:
|
except Exception:
|
||||||
counters["aot_autograd"]["not_ok"] += 1
|
counters["aot_autograd"]["not_ok"] += 1
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return compiler_fn
|
|
||||||
|
def aot_autograd(**kwargs):
|
||||||
|
return AotAutograd(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def mem_efficient_fusion_kwargs(use_decomps):
|
def mem_efficient_fusion_kwargs(use_decomps):
|
||||||
|
|||||||
@ -361,17 +361,34 @@ def cprofile_wrapper(func):
|
|||||||
return profile_wrapper
|
return profile_wrapper
|
||||||
|
|
||||||
|
|
||||||
def convert_frame_assert(
|
class ConvertFrameAssert:
|
||||||
compiler_fn: CompilerFn,
|
def __init__(
|
||||||
one_graph: bool = True,
|
self,
|
||||||
export: bool = False,
|
compiler_fn: CompilerFn,
|
||||||
export_constraints=None,
|
one_graph: bool = True,
|
||||||
):
|
export: bool = False,
|
||||||
"""Fully convert a frame into an FX graph"""
|
export_constraints=None,
|
||||||
reset_graph_break_dup_checker()
|
):
|
||||||
|
reset_graph_break_dup_checker()
|
||||||
|
self._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
|
||||||
|
self._one_graph = one_graph
|
||||||
|
self._export = export
|
||||||
|
self._export_constraints = export_constraints
|
||||||
|
|
||||||
def _convert_frame_assert(
|
@property
|
||||||
frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, *, skip: int = 0
|
def _clone_with_backend(self):
|
||||||
|
return lambda backend: convert_frame_assert(
|
||||||
|
backend, self._one_graph, self._export, self._export_constraints
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
frame: types.FrameType,
|
||||||
|
cache_entry,
|
||||||
|
hooks: Hooks,
|
||||||
|
frame_state,
|
||||||
|
*,
|
||||||
|
skip: int = 0,
|
||||||
):
|
):
|
||||||
increment_frame()
|
increment_frame()
|
||||||
|
|
||||||
@ -458,10 +475,10 @@ def convert_frame_assert(
|
|||||||
frame.f_globals,
|
frame.f_globals,
|
||||||
frame.f_locals,
|
frame.f_locals,
|
||||||
frame.f_builtins,
|
frame.f_builtins,
|
||||||
compiler_fn,
|
self._torchdynamo_orig_callable,
|
||||||
one_graph,
|
self._one_graph,
|
||||||
export,
|
self._export,
|
||||||
export_constraints,
|
self._export_constraints,
|
||||||
hooks,
|
hooks,
|
||||||
cache_entry,
|
cache_entry,
|
||||||
cache_size,
|
cache_size,
|
||||||
@ -471,13 +488,15 @@ def convert_frame_assert(
|
|||||||
skip=skip + 1,
|
skip=skip + 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
_convert_frame_assert._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
def _clone_with_backend(backend):
|
def convert_frame_assert(
|
||||||
return convert_frame_assert(backend, one_graph, export, export_constraints)
|
compiler_fn: CompilerFn,
|
||||||
|
one_graph: bool = True,
|
||||||
_convert_frame_assert._clone_with_backend = _clone_with_backend # type: ignore[attr-defined]
|
export: bool = False,
|
||||||
return _convert_frame_assert
|
export_constraints=None,
|
||||||
|
):
|
||||||
|
"""Fully convert a frame into an FX graph"""
|
||||||
|
return ConvertFrameAssert(compiler_fn, one_graph, export, export_constraints)
|
||||||
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
@ -907,16 +926,27 @@ def _compile(
|
|||||||
torch._dynamo.callback_handler.run_end_callbacks()
|
torch._dynamo.callback_handler.run_end_callbacks()
|
||||||
|
|
||||||
|
|
||||||
def convert_frame(compiler_fn: CompilerFn, hooks: Hooks):
|
class ConvertFrame:
|
||||||
"""Try to convert a frame into an FX graph, if error leave frame unmodified"""
|
def __init__(self, compiler_fn: CompilerFn, hooks: Hooks):
|
||||||
inner_convert = convert_frame_assert(compiler_fn, one_graph=False)
|
self._torchdynamo_orig_callable = compiler_fn
|
||||||
|
self._inner_convert = convert_frame_assert(compiler_fn, one_graph=False)
|
||||||
|
self._hooks = hooks
|
||||||
|
|
||||||
def _convert_frame(
|
@property
|
||||||
frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, skip: int = 0
|
def _clone_with_backend(self):
|
||||||
|
return lambda backend: convert_frame(backend, self._hooks)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
frame: types.FrameType,
|
||||||
|
cache_entry,
|
||||||
|
hooks: Hooks,
|
||||||
|
frame_state,
|
||||||
|
skip: int = 0,
|
||||||
):
|
):
|
||||||
counters["frames"]["total"] += 1
|
counters["frames"]["total"] += 1
|
||||||
try:
|
try:
|
||||||
result = inner_convert(
|
result = self._inner_convert(
|
||||||
frame, cache_entry, hooks, frame_state, skip=skip + 1
|
frame, cache_entry, hooks, frame_state, skip=skip + 1
|
||||||
)
|
)
|
||||||
counters["frames"]["ok"] += 1
|
counters["frames"]["ok"] += 1
|
||||||
@ -980,9 +1010,10 @@ def convert_frame(compiler_fn: CompilerFn, hooks: Hooks):
|
|||||||
log.warning(error_msg, exc_info=True)
|
log.warning(error_msg, exc_info=True)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
_convert_frame._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
|
|
||||||
_convert_frame._clone_with_backend = lambda backend: convert_frame(backend, hooks) # type: ignore[attr-defined]
|
def convert_frame(compiler_fn: CompilerFn, hooks: Hooks):
|
||||||
return _convert_frame
|
"""Try to convert a frame into an FX graph, if error leave frame unmodified"""
|
||||||
|
return ConvertFrame(compiler_fn, hooks)
|
||||||
|
|
||||||
|
|
||||||
# TODO mlazos: add support for same args, or record them
|
# TODO mlazos: add support for same args, or record them
|
||||||
@ -1023,9 +1054,13 @@ def first_real_inst_idx(code):
|
|||||||
raise RuntimeError("RESUME instruction not found in code")
|
raise RuntimeError("RESUME instruction not found in code")
|
||||||
|
|
||||||
|
|
||||||
def catch_errors_wrapper(callback, hooks: Hooks):
|
class CatchErrorsWrapper:
|
||||||
@functools.wraps(callback)
|
def __init__(self, callback, hooks):
|
||||||
def catch_errors(frame, cache_entry, frame_state):
|
functools.wraps(callback)(self)
|
||||||
|
self._torchdynamo_orig_callable = callback
|
||||||
|
self.hooks = hooks
|
||||||
|
|
||||||
|
def __call__(self, frame, cache_entry, frame_state):
|
||||||
assert frame_state is not None
|
assert frame_state is not None
|
||||||
|
|
||||||
is_skipfile = trace_rules.check(frame.f_code)
|
is_skipfile = trace_rules.check(frame.f_code)
|
||||||
@ -1063,19 +1098,26 @@ def catch_errors_wrapper(callback, hooks: Hooks):
|
|||||||
|
|
||||||
ddp_optimizer = DDPOptimizer(
|
ddp_optimizer = DDPOptimizer(
|
||||||
bucket_bytes_cap=ddp_module.bucket_bytes_cap,
|
bucket_bytes_cap=ddp_module.bucket_bytes_cap,
|
||||||
backend_compile_fn=callback._torchdynamo_orig_callable,
|
backend_compile_fn=self._torchdynamo_orig_callable._torchdynamo_orig_callable,
|
||||||
)
|
)
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
callback, "_clone_with_backend"
|
self._torchdynamo_orig_callable, "_clone_with_backend"
|
||||||
), "DDPOptimizer only supports callback fns that know how to clone themselves."
|
), "DDPOptimizer only supports callback fns that know how to clone themselves."
|
||||||
hijacked_callback = callback._clone_with_backend(
|
hijacked_callback = (
|
||||||
ddp_optimizer.compile_fn,
|
self._torchdynamo_orig_callable._clone_with_backend(
|
||||||
|
ddp_optimizer.compile_fn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return hijacked_callback(
|
||||||
|
frame, cache_entry, self.hooks, frame_state
|
||||||
)
|
)
|
||||||
return hijacked_callback(frame, cache_entry, hooks, frame_state)
|
|
||||||
|
|
||||||
with compile_lock, _disable_current_modes():
|
with compile_lock, _disable_current_modes():
|
||||||
# skip=1: skip this frame
|
# skip=1: skip this frame
|
||||||
return callback(frame, cache_entry, hooks, frame_state, skip=1)
|
return self._torchdynamo_orig_callable(
|
||||||
|
frame, cache_entry, self.hooks, frame_state, skip=1
|
||||||
|
)
|
||||||
|
|
||||||
catch_errors._torchdynamo_orig_callable = callback # type: ignore[attr-defined]
|
|
||||||
return catch_errors
|
def catch_errors_wrapper(callback, hooks: Hooks):
|
||||||
|
return CatchErrorsWrapper(callback, hooks)
|
||||||
|
|||||||
@ -168,6 +168,9 @@ class OptimizedModule(torch.nn.Module):
|
|||||||
self._forward = self.forward
|
self._forward = self.forward
|
||||||
self.forward = self._call_lazy_check
|
self.forward = self._call_lazy_check
|
||||||
|
|
||||||
|
def __reduce__(self):
|
||||||
|
return (self.__class__, (self._orig_mod, self.dynamo_ctx))
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
state = dict(self.__dict__)
|
state = dict(self.__dict__)
|
||||||
state.pop("forward", None)
|
state.pop("forward", None)
|
||||||
@ -273,9 +276,11 @@ class _TorchDynamoContext:
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
assert callable(callback) or callback is False or callback is None
|
assert callable(callback) or callback is False or callback is None
|
||||||
self.callback: DynamoCallback = callback
|
self.callback: DynamoCallback = callback
|
||||||
|
self._backend_ctx_ctor = backend_ctx_ctor
|
||||||
self.prior: Union[Unset, DynamoCallback] = unset
|
self.prior: Union[Unset, DynamoCallback] = unset
|
||||||
self.first_ctx = first_ctx
|
self.first_ctx = first_ctx
|
||||||
self.export = export
|
self.export = export
|
||||||
|
self._dynamic = dynamic
|
||||||
self.compiler_config = compiler_config
|
self.compiler_config = compiler_config
|
||||||
self.cleanup_fns: List[Callable[[], Any]] = []
|
self.cleanup_fns: List[Callable[[], Any]] = []
|
||||||
self.enter_exit_hooks = []
|
self.enter_exit_hooks = []
|
||||||
@ -379,7 +384,13 @@ class _TorchDynamoContext:
|
|||||||
# call to a builtin without a frame for us to capture
|
# call to a builtin without a frame for us to capture
|
||||||
fn = external_utils.wrap_inline(fn)
|
fn = external_utils.wrap_inline(fn)
|
||||||
|
|
||||||
callback = self.callback
|
def do_nothing(*arg, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if hasattr(self, "callback"):
|
||||||
|
callback = self.callback
|
||||||
|
else:
|
||||||
|
callback = do_nothing
|
||||||
|
|
||||||
is_jit_tracing = torch._C._is_tracing
|
is_jit_tracing = torch._C._is_tracing
|
||||||
is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing
|
is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing
|
||||||
@ -522,6 +533,17 @@ class OptimizeContext(_TorchDynamoContext):
|
|||||||
|
|
||||||
self.enter_exit_hooks.append(call_compiled_autograd)
|
self.enter_exit_hooks.append(call_compiled_autograd)
|
||||||
|
|
||||||
|
def __reduce__(self):
|
||||||
|
return (
|
||||||
|
self.__class__,
|
||||||
|
(self.callback, self._backend_ctx_ctor, self.first_ctx),
|
||||||
|
{
|
||||||
|
"export": self.export,
|
||||||
|
"dynamic": self._dynamic,
|
||||||
|
"compiler_config": self.compiler_config,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RunOnlyContext(_TorchDynamoContext):
|
class RunOnlyContext(_TorchDynamoContext):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -531,6 +553,9 @@ class RunOnlyContext(_TorchDynamoContext):
|
|||||||
|
|
||||||
super().__init__(callback=False, on_enter=on_enter)
|
super().__init__(callback=False, on_enter=on_enter)
|
||||||
|
|
||||||
|
def __reduce__(self):
|
||||||
|
return (self.__class__, ())
|
||||||
|
|
||||||
|
|
||||||
class DisableContext(_TorchDynamoContext):
|
class DisableContext(_TorchDynamoContext):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -583,6 +608,9 @@ class DisableContext(_TorchDynamoContext):
|
|||||||
|
|
||||||
return _fn
|
return _fn
|
||||||
|
|
||||||
|
def __reduce__(self):
|
||||||
|
return (self.__class__, ())
|
||||||
|
|
||||||
|
|
||||||
def _optimize_catch_errors(
|
def _optimize_catch_errors(
|
||||||
compile_fn,
|
compile_fn,
|
||||||
|
|||||||
@ -56,19 +56,20 @@ def _accuracy_fails(gm, example_inputs, compiler_fn):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
|
class WrapBackendDebug:
|
||||||
"""
|
def __init__(self, unconfigured_compiler_fn, compiler_name: str):
|
||||||
A minifier decorator that wraps the TorchDynamo produced Fx graph modules.
|
functools.wraps(unconfigured_compiler_fn)(self)
|
||||||
As opposed to wrap_compiler_debug, this wrapper intercepts at the
|
self._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined]
|
||||||
TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some
|
self._compiler_name = compiler_name
|
||||||
level, e.g., it is useful for minifying issues related to Aot Autograd
|
if hasattr(unconfigured_compiler_fn, "__name__"):
|
||||||
tracing. If an error is found, we minify and save the minified repro in
|
self.__name__ = unconfigured_compiler_fn.__name__
|
||||||
repro.tar.gz.
|
if hasattr(unconfigured_compiler_fn, "compiler_name"):
|
||||||
"""
|
self.__name__ = unconfigured_compiler_fn.compiler_name
|
||||||
|
if hasattr(unconfigured_compiler_fn, "get_compiler_config"):
|
||||||
|
self.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined]
|
||||||
|
|
||||||
@functools.wraps(unconfigured_compiler_fn)
|
def __call__(self, gm, example_inputs, **kwargs):
|
||||||
def debug_wrapper(gm, example_inputs, **kwargs):
|
compiler_fn = functools.partial(self._torchdynamo_orig_callable, **kwargs)
|
||||||
compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs)
|
|
||||||
assert config.repro_after in ("dynamo", "aot", None)
|
assert config.repro_after in ("dynamo", "aot", None)
|
||||||
|
|
||||||
if config.repro_after == "dynamo":
|
if config.repro_after == "dynamo":
|
||||||
@ -82,7 +83,7 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if config.repro_level == 3:
|
if config.repro_level == 3:
|
||||||
dump_to_minify_after_dynamo(gm, example_inputs, compiler_name)
|
dump_to_minify_after_dynamo(gm, example_inputs, self._compiler_name)
|
||||||
|
|
||||||
# Check for either accuracy (level 4) or other type of failures.
|
# Check for either accuracy (level 4) or other type of failures.
|
||||||
if config.repro_level == 4:
|
if config.repro_level == 4:
|
||||||
@ -95,7 +96,7 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
|
|||||||
dump_to_minify_after_dynamo(
|
dump_to_minify_after_dynamo(
|
||||||
fx.GraphModule(gm, copy.deepcopy(gm.graph)),
|
fx.GraphModule(gm, copy.deepcopy(gm.graph)),
|
||||||
example_inputs,
|
example_inputs,
|
||||||
compiler_name,
|
self._compiler_name,
|
||||||
)
|
)
|
||||||
exc = AccuracyError("Bad accuracy detected.")
|
exc = AccuracyError("Bad accuracy detected.")
|
||||||
add_paths(exc)
|
add_paths(exc)
|
||||||
@ -110,7 +111,7 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
|
|||||||
)
|
)
|
||||||
if config.repro_level == 1:
|
if config.repro_level == 1:
|
||||||
dump_state_fn = functools.partial(
|
dump_state_fn = functools.partial(
|
||||||
dump_backend_state, compiler_name=compiler_name
|
dump_backend_state, compiler_name=self._compiler_name
|
||||||
)
|
)
|
||||||
dump_state_fn(
|
dump_state_fn(
|
||||||
fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs
|
fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs
|
||||||
@ -119,7 +120,7 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
|
|||||||
dump_to_minify_after_dynamo(
|
dump_to_minify_after_dynamo(
|
||||||
fx.GraphModule(gm, copy.deepcopy(gm.graph)),
|
fx.GraphModule(gm, copy.deepcopy(gm.graph)),
|
||||||
example_inputs,
|
example_inputs,
|
||||||
compiler_name,
|
self._compiler_name,
|
||||||
)
|
)
|
||||||
add_paths(exc)
|
add_paths(exc)
|
||||||
raise
|
raise
|
||||||
@ -128,12 +129,17 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
|
|||||||
|
|
||||||
return compiled_gm
|
return compiled_gm
|
||||||
|
|
||||||
debug_wrapper._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined]
|
|
||||||
if hasattr(unconfigured_compiler_fn, "compiler_name"):
|
def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
|
||||||
debug_wrapper.__name__ = unconfigured_compiler_fn.compiler_name
|
"""
|
||||||
if hasattr(unconfigured_compiler_fn, "get_compiler_config"):
|
A minifier decorator that wraps the TorchDynamo produced Fx graph modules.
|
||||||
debug_wrapper.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined]
|
As opposed to wrap_compiler_debug, this wrapper intercepts at the
|
||||||
return debug_wrapper
|
TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some
|
||||||
|
level, e.g., it is useful for minifying issues related to Aot Autograd
|
||||||
|
tracing. If an error is found, we minify and save the minified repro in
|
||||||
|
repro.tar.gz.
|
||||||
|
"""
|
||||||
|
return WrapBackendDebug(unconfigured_compiler_fn, compiler_name)
|
||||||
|
|
||||||
|
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||||
|
|||||||
Reference in New Issue
Block a user