Revert "[dynamo][guards] Do not construct entire framelocals dict for LAMBDA_GUARD (#162525)"

This reverts commit 5f630d28d7ff9fdd8bd6cdbe2438e5c821007845.

Reverted https://github.com/pytorch/pytorch/pull/162525 on behalf of https://github.com/anijain2305 due to internal tests fail ([comment](https://github.com/pytorch/pytorch/pull/162525#issuecomment-3310748980))
This commit is contained in:
PyTorch MergeBot
2025-09-19 06:15:28 +00:00
parent e0bcd58f57
commit 1302637a23
7 changed files with 19 additions and 128 deletions

View File

@ -116,8 +116,6 @@ num_guards_executed=0)
const_guard = guards.LAMBDA_GUARD(
root,
functools.partial(equals_match, expected=5),
{},
False,
equals_match_verbose_code_parts(5),
)
self.assertTrue(const_guard(5))
@ -407,14 +405,10 @@ num_guards_executed=0)
guard_manager.add_type_match_guard(id_type(5), ["type(x) == int"])
guard_manager.add_lambda_guard(
functools.partial(ge_match, expected=5),
{},
False,
ge_match_verbose_code_parts(expected=5),
)
guard_manager.add_lambda_guard(
functools.partial(less_match, expected=10),
{},
False,
less_match_verbose_code_parts(expected=10),
)
self.assertEqual(len(guard_manager.get_leaf_guards()), 3)
@ -434,14 +428,10 @@ num_guards_executed=0)
guard_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"])
guard_manager.getattr_manager("x", "x", 1, default_mgr_enum).add_lambda_guard(
functools.partial(equals_match, expected=foo.x),
{},
False,
equals_match_verbose_code_parts(foo.x),
)
guard_manager.getattr_manager("y", "y", 2, default_mgr_enum).add_lambda_guard(
functools.partial(equals_match, expected=foo.y),
{},
False,
equals_match_verbose_code_parts(foo.y),
)
self.assertEqual(len(guard_manager.get_leaf_guards()), 1)
@ -484,14 +474,10 @@ num_guards_executed=0)
guard_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"])
guard_manager.getitem_manager(0, "", 1, default_mgr_enum).add_lambda_guard(
functools.partial(equals_match, expected=foo[0]),
{},
False,
equals_match_verbose_code_parts(foo[0]),
)
guard_manager.getitem_manager(1, "", 2, default_mgr_enum).add_lambda_guard(
functools.partial(equals_match, expected=foo[1]),
{},
False,
equals_match_verbose_code_parts(foo[1]),
)
self.assertEqual(len(guard_manager.get_leaf_guards()), 1)
@ -599,8 +585,6 @@ num_guards_executed=0)
lambda x: isinstance(x, Pair)
and isinstance(x.x, torch.Tensor)
and isinstance(x.y, int),
{},
False,
"global guard fail",
)
@ -651,8 +635,6 @@ num_guards_executed=0)
)
attr_manager.add_lambda_guard(
lambda x: x == 4,
{},
False,
"Expected value 4",
)
@ -693,8 +675,6 @@ num_guards_executed=0)
weakref_manager.add_lambda_guard(
lambda x: isinstance(x, torch.Tensor),
{},
False,
"global weakref fail",
)
@ -714,8 +694,6 @@ num_guards_executed=0)
)
foo_mgr.add_lambda_guard(
lambda x: x == 3,
{},
False,
"Expected value 3",
)
self.assertTrue(guard_manager.check(a))
@ -801,7 +779,7 @@ num_guards_executed=0)
# Add key-value manager (nothing : {"z" : 3})
self.assertTrue(root.check(f_locals))
dict_mgr.get_key_manager(1, "", nothing, default_mgr_enum).add_lambda_guard(
lambda x: x is nothing, {}, False, ["x is nothing"]
lambda x: x is nothing, ["x is nothing"]
)
self.assertTrue(root.check(f_locals))
value_mgr = dict_mgr.get_value_manager(

View File

@ -7207,9 +7207,7 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
return x + 1
guard_manager = torch._dynamo.guards.RootGuardManager()
guard_manager.add_lambda_guard(
lambda L: isinstance(L["x"], int), {"x": 0}, True, []
)
guard_manager.add_lambda_guard(lambda L: isinstance(L["x"], int), [])
def injected(x):
return x + 42
@ -7234,33 +7232,27 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
return x + 1
guard_manager_bool = torch._dynamo.guards.RootGuardManager()
guard_manager_bool.add_lambda_guard(
lambda L: isinstance(L["x"], bool), {"x": 0}, True, []
)
guard_manager_bool.add_lambda_guard(lambda L: isinstance(L["x"], bool), [])
def injected_bool(x: bool):
return x + 102
guard_manager_int = torch._dynamo.guards.RootGuardManager()
guard_manager_int.add_lambda_guard(
lambda L: isinstance(L["x"], int), {"x": 0}, True, []
)
guard_manager_int.add_lambda_guard(lambda L: isinstance(L["x"], int), [])
def injected_int(x: int):
return x + 42
guard_manager_tensor = torch._dynamo.guards.RootGuardManager()
guard_manager_tensor.add_lambda_guard(
lambda L: isinstance(L["x"], torch.Tensor), {"x": 0}, True, []
lambda L: isinstance(L["x"], torch.Tensor), []
)
def injected_tensor(x: torch.Tensor):
return x + 100
guard_manager_str = torch._dynamo.guards.RootGuardManager()
guard_manager_str.add_lambda_guard(
lambda L: isinstance(L["x"], str), {"x": 0}, True, []
)
guard_manager_str.add_lambda_guard(lambda L: isinstance(L["x"], str), [])
def injected_str(x: str):
return x + "1"
@ -7337,10 +7329,7 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
guard_manager_bool = torch._dynamo.guards.RootGuardManager()
guard_manager_bool.add_lambda_guard(
lambda L: isinstance(L["x"], bool),
{"x": 0},
True,
["isinstance(L['x'], bool)"],
lambda L: isinstance(L["x"], bool), ["isinstance(L['x'], bool)"]
)
def injected_bool(x: bool):

View File

@ -222,11 +222,7 @@ class GuardManager:
) -> GuardManager: ...
# Leaf guards
def add_lambda_guard(
self,
user_lambda: Callable[..., Any],
required_locals: dict[str, int],
construct_partial_framelocals_dict: bool,
verbose_code_parts: list[str],
self, user_lambda: Callable[..., Any], verbose_code_parts: list[str]
) -> None: ...
def add_lambda_guard_no_args(
self, user_lambda: Callable[..., Any], verbose_code_parts: list[str]
@ -359,8 +355,6 @@ class RootGuardManager(GuardManager):
def add_epilogue_lambda_guard(
self,
guard: LeafGuard,
required_locals: dict[str, int],
construct_partial_framelocals_dict: bool,
verbose_code_parts: list[str],
) -> None: ...
def clone_manager(

View File

@ -381,12 +381,6 @@ use_recursive_dict_tags_for_guards = True
# useful for regional compilation.
max_saved_pointers_for_recursive_dict_tags_check = 256
# Controls whether to construct the partial framelocals to dict for lambda
# guards. This is a temporary flag to allow quick fallback behavior in case of
# unexpected issues. Default is True, i.e., we will construct only partial
# dict, a faster version for guards. Set to False to fallback to old behavior.
construct_partial_framelocals_dict = True
# If True, raises exception if TorchDynamo is called with a context manager
raise_on_ctx_manager_usage = True

View File

@ -235,7 +235,7 @@ dunder_attrs_assumed_constants = (
)
def get_framelocals_idx(code: types.CodeType, var_name: str) -> Optional[int]:
def get_framelocals_idx(code: types.CodeType, var_name: str) -> int:
# Refer to index in the frame's localsplus directly.
# NOTE: name order for a code object doesn't change.
# NOTE: we need to find the LAST matching index because <= 3.10 contains
@ -243,8 +243,6 @@ def get_framelocals_idx(code: types.CodeType, var_name: str) -> Optional[int]:
# and will take up 2 slots of the frame's localsplus. The correct behavior
# is to refer to the cell, which has a higher index.
framelocals_names_reversed = code_framelocals_names_reversed_cached(code)
if var_name not in framelocals_names_reversed:
return None
framelocals_idx = (
len(framelocals_names_reversed) - framelocals_names_reversed.index(var_name) - 1
)
@ -1362,7 +1360,6 @@ class GuardBuilder(GuardBuilderBase):
# Use istype instead of isinstance to check for exact type of source.
if istype(source, LocalSource):
framelocals_idx = get_framelocals_idx(self.f_code, source.local_name)
assert framelocals_idx is not None
out = root_guard_manager.framelocals_manager(
key=(source.local_name, framelocals_idx),
source=source_name,
@ -1758,34 +1755,15 @@ class GuardBuilder(GuardBuilderBase):
guards_log.debug("Python shape guard function:\n%s", pycode)
exec(pycode, globals_for_guard_fn, out)
guard_fn = out["___make_guard_fn"](*closure_vars.values())
required_locals = {}
all_locals = self.scope["L"].keys()
for var_name in guard_fn.__code__.co_consts:
if isinstance(var_name, str) and var_name in all_locals:
index = get_framelocals_idx(self.f_code, var_name)
if index is not None:
required_locals[var_name] = index
construct_partial_framelocals_dict = config.construct_partial_framelocals_dict
if is_epilogue:
# Epilogue guards are run after all the other guards have finished.
# If epilogue guards contain a getattr or getitem access, one of the
# other guards would fail preventing the epilogue guards to run.
self.guard_manager.root.add_epilogue_lambda_guard(
guard_fn,
required_locals,
construct_partial_framelocals_dict,
verbose_code_parts,
guard_fn, verbose_code_parts
)
else:
self.guard_manager.root.add_lambda_guard(
guard_fn,
required_locals,
construct_partial_framelocals_dict,
verbose_code_parts,
)
self.guard_manager.root.add_lambda_guard(guard_fn, verbose_code_parts)
# Warning: use this with care! This lets you access what the current
# value of the value you are guarding on is. You probably don't want

View File

@ -2074,7 +2074,7 @@ class OutputGraph(OutputGraphGuardsState):
check_fn_source = inspect.getsource(specialization.check_fn).strip()
# Required because the LABDA_GUARD API requires a root guard manager
unused_root_guard_manager = RootGuardManager()
check_fn = guards.LAMBDA_GUARD_NO_FRAMELOCALS( # type: ignore[attr-defined]
check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined]
unused_root_guard_manager,
specialization.check_fn,
[check_fn_source],

View File

@ -1625,7 +1625,9 @@ class LeafGuard {
// is not exposed to Python and can only be called from C++.
virtual bool check_nopybind(PyObject* value) = 0;
virtual bool check_nopybind(FrameLocalsMapping* map) {
throw std::runtime_error("fallback to python");
// throw std::runtime_error("fallback to python");
// Could fallback to running check on the Python dict (lazily constructed)
return check_nopybind((PyObject*)map->to_dict());
}
virtual ~LeafGuard() = default;
@ -1656,13 +1658,8 @@ class LAMBDA_GUARD : public LeafGuard {
LAMBDA_GUARD(
RootGuardManager* root_guard_manager,
py::object guard_check_fn,
py::object required_locals,
bool construct_partial_framelocals_dict,
py::object verbose_code_parts)
: LeafGuard(root_guard_manager, std::move(verbose_code_parts)),
_required_locals(py::cast<py::dict>(required_locals)),
_construct_partial_framelocals_dict(
construct_partial_framelocals_dict) {
: LeafGuard(root_guard_manager, std::move(verbose_code_parts)) {
if (py::isinstance<py::function>(guard_check_fn)) {
_guard_check_fn = py::cast<py::function>(std::move(guard_check_fn));
} else {
@ -1699,30 +1696,7 @@ class LAMBDA_GUARD : public LeafGuard {
return GuardDebugInfo(false, verbose_code_parts(), 0);
}
bool check_nopybind(FrameLocalsMapping* map) override {
// TODO (anijain2305) - Get rid of the _construct_partial_framelocals_dict
// once its stable.
if (_construct_partial_framelocals_dict) {
py::dict partial_dict;
for (auto item : _required_locals) {
partial_dict[item.first] = map->get(item.second.cast<int>());
}
return check_nopybind(partial_dict.ptr());
}
return check_nopybind((PyObject*)map->to_dict());
}
private:
// Dict of (local_name, framelocal_idx) representing the minimum number of
// framelocals needed to construct the dictionary for the lambda guard.
py::dict _required_locals;
// Temporary flag to allow a fallback behavior. With stability, we can remove
// this member.
bool _construct_partial_framelocals_dict;
// The user provided lambda function for check_fn.
py::function _guard_check_fn;
};
@ -1798,12 +1772,7 @@ class LAMBDA_GUARD_NO_FRAMELOCALS : public LAMBDA_GUARD {
RootGuardManager* root_guard_manager,
py::object guard_check_fn,
py::object verbose_code_parts)
: LAMBDA_GUARD(
root_guard_manager,
guard_check_fn,
py::dict(),
false,
verbose_code_parts) {}
: LAMBDA_GUARD(root_guard_manager, guard_check_fn, verbose_code_parts) {}
bool check_nopybind(PyObject* value) override { // borrowed ref
return LAMBDA_GUARD::check_nopybind(value);
@ -6802,8 +6771,7 @@ PyObject* torch_c_dynamo_guards_init() {
.def("verbose_code_parts", &LeafGuard::verbose_code_parts);
py::class_<LAMBDA_GUARD, LeafGuard, std::shared_ptr<LAMBDA_GUARD>>(
py_m, "LAMBDA_GUARD")
.def(
py::init<RootGuardManager*, py::function, py::dict, bool, py::list>())
.def(py::init<RootGuardManager*, py::function, py::list>())
.def("__call__", &LAMBDA_GUARD::check);
py::class_<
LAMBDA_GUARD_NO_ARGS,
@ -7126,14 +7094,10 @@ PyObject* torch_c_dynamo_guards_init() {
"add_lambda_guard",
[](GuardManager& self,
py::object lambda,
py::object required_locals,
bool construct_partial_framelocals_dict,
py::object verbose_code_parts) -> void {
self.add_leaf_guard(std::make_shared<LAMBDA_GUARD>(
self.get_root(),
std::move(lambda),
std::move(required_locals),
construct_partial_framelocals_dict,
std::move(verbose_code_parts)));
})
.def(
@ -7816,15 +7780,9 @@ PyObject* torch_c_dynamo_guards_init() {
"add_epilogue_lambda_guard",
[](RootGuardManager& self,
py::object lambda,
py::object required_locals,
bool construct_partial_framelocals_dict,
py::object verbose_code_parts) -> void {
self.add_epilogue_lambda_guard(std::make_unique<LAMBDA_GUARD>(
&self,
std::move(lambda),
std::move(required_locals),
construct_partial_framelocals_dict,
std::move(verbose_code_parts)));
&self, std::move(lambda), std::move(verbose_code_parts)));
});
// Dict Guard Manager