Revert "Reland 3rd try [finishing colesbury's PR 100642] Guard on nn.Module dicts and type (#109323)" + Forward fixes + test (#110964)

This reverts commit f786fbdebdd24d3a6807e3b9fbf055836db4ad60.

Forward fixes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110964
Approved by: https://github.com/ezyang, https://github.com/anijain2305
This commit is contained in:
Michael Voznesensky
2023-10-10 18:01:02 -07:00
committed by PyTorch MergeBot
parent e49ea87162
commit 1e7947b3e0
11 changed files with 139 additions and 385 deletions

View File

@ -8,7 +8,8 @@ basic_gnn_edgecnn,pass,0
basic_gnn_gcn,pass,6
basic_gnn_gin,pass,0
basic_gnn_sage,pass,0
cm3leon_generate,pass,7
clip,pass_due_to_skip,0
cm3leon_generate,pass,5
dcgan,pass,0
dlrm,pass,0
doctr_det_predictor,pass,2

1 name accuracy graph_breaks
8 basic_gnn_gcn pass 6
9 basic_gnn_gin pass 0
10 basic_gnn_sage pass 0
11 cm3leon_generate clip pass pass_due_to_skip 7 0
12 cm3leon_generate pass 5
13 dcgan pass 0
14 dlrm pass 0
15 doctr_det_predictor pass 2

View File

@ -8,6 +8,7 @@ basic_gnn_edgecnn,pass,23
basic_gnn_gcn,pass,14
basic_gnn_gin,pass,8
basic_gnn_sage,pass,8
clip,pass_due_to_skip,0
dcgan,pass,8
dlrm,pass,8
drq,pass,7
@ -18,7 +19,7 @@ hf_Albert,pass,7
hf_Bart,pass,7
hf_DistilBert,pass,7
hf_GPT2,pass,7
hf_Reformer,pass,28
hf_Reformer,pass,27
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,7
lennard_jones,pass,8

1 name accuracy graph_breaks
8 basic_gnn_gcn pass 14
9 basic_gnn_gin pass 8
10 basic_gnn_sage pass 8
11 clip pass_due_to_skip 0
12 dcgan pass 8
13 dlrm pass 8
14 drq pass 7
19 hf_Bart pass 7
20 hf_DistilBert pass 7
21 hf_GPT2 pass 7
22 hf_Reformer pass 28 27
23 hf_T5_large pass_due_to_skip 0
24 hf_Whisper pass 7
25 lennard_jones pass 8

View File

@ -8,7 +8,8 @@ basic_gnn_edgecnn,pass,0
basic_gnn_gcn,pass,6
basic_gnn_gin,pass,0
basic_gnn_sage,pass,0
cm3leon_generate,pass,7
clip,pass_due_to_skip,0
cm3leon_generate,pass,5
dcgan,pass,0
dlrm,pass,0
doctr_det_predictor,pass,2

1 name accuracy graph_breaks
8 basic_gnn_gcn pass 6
9 basic_gnn_gin pass 0
10 basic_gnn_sage pass 0
11 cm3leon_generate clip pass pass_due_to_skip 7 0
12 cm3leon_generate pass 5
13 dcgan pass 0
14 dlrm pass 0
15 doctr_det_predictor pass 2

View File

@ -8,6 +8,7 @@ basic_gnn_edgecnn,pass,23
basic_gnn_gcn,pass,14
basic_gnn_gin,pass,8
basic_gnn_sage,pass,8
clip,pass_due_to_skip,0
dcgan,pass,8
dlrm,pass,8
drq,pass,7
@ -18,7 +19,7 @@ hf_Albert,pass,7
hf_Bart,pass,7
hf_DistilBert,pass,7
hf_GPT2,pass,7
hf_Reformer,pass,28
hf_Reformer,pass,27
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,7
lennard_jones,pass,8
@ -51,5 +52,5 @@ timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,8
tts_angular,pass,10
vgg16,pass,8
vision_maskrcnn,fail_accuracy,39
vision_maskrcnn,fail_accuracy,37
yolov3,pass,10

1 name accuracy graph_breaks
8 basic_gnn_gcn pass 14
9 basic_gnn_gin pass 8
10 basic_gnn_sage pass 8
11 clip pass_due_to_skip 0
12 dcgan pass 8
13 dlrm pass 8
14 drq pass 7
19 hf_Bart pass 7
20 hf_DistilBert pass 7
21 hf_GPT2 pass 7
22 hf_Reformer pass 28 27
23 hf_T5_large pass_due_to_skip 0
24 hf_Whisper pass 7
25 lennard_jones pass 8
52 timm_vovnet pass 8
53 tts_angular pass 10
54 vgg16 pass 8
55 vision_maskrcnn fail_accuracy 39 37
56 yolov3 pass 10

View File

@ -2478,9 +2478,6 @@ class BenchmarkRunner:
# Use distributed wrapping as necessary
model = self.deepcopy_and_maybe_ddp(model)
if not hasattr(model, name):
model.name = name
self.init_optimizer(name, current_device, model.parameters())
with self.pick_grad(name, self.args.training):
ok, total = Stats.reset_counters()
@ -2540,6 +2537,8 @@ class BenchmarkRunner:
f"{ok:3}/{total:3} +{frames_third_pass} frames {compilation_time:3.0f}s"
)
if not hasattr(model, name):
model.name = name
results.append(experiment(model, example_inputs, **experiment_kwargs))
return " ".join(map(str, results))

View File

@ -1,5 +1,5 @@
PyTorch 2.0 nn.Module Support
=============================
PyTorch 2.0 NNModule Support
============================
**Author**: `Will Constable <https://github.com/wconstab>`_
@ -8,9 +8,12 @@ arbitrary python classes, with the intent of producing faster code by making ass
This doc describes some of the tradeoffs or edge cases that come up due to this specialization.
`nn.Module` Hooks Support
-------------------------
`torch.compile` now has partial support for forward and backward hooks on nn.Modules.
NNModule Hooks Support
----------------------
Previously, `torch.compile` had no support for hooks on nn.Modules, and if hooks were registered
they would simply be ignored in the compiled program. Indeed many users do not
use nn.Module hooks at all, or only use them for debug workflows, but there are valid use cases
for composing nn.Module hooks with `torch.compile`.
Hooks that are orchestrated via nn.Module.__call__ implementation include `_forward_pre_hooks`,
`forward_hooks`, `_backward_pre_hooks`, and `_backward_hooks`, and will be referred to as 'call hooks'.
@ -22,11 +25,11 @@ unsupported by `torch.compile`.
`nn.Module.__call__` Hooks Usage and limitations
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
By default, `torch.compile` will trace the contents of `nn.Module.__call__` which means it will encounter
and run forward/pre-forward hooks. `torch.compile` installs guards that detect added and removed hooks,
and will trigger a recompilation if the forward/pre-forward hooks change.
and run forward/pre-forward hooks. If you install hooks before calling `torch.compile` and then do not remove
or alter the hooks later, your use case should be supported by default.
Backward/Pre-backward hooks are generally also supported, with similar caveats: currently graph-breaks in dynamo
occur when accessing backward_hooks dicts, which is probably avoidable with some work. Graph-breaks also impact the
occur when accessing backward_hooks dicts, which is probably avoiable with some work. Graph-breaks also impact the
timing of firing backward hooks, since graph-segments are run as autograd-functions which produce all their grads at
the same time. Assuming it were possible for dynamo to not graph-break on the presence of backward-hooks, we would
still expect the backward hooks for a series of modules to all fire together after the whole compiled graph's backward
@ -38,6 +41,17 @@ by allowing them to be called opaquely in the dynamo graph instead of traced int
currently trigger a graph-break so that the affected modules run outside of dynamo. Depending on the model, this could
introduce a significant performance regression, and additional work is required to improve this support.
**skip_nnmodule_hook_guards**
By default, `torch._dynamo.config.skip_nnmodule_hook_guards` is set to True, meaning no guards will be installed
on each nn.Module hook dictionary, improving runtime by reducing guard execution time, at the cost of not noticing
if any hook dict is changed after compilation.
If you want to be able to remove or modify hooks after compilation and have `torch.compile` react appropriately
(by recompiling), then you need to set `skip_nnmodule_hook_guards=False` and expect a runtime penalty for the added
guards.
TODO: confirm if backward/pre_backward hooks are working or not and document accordingly
state_dict Hooks
~~~~~~~~~~~~~~~~
State dict hooks have not yet been supported in `torch.compile`.

View File

@ -1247,7 +1247,6 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
module_dict = torch.nn.ModuleDict({"cat": torch.nn.Conv2d(1, 1, 1)})
pre = m(data)
cnt.clear()
torch._dynamo.reset()
with torch._dynamo.optimize(cnt, nopython=False):
opt_pre = m(data)
@ -1256,8 +1255,8 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
out1 = m(data)
out_post = m(data)
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 2)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 1)
self.assertTrue(torch._dynamo.testing.same(pre, opt_pre))
self.assertTrue(torch._dynamo.testing.same(out1, out_post))
@ -1731,112 +1730,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
)
)
def test_hooks_recompile(self):
# Modifying hooks should lead to a recompiation
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return 2 * x + 1
def compute_output_and_grad(m, x):
output = m(x)
output.sum().backward()
return x.grad
def forward_pre_hook(module: torch.nn.Module, inputs: Tuple[torch.Tensor]):
return (2 * inputs[0] + 1,)
def forward_hook(
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
):
return 2 * output + 1
def backward_hook(module, grad_input, grad_output):
if len(grad_input) == 1:
return (grad_input[0] * 3,)
else:
return (grad_input[0] * 3, None)
def backward_pre_hook(module, grad_outputs):
return (grad_outputs[0] * 5,)
def run_test_case(hook_type, hook_func, expected_grad):
torch._dynamo.reset()
m = TestModule()
input = torch.ones(10, requires_grad=True)
cnt = torch._dynamo.testing.CompileCounter()
opt = torch._dynamo.optimize(cnt)(compute_output_and_grad)
grad1 = opt(m, input)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(grad1, torch.full_like(grad1, 2))
input.grad = None
handle = getattr(m, hook_type)(hook_func)
grad2 = opt(m, input)
frame_count2 = cnt.frame_count
# Some backward hooks lead to graph breaks so frame_count may be 2 or 3
self.assertGreaterEqual(frame_count2, 2)
self.assertEqual(grad2, torch.full_like(grad2, expected_grad))
# Running again should not recompile
opt(m, input)
self.assertEqual(cnt.frame_count, frame_count2)
# Removing handle should lead to original result
input.grad = None
handle.remove()
grad3 = opt(m, input)
self.assertEqual(grad1, grad3)
run_test_case("register_forward_pre_hook", forward_pre_hook, 4)
run_test_case("register_forward_hook", forward_hook, 4)
run_test_case("register_backward_hook", backward_hook, 6)
run_test_case("register_full_backward_hook", backward_hook, 6)
run_test_case("register_full_backward_pre_hook", backward_pre_hook, 10)
def test_unspecialized_nn_module(self):
# This test is little confusing because of combination of
# nn_module_guard and unspecialized nn module variable.
# The graph break in forward causes two graphs
# 1) The first graph has self.relu which introduces a nn_module_guard
# 2) The second graph first assumes self to be NNModuleVariable, but the
# restarts the analysis with self mapping to
# UnSpecializedNNModuleVariable, on witnessing self.a += 1.
# Now, when we run the compiled mod the first time, it changes the value
# of self.a. This is fine for the first run. But, when we run the
# compiled module again, the first graph recompiles. This is because
# self.a has changed, changing the ma_version_tag, causing
# nn_module_guard to fail.
# At this point, we might feel that this is doomed as we will always
# keep recompiling on the first graph. But, then Dynamo has already
# marked the self to be UnspecializedNNModuleVariable (because of self.a
# in the second graph), and therefore during the recompilation, we do
# not introduce any nn_module_guard. So, in all we have just one
# recompilation.
class Mock(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = 5
self.relu = torch.nn.ReLU()
def forward(self, x):
z = self.relu(x)
torch._dynamo.graph_break()
self.a += 1
return z * self.a
mod = Mock()
cnt = torch._dynamo.testing.CompileCounter()
opt = torch.compile(mod, backend=cnt)
for _ in range(5):
opt(torch.randn(4))
self.assertEqual(cnt.frame_count, 4)
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
def test_hooks_outer(self):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -1883,6 +1777,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
the eval_frame entrypoint to Module.__call__?
"""
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", False)
def test_hooks_inner(self):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -1927,7 +1822,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
handle.remove()
self.assertEqual(compiled_func(inp), outer_func(inp))
self.assertEqual(compiled_func(inp).item(), 7)
self.assertTrue("__nn_module_guard" in failure_reason)
self.assertTrue("forward_hooks.keys" in failure_reason)
self.assertEqual(cc.frame_count, 1 + 1)
self.assertEqual(cc.op_count, 6 + 4)
@ -1947,7 +1842,55 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
m._forward_hooks[handle.id] = new_forward_hook
self.assertEqual(compiled_func(inp), outer_func(inp))
self.assertEqual(compiled_func(inp).item(), 16)
self.assertTrue("__nn_module_guard" in failure_reason)
self.assertTrue("___check_obj_id(L['m']._forward_hooks" in failure_reason)
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", True)
def test_hooks_skip_guards(self):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return 2 * x + 1
m = TestModule()
def forward_hook(
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
) -> torch.Tensor:
return 2 * output + 1
handle = m.register_forward_hook(forward_hook)
def outer_func(tensor):
x = tensor * 2 + 1
y = m(x)
return y
inp = torch.tensor(1.0, requires_grad=True)
failure_reason = None
def guard_fail_fn(failure):
nonlocal failure_reason
failure_reason = failure[0]
cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
compiled_func = torch._dynamo.optimize(
guard_fail_fn=guard_fail_fn,
backend=cc,
)(outer_func)
m = TestModule()
handle = m.register_forward_hook(forward_hook)
failure_reason = None
self.assertEqual(compiled_func(inp), outer_func(inp))
self.assertEqual(compiled_func(inp).item(), 15)
self.assertEqual(cc.frame_count, 1)
self.assertEqual(cc.op_count, 6)
# if we remove the hook, dynamo shouldn't notice
handle.remove()
self.assertNotEqual(compiled_func(inp), outer_func(inp))
self.assertEqual(compiled_func(inp).item(), 15)
self.assertEqual(cc.frame_count, 1)
def _forward_hook_test_helper(self, model):
forward_handles = {}

View File

@ -295,6 +295,27 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
self.assertEqual(dynamic_comp_dynamic_param.frame_count, 2)
self.assertEqual(dynamic_comp_dynamic_param.op_count, 2)
def test_simple_module_recompile(self):
class SimpleDropout(torch.nn.Module):
def __init__(self):
super().__init__()
self.dropout = torch.nn.Dropout(0.5)
self.linear = torch.nn.Linear(10, 1)
def forward(self, x):
return self.dropout(self.linear(x))
model = SimpleDropout()
x = torch.randn(10)
counter = torch._dynamo.testing.CompileCounter()
model = torch.compile(model, backend=counter, fullgraph=True)
for _ in range(20):
model.eval()
model(x)
model.train()
model(x)
self.assertEqual(counter.frame_count, 2)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -233,8 +233,6 @@ skip_fsdp_guards = True
# Make dynamo skip guarding on hooks on nn modules
# Note: unsafe: if your model actually has hooks and you remove them, or doesn't and you add them,
# dynamo will not notice and will execute whichever version you first compiled.
# TODO(janimesh) - Remove as once internal is not reliant on this flag. This
# flag has no meaning now after nn_module_guard is introduced.
skip_nnmodule_hook_guards = True
# If True, raises exception if TorchDynamo is called with a context manager

View File

@ -20,7 +20,6 @@ from inspect import currentframe, getframeinfo
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from weakref import ReferenceType
from .allowed_functions import is_allowed
try:
import numpy as np
@ -453,52 +452,22 @@ class GuardBuilder(GuardBuilderBase):
self.EQUALS_MATCH(guard)
def NN_MODULE(self, guard: Guard):
# TODO(janimesh) - This id_match can be removed because nn_module_guard
# checks this in C. However, we need this redundant check to allow
# flexible cache size policy when we guard on self of nn module
# instances. See Note in torch/_dynamo/cache_size.py
self.ID_MATCH(guard)
ref = self.arg_ref(guard)
val = self.get(guard.name)
# The module guard checks for modifications to the Module's type, __dict__,
# and various nested OrderedDicts, such as _parameters, _buffers, and _modules.
# This subsumes the check for Module.training.
if not hasattr(val, "training"):
unimplemented(f"Guard setup for uninitialized class {type(val)}")
if config.allow_rnn and isinstance(
val, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM)
):
# TorchDynamo graph breaks on LSTMs, but this is a way if user wants
# to override it. LSTMs change the module state in every invocation,
# leading to recompilations.
log.warning("Skipping nn module guard on LSTMs")
return
# Dynamo does not trace inside the inbuilt torch nn modules. Skip
# guarding on those. More rationale at https://github.com/pytorch/pytorch/issues/110048
if is_allowed(val.__class__):
return
try:
g = torch._C._dynamo.guards.nn_module_guard(val)
except AttributeError:
# We get an attribute error if the module is partially initialized. For example,
# we might be trying to install a guard before a super().__init__() call when
# the module is missing _parameters, _modules, and other attributes.
# For now, we skip installing the guard.
log.warning(
"Skipping nn module guard because the module could be partially initialized"
def setup_guard():
assert istype(val.training, bool)
# TODO: Why doesn't this use produce_guard_code?
self.code.append(
GuardCodeList([f"{ref}.training == {val.training}"], guard)
)
return
name = self.check_fn_manager.add_extra_closure_var("__nn_module_guard", g)
if guards_log.isEnabledFor(logging.DEBUG):
# Avoid debug_msg related python bytecode overhead in the runtime.
# debug_msg is only for debugging help and goes to kwargs of guard call,
# which is ignored.
self._produce_guard_code(guard, [f'{name}({ref}, debug_msg="{g}")'])
if hasattr(val, "training"):
# There are cases where a monkeypatched object has a guard made between __new__ and __init__
setup_guard()
else:
self._produce_guard_code(guard, [f"{name}({ref})"])
unimplemented(f"Guard setup for uninitialized class {type(val)}")
def FUNCTION_MATCH(self, guard: Guard):
"""things like torch.add and user defined functions"""
@ -960,9 +929,7 @@ class CheckFunctionManager:
guards = output_graph.guards if output_graph else None
self.valid = True
self._weakrefs: Dict[int, ReferenceType[object]] = {}
self._extra_closure_vars: Dict[str, object] = {}
self.output_graph = output_graph
self.extra_closure_vars_count = 0
# Note: right overrides left
def combine_scopes(left, right):
@ -1012,7 +979,11 @@ class CheckFunctionManager:
if (
not config.guard_nn_modules
and guard.is_nn_module()
and not must_add_nn_module_guards(guard)
# Default func args must be guarded on.
# TODO: we could make use of 'DefaultsSource' and offer a .guard.is_defaults() API
and "__defaults__" not in guard.name
and "__kwdefaults__" not in guard.name
and (config.skip_nnmodule_hook_guards or "hooks" not in guard.name)
):
continue
@ -1045,18 +1016,14 @@ class CheckFunctionManager:
code_parts = ["___guarded_code.valid", "___check_global_state()"]
def add_code_part(code, guard, log_only=False):
if guards_log.isEnabledFor(logging.DEBUG):
extra = ""
if guard is not None:
if guard.user_stack:
for fs in reversed(guard.user_stack):
if fs.filename not in uninteresting_files():
break
else:
fs = guard.user_stack[-1]
extra = f" # {format_frame(fs, line=True)}"
elif guard.stack:
extra = f" # {format_frame(guard.stack.summary()[-1])}"
if guard.user_stack:
for fs in reversed(guard.user_stack):
if fs.filename not in uninteresting_files():
break
else:
extra = f" # {format_frame(fs, line=True)}"
elif guard.stack:
extra = f" # {format_frame(guard.stack.summary()[-1])}"
guards_log.debug("%s", f"{code:<60}{extra}")
@ -1208,7 +1175,6 @@ class CheckFunctionManager:
+ list(SYMPY_INTERP.items())
)
closure_vars.update(CLOSURE_VARS)
closure_vars.update(self._extra_closure_vars)
unique_code_parts = list(unique(code_parts))
make_guard_fn_args = ", ".join(closure_vars.keys())
@ -1264,12 +1230,6 @@ class CheckFunctionManager:
return self._weakrefs[id(obj)]
return None
def add_extra_closure_var(self, name_hint, obj):
name = f"{name_hint}_{self.extra_closure_vars_count}"
self.extra_closure_vars_count += 1
self._extra_closure_vars[name] = obj
return name
def build_guard_function(code_parts, closure_args) -> Tuple[str, str]:
from torch._inductor.utils import IndentedBuffer

View File

@ -3,7 +3,6 @@
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/dynamo/guards.h>
#include <torch/csrc/utils/disable_torch_function.h>
#include <torch/csrc/utils/python_compat.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_symnode.h>
#include <torch/extension.h>
@ -586,180 +585,13 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
Py_RETURN_TRUE;
}
typedef struct {
/* Dict for an attr of the nn module */
PyDictObject* dict; // borrowed reference
/* version tag of the attr dict to watch mutations */
uint64_t dict_version_tag;
} AttrTag;
static const char* module_guard_attrs[] = {
"_parameters",
"_buffers",
"_modules",
"_forward_hooks",
"_forward_pre_hooks",
"_backward_hooks",
"_backward_pre_hooks",
};
typedef struct {
PyObject_HEAD;
PyObject* mod; // borrowed reference
unsigned int version_tag;
uint64_t dict_version_tag;
AttrTag attr_tags[sizeof(module_guard_attrs) / sizeof(module_guard_attrs[0])];
} NNModuleGuard;
static void NNModuleGuard_dealloc(NNModuleGuard* self) {
self->mod = nullptr;
Py_TYPE(self)->tp_free((PyObject*)self);
}
static PyTypeObject NNModuleGuardType = {
// NOLINTNEXTLINE
PyVarObject_HEAD_INIT(nullptr, 0)};
static PyObject* NNModuleGuard_call(
PyObject* callable,
PyObject* args,
PyObject* kwargs) {
NNModuleGuard* guard = (NNModuleGuard*)callable;
if (PyTuple_GET_SIZE(args) != 1) {
PyErr_SetString(
PyExc_TypeError, "NNModuleGuardType: expected one argument");
return nullptr;
}
PyObject* mod = PyTuple_GET_ITEM(args, 0);
if (guard->mod != mod) {
Py_RETURN_FALSE;
}
// TODO(sgross): temporarily disable tp_version_tag check due to
// torch.fx._symbolic_trace patching __getattr__ and __call__. Modifying
// those attributes on the class changes the tp_version_tag, invalidating
// the guard.
// if (Py_TYPE(mod)->tp_version_tag != guard->version_tag) {
// Py_RETURN_FALSE;
// }
// NOTE: we must check the dict version tag before we check the attributes,
// because the attributes may be dead references if the dict has been updated.
PyObject* dict = PyObject_GenericGetDict(mod, nullptr);
if (((PyDictObject*)dict)->ma_version_tag != guard->dict_version_tag) {
Py_DECREF(dict);
Py_RETURN_FALSE;
}
Py_DECREF(dict);
for (auto& attr_tag : guard->attr_tags) {
if (attr_tag.dict->ma_version_tag != attr_tag.dict_version_tag) {
Py_RETURN_FALSE;
}
}
Py_RETURN_TRUE;
}
static PyObject* NNModuleGuard_repr(PyObject* self) {
// Prints versions of the module and the attributes.
NNModuleGuard* guard = (NNModuleGuard*)self;
std::ostringstream oss;
oss << "versions(mod=" << guard->dict_version_tag;
for (size_t index = 0;
index < sizeof(module_guard_attrs) / sizeof(module_guard_attrs[0]);
index++) {
oss << ", " << module_guard_attrs[index] << "="
<< guard->attr_tags[index].dict_version_tag;
}
oss << ")";
return Py_BuildValue("s", oss.str().c_str());
}
static PyObject* nn_module_guard(PyObject* dummy, PyObject* obj) {
// Uses a private tags introduced in PEP 509 - ma_version_tag to check if
// there are any changes in the dict.
// TODO(jansel,janimesh) Note that this ma_version_tag be removed/repurposed
// in Python 3.12 under PEP 699. We can rely on newly introduced dict watchers
// in 3.12 - https://docs.python.org/3.12/c-api/dict.html#c.PyDict_Watch
NNModuleGuard* guard =
(NNModuleGuard*)NNModuleGuardType.tp_alloc(&NNModuleGuardType, 0);
if (guard == nullptr) {
return nullptr;
}
guard->mod = obj;
PyObject* dict = PyObject_GenericGetDict(obj, nullptr);
if (dict == nullptr) {
Py_DECREF(guard);
return nullptr;
}
guard->dict_version_tag = ((PyDictObject*)dict)->ma_version_tag;
Py_ssize_t idx = 0;
for (const char* name : module_guard_attrs) {
auto& tag = guard->attr_tags[idx];
PyObject* key = PyUnicode_FromString(name);
if (key == nullptr) {
Py_DECREF(dict);
Py_DECREF(guard);
return nullptr;
}
PyObject* attr_obj = PyDict_GetItemWithError(dict, key);
if (attr_obj == nullptr) {
if (!PyErr_Occurred()) {
// this module doesn't have the specific attribute
PyErr_Format(
PyExc_AttributeError,
"'%s' object has no attribute '%s'",
Py_TYPE(obj)->tp_name,
name);
}
Py_DECREF(dict);
Py_DECREF(guard);
return nullptr;
}
tag.dict = (PyDictObject*)attr_obj;
tag.dict_version_tag = tag.dict->ma_version_tag;
idx++;
}
Py_DECREF(dict);
if (Py_TYPE(obj)->tp_version_tag == 0) {
// The tp_version_tag may be lazily set on attribute access. If we don't
// have a valid tag, perform a property lookup to force the tag to be set.
PyObject* tmp = PyObject_GetAttrString(obj, "__dict__");
if (tmp == nullptr) {
Py_DECREF(guard);
return nullptr;
}
Py_DECREF(tmp);
}
guard->version_tag = Py_TYPE(obj)->tp_version_tag;
if (guard->version_tag == 0) {
Py_DECREF(guard);
PyErr_SetString(PyExc_ValueError, "object has no version tag");
return nullptr;
}
return (PyObject*)guard;
}
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
static PyMethodDef _methods[] = {
{"check_type_id", check_type_id, METH_VARARGS, NULL},
{"check_obj_id", check_obj_id, METH_VARARGS, NULL},
{"assert_size_stride", assert_size_stride, METH_VARARGS, NULL},
{"nn_module_guard", nn_module_guard, METH_O, NULL},
{"check_type_id", check_type_id, METH_VARARGS, nullptr},
{"check_obj_id", check_obj_id, METH_VARARGS, nullptr},
{"assert_size_stride", assert_size_stride, METH_VARARGS, nullptr},
{"dict_version", dict_version, METH_VARARGS, NULL},
{NULL, NULL, 0, NULL}};
{nullptr, nullptr, 0, nullptr}};
static struct PyModuleDef _module = {
PyModuleDef_HEAD_INIT,
@ -771,10 +603,6 @@ static struct PyModuleDef _module = {
} // namespace
PyObject* torch_c_dynamo_guards_init() {
PyObject* m = PyModule_Create(&_module);
if (m == nullptr)
return nullptr;
// initialize TensorGuardsType
TensorGuardsType.tp_name = "torch._C._dynamo.guards.TensorGuards";
TensorGuardsType.tp_basicsize = sizeof(TensorGuards);
@ -786,12 +614,8 @@ PyObject* torch_c_dynamo_guards_init() {
TensorGuardsType.tp_init = (initproc)TensorGuards_init;
TensorGuardsType.tp_new = TensorGuards_new;
NNModuleGuardType.tp_name = "torch._C._dynamo.guards.NNModuleGuard";
NNModuleGuardType.tp_basicsize = sizeof(NNModuleGuard);
NNModuleGuardType.tp_call = NNModuleGuard_call;
NNModuleGuardType.tp_dealloc = (destructor)NNModuleGuard_dealloc;
NNModuleGuardType.tp_flags = Py_TPFLAGS_DEFAULT;
NNModuleGuardType.tp_repr = NNModuleGuard_repr;
if (PyType_Ready(&TensorGuardsType) < 0)
return nullptr;
GlobalStateGuardType.tp_name = "torch._C._dynamo.guards.GlobalStateGuard";
GlobalStateGuardType.tp_basicsize = sizeof(GlobalStateGuard);
@ -802,13 +626,11 @@ PyObject* torch_c_dynamo_guards_init() {
GlobalStateGuardType.tp_init = (initproc)GlobalStateGuard_init;
GlobalStateGuardType.tp_new = PyType_GenericNew;
if (PyType_Ready(&TensorGuardsType) < 0)
return nullptr;
if (PyType_Ready(&GlobalStateGuardType) < 0)
return nullptr;
if (PyType_Ready(&NNModuleGuardType) < 0)
auto m = PyModule_Create(&_module);
if (m == nullptr)
return nullptr;
Py_INCREF(&TensorGuardsType);
@ -826,12 +648,5 @@ PyObject* torch_c_dynamo_guards_init() {
return nullptr;
}
if (PyModule_AddObject(
m, "NNModuleGuardType", Py_NewRef(&NNModuleGuardType)) < 0) {
Py_DECREF(&NNModuleGuardType);
Py_DECREF(m);
return nullptr;
}
return m;
}