Opt model save and load (#126374)

## save&load support for OptimizedModule

[Issue Description](https://github.com/pytorch/pytorch/pull/101651)

English is not my native language; please excuse typing errors.

This pr is based on commit b9588101c4d3411b107fdc860acfa8a72c642f91\
I'll do something with the merge conflicts later

### test result for test/dynamo

Conclusion:\
It performs the same as before as far as I can see.

ENV(CPU only):\
platform linux -- Python 3.10.14, pytest-7.3.2, pluggy-1.5.0\
configfile: pytest.ini\
plugins: anyio-3.7.1, cpp-2.3.0, flakefinder-1.1.0, xdist-3.3.1, xdoctest-1.1.0, metadata-3.1.1, html-4.1.1, hypothesis-5.35.1, rerunfailures-14.0

#### before this pr:

[before](https://github.com/pytorch/pytorch/files/15329370/before.md)

#### after this pr:

[after](https://github.com/pytorch/pytorch/files/15329376/after.md)

### some changes

1. add test_save_and_load to test/dynamo/test_modules.py with & without "backend='inductor'"
2. add \_\_reduce\_\_ function to OptimizedModule and derived classes of _TorchDynamoContext for pickling & unpickling
3. change the wrappers into wrapper classes ( including convert_frame_assert, convert_frame, catch_errors_wrapper in torch/_dynamo/convert_frame.py & wrap_backend_debug in torch/_dynamo/repro/after_dynamo.py )
4. change self.output.compiler_fn into innermost_fn(self.output.compiler_fn) in torch/_dynamo/symbolic_convert.py to get the origin compiler_fn and to avoid the "compiler_fn is not eager" condition

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126374
Approved by: https://github.com/msaroufim, https://github.com/jansel
This commit is contained in:
weiyusheng
2024-06-05 13:01:16 +00:00
committed by PyTorch MergeBot
parent 9a8ab778d3
commit c3949b20a1
5 changed files with 203 additions and 75 deletions

View File

@ -3,6 +3,8 @@
import collections import collections
import copy import copy
import itertools import itertools
import os
import tempfile
import traceback import traceback
import types import types
import unittest import unittest
@ -16,6 +18,7 @@ import torch
import torch._dynamo.test_case import torch._dynamo.test_case
import torch._dynamo.testing import torch._dynamo.testing
import torch.nn.functional as F import torch.nn.functional as F
from torch._dynamo.debug_utils import same_two_models
from torch._dynamo.eval_frame import unsupported from torch._dynamo.eval_frame import unsupported
from torch._dynamo.mutation_guard import GenerationTracker from torch._dynamo.mutation_guard import GenerationTracker
from torch._dynamo.testing import expectedFailureDynamic, same from torch._dynamo.testing import expectedFailureDynamic, same
@ -2739,6 +2742,49 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
self.assertEqual(test_functions._variable, 1) self.assertEqual(test_functions._variable, 1)
self.assertEqual(res, 3 * torch.ones(10)) self.assertEqual(res, 3 * torch.ones(10))
@unittest.skipIf(
"inductor" not in torch._dynamo.list_backends(),
"inductor backend is not available",
)
def test_save_and_load_inductor(self):
mod = MockModule()
opt_mod = torch.compile(mod, backend="inductor")
inp = torch.randn(10, 10)
opt_mod(inp)
with tempfile.TemporaryDirectory() as tmpdirname:
torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
loaded_model(inp)
self.assertTrue(same_two_models(loaded_model, mod, [inp]))
self.assertTrue(same_two_models(loaded_model, opt_mod, [inp]))
torch._dynamo.reset() # force recompiles
torch._inductor.metrics.generated_kernel_count = 0
loaded_model(inp)
self.assertGreater(torch._inductor.metrics.generated_kernel_count, 0)
def test_save_and_load_all_backends(self):
mod = MockModule()
inp = torch.randn(10, 10)
for backend in torch._dynamo.list_backends():
try:
opt_mod = torch.compile(mod, backend=backend)
with tempfile.TemporaryDirectory() as tmpdirname:
torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
torch._dynamo.reset() # force recompiles
torch._inductor.metrics.generated_kernel_count = 0
opt_mod(inp)
opt_success = torch._inductor.metrics.generated_kernel_count == 0
torch._dynamo.reset() # force recompiles
torch._inductor.metrics.generated_kernel_count = 0
loaded_model(inp)
loaded_success = torch._inductor.metrics.generated_kernel_count == 0
self.assertEqual(opt_success, loaded_success)
except torch._dynamo.exc.BackendCompilerFailed:
pass
def test_monkeypatching_forward(self): def test_monkeypatching_forward(self):
class FakeModule(torch.nn.Module): class FakeModule(torch.nn.Module):
def forward(self, x): def forward(self, x):

View File

@ -14,18 +14,22 @@ from torch.utils._python_dispatch import _disable_current_modes
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def aot_autograd(**kwargs): class AotAutograd:
def compiler_fn(gm: torch.fx.GraphModule, example_inputs): def __init__(self, **kwargs):
self.__name__ = "compiler_fn"
self.kwargs = kwargs
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs): if any(isinstance(x, (list, tuple, dict)) for x in example_inputs):
return flatten_graph_inputs( return flatten_graph_inputs(
gm, gm,
example_inputs, example_inputs,
compiler_fn, self,
) )
# Hack to get around circular import problems with aot_eager_decomp_partition # Hack to get around circular import problems with aot_eager_decomp_partition
if callable(kwargs.get("decompositions")): if callable(self.kwargs.get("decompositions")):
kwargs["decompositions"] = kwargs["decompositions"]() self.kwargs["decompositions"] = self.kwargs["decompositions"]()
# NB: dont delete counter increment # NB: dont delete counter increment
counters["aot_autograd"]["total"] += 1 counters["aot_autograd"]["total"] += 1
@ -42,10 +46,10 @@ def aot_autograd(**kwargs):
# stop TorchDynamo from trying to compile our generated backwards pass # stop TorchDynamo from trying to compile our generated backwards pass
return disable(disable(bw_compiler)(*args, **kwargs)) return disable(disable(bw_compiler)(*args, **kwargs))
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"] bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"]
kwargs["bw_compiler"] = _wrapped_bw_compiler self.kwargs["bw_compiler"] = _wrapped_bw_compiler
kwargs["inference_compiler"] = ( self.kwargs["inference_compiler"] = (
kwargs.get("inference_compiler") or kwargs["fw_compiler"] self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"]
) )
from functorch.compile import nop from functorch.compile import nop
@ -54,7 +58,7 @@ def aot_autograd(**kwargs):
# debug asserts slow down compile time noticeably, # debug asserts slow down compile time noticeably,
# So only default them on when the aot_eager backend is used. # So only default them on when the aot_eager backend is used.
if kwargs.get("fw_compiler", None) == nop: if self.kwargs.get("fw_compiler", None) == nop:
patch_config = patch("functorch.compile.config.debug_assert", True) patch_config = patch("functorch.compile.config.debug_assert", True)
else: else:
patch_config = contextlib.nullcontext() patch_config = contextlib.nullcontext()
@ -62,14 +66,16 @@ def aot_autograd(**kwargs):
try: try:
# NB: NOT cloned! # NB: NOT cloned!
with enable_aot_logging(), patch_config: with enable_aot_logging(), patch_config:
cg = aot_module_simplified(gm, example_inputs, **kwargs) cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
counters["aot_autograd"]["ok"] += 1 counters["aot_autograd"]["ok"] += 1
return disable(cg) return disable(cg)
except Exception: except Exception:
counters["aot_autograd"]["not_ok"] += 1 counters["aot_autograd"]["not_ok"] += 1
raise raise
return compiler_fn
def aot_autograd(**kwargs):
return AotAutograd(**kwargs)
def mem_efficient_fusion_kwargs(use_decomps): def mem_efficient_fusion_kwargs(use_decomps):

View File

@ -361,17 +361,34 @@ def cprofile_wrapper(func):
return profile_wrapper return profile_wrapper
def convert_frame_assert( class ConvertFrameAssert:
compiler_fn: CompilerFn, def __init__(
one_graph: bool = True, self,
export: bool = False, compiler_fn: CompilerFn,
export_constraints=None, one_graph: bool = True,
): export: bool = False,
"""Fully convert a frame into an FX graph""" export_constraints=None,
reset_graph_break_dup_checker() ):
reset_graph_break_dup_checker()
self._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
self._one_graph = one_graph
self._export = export
self._export_constraints = export_constraints
def _convert_frame_assert( @property
frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, *, skip: int = 0 def _clone_with_backend(self):
return lambda backend: convert_frame_assert(
backend, self._one_graph, self._export, self._export_constraints
)
def __call__(
self,
frame: types.FrameType,
cache_entry,
hooks: Hooks,
frame_state,
*,
skip: int = 0,
): ):
increment_frame() increment_frame()
@ -458,10 +475,10 @@ def convert_frame_assert(
frame.f_globals, frame.f_globals,
frame.f_locals, frame.f_locals,
frame.f_builtins, frame.f_builtins,
compiler_fn, self._torchdynamo_orig_callable,
one_graph, self._one_graph,
export, self._export,
export_constraints, self._export_constraints,
hooks, hooks,
cache_entry, cache_entry,
cache_size, cache_size,
@ -471,13 +488,15 @@ def convert_frame_assert(
skip=skip + 1, skip=skip + 1,
) )
_convert_frame_assert._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
def _clone_with_backend(backend): def convert_frame_assert(
return convert_frame_assert(backend, one_graph, export, export_constraints) compiler_fn: CompilerFn,
one_graph: bool = True,
_convert_frame_assert._clone_with_backend = _clone_with_backend # type: ignore[attr-defined] export: bool = False,
return _convert_frame_assert export_constraints=None,
):
"""Fully convert a frame into an FX graph"""
return ConvertFrameAssert(compiler_fn, one_graph, export, export_constraints)
from collections import OrderedDict from collections import OrderedDict
@ -907,16 +926,27 @@ def _compile(
torch._dynamo.callback_handler.run_end_callbacks() torch._dynamo.callback_handler.run_end_callbacks()
def convert_frame(compiler_fn: CompilerFn, hooks: Hooks): class ConvertFrame:
"""Try to convert a frame into an FX graph, if error leave frame unmodified""" def __init__(self, compiler_fn: CompilerFn, hooks: Hooks):
inner_convert = convert_frame_assert(compiler_fn, one_graph=False) self._torchdynamo_orig_callable = compiler_fn
self._inner_convert = convert_frame_assert(compiler_fn, one_graph=False)
self._hooks = hooks
def _convert_frame( @property
frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, skip: int = 0 def _clone_with_backend(self):
return lambda backend: convert_frame(backend, self._hooks)
def __call__(
self,
frame: types.FrameType,
cache_entry,
hooks: Hooks,
frame_state,
skip: int = 0,
): ):
counters["frames"]["total"] += 1 counters["frames"]["total"] += 1
try: try:
result = inner_convert( result = self._inner_convert(
frame, cache_entry, hooks, frame_state, skip=skip + 1 frame, cache_entry, hooks, frame_state, skip=skip + 1
) )
counters["frames"]["ok"] += 1 counters["frames"]["ok"] += 1
@ -980,9 +1010,10 @@ def convert_frame(compiler_fn: CompilerFn, hooks: Hooks):
log.warning(error_msg, exc_info=True) log.warning(error_msg, exc_info=True)
return None return None
_convert_frame._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
_convert_frame._clone_with_backend = lambda backend: convert_frame(backend, hooks) # type: ignore[attr-defined] def convert_frame(compiler_fn: CompilerFn, hooks: Hooks):
return _convert_frame """Try to convert a frame into an FX graph, if error leave frame unmodified"""
return ConvertFrame(compiler_fn, hooks)
# TODO mlazos: add support for same args, or record them # TODO mlazos: add support for same args, or record them
@ -1023,9 +1054,13 @@ def first_real_inst_idx(code):
raise RuntimeError("RESUME instruction not found in code") raise RuntimeError("RESUME instruction not found in code")
def catch_errors_wrapper(callback, hooks: Hooks): class CatchErrorsWrapper:
@functools.wraps(callback) def __init__(self, callback, hooks):
def catch_errors(frame, cache_entry, frame_state): functools.wraps(callback)(self)
self._torchdynamo_orig_callable = callback
self.hooks = hooks
def __call__(self, frame, cache_entry, frame_state):
assert frame_state is not None assert frame_state is not None
is_skipfile = trace_rules.check(frame.f_code) is_skipfile = trace_rules.check(frame.f_code)
@ -1063,19 +1098,26 @@ def catch_errors_wrapper(callback, hooks: Hooks):
ddp_optimizer = DDPOptimizer( ddp_optimizer = DDPOptimizer(
bucket_bytes_cap=ddp_module.bucket_bytes_cap, bucket_bytes_cap=ddp_module.bucket_bytes_cap,
backend_compile_fn=callback._torchdynamo_orig_callable, backend_compile_fn=self._torchdynamo_orig_callable._torchdynamo_orig_callable,
) )
assert hasattr( assert hasattr(
callback, "_clone_with_backend" self._torchdynamo_orig_callable, "_clone_with_backend"
), "DDPOptimizer only supports callback fns that know how to clone themselves." ), "DDPOptimizer only supports callback fns that know how to clone themselves."
hijacked_callback = callback._clone_with_backend( hijacked_callback = (
ddp_optimizer.compile_fn, self._torchdynamo_orig_callable._clone_with_backend(
ddp_optimizer.compile_fn,
)
)
return hijacked_callback(
frame, cache_entry, self.hooks, frame_state
) )
return hijacked_callback(frame, cache_entry, hooks, frame_state)
with compile_lock, _disable_current_modes(): with compile_lock, _disable_current_modes():
# skip=1: skip this frame # skip=1: skip this frame
return callback(frame, cache_entry, hooks, frame_state, skip=1) return self._torchdynamo_orig_callable(
frame, cache_entry, self.hooks, frame_state, skip=1
)
catch_errors._torchdynamo_orig_callable = callback # type: ignore[attr-defined]
return catch_errors def catch_errors_wrapper(callback, hooks: Hooks):
return CatchErrorsWrapper(callback, hooks)

View File

@ -168,6 +168,9 @@ class OptimizedModule(torch.nn.Module):
self._forward = self.forward self._forward = self.forward
self.forward = self._call_lazy_check self.forward = self._call_lazy_check
def __reduce__(self):
return (self.__class__, (self._orig_mod, self.dynamo_ctx))
def __getstate__(self): def __getstate__(self):
state = dict(self.__dict__) state = dict(self.__dict__)
state.pop("forward", None) state.pop("forward", None)
@ -273,9 +276,11 @@ class _TorchDynamoContext:
super().__init__() super().__init__()
assert callable(callback) or callback is False or callback is None assert callable(callback) or callback is False or callback is None
self.callback: DynamoCallback = callback self.callback: DynamoCallback = callback
self._backend_ctx_ctor = backend_ctx_ctor
self.prior: Union[Unset, DynamoCallback] = unset self.prior: Union[Unset, DynamoCallback] = unset
self.first_ctx = first_ctx self.first_ctx = first_ctx
self.export = export self.export = export
self._dynamic = dynamic
self.compiler_config = compiler_config self.compiler_config = compiler_config
self.cleanup_fns: List[Callable[[], Any]] = [] self.cleanup_fns: List[Callable[[], Any]] = []
self.enter_exit_hooks = [] self.enter_exit_hooks = []
@ -379,7 +384,13 @@ class _TorchDynamoContext:
# call to a builtin without a frame for us to capture # call to a builtin without a frame for us to capture
fn = external_utils.wrap_inline(fn) fn = external_utils.wrap_inline(fn)
callback = self.callback def do_nothing(*arg, **kwargs):
pass
if hasattr(self, "callback"):
callback = self.callback
else:
callback = do_nothing
is_jit_tracing = torch._C._is_tracing is_jit_tracing = torch._C._is_tracing
is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing
@ -522,6 +533,17 @@ class OptimizeContext(_TorchDynamoContext):
self.enter_exit_hooks.append(call_compiled_autograd) self.enter_exit_hooks.append(call_compiled_autograd)
def __reduce__(self):
return (
self.__class__,
(self.callback, self._backend_ctx_ctor, self.first_ctx),
{
"export": self.export,
"dynamic": self._dynamic,
"compiler_config": self.compiler_config,
},
)
class RunOnlyContext(_TorchDynamoContext): class RunOnlyContext(_TorchDynamoContext):
def __init__(self): def __init__(self):
@ -531,6 +553,9 @@ class RunOnlyContext(_TorchDynamoContext):
super().__init__(callback=False, on_enter=on_enter) super().__init__(callback=False, on_enter=on_enter)
def __reduce__(self):
return (self.__class__, ())
class DisableContext(_TorchDynamoContext): class DisableContext(_TorchDynamoContext):
def __init__(self): def __init__(self):
@ -583,6 +608,9 @@ class DisableContext(_TorchDynamoContext):
return _fn return _fn
def __reduce__(self):
return (self.__class__, ())
def _optimize_catch_errors( def _optimize_catch_errors(
compile_fn, compile_fn,

View File

@ -56,19 +56,20 @@ def _accuracy_fails(gm, example_inputs, compiler_fn):
) )
def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str): class WrapBackendDebug:
""" def __init__(self, unconfigured_compiler_fn, compiler_name: str):
A minifier decorator that wraps the TorchDynamo produced Fx graph modules. functools.wraps(unconfigured_compiler_fn)(self)
As opposed to wrap_compiler_debug, this wrapper intercepts at the self._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined]
TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some self._compiler_name = compiler_name
level, e.g., it is useful for minifying issues related to Aot Autograd if hasattr(unconfigured_compiler_fn, "__name__"):
tracing. If an error is found, we minify and save the minified repro in self.__name__ = unconfigured_compiler_fn.__name__
repro.tar.gz. if hasattr(unconfigured_compiler_fn, "compiler_name"):
""" self.__name__ = unconfigured_compiler_fn.compiler_name
if hasattr(unconfigured_compiler_fn, "get_compiler_config"):
self.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined]
@functools.wraps(unconfigured_compiler_fn) def __call__(self, gm, example_inputs, **kwargs):
def debug_wrapper(gm, example_inputs, **kwargs): compiler_fn = functools.partial(self._torchdynamo_orig_callable, **kwargs)
compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs)
assert config.repro_after in ("dynamo", "aot", None) assert config.repro_after in ("dynamo", "aot", None)
if config.repro_after == "dynamo": if config.repro_after == "dynamo":
@ -82,7 +83,7 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
) )
if config.repro_level == 3: if config.repro_level == 3:
dump_to_minify_after_dynamo(gm, example_inputs, compiler_name) dump_to_minify_after_dynamo(gm, example_inputs, self._compiler_name)
# Check for either accuracy (level 4) or other type of failures. # Check for either accuracy (level 4) or other type of failures.
if config.repro_level == 4: if config.repro_level == 4:
@ -95,7 +96,7 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
dump_to_minify_after_dynamo( dump_to_minify_after_dynamo(
fx.GraphModule(gm, copy.deepcopy(gm.graph)), fx.GraphModule(gm, copy.deepcopy(gm.graph)),
example_inputs, example_inputs,
compiler_name, self._compiler_name,
) )
exc = AccuracyError("Bad accuracy detected.") exc = AccuracyError("Bad accuracy detected.")
add_paths(exc) add_paths(exc)
@ -110,7 +111,7 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
) )
if config.repro_level == 1: if config.repro_level == 1:
dump_state_fn = functools.partial( dump_state_fn = functools.partial(
dump_backend_state, compiler_name=compiler_name dump_backend_state, compiler_name=self._compiler_name
) )
dump_state_fn( dump_state_fn(
fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs
@ -119,7 +120,7 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
dump_to_minify_after_dynamo( dump_to_minify_after_dynamo(
fx.GraphModule(gm, copy.deepcopy(gm.graph)), fx.GraphModule(gm, copy.deepcopy(gm.graph)),
example_inputs, example_inputs,
compiler_name, self._compiler_name,
) )
add_paths(exc) add_paths(exc)
raise raise
@ -128,12 +129,17 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
return compiled_gm return compiled_gm
debug_wrapper._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined]
if hasattr(unconfigured_compiler_fn, "compiler_name"): def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
debug_wrapper.__name__ = unconfigured_compiler_fn.compiler_name """
if hasattr(unconfigured_compiler_fn, "get_compiler_config"): A minifier decorator that wraps the TorchDynamo produced Fx graph modules.
debug_wrapper.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined] As opposed to wrap_compiler_debug, this wrapper intercepts at the
return debug_wrapper TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some
level, e.g., it is useful for minifying issues related to Aot Autograd
tracing. If an error is found, we minify and save the minified repro in
repro.tar.gz.
"""
return WrapBackendDebug(unconfigured_compiler_fn, compiler_name)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #