From 1e7947b3e0707b47d0a0c6432fe34031a5a5ae36 Mon Sep 17 00:00:00 2001 From: Michael Voznesensky Date: Tue, 10 Oct 2023 18:01:02 -0700 Subject: [PATCH] Revert "Reland 3rd try [finishing colesbury's PR 100642] Guard on nn.Module dicts and type (#109323)" + Forward fixes + test (#110964) This reverts commit f786fbdebdd24d3a6807e3b9fbf055836db4ad60. Forward fixes Pull Request resolved: https://github.com/pytorch/pytorch/pull/110964 Approved by: https://github.com/ezyang, https://github.com/anijain2305 --- .../inductor_torchbench_dynamic_inference.csv | 3 +- .../inductor_torchbench_dynamic_training.csv | 3 +- .../inductor_torchbench_inference.csv | 3 +- .../inductor_torchbench_training.csv | 5 +- benchmarks/dynamo/common.py | 5 +- docs/source/torch.compiler_nn_module.rst | 30 ++- test/dynamo/test_modules.py | 165 +++++--------- test/dynamo/test_recompiles.py | 21 ++ torch/_dynamo/config.py | 2 - torch/_dynamo/guards.py | 84 ++------ torch/csrc/dynamo/guards.cpp | 203 +----------------- 11 files changed, 139 insertions(+), 385 deletions(-) diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_inference.csv index 7cb2d713b73e..2df4a137cfaa 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_inference.csv @@ -8,7 +8,8 @@ basic_gnn_edgecnn,pass,0 basic_gnn_gcn,pass,6 basic_gnn_gin,pass,0 basic_gnn_sage,pass,0 -cm3leon_generate,pass,7 +clip,pass_due_to_skip,0 +cm3leon_generate,pass,5 dcgan,pass,0 dlrm,pass,0 doctr_det_predictor,pass,2 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_training.csv index 3dd1c6d6230b..057a5a2dd180 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_training.csv @@ -8,6 +8,7 @@ basic_gnn_edgecnn,pass,23 basic_gnn_gcn,pass,14 basic_gnn_gin,pass,8 basic_gnn_sage,pass,8 +clip,pass_due_to_skip,0 dcgan,pass,8 dlrm,pass,8 drq,pass,7 @@ -18,7 +19,7 @@ hf_Albert,pass,7 hf_Bart,pass,7 hf_DistilBert,pass,7 hf_GPT2,pass,7 -hf_Reformer,pass,28 +hf_Reformer,pass,27 hf_T5_large,pass_due_to_skip,0 hf_Whisper,pass,7 lennard_jones,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv index ae2e94b233b9..d50c310b6bde 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv @@ -8,7 +8,8 @@ basic_gnn_edgecnn,pass,0 basic_gnn_gcn,pass,6 basic_gnn_gin,pass,0 basic_gnn_sage,pass,0 -cm3leon_generate,pass,7 +clip,pass_due_to_skip,0 +cm3leon_generate,pass,5 dcgan,pass,0 dlrm,pass,0 doctr_det_predictor,pass,2 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv index 3f61b8c5acfd..d25d951da550 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv @@ -8,6 +8,7 @@ basic_gnn_edgecnn,pass,23 basic_gnn_gcn,pass,14 basic_gnn_gin,pass,8 basic_gnn_sage,pass,8 +clip,pass_due_to_skip,0 dcgan,pass,8 dlrm,pass,8 drq,pass,7 @@ -18,7 +19,7 @@ hf_Albert,pass,7 hf_Bart,pass,7 hf_DistilBert,pass,7 hf_GPT2,pass,7 -hf_Reformer,pass,28 +hf_Reformer,pass,27 hf_T5_large,pass_due_to_skip,0 hf_Whisper,pass,7 lennard_jones,pass,8 @@ -51,5 +52,5 @@ timm_vision_transformer_large,pass_due_to_skip,0 timm_vovnet,pass,8 tts_angular,pass,10 vgg16,pass,8 -vision_maskrcnn,fail_accuracy,39 +vision_maskrcnn,fail_accuracy,37 yolov3,pass,10 diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 3de8e25a4336..e8c0cb9afbd8 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -2478,9 +2478,6 @@ class BenchmarkRunner: # Use distributed wrapping as necessary model = self.deepcopy_and_maybe_ddp(model) - if not hasattr(model, name): - model.name = name - self.init_optimizer(name, current_device, model.parameters()) with self.pick_grad(name, self.args.training): ok, total = Stats.reset_counters() @@ -2540,6 +2537,8 @@ class BenchmarkRunner: f"{ok:3}/{total:3} +{frames_third_pass} frames {compilation_time:3.0f}s" ) + if not hasattr(model, name): + model.name = name results.append(experiment(model, example_inputs, **experiment_kwargs)) return " ".join(map(str, results)) diff --git a/docs/source/torch.compiler_nn_module.rst b/docs/source/torch.compiler_nn_module.rst index 0afd6905d1af..21a8e624a247 100644 --- a/docs/source/torch.compiler_nn_module.rst +++ b/docs/source/torch.compiler_nn_module.rst @@ -1,5 +1,5 @@ -PyTorch 2.0 nn.Module Support -============================= +PyTorch 2.0 NNModule Support +============================ **Author**: `Will Constable `_ @@ -8,9 +8,12 @@ arbitrary python classes, with the intent of producing faster code by making ass This doc describes some of the tradeoffs or edge cases that come up due to this specialization. -`nn.Module` Hooks Support -------------------------- -`torch.compile` now has partial support for forward and backward hooks on nn.Modules. +NNModule Hooks Support +---------------------- +Previously, `torch.compile` had no support for hooks on nn.Modules, and if hooks were registered +they would simply be ignored in the compiled program. Indeed many users do not +use nn.Module hooks at all, or only use them for debug workflows, but there are valid use cases +for composing nn.Module hooks with `torch.compile`. Hooks that are orchestrated via nn.Module.__call__ implementation include `_forward_pre_hooks`, `forward_hooks`, `_backward_pre_hooks`, and `_backward_hooks`, and will be referred to as 'call hooks'. @@ -22,11 +25,11 @@ unsupported by `torch.compile`. `nn.Module.__call__` Hooks Usage and limitations ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ By default, `torch.compile` will trace the contents of `nn.Module.__call__` which means it will encounter -and run forward/pre-forward hooks. `torch.compile` installs guards that detect added and removed hooks, -and will trigger a recompilation if the forward/pre-forward hooks change. +and run forward/pre-forward hooks. If you install hooks before calling `torch.compile` and then do not remove +or alter the hooks later, your use case should be supported by default. Backward/Pre-backward hooks are generally also supported, with similar caveats: currently graph-breaks in dynamo -occur when accessing backward_hooks dicts, which is probably avoidable with some work. Graph-breaks also impact the +occur when accessing backward_hooks dicts, which is probably avoiable with some work. Graph-breaks also impact the timing of firing backward hooks, since graph-segments are run as autograd-functions which produce all their grads at the same time. Assuming it were possible for dynamo to not graph-break on the presence of backward-hooks, we would still expect the backward hooks for a series of modules to all fire together after the whole compiled graph's backward @@ -38,6 +41,17 @@ by allowing them to be called opaquely in the dynamo graph instead of traced int currently trigger a graph-break so that the affected modules run outside of dynamo. Depending on the model, this could introduce a significant performance regression, and additional work is required to improve this support. +**skip_nnmodule_hook_guards** +By default, `torch._dynamo.config.skip_nnmodule_hook_guards` is set to True, meaning no guards will be installed +on each nn.Module hook dictionary, improving runtime by reducing guard execution time, at the cost of not noticing +if any hook dict is changed after compilation. + +If you want to be able to remove or modify hooks after compilation and have `torch.compile` react appropriately +(by recompiling), then you need to set `skip_nnmodule_hook_guards=False` and expect a runtime penalty for the added +guards. + +TODO: confirm if backward/pre_backward hooks are working or not and document accordingly + state_dict Hooks ~~~~~~~~~~~~~~~~ State dict hooks have not yet been supported in `torch.compile`. diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 50694b57a9ca..b3183c029fec 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -1247,7 +1247,6 @@ class NNModuleTests(torch._dynamo.test_case.TestCase): module_dict = torch.nn.ModuleDict({"cat": torch.nn.Conv2d(1, 1, 1)}) pre = m(data) cnt.clear() - torch._dynamo.reset() with torch._dynamo.optimize(cnt, nopython=False): opt_pre = m(data) @@ -1256,8 +1255,8 @@ class NNModuleTests(torch._dynamo.test_case.TestCase): out1 = m(data) out_post = m(data) - self.assertEqual(cnt.frame_count, 2) - self.assertEqual(cnt.op_count, 2) + self.assertEqual(cnt.frame_count, 1) + self.assertEqual(cnt.op_count, 1) self.assertTrue(torch._dynamo.testing.same(pre, opt_pre)) self.assertTrue(torch._dynamo.testing.same(out1, out_post)) @@ -1731,112 +1730,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): ) ) - def test_hooks_recompile(self): - # Modifying hooks should lead to a recompiation - class TestModule(torch.nn.Module): - def forward(self, x: torch.Tensor) -> torch.Tensor: - return 2 * x + 1 - - def compute_output_and_grad(m, x): - output = m(x) - output.sum().backward() - return x.grad - - def forward_pre_hook(module: torch.nn.Module, inputs: Tuple[torch.Tensor]): - return (2 * inputs[0] + 1,) - - def forward_hook( - module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor - ): - return 2 * output + 1 - - def backward_hook(module, grad_input, grad_output): - if len(grad_input) == 1: - return (grad_input[0] * 3,) - else: - return (grad_input[0] * 3, None) - - def backward_pre_hook(module, grad_outputs): - return (grad_outputs[0] * 5,) - - def run_test_case(hook_type, hook_func, expected_grad): - torch._dynamo.reset() - m = TestModule() - input = torch.ones(10, requires_grad=True) - cnt = torch._dynamo.testing.CompileCounter() - opt = torch._dynamo.optimize(cnt)(compute_output_and_grad) - - grad1 = opt(m, input) - self.assertEqual(cnt.frame_count, 1) - self.assertEqual(grad1, torch.full_like(grad1, 2)) - - input.grad = None - handle = getattr(m, hook_type)(hook_func) - grad2 = opt(m, input) - frame_count2 = cnt.frame_count - # Some backward hooks lead to graph breaks so frame_count may be 2 or 3 - self.assertGreaterEqual(frame_count2, 2) - self.assertEqual(grad2, torch.full_like(grad2, expected_grad)) - - # Running again should not recompile - opt(m, input) - self.assertEqual(cnt.frame_count, frame_count2) - - # Removing handle should lead to original result - input.grad = None - handle.remove() - grad3 = opt(m, input) - self.assertEqual(grad1, grad3) - - run_test_case("register_forward_pre_hook", forward_pre_hook, 4) - run_test_case("register_forward_hook", forward_hook, 4) - run_test_case("register_backward_hook", backward_hook, 6) - run_test_case("register_full_backward_hook", backward_hook, 6) - run_test_case("register_full_backward_pre_hook", backward_pre_hook, 10) - - def test_unspecialized_nn_module(self): - # This test is little confusing because of combination of - # nn_module_guard and unspecialized nn module variable. - - # The graph break in forward causes two graphs - # 1) The first graph has self.relu which introduces a nn_module_guard - # 2) The second graph first assumes self to be NNModuleVariable, but the - # restarts the analysis with self mapping to - # UnSpecializedNNModuleVariable, on witnessing self.a += 1. - - # Now, when we run the compiled mod the first time, it changes the value - # of self.a. This is fine for the first run. But, when we run the - # compiled module again, the first graph recompiles. This is because - # self.a has changed, changing the ma_version_tag, causing - # nn_module_guard to fail. - - # At this point, we might feel that this is doomed as we will always - # keep recompiling on the first graph. But, then Dynamo has already - # marked the self to be UnspecializedNNModuleVariable (because of self.a - # in the second graph), and therefore during the recompilation, we do - # not introduce any nn_module_guard. So, in all we have just one - # recompilation. - class Mock(torch.nn.Module): - def __init__(self): - super().__init__() - self.a = 5 - self.relu = torch.nn.ReLU() - - def forward(self, x): - z = self.relu(x) - torch._dynamo.graph_break() - self.a += 1 - return z * self.a - - mod = Mock() - cnt = torch._dynamo.testing.CompileCounter() - opt = torch.compile(mod, backend=cnt) - - for _ in range(5): - opt(torch.randn(4)) - - self.assertEqual(cnt.frame_count, 4) - + @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False) def test_hooks_outer(self): class TestModule(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -1883,6 +1777,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): the eval_frame entrypoint to Module.__call__? """ + @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False) def test_hooks_inner(self): class TestModule(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -1927,7 +1822,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): handle.remove() self.assertEqual(compiled_func(inp), outer_func(inp)) self.assertEqual(compiled_func(inp).item(), 7) - self.assertTrue("__nn_module_guard" in failure_reason) + self.assertTrue("forward_hooks.keys" in failure_reason) self.assertEqual(cc.frame_count, 1 + 1) self.assertEqual(cc.op_count, 6 + 4) @@ -1947,7 +1842,55 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): m._forward_hooks[handle.id] = new_forward_hook self.assertEqual(compiled_func(inp), outer_func(inp)) self.assertEqual(compiled_func(inp).item(), 16) - self.assertTrue("__nn_module_guard" in failure_reason) + self.assertTrue("___check_obj_id(L['m']._forward_hooks" in failure_reason) + + @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", True) + def test_hooks_skip_guards(self): + class TestModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return 2 * x + 1 + + m = TestModule() + + def forward_hook( + module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor + ) -> torch.Tensor: + return 2 * output + 1 + + handle = m.register_forward_hook(forward_hook) + + def outer_func(tensor): + x = tensor * 2 + 1 + y = m(x) + return y + + inp = torch.tensor(1.0, requires_grad=True) + + failure_reason = None + + def guard_fail_fn(failure): + nonlocal failure_reason + failure_reason = failure[0] + + cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") + compiled_func = torch._dynamo.optimize( + guard_fail_fn=guard_fail_fn, + backend=cc, + )(outer_func) + + m = TestModule() + handle = m.register_forward_hook(forward_hook) + failure_reason = None + self.assertEqual(compiled_func(inp), outer_func(inp)) + self.assertEqual(compiled_func(inp).item(), 15) + self.assertEqual(cc.frame_count, 1) + self.assertEqual(cc.op_count, 6) + + # if we remove the hook, dynamo shouldn't notice + handle.remove() + self.assertNotEqual(compiled_func(inp), outer_func(inp)) + self.assertEqual(compiled_func(inp).item(), 15) + self.assertEqual(cc.frame_count, 1) def _forward_hook_test_helper(self, model): forward_handles = {} diff --git a/test/dynamo/test_recompiles.py b/test/dynamo/test_recompiles.py index a30a0b5374bc..ff39d0c8052a 100644 --- a/test/dynamo/test_recompiles.py +++ b/test/dynamo/test_recompiles.py @@ -295,6 +295,27 @@ class RecompileTests(torch._dynamo.test_case.TestCase): self.assertEqual(dynamic_comp_dynamic_param.frame_count, 2) self.assertEqual(dynamic_comp_dynamic_param.op_count, 2) + def test_simple_module_recompile(self): + class SimpleDropout(torch.nn.Module): + def __init__(self): + super().__init__() + self.dropout = torch.nn.Dropout(0.5) + self.linear = torch.nn.Linear(10, 1) + + def forward(self, x): + return self.dropout(self.linear(x)) + + model = SimpleDropout() + x = torch.randn(10) + counter = torch._dynamo.testing.CompileCounter() + model = torch.compile(model, backend=counter, fullgraph=True) + for _ in range(20): + model.eval() + model(x) + model.train() + model(x) + self.assertEqual(counter.frame_count, 2) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 4f8c825af760..7bd172e4db08 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -233,8 +233,6 @@ skip_fsdp_guards = True # Make dynamo skip guarding on hooks on nn modules # Note: unsafe: if your model actually has hooks and you remove them, or doesn't and you add them, # dynamo will not notice and will execute whichever version you first compiled. -# TODO(janimesh) - Remove as once internal is not reliant on this flag. This -# flag has no meaning now after nn_module_guard is introduced. skip_nnmodule_hook_guards = True # If True, raises exception if TorchDynamo is called with a context manager diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 29dd8d8cdaa4..75375116a8c6 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -20,7 +20,6 @@ from inspect import currentframe, getframeinfo from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from weakref import ReferenceType -from .allowed_functions import is_allowed try: import numpy as np @@ -453,52 +452,22 @@ class GuardBuilder(GuardBuilderBase): self.EQUALS_MATCH(guard) def NN_MODULE(self, guard: Guard): - # TODO(janimesh) - This id_match can be removed because nn_module_guard - # checks this in C. However, we need this redundant check to allow - # flexible cache size policy when we guard on self of nn module - # instances. See Note in torch/_dynamo/cache_size.py self.ID_MATCH(guard) ref = self.arg_ref(guard) val = self.get(guard.name) - # The module guard checks for modifications to the Module's type, __dict__, - # and various nested OrderedDicts, such as _parameters, _buffers, and _modules. - # This subsumes the check for Module.training. - if not hasattr(val, "training"): - unimplemented(f"Guard setup for uninitialized class {type(val)}") - if config.allow_rnn and isinstance( - val, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM) - ): - # TorchDynamo graph breaks on LSTMs, but this is a way if user wants - # to override it. LSTMs change the module state in every invocation, - # leading to recompilations. - log.warning("Skipping nn module guard on LSTMs") - return - - # Dynamo does not trace inside the inbuilt torch nn modules. Skip - # guarding on those. More rationale at https://github.com/pytorch/pytorch/issues/110048 - if is_allowed(val.__class__): - return - try: - g = torch._C._dynamo.guards.nn_module_guard(val) - except AttributeError: - # We get an attribute error if the module is partially initialized. For example, - # we might be trying to install a guard before a super().__init__() call when - # the module is missing _parameters, _modules, and other attributes. - # For now, we skip installing the guard. - log.warning( - "Skipping nn module guard because the module could be partially initialized" + def setup_guard(): + assert istype(val.training, bool) + # TODO: Why doesn't this use produce_guard_code? + self.code.append( + GuardCodeList([f"{ref}.training == {val.training}"], guard) ) - return - name = self.check_fn_manager.add_extra_closure_var("__nn_module_guard", g) - if guards_log.isEnabledFor(logging.DEBUG): - # Avoid debug_msg related python bytecode overhead in the runtime. - # debug_msg is only for debugging help and goes to kwargs of guard call, - # which is ignored. - self._produce_guard_code(guard, [f'{name}({ref}, debug_msg="{g}")']) + if hasattr(val, "training"): + # There are cases where a monkeypatched object has a guard made between __new__ and __init__ + setup_guard() else: - self._produce_guard_code(guard, [f"{name}({ref})"]) + unimplemented(f"Guard setup for uninitialized class {type(val)}") def FUNCTION_MATCH(self, guard: Guard): """things like torch.add and user defined functions""" @@ -960,9 +929,7 @@ class CheckFunctionManager: guards = output_graph.guards if output_graph else None self.valid = True self._weakrefs: Dict[int, ReferenceType[object]] = {} - self._extra_closure_vars: Dict[str, object] = {} self.output_graph = output_graph - self.extra_closure_vars_count = 0 # Note: right overrides left def combine_scopes(left, right): @@ -1012,7 +979,11 @@ class CheckFunctionManager: if ( not config.guard_nn_modules and guard.is_nn_module() - and not must_add_nn_module_guards(guard) + # Default func args must be guarded on. + # TODO: we could make use of 'DefaultsSource' and offer a .guard.is_defaults() API + and "__defaults__" not in guard.name + and "__kwdefaults__" not in guard.name + and (config.skip_nnmodule_hook_guards or "hooks" not in guard.name) ): continue @@ -1045,18 +1016,14 @@ class CheckFunctionManager: code_parts = ["___guarded_code.valid", "___check_global_state()"] def add_code_part(code, guard, log_only=False): - if guards_log.isEnabledFor(logging.DEBUG): - extra = "" - if guard is not None: - if guard.user_stack: - for fs in reversed(guard.user_stack): - if fs.filename not in uninteresting_files(): - break - else: - fs = guard.user_stack[-1] - extra = f" # {format_frame(fs, line=True)}" - elif guard.stack: - extra = f" # {format_frame(guard.stack.summary()[-1])}" + if guard.user_stack: + for fs in reversed(guard.user_stack): + if fs.filename not in uninteresting_files(): + break + else: + extra = f" # {format_frame(fs, line=True)}" + elif guard.stack: + extra = f" # {format_frame(guard.stack.summary()[-1])}" guards_log.debug("%s", f"{code:<60}{extra}") @@ -1208,7 +1175,6 @@ class CheckFunctionManager: + list(SYMPY_INTERP.items()) ) closure_vars.update(CLOSURE_VARS) - closure_vars.update(self._extra_closure_vars) unique_code_parts = list(unique(code_parts)) make_guard_fn_args = ", ".join(closure_vars.keys()) @@ -1264,12 +1230,6 @@ class CheckFunctionManager: return self._weakrefs[id(obj)] return None - def add_extra_closure_var(self, name_hint, obj): - name = f"{name_hint}_{self.extra_closure_vars_count}" - self.extra_closure_vars_count += 1 - self._extra_closure_vars[name] = obj - return name - def build_guard_function(code_parts, closure_args) -> Tuple[str, str]: from torch._inductor.utils import IndentedBuffer diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 2406869707b3..45b0e1f77f9d 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -586,180 +585,13 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) { Py_RETURN_TRUE; } -typedef struct { - /* Dict for an attr of the nn module */ - PyDictObject* dict; // borrowed reference - /* version tag of the attr dict to watch mutations */ - uint64_t dict_version_tag; -} AttrTag; - -static const char* module_guard_attrs[] = { - "_parameters", - "_buffers", - "_modules", - "_forward_hooks", - "_forward_pre_hooks", - "_backward_hooks", - "_backward_pre_hooks", -}; - -typedef struct { - PyObject_HEAD; - PyObject* mod; // borrowed reference - unsigned int version_tag; - uint64_t dict_version_tag; - AttrTag attr_tags[sizeof(module_guard_attrs) / sizeof(module_guard_attrs[0])]; -} NNModuleGuard; - -static void NNModuleGuard_dealloc(NNModuleGuard* self) { - self->mod = nullptr; - Py_TYPE(self)->tp_free((PyObject*)self); -} - -static PyTypeObject NNModuleGuardType = { - // NOLINTNEXTLINE - PyVarObject_HEAD_INIT(nullptr, 0)}; - -static PyObject* NNModuleGuard_call( - PyObject* callable, - PyObject* args, - PyObject* kwargs) { - NNModuleGuard* guard = (NNModuleGuard*)callable; - - if (PyTuple_GET_SIZE(args) != 1) { - PyErr_SetString( - PyExc_TypeError, "NNModuleGuardType: expected one argument"); - return nullptr; - } - - PyObject* mod = PyTuple_GET_ITEM(args, 0); - if (guard->mod != mod) { - Py_RETURN_FALSE; - } - - // TODO(sgross): temporarily disable tp_version_tag check due to - // torch.fx._symbolic_trace patching __getattr__ and __call__. Modifying - // those attributes on the class changes the tp_version_tag, invalidating - // the guard. - // if (Py_TYPE(mod)->tp_version_tag != guard->version_tag) { - // Py_RETURN_FALSE; - // } - - // NOTE: we must check the dict version tag before we check the attributes, - // because the attributes may be dead references if the dict has been updated. - PyObject* dict = PyObject_GenericGetDict(mod, nullptr); - if (((PyDictObject*)dict)->ma_version_tag != guard->dict_version_tag) { - Py_DECREF(dict); - Py_RETURN_FALSE; - } - Py_DECREF(dict); - - for (auto& attr_tag : guard->attr_tags) { - if (attr_tag.dict->ma_version_tag != attr_tag.dict_version_tag) { - Py_RETURN_FALSE; - } - } - Py_RETURN_TRUE; -} - -static PyObject* NNModuleGuard_repr(PyObject* self) { - // Prints versions of the module and the attributes. - NNModuleGuard* guard = (NNModuleGuard*)self; - std::ostringstream oss; - oss << "versions(mod=" << guard->dict_version_tag; - - for (size_t index = 0; - index < sizeof(module_guard_attrs) / sizeof(module_guard_attrs[0]); - index++) { - oss << ", " << module_guard_attrs[index] << "=" - << guard->attr_tags[index].dict_version_tag; - } - - oss << ")"; - return Py_BuildValue("s", oss.str().c_str()); -} - -static PyObject* nn_module_guard(PyObject* dummy, PyObject* obj) { - // Uses a private tags introduced in PEP 509 - ma_version_tag to check if - // there are any changes in the dict. - // TODO(jansel,janimesh) Note that this ma_version_tag be removed/repurposed - // in Python 3.12 under PEP 699. We can rely on newly introduced dict watchers - // in 3.12 - https://docs.python.org/3.12/c-api/dict.html#c.PyDict_Watch - - NNModuleGuard* guard = - (NNModuleGuard*)NNModuleGuardType.tp_alloc(&NNModuleGuardType, 0); - if (guard == nullptr) { - return nullptr; - } - - guard->mod = obj; - - PyObject* dict = PyObject_GenericGetDict(obj, nullptr); - if (dict == nullptr) { - Py_DECREF(guard); - return nullptr; - } - guard->dict_version_tag = ((PyDictObject*)dict)->ma_version_tag; - - Py_ssize_t idx = 0; - for (const char* name : module_guard_attrs) { - auto& tag = guard->attr_tags[idx]; - - PyObject* key = PyUnicode_FromString(name); - if (key == nullptr) { - Py_DECREF(dict); - Py_DECREF(guard); - return nullptr; - } - - PyObject* attr_obj = PyDict_GetItemWithError(dict, key); - if (attr_obj == nullptr) { - if (!PyErr_Occurred()) { - // this module doesn't have the specific attribute - PyErr_Format( - PyExc_AttributeError, - "'%s' object has no attribute '%s'", - Py_TYPE(obj)->tp_name, - name); - } - Py_DECREF(dict); - Py_DECREF(guard); - return nullptr; - } - - tag.dict = (PyDictObject*)attr_obj; - tag.dict_version_tag = tag.dict->ma_version_tag; - idx++; - } - Py_DECREF(dict); - - if (Py_TYPE(obj)->tp_version_tag == 0) { - // The tp_version_tag may be lazily set on attribute access. If we don't - // have a valid tag, perform a property lookup to force the tag to be set. - PyObject* tmp = PyObject_GetAttrString(obj, "__dict__"); - if (tmp == nullptr) { - Py_DECREF(guard); - return nullptr; - } - Py_DECREF(tmp); - } - - guard->version_tag = Py_TYPE(obj)->tp_version_tag; - if (guard->version_tag == 0) { - Py_DECREF(guard); - PyErr_SetString(PyExc_ValueError, "object has no version tag"); - return nullptr; - } - return (PyObject*)guard; -} - +// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) static PyMethodDef _methods[] = { - {"check_type_id", check_type_id, METH_VARARGS, NULL}, - {"check_obj_id", check_obj_id, METH_VARARGS, NULL}, - {"assert_size_stride", assert_size_stride, METH_VARARGS, NULL}, - {"nn_module_guard", nn_module_guard, METH_O, NULL}, + {"check_type_id", check_type_id, METH_VARARGS, nullptr}, + {"check_obj_id", check_obj_id, METH_VARARGS, nullptr}, + {"assert_size_stride", assert_size_stride, METH_VARARGS, nullptr}, {"dict_version", dict_version, METH_VARARGS, NULL}, - {NULL, NULL, 0, NULL}}; + {nullptr, nullptr, 0, nullptr}}; static struct PyModuleDef _module = { PyModuleDef_HEAD_INIT, @@ -771,10 +603,6 @@ static struct PyModuleDef _module = { } // namespace PyObject* torch_c_dynamo_guards_init() { - PyObject* m = PyModule_Create(&_module); - if (m == nullptr) - return nullptr; - // initialize TensorGuardsType TensorGuardsType.tp_name = "torch._C._dynamo.guards.TensorGuards"; TensorGuardsType.tp_basicsize = sizeof(TensorGuards); @@ -786,12 +614,8 @@ PyObject* torch_c_dynamo_guards_init() { TensorGuardsType.tp_init = (initproc)TensorGuards_init; TensorGuardsType.tp_new = TensorGuards_new; - NNModuleGuardType.tp_name = "torch._C._dynamo.guards.NNModuleGuard"; - NNModuleGuardType.tp_basicsize = sizeof(NNModuleGuard); - NNModuleGuardType.tp_call = NNModuleGuard_call; - NNModuleGuardType.tp_dealloc = (destructor)NNModuleGuard_dealloc; - NNModuleGuardType.tp_flags = Py_TPFLAGS_DEFAULT; - NNModuleGuardType.tp_repr = NNModuleGuard_repr; + if (PyType_Ready(&TensorGuardsType) < 0) + return nullptr; GlobalStateGuardType.tp_name = "torch._C._dynamo.guards.GlobalStateGuard"; GlobalStateGuardType.tp_basicsize = sizeof(GlobalStateGuard); @@ -802,13 +626,11 @@ PyObject* torch_c_dynamo_guards_init() { GlobalStateGuardType.tp_init = (initproc)GlobalStateGuard_init; GlobalStateGuardType.tp_new = PyType_GenericNew; - if (PyType_Ready(&TensorGuardsType) < 0) - return nullptr; - if (PyType_Ready(&GlobalStateGuardType) < 0) return nullptr; - if (PyType_Ready(&NNModuleGuardType) < 0) + auto m = PyModule_Create(&_module); + if (m == nullptr) return nullptr; Py_INCREF(&TensorGuardsType); @@ -826,12 +648,5 @@ PyObject* torch_c_dynamo_guards_init() { return nullptr; } - if (PyModule_AddObject( - m, "NNModuleGuardType", Py_NewRef(&NNModuleGuardType)) < 0) { - Py_DECREF(&NNModuleGuardType); - Py_DECREF(m); - return nullptr; - } - return m; }