mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "Reland 3rd try [finishing colesbury's PR 100642] Guard on nn.Module dicts and type (#109323)" + Forward fixes + test (#110964)
This reverts commit f786fbdebdd24d3a6807e3b9fbf055836db4ad60. Forward fixes Pull Request resolved: https://github.com/pytorch/pytorch/pull/110964 Approved by: https://github.com/ezyang, https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
e49ea87162
commit
1e7947b3e0
@ -8,7 +8,8 @@ basic_gnn_edgecnn,pass,0
|
||||
basic_gnn_gcn,pass,6
|
||||
basic_gnn_gin,pass,0
|
||||
basic_gnn_sage,pass,0
|
||||
cm3leon_generate,pass,7
|
||||
clip,pass_due_to_skip,0
|
||||
cm3leon_generate,pass,5
|
||||
dcgan,pass,0
|
||||
dlrm,pass,0
|
||||
doctr_det_predictor,pass,2
|
||||
|
|
@ -8,6 +8,7 @@ basic_gnn_edgecnn,pass,23
|
||||
basic_gnn_gcn,pass,14
|
||||
basic_gnn_gin,pass,8
|
||||
basic_gnn_sage,pass,8
|
||||
clip,pass_due_to_skip,0
|
||||
dcgan,pass,8
|
||||
dlrm,pass,8
|
||||
drq,pass,7
|
||||
@ -18,7 +19,7 @@ hf_Albert,pass,7
|
||||
hf_Bart,pass,7
|
||||
hf_DistilBert,pass,7
|
||||
hf_GPT2,pass,7
|
||||
hf_Reformer,pass,28
|
||||
hf_Reformer,pass,27
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
hf_Whisper,pass,7
|
||||
lennard_jones,pass,8
|
||||
|
|
@ -8,7 +8,8 @@ basic_gnn_edgecnn,pass,0
|
||||
basic_gnn_gcn,pass,6
|
||||
basic_gnn_gin,pass,0
|
||||
basic_gnn_sage,pass,0
|
||||
cm3leon_generate,pass,7
|
||||
clip,pass_due_to_skip,0
|
||||
cm3leon_generate,pass,5
|
||||
dcgan,pass,0
|
||||
dlrm,pass,0
|
||||
doctr_det_predictor,pass,2
|
||||
|
|
@ -8,6 +8,7 @@ basic_gnn_edgecnn,pass,23
|
||||
basic_gnn_gcn,pass,14
|
||||
basic_gnn_gin,pass,8
|
||||
basic_gnn_sage,pass,8
|
||||
clip,pass_due_to_skip,0
|
||||
dcgan,pass,8
|
||||
dlrm,pass,8
|
||||
drq,pass,7
|
||||
@ -18,7 +19,7 @@ hf_Albert,pass,7
|
||||
hf_Bart,pass,7
|
||||
hf_DistilBert,pass,7
|
||||
hf_GPT2,pass,7
|
||||
hf_Reformer,pass,28
|
||||
hf_Reformer,pass,27
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
hf_Whisper,pass,7
|
||||
lennard_jones,pass,8
|
||||
@ -51,5 +52,5 @@ timm_vision_transformer_large,pass_due_to_skip,0
|
||||
timm_vovnet,pass,8
|
||||
tts_angular,pass,10
|
||||
vgg16,pass,8
|
||||
vision_maskrcnn,fail_accuracy,39
|
||||
vision_maskrcnn,fail_accuracy,37
|
||||
yolov3,pass,10
|
||||
|
|
@ -2478,9 +2478,6 @@ class BenchmarkRunner:
|
||||
# Use distributed wrapping as necessary
|
||||
model = self.deepcopy_and_maybe_ddp(model)
|
||||
|
||||
if not hasattr(model, name):
|
||||
model.name = name
|
||||
|
||||
self.init_optimizer(name, current_device, model.parameters())
|
||||
with self.pick_grad(name, self.args.training):
|
||||
ok, total = Stats.reset_counters()
|
||||
@ -2540,6 +2537,8 @@ class BenchmarkRunner:
|
||||
f"{ok:3}/{total:3} +{frames_third_pass} frames {compilation_time:3.0f}s"
|
||||
)
|
||||
|
||||
if not hasattr(model, name):
|
||||
model.name = name
|
||||
results.append(experiment(model, example_inputs, **experiment_kwargs))
|
||||
return " ".join(map(str, results))
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
PyTorch 2.0 nn.Module Support
|
||||
=============================
|
||||
PyTorch 2.0 NNModule Support
|
||||
============================
|
||||
|
||||
**Author**: `Will Constable <https://github.com/wconstab>`_
|
||||
|
||||
@ -8,9 +8,12 @@ arbitrary python classes, with the intent of producing faster code by making ass
|
||||
|
||||
This doc describes some of the tradeoffs or edge cases that come up due to this specialization.
|
||||
|
||||
`nn.Module` Hooks Support
|
||||
-------------------------
|
||||
`torch.compile` now has partial support for forward and backward hooks on nn.Modules.
|
||||
NNModule Hooks Support
|
||||
----------------------
|
||||
Previously, `torch.compile` had no support for hooks on nn.Modules, and if hooks were registered
|
||||
they would simply be ignored in the compiled program. Indeed many users do not
|
||||
use nn.Module hooks at all, or only use them for debug workflows, but there are valid use cases
|
||||
for composing nn.Module hooks with `torch.compile`.
|
||||
|
||||
Hooks that are orchestrated via nn.Module.__call__ implementation include `_forward_pre_hooks`,
|
||||
`forward_hooks`, `_backward_pre_hooks`, and `_backward_hooks`, and will be referred to as 'call hooks'.
|
||||
@ -22,11 +25,11 @@ unsupported by `torch.compile`.
|
||||
`nn.Module.__call__` Hooks Usage and limitations
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
By default, `torch.compile` will trace the contents of `nn.Module.__call__` which means it will encounter
|
||||
and run forward/pre-forward hooks. `torch.compile` installs guards that detect added and removed hooks,
|
||||
and will trigger a recompilation if the forward/pre-forward hooks change.
|
||||
and run forward/pre-forward hooks. If you install hooks before calling `torch.compile` and then do not remove
|
||||
or alter the hooks later, your use case should be supported by default.
|
||||
|
||||
Backward/Pre-backward hooks are generally also supported, with similar caveats: currently graph-breaks in dynamo
|
||||
occur when accessing backward_hooks dicts, which is probably avoidable with some work. Graph-breaks also impact the
|
||||
occur when accessing backward_hooks dicts, which is probably avoiable with some work. Graph-breaks also impact the
|
||||
timing of firing backward hooks, since graph-segments are run as autograd-functions which produce all their grads at
|
||||
the same time. Assuming it were possible for dynamo to not graph-break on the presence of backward-hooks, we would
|
||||
still expect the backward hooks for a series of modules to all fire together after the whole compiled graph's backward
|
||||
@ -38,6 +41,17 @@ by allowing them to be called opaquely in the dynamo graph instead of traced int
|
||||
currently trigger a graph-break so that the affected modules run outside of dynamo. Depending on the model, this could
|
||||
introduce a significant performance regression, and additional work is required to improve this support.
|
||||
|
||||
**skip_nnmodule_hook_guards**
|
||||
By default, `torch._dynamo.config.skip_nnmodule_hook_guards` is set to True, meaning no guards will be installed
|
||||
on each nn.Module hook dictionary, improving runtime by reducing guard execution time, at the cost of not noticing
|
||||
if any hook dict is changed after compilation.
|
||||
|
||||
If you want to be able to remove or modify hooks after compilation and have `torch.compile` react appropriately
|
||||
(by recompiling), then you need to set `skip_nnmodule_hook_guards=False` and expect a runtime penalty for the added
|
||||
guards.
|
||||
|
||||
TODO: confirm if backward/pre_backward hooks are working or not and document accordingly
|
||||
|
||||
state_dict Hooks
|
||||
~~~~~~~~~~~~~~~~
|
||||
State dict hooks have not yet been supported in `torch.compile`.
|
||||
|
@ -1247,7 +1247,6 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
|
||||
module_dict = torch.nn.ModuleDict({"cat": torch.nn.Conv2d(1, 1, 1)})
|
||||
pre = m(data)
|
||||
cnt.clear()
|
||||
torch._dynamo.reset()
|
||||
|
||||
with torch._dynamo.optimize(cnt, nopython=False):
|
||||
opt_pre = m(data)
|
||||
@ -1256,8 +1255,8 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
|
||||
out1 = m(data)
|
||||
|
||||
out_post = m(data)
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
self.assertEqual(cnt.op_count, 2)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(cnt.op_count, 1)
|
||||
self.assertTrue(torch._dynamo.testing.same(pre, opt_pre))
|
||||
self.assertTrue(torch._dynamo.testing.same(out1, out_post))
|
||||
|
||||
@ -1731,112 +1730,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def test_hooks_recompile(self):
|
||||
# Modifying hooks should lead to a recompiation
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return 2 * x + 1
|
||||
|
||||
def compute_output_and_grad(m, x):
|
||||
output = m(x)
|
||||
output.sum().backward()
|
||||
return x.grad
|
||||
|
||||
def forward_pre_hook(module: torch.nn.Module, inputs: Tuple[torch.Tensor]):
|
||||
return (2 * inputs[0] + 1,)
|
||||
|
||||
def forward_hook(
|
||||
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
|
||||
):
|
||||
return 2 * output + 1
|
||||
|
||||
def backward_hook(module, grad_input, grad_output):
|
||||
if len(grad_input) == 1:
|
||||
return (grad_input[0] * 3,)
|
||||
else:
|
||||
return (grad_input[0] * 3, None)
|
||||
|
||||
def backward_pre_hook(module, grad_outputs):
|
||||
return (grad_outputs[0] * 5,)
|
||||
|
||||
def run_test_case(hook_type, hook_func, expected_grad):
|
||||
torch._dynamo.reset()
|
||||
m = TestModule()
|
||||
input = torch.ones(10, requires_grad=True)
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
opt = torch._dynamo.optimize(cnt)(compute_output_and_grad)
|
||||
|
||||
grad1 = opt(m, input)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(grad1, torch.full_like(grad1, 2))
|
||||
|
||||
input.grad = None
|
||||
handle = getattr(m, hook_type)(hook_func)
|
||||
grad2 = opt(m, input)
|
||||
frame_count2 = cnt.frame_count
|
||||
# Some backward hooks lead to graph breaks so frame_count may be 2 or 3
|
||||
self.assertGreaterEqual(frame_count2, 2)
|
||||
self.assertEqual(grad2, torch.full_like(grad2, expected_grad))
|
||||
|
||||
# Running again should not recompile
|
||||
opt(m, input)
|
||||
self.assertEqual(cnt.frame_count, frame_count2)
|
||||
|
||||
# Removing handle should lead to original result
|
||||
input.grad = None
|
||||
handle.remove()
|
||||
grad3 = opt(m, input)
|
||||
self.assertEqual(grad1, grad3)
|
||||
|
||||
run_test_case("register_forward_pre_hook", forward_pre_hook, 4)
|
||||
run_test_case("register_forward_hook", forward_hook, 4)
|
||||
run_test_case("register_backward_hook", backward_hook, 6)
|
||||
run_test_case("register_full_backward_hook", backward_hook, 6)
|
||||
run_test_case("register_full_backward_pre_hook", backward_pre_hook, 10)
|
||||
|
||||
def test_unspecialized_nn_module(self):
|
||||
# This test is little confusing because of combination of
|
||||
# nn_module_guard and unspecialized nn module variable.
|
||||
|
||||
# The graph break in forward causes two graphs
|
||||
# 1) The first graph has self.relu which introduces a nn_module_guard
|
||||
# 2) The second graph first assumes self to be NNModuleVariable, but the
|
||||
# restarts the analysis with self mapping to
|
||||
# UnSpecializedNNModuleVariable, on witnessing self.a += 1.
|
||||
|
||||
# Now, when we run the compiled mod the first time, it changes the value
|
||||
# of self.a. This is fine for the first run. But, when we run the
|
||||
# compiled module again, the first graph recompiles. This is because
|
||||
# self.a has changed, changing the ma_version_tag, causing
|
||||
# nn_module_guard to fail.
|
||||
|
||||
# At this point, we might feel that this is doomed as we will always
|
||||
# keep recompiling on the first graph. But, then Dynamo has already
|
||||
# marked the self to be UnspecializedNNModuleVariable (because of self.a
|
||||
# in the second graph), and therefore during the recompilation, we do
|
||||
# not introduce any nn_module_guard. So, in all we have just one
|
||||
# recompilation.
|
||||
class Mock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.a = 5
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
z = self.relu(x)
|
||||
torch._dynamo.graph_break()
|
||||
self.a += 1
|
||||
return z * self.a
|
||||
|
||||
mod = Mock()
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
opt = torch.compile(mod, backend=cnt)
|
||||
|
||||
for _ in range(5):
|
||||
opt(torch.randn(4))
|
||||
|
||||
self.assertEqual(cnt.frame_count, 4)
|
||||
|
||||
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
|
||||
def test_hooks_outer(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -1883,6 +1777,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||
the eval_frame entrypoint to Module.__call__?
|
||||
"""
|
||||
|
||||
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
|
||||
def test_hooks_inner(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -1927,7 +1822,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||
handle.remove()
|
||||
self.assertEqual(compiled_func(inp), outer_func(inp))
|
||||
self.assertEqual(compiled_func(inp).item(), 7)
|
||||
self.assertTrue("__nn_module_guard" in failure_reason)
|
||||
self.assertTrue("forward_hooks.keys" in failure_reason)
|
||||
self.assertEqual(cc.frame_count, 1 + 1)
|
||||
self.assertEqual(cc.op_count, 6 + 4)
|
||||
|
||||
@ -1947,7 +1842,55 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||
m._forward_hooks[handle.id] = new_forward_hook
|
||||
self.assertEqual(compiled_func(inp), outer_func(inp))
|
||||
self.assertEqual(compiled_func(inp).item(), 16)
|
||||
self.assertTrue("__nn_module_guard" in failure_reason)
|
||||
self.assertTrue("___check_obj_id(L['m']._forward_hooks" in failure_reason)
|
||||
|
||||
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", True)
|
||||
def test_hooks_skip_guards(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return 2 * x + 1
|
||||
|
||||
m = TestModule()
|
||||
|
||||
def forward_hook(
|
||||
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return 2 * output + 1
|
||||
|
||||
handle = m.register_forward_hook(forward_hook)
|
||||
|
||||
def outer_func(tensor):
|
||||
x = tensor * 2 + 1
|
||||
y = m(x)
|
||||
return y
|
||||
|
||||
inp = torch.tensor(1.0, requires_grad=True)
|
||||
|
||||
failure_reason = None
|
||||
|
||||
def guard_fail_fn(failure):
|
||||
nonlocal failure_reason
|
||||
failure_reason = failure[0]
|
||||
|
||||
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
compiled_func = torch._dynamo.optimize(
|
||||
guard_fail_fn=guard_fail_fn,
|
||||
backend=cc,
|
||||
)(outer_func)
|
||||
|
||||
m = TestModule()
|
||||
handle = m.register_forward_hook(forward_hook)
|
||||
failure_reason = None
|
||||
self.assertEqual(compiled_func(inp), outer_func(inp))
|
||||
self.assertEqual(compiled_func(inp).item(), 15)
|
||||
self.assertEqual(cc.frame_count, 1)
|
||||
self.assertEqual(cc.op_count, 6)
|
||||
|
||||
# if we remove the hook, dynamo shouldn't notice
|
||||
handle.remove()
|
||||
self.assertNotEqual(compiled_func(inp), outer_func(inp))
|
||||
self.assertEqual(compiled_func(inp).item(), 15)
|
||||
self.assertEqual(cc.frame_count, 1)
|
||||
|
||||
def _forward_hook_test_helper(self, model):
|
||||
forward_handles = {}
|
||||
|
@ -295,6 +295,27 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(dynamic_comp_dynamic_param.frame_count, 2)
|
||||
self.assertEqual(dynamic_comp_dynamic_param.op_count, 2)
|
||||
|
||||
def test_simple_module_recompile(self):
|
||||
class SimpleDropout(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dropout = torch.nn.Dropout(0.5)
|
||||
self.linear = torch.nn.Linear(10, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.dropout(self.linear(x))
|
||||
|
||||
model = SimpleDropout()
|
||||
x = torch.randn(10)
|
||||
counter = torch._dynamo.testing.CompileCounter()
|
||||
model = torch.compile(model, backend=counter, fullgraph=True)
|
||||
for _ in range(20):
|
||||
model.eval()
|
||||
model(x)
|
||||
model.train()
|
||||
model(x)
|
||||
self.assertEqual(counter.frame_count, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -233,8 +233,6 @@ skip_fsdp_guards = True
|
||||
# Make dynamo skip guarding on hooks on nn modules
|
||||
# Note: unsafe: if your model actually has hooks and you remove them, or doesn't and you add them,
|
||||
# dynamo will not notice and will execute whichever version you first compiled.
|
||||
# TODO(janimesh) - Remove as once internal is not reliant on this flag. This
|
||||
# flag has no meaning now after nn_module_guard is introduced.
|
||||
skip_nnmodule_hook_guards = True
|
||||
|
||||
# If True, raises exception if TorchDynamo is called with a context manager
|
||||
|
@ -20,7 +20,6 @@ from inspect import currentframe, getframeinfo
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
from weakref import ReferenceType
|
||||
|
||||
from .allowed_functions import is_allowed
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
@ -453,52 +452,22 @@ class GuardBuilder(GuardBuilderBase):
|
||||
self.EQUALS_MATCH(guard)
|
||||
|
||||
def NN_MODULE(self, guard: Guard):
|
||||
# TODO(janimesh) - This id_match can be removed because nn_module_guard
|
||||
# checks this in C. However, we need this redundant check to allow
|
||||
# flexible cache size policy when we guard on self of nn module
|
||||
# instances. See Note in torch/_dynamo/cache_size.py
|
||||
self.ID_MATCH(guard)
|
||||
ref = self.arg_ref(guard)
|
||||
val = self.get(guard.name)
|
||||
|
||||
# The module guard checks for modifications to the Module's type, __dict__,
|
||||
# and various nested OrderedDicts, such as _parameters, _buffers, and _modules.
|
||||
# This subsumes the check for Module.training.
|
||||
if not hasattr(val, "training"):
|
||||
unimplemented(f"Guard setup for uninitialized class {type(val)}")
|
||||
if config.allow_rnn and isinstance(
|
||||
val, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM)
|
||||
):
|
||||
# TorchDynamo graph breaks on LSTMs, but this is a way if user wants
|
||||
# to override it. LSTMs change the module state in every invocation,
|
||||
# leading to recompilations.
|
||||
log.warning("Skipping nn module guard on LSTMs")
|
||||
return
|
||||
|
||||
# Dynamo does not trace inside the inbuilt torch nn modules. Skip
|
||||
# guarding on those. More rationale at https://github.com/pytorch/pytorch/issues/110048
|
||||
if is_allowed(val.__class__):
|
||||
return
|
||||
try:
|
||||
g = torch._C._dynamo.guards.nn_module_guard(val)
|
||||
except AttributeError:
|
||||
# We get an attribute error if the module is partially initialized. For example,
|
||||
# we might be trying to install a guard before a super().__init__() call when
|
||||
# the module is missing _parameters, _modules, and other attributes.
|
||||
# For now, we skip installing the guard.
|
||||
log.warning(
|
||||
"Skipping nn module guard because the module could be partially initialized"
|
||||
def setup_guard():
|
||||
assert istype(val.training, bool)
|
||||
# TODO: Why doesn't this use produce_guard_code?
|
||||
self.code.append(
|
||||
GuardCodeList([f"{ref}.training == {val.training}"], guard)
|
||||
)
|
||||
return
|
||||
name = self.check_fn_manager.add_extra_closure_var("__nn_module_guard", g)
|
||||
if guards_log.isEnabledFor(logging.DEBUG):
|
||||
# Avoid debug_msg related python bytecode overhead in the runtime.
|
||||
|
||||
# debug_msg is only for debugging help and goes to kwargs of guard call,
|
||||
# which is ignored.
|
||||
self._produce_guard_code(guard, [f'{name}({ref}, debug_msg="{g}")'])
|
||||
if hasattr(val, "training"):
|
||||
# There are cases where a monkeypatched object has a guard made between __new__ and __init__
|
||||
setup_guard()
|
||||
else:
|
||||
self._produce_guard_code(guard, [f"{name}({ref})"])
|
||||
unimplemented(f"Guard setup for uninitialized class {type(val)}")
|
||||
|
||||
def FUNCTION_MATCH(self, guard: Guard):
|
||||
"""things like torch.add and user defined functions"""
|
||||
@ -960,9 +929,7 @@ class CheckFunctionManager:
|
||||
guards = output_graph.guards if output_graph else None
|
||||
self.valid = True
|
||||
self._weakrefs: Dict[int, ReferenceType[object]] = {}
|
||||
self._extra_closure_vars: Dict[str, object] = {}
|
||||
self.output_graph = output_graph
|
||||
self.extra_closure_vars_count = 0
|
||||
|
||||
# Note: right overrides left
|
||||
def combine_scopes(left, right):
|
||||
@ -1012,7 +979,11 @@ class CheckFunctionManager:
|
||||
if (
|
||||
not config.guard_nn_modules
|
||||
and guard.is_nn_module()
|
||||
and not must_add_nn_module_guards(guard)
|
||||
# Default func args must be guarded on.
|
||||
# TODO: we could make use of 'DefaultsSource' and offer a .guard.is_defaults() API
|
||||
and "__defaults__" not in guard.name
|
||||
and "__kwdefaults__" not in guard.name
|
||||
and (config.skip_nnmodule_hook_guards or "hooks" not in guard.name)
|
||||
):
|
||||
continue
|
||||
|
||||
@ -1045,18 +1016,14 @@ class CheckFunctionManager:
|
||||
code_parts = ["___guarded_code.valid", "___check_global_state()"]
|
||||
|
||||
def add_code_part(code, guard, log_only=False):
|
||||
if guards_log.isEnabledFor(logging.DEBUG):
|
||||
extra = ""
|
||||
if guard is not None:
|
||||
if guard.user_stack:
|
||||
for fs in reversed(guard.user_stack):
|
||||
if fs.filename not in uninteresting_files():
|
||||
break
|
||||
else:
|
||||
fs = guard.user_stack[-1]
|
||||
extra = f" # {format_frame(fs, line=True)}"
|
||||
elif guard.stack:
|
||||
extra = f" # {format_frame(guard.stack.summary()[-1])}"
|
||||
if guard.user_stack:
|
||||
for fs in reversed(guard.user_stack):
|
||||
if fs.filename not in uninteresting_files():
|
||||
break
|
||||
else:
|
||||
extra = f" # {format_frame(fs, line=True)}"
|
||||
elif guard.stack:
|
||||
extra = f" # {format_frame(guard.stack.summary()[-1])}"
|
||||
|
||||
guards_log.debug("%s", f"{code:<60}{extra}")
|
||||
|
||||
@ -1208,7 +1175,6 @@ class CheckFunctionManager:
|
||||
+ list(SYMPY_INTERP.items())
|
||||
)
|
||||
closure_vars.update(CLOSURE_VARS)
|
||||
closure_vars.update(self._extra_closure_vars)
|
||||
|
||||
unique_code_parts = list(unique(code_parts))
|
||||
make_guard_fn_args = ", ".join(closure_vars.keys())
|
||||
@ -1264,12 +1230,6 @@ class CheckFunctionManager:
|
||||
return self._weakrefs[id(obj)]
|
||||
return None
|
||||
|
||||
def add_extra_closure_var(self, name_hint, obj):
|
||||
name = f"{name_hint}_{self.extra_closure_vars_count}"
|
||||
self.extra_closure_vars_count += 1
|
||||
self._extra_closure_vars[name] = obj
|
||||
return name
|
||||
|
||||
|
||||
def build_guard_function(code_parts, closure_args) -> Tuple[str, str]:
|
||||
from torch._inductor.utils import IndentedBuffer
|
||||
|
@ -3,7 +3,6 @@
|
||||
#include <torch/csrc/autograd/grad_mode.h>
|
||||
#include <torch/csrc/dynamo/guards.h>
|
||||
#include <torch/csrc/utils/disable_torch_function.h>
|
||||
#include <torch/csrc/utils/python_compat.h>
|
||||
#include <torch/csrc/utils/python_numbers.h>
|
||||
#include <torch/csrc/utils/python_symnode.h>
|
||||
#include <torch/extension.h>
|
||||
@ -586,180 +585,13 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
/* Dict for an attr of the nn module */
|
||||
PyDictObject* dict; // borrowed reference
|
||||
/* version tag of the attr dict to watch mutations */
|
||||
uint64_t dict_version_tag;
|
||||
} AttrTag;
|
||||
|
||||
static const char* module_guard_attrs[] = {
|
||||
"_parameters",
|
||||
"_buffers",
|
||||
"_modules",
|
||||
"_forward_hooks",
|
||||
"_forward_pre_hooks",
|
||||
"_backward_hooks",
|
||||
"_backward_pre_hooks",
|
||||
};
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD;
|
||||
PyObject* mod; // borrowed reference
|
||||
unsigned int version_tag;
|
||||
uint64_t dict_version_tag;
|
||||
AttrTag attr_tags[sizeof(module_guard_attrs) / sizeof(module_guard_attrs[0])];
|
||||
} NNModuleGuard;
|
||||
|
||||
static void NNModuleGuard_dealloc(NNModuleGuard* self) {
|
||||
self->mod = nullptr;
|
||||
Py_TYPE(self)->tp_free((PyObject*)self);
|
||||
}
|
||||
|
||||
static PyTypeObject NNModuleGuardType = {
|
||||
// NOLINTNEXTLINE
|
||||
PyVarObject_HEAD_INIT(nullptr, 0)};
|
||||
|
||||
static PyObject* NNModuleGuard_call(
|
||||
PyObject* callable,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
NNModuleGuard* guard = (NNModuleGuard*)callable;
|
||||
|
||||
if (PyTuple_GET_SIZE(args) != 1) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError, "NNModuleGuardType: expected one argument");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PyObject* mod = PyTuple_GET_ITEM(args, 0);
|
||||
if (guard->mod != mod) {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
// TODO(sgross): temporarily disable tp_version_tag check due to
|
||||
// torch.fx._symbolic_trace patching __getattr__ and __call__. Modifying
|
||||
// those attributes on the class changes the tp_version_tag, invalidating
|
||||
// the guard.
|
||||
// if (Py_TYPE(mod)->tp_version_tag != guard->version_tag) {
|
||||
// Py_RETURN_FALSE;
|
||||
// }
|
||||
|
||||
// NOTE: we must check the dict version tag before we check the attributes,
|
||||
// because the attributes may be dead references if the dict has been updated.
|
||||
PyObject* dict = PyObject_GenericGetDict(mod, nullptr);
|
||||
if (((PyDictObject*)dict)->ma_version_tag != guard->dict_version_tag) {
|
||||
Py_DECREF(dict);
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
Py_DECREF(dict);
|
||||
|
||||
for (auto& attr_tag : guard->attr_tags) {
|
||||
if (attr_tag.dict->ma_version_tag != attr_tag.dict_version_tag) {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
}
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
|
||||
static PyObject* NNModuleGuard_repr(PyObject* self) {
|
||||
// Prints versions of the module and the attributes.
|
||||
NNModuleGuard* guard = (NNModuleGuard*)self;
|
||||
std::ostringstream oss;
|
||||
oss << "versions(mod=" << guard->dict_version_tag;
|
||||
|
||||
for (size_t index = 0;
|
||||
index < sizeof(module_guard_attrs) / sizeof(module_guard_attrs[0]);
|
||||
index++) {
|
||||
oss << ", " << module_guard_attrs[index] << "="
|
||||
<< guard->attr_tags[index].dict_version_tag;
|
||||
}
|
||||
|
||||
oss << ")";
|
||||
return Py_BuildValue("s", oss.str().c_str());
|
||||
}
|
||||
|
||||
static PyObject* nn_module_guard(PyObject* dummy, PyObject* obj) {
|
||||
// Uses a private tags introduced in PEP 509 - ma_version_tag to check if
|
||||
// there are any changes in the dict.
|
||||
// TODO(jansel,janimesh) Note that this ma_version_tag be removed/repurposed
|
||||
// in Python 3.12 under PEP 699. We can rely on newly introduced dict watchers
|
||||
// in 3.12 - https://docs.python.org/3.12/c-api/dict.html#c.PyDict_Watch
|
||||
|
||||
NNModuleGuard* guard =
|
||||
(NNModuleGuard*)NNModuleGuardType.tp_alloc(&NNModuleGuardType, 0);
|
||||
if (guard == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
guard->mod = obj;
|
||||
|
||||
PyObject* dict = PyObject_GenericGetDict(obj, nullptr);
|
||||
if (dict == nullptr) {
|
||||
Py_DECREF(guard);
|
||||
return nullptr;
|
||||
}
|
||||
guard->dict_version_tag = ((PyDictObject*)dict)->ma_version_tag;
|
||||
|
||||
Py_ssize_t idx = 0;
|
||||
for (const char* name : module_guard_attrs) {
|
||||
auto& tag = guard->attr_tags[idx];
|
||||
|
||||
PyObject* key = PyUnicode_FromString(name);
|
||||
if (key == nullptr) {
|
||||
Py_DECREF(dict);
|
||||
Py_DECREF(guard);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PyObject* attr_obj = PyDict_GetItemWithError(dict, key);
|
||||
if (attr_obj == nullptr) {
|
||||
if (!PyErr_Occurred()) {
|
||||
// this module doesn't have the specific attribute
|
||||
PyErr_Format(
|
||||
PyExc_AttributeError,
|
||||
"'%s' object has no attribute '%s'",
|
||||
Py_TYPE(obj)->tp_name,
|
||||
name);
|
||||
}
|
||||
Py_DECREF(dict);
|
||||
Py_DECREF(guard);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tag.dict = (PyDictObject*)attr_obj;
|
||||
tag.dict_version_tag = tag.dict->ma_version_tag;
|
||||
idx++;
|
||||
}
|
||||
Py_DECREF(dict);
|
||||
|
||||
if (Py_TYPE(obj)->tp_version_tag == 0) {
|
||||
// The tp_version_tag may be lazily set on attribute access. If we don't
|
||||
// have a valid tag, perform a property lookup to force the tag to be set.
|
||||
PyObject* tmp = PyObject_GetAttrString(obj, "__dict__");
|
||||
if (tmp == nullptr) {
|
||||
Py_DECREF(guard);
|
||||
return nullptr;
|
||||
}
|
||||
Py_DECREF(tmp);
|
||||
}
|
||||
|
||||
guard->version_tag = Py_TYPE(obj)->tp_version_tag;
|
||||
if (guard->version_tag == 0) {
|
||||
Py_DECREF(guard);
|
||||
PyErr_SetString(PyExc_ValueError, "object has no version tag");
|
||||
return nullptr;
|
||||
}
|
||||
return (PyObject*)guard;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
static PyMethodDef _methods[] = {
|
||||
{"check_type_id", check_type_id, METH_VARARGS, NULL},
|
||||
{"check_obj_id", check_obj_id, METH_VARARGS, NULL},
|
||||
{"assert_size_stride", assert_size_stride, METH_VARARGS, NULL},
|
||||
{"nn_module_guard", nn_module_guard, METH_O, NULL},
|
||||
{"check_type_id", check_type_id, METH_VARARGS, nullptr},
|
||||
{"check_obj_id", check_obj_id, METH_VARARGS, nullptr},
|
||||
{"assert_size_stride", assert_size_stride, METH_VARARGS, nullptr},
|
||||
{"dict_version", dict_version, METH_VARARGS, NULL},
|
||||
{NULL, NULL, 0, NULL}};
|
||||
{nullptr, nullptr, 0, nullptr}};
|
||||
|
||||
static struct PyModuleDef _module = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
@ -771,10 +603,6 @@ static struct PyModuleDef _module = {
|
||||
} // namespace
|
||||
|
||||
PyObject* torch_c_dynamo_guards_init() {
|
||||
PyObject* m = PyModule_Create(&_module);
|
||||
if (m == nullptr)
|
||||
return nullptr;
|
||||
|
||||
// initialize TensorGuardsType
|
||||
TensorGuardsType.tp_name = "torch._C._dynamo.guards.TensorGuards";
|
||||
TensorGuardsType.tp_basicsize = sizeof(TensorGuards);
|
||||
@ -786,12 +614,8 @@ PyObject* torch_c_dynamo_guards_init() {
|
||||
TensorGuardsType.tp_init = (initproc)TensorGuards_init;
|
||||
TensorGuardsType.tp_new = TensorGuards_new;
|
||||
|
||||
NNModuleGuardType.tp_name = "torch._C._dynamo.guards.NNModuleGuard";
|
||||
NNModuleGuardType.tp_basicsize = sizeof(NNModuleGuard);
|
||||
NNModuleGuardType.tp_call = NNModuleGuard_call;
|
||||
NNModuleGuardType.tp_dealloc = (destructor)NNModuleGuard_dealloc;
|
||||
NNModuleGuardType.tp_flags = Py_TPFLAGS_DEFAULT;
|
||||
NNModuleGuardType.tp_repr = NNModuleGuard_repr;
|
||||
if (PyType_Ready(&TensorGuardsType) < 0)
|
||||
return nullptr;
|
||||
|
||||
GlobalStateGuardType.tp_name = "torch._C._dynamo.guards.GlobalStateGuard";
|
||||
GlobalStateGuardType.tp_basicsize = sizeof(GlobalStateGuard);
|
||||
@ -802,13 +626,11 @@ PyObject* torch_c_dynamo_guards_init() {
|
||||
GlobalStateGuardType.tp_init = (initproc)GlobalStateGuard_init;
|
||||
GlobalStateGuardType.tp_new = PyType_GenericNew;
|
||||
|
||||
if (PyType_Ready(&TensorGuardsType) < 0)
|
||||
return nullptr;
|
||||
|
||||
if (PyType_Ready(&GlobalStateGuardType) < 0)
|
||||
return nullptr;
|
||||
|
||||
if (PyType_Ready(&NNModuleGuardType) < 0)
|
||||
auto m = PyModule_Create(&_module);
|
||||
if (m == nullptr)
|
||||
return nullptr;
|
||||
|
||||
Py_INCREF(&TensorGuardsType);
|
||||
@ -826,12 +648,5 @@ PyObject* torch_c_dynamo_guards_init() {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (PyModule_AddObject(
|
||||
m, "NNModuleGuardType", Py_NewRef(&NNModuleGuardType)) < 0) {
|
||||
Py_DECREF(&NNModuleGuardType);
|
||||
Py_DECREF(m);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return m;
|
||||
}
|
||||
|
Reference in New Issue
Block a user