mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[dynamo] fixes dict changed during runtime error (#87526)
Fixes https://github.com/pytorch/torchdynamo/issues/1744 Pull Request resolved: https://github.com/pytorch/pytorch/pull/87526 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
0b8889c724
commit
cf04b36ce8
@ -71,7 +71,6 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
||||
y = torch.randn(3, device="cuda")
|
||||
fn(x, y)
|
||||
|
||||
@patch("torch._dynamo.config.suppress_errors", True)
|
||||
@patch_all()
|
||||
def test_dtoh(self):
|
||||
def model(x, y):
|
||||
@ -105,7 +104,6 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
||||
y = torch.randn((), device="cpu")
|
||||
fn(x, y)
|
||||
|
||||
@patch("torch._dynamo.config.suppress_errors", True)
|
||||
@patch("functorch._src.config.use_functionalize", True)
|
||||
@patch_all(ok=False) # input mutation not supported yet
|
||||
def test_mutate_input(self):
|
||||
@ -145,7 +143,6 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
||||
y = torch.randn(1, device="cuda")
|
||||
fn(x, y)
|
||||
|
||||
@patch("torch._dynamo.config.suppress_errors", True)
|
||||
@patch_all()
|
||||
def test_factory(self):
|
||||
def model(y):
|
||||
|
@ -156,7 +156,11 @@ def has_tensor_in_frame(frame):
|
||||
seen_ids[obj_id] = any([has_tensor(v) for v in obj])
|
||||
return seen_ids[obj_id]
|
||||
elif istype(obj, dict):
|
||||
seen_ids[obj_id] = any([has_tensor(v) for v in obj.values()])
|
||||
# Some packages like pytest can be updated during runtime. So, make a
|
||||
# copy of values to avoid issues like "RuntimeError: dictionary
|
||||
# changed size during iteration"
|
||||
values = list(obj.values())
|
||||
seen_ids[obj_id] = any([has_tensor(v) for v in values])
|
||||
return seen_ids[obj_id]
|
||||
elif istype(obj, (str, int, float, type(None), bool)):
|
||||
seen_ids[obj_id] = False
|
||||
@ -164,8 +168,13 @@ def has_tensor_in_frame(frame):
|
||||
elif is_namedtuple(obj):
|
||||
seen_ids[obj_id] = any([has_tensor(getattr(obj, v)) for v in obj._fields])
|
||||
return seen_ids[obj_id]
|
||||
elif not is_allowed(obj) and hasattr(obj, "__dict__") and len(obj.__dict__):
|
||||
seen_ids[obj_id] = any([has_tensor(v) for v in obj.__dict__.values()])
|
||||
elif (
|
||||
not is_allowed(obj)
|
||||
and not hasattr(obj, "__get__") # overridden get can mutate the object
|
||||
and hasattr(obj, "__dict__")
|
||||
and istype(obj.__dict__, dict)
|
||||
):
|
||||
seen_ids[obj_id] = has_tensor(obj.__dict__)
|
||||
return seen_ids[obj_id]
|
||||
else:
|
||||
# if config.debug:
|
||||
|
Reference in New Issue
Block a user