Compare commits

...

1 Commits

Author SHA1 Message Date
7e9437f5a8 [dynamo] Install guard when branching on empty dictionary
This fixes an internal test failure on guarding NN module hooks, which
started failing after #143997 stopped eagerly guard on dictionary
length.
2025-01-22 13:04:20 -08:00
4 changed files with 32 additions and 0 deletions

View File

@ -725,6 +725,20 @@ class DictTests(torch._dynamo.test_case.TestCase):
foo.scalar = 12
self.assertEqual(fn(d, inp), opt_fn(d, inp))
def test_branch_on_dict(self):
def fn(x, d):
if d:
return x + 1
return x + 2
x = torch.ones(1)
d = {}
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(x, d), opt_fn(x, d))
d["a"] = 1
self.assertEqual(fn(x, d), opt_fn(x, d))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -574,10 +574,18 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
return
if value.is_python_constant():
# TODO materializing containers like dict as a constant could force
# installing more guards than necessary (e.g., guards for all keys
# and values), when in theory we only need a `SEQUENCE_LENGTH` guard
# for these objects.
if truth_fn(value.as_python_constant()):
if push:
self.push(value)
self.jump(inst)
# TODO install guards for more types.
if istype(value, ConstDictVariable) and value.source:
install_guard(value.source.make_guard(GuardBuilder.SEQUENCE_LENGTH))
elif (
isinstance(value, (TensorVariable)) and self.should_compile_partial_graph()
):

View File

@ -918,8 +918,12 @@ def istype(obj: object, allowed_types: Iterable[type]) -> bool:
def istype(obj, allowed_types):
"""isinstance() without subclasses"""
from .variables import LazyVariableTracker
if isinstance(allowed_types, (tuple, list, set)):
return type(obj) in allowed_types
if isinstance(obj, LazyVariableTracker):
obj = obj.realize()
return type(obj) is allowed_types

View File

@ -60,6 +60,12 @@ class LazyVariableTracker(VariableTracker):
assert isinstance(_cache, LazyCache)
super().__init__(**kwargs)
self._cache = _cache
# NOTE: The value of `mutation_type` is decided in the underlying `VT`
# (after it's been realized), thus we remove `mutation_type` from Lazy
# VT's instance __dict__, and have force any `lazy_vt.mutation_type` to
# route through `__getattr__` below, rather than returning the default
# `None` assigned in `VariableTracker.__init__`.
del self.mutation_type
def realize(self) -> VariableTracker:
"""Force construction of the real VariableTracker"""