[dynamo] fixes dict changed during runtime error (#87526)

Fixes https://github.com/pytorch/torchdynamo/issues/1744

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87526
Approved by: https://github.com/ezyang
This commit is contained in:
Animesh Jain
2022-11-10 01:57:17 +00:00
committed by PyTorch MergeBot
parent 0b8889c724
commit cf04b36ce8
2 changed files with 12 additions and 6 deletions

View File

@ -71,7 +71,6 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
y = torch.randn(3, device="cuda") y = torch.randn(3, device="cuda")
fn(x, y) fn(x, y)
@patch("torch._dynamo.config.suppress_errors", True)
@patch_all() @patch_all()
def test_dtoh(self): def test_dtoh(self):
def model(x, y): def model(x, y):
@ -105,7 +104,6 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
y = torch.randn((), device="cpu") y = torch.randn((), device="cpu")
fn(x, y) fn(x, y)
@patch("torch._dynamo.config.suppress_errors", True)
@patch("functorch._src.config.use_functionalize", True) @patch("functorch._src.config.use_functionalize", True)
@patch_all(ok=False) # input mutation not supported yet @patch_all(ok=False) # input mutation not supported yet
def test_mutate_input(self): def test_mutate_input(self):
@ -145,7 +143,6 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
y = torch.randn(1, device="cuda") y = torch.randn(1, device="cuda")
fn(x, y) fn(x, y)
@patch("torch._dynamo.config.suppress_errors", True)
@patch_all() @patch_all()
def test_factory(self): def test_factory(self):
def model(y): def model(y):

View File

@ -156,7 +156,11 @@ def has_tensor_in_frame(frame):
seen_ids[obj_id] = any([has_tensor(v) for v in obj]) seen_ids[obj_id] = any([has_tensor(v) for v in obj])
return seen_ids[obj_id] return seen_ids[obj_id]
elif istype(obj, dict): elif istype(obj, dict):
seen_ids[obj_id] = any([has_tensor(v) for v in obj.values()]) # Some packages like pytest can be updated during runtime. So, make a
# copy of values to avoid issues like "RuntimeError: dictionary
# changed size during iteration"
values = list(obj.values())
seen_ids[obj_id] = any([has_tensor(v) for v in values])
return seen_ids[obj_id] return seen_ids[obj_id]
elif istype(obj, (str, int, float, type(None), bool)): elif istype(obj, (str, int, float, type(None), bool)):
seen_ids[obj_id] = False seen_ids[obj_id] = False
@ -164,8 +168,13 @@ def has_tensor_in_frame(frame):
elif is_namedtuple(obj): elif is_namedtuple(obj):
seen_ids[obj_id] = any([has_tensor(getattr(obj, v)) for v in obj._fields]) seen_ids[obj_id] = any([has_tensor(getattr(obj, v)) for v in obj._fields])
return seen_ids[obj_id] return seen_ids[obj_id]
elif not is_allowed(obj) and hasattr(obj, "__dict__") and len(obj.__dict__): elif (
seen_ids[obj_id] = any([has_tensor(v) for v in obj.__dict__.values()]) not is_allowed(obj)
and not hasattr(obj, "__get__") # overridden get can mutate the object
and hasattr(obj, "__dict__")
and istype(obj.__dict__, dict)
):
seen_ids[obj_id] = has_tensor(obj.__dict__)
return seen_ids[obj_id] return seen_ids[obj_id]
else: else:
# if config.debug: # if config.debug: