mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[dynamo] fixes dict changed during runtime error (#87526)
Fixes https://github.com/pytorch/torchdynamo/issues/1744 Pull Request resolved: https://github.com/pytorch/pytorch/pull/87526 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
0b8889c724
commit
cf04b36ce8
@ -71,7 +71,6 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
|||||||
y = torch.randn(3, device="cuda")
|
y = torch.randn(3, device="cuda")
|
||||||
fn(x, y)
|
fn(x, y)
|
||||||
|
|
||||||
@patch("torch._dynamo.config.suppress_errors", True)
|
|
||||||
@patch_all()
|
@patch_all()
|
||||||
def test_dtoh(self):
|
def test_dtoh(self):
|
||||||
def model(x, y):
|
def model(x, y):
|
||||||
@ -105,7 +104,6 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
|||||||
y = torch.randn((), device="cpu")
|
y = torch.randn((), device="cpu")
|
||||||
fn(x, y)
|
fn(x, y)
|
||||||
|
|
||||||
@patch("torch._dynamo.config.suppress_errors", True)
|
|
||||||
@patch("functorch._src.config.use_functionalize", True)
|
@patch("functorch._src.config.use_functionalize", True)
|
||||||
@patch_all(ok=False) # input mutation not supported yet
|
@patch_all(ok=False) # input mutation not supported yet
|
||||||
def test_mutate_input(self):
|
def test_mutate_input(self):
|
||||||
@ -145,7 +143,6 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
|||||||
y = torch.randn(1, device="cuda")
|
y = torch.randn(1, device="cuda")
|
||||||
fn(x, y)
|
fn(x, y)
|
||||||
|
|
||||||
@patch("torch._dynamo.config.suppress_errors", True)
|
|
||||||
@patch_all()
|
@patch_all()
|
||||||
def test_factory(self):
|
def test_factory(self):
|
||||||
def model(y):
|
def model(y):
|
||||||
|
@ -156,7 +156,11 @@ def has_tensor_in_frame(frame):
|
|||||||
seen_ids[obj_id] = any([has_tensor(v) for v in obj])
|
seen_ids[obj_id] = any([has_tensor(v) for v in obj])
|
||||||
return seen_ids[obj_id]
|
return seen_ids[obj_id]
|
||||||
elif istype(obj, dict):
|
elif istype(obj, dict):
|
||||||
seen_ids[obj_id] = any([has_tensor(v) for v in obj.values()])
|
# Some packages like pytest can be updated during runtime. So, make a
|
||||||
|
# copy of values to avoid issues like "RuntimeError: dictionary
|
||||||
|
# changed size during iteration"
|
||||||
|
values = list(obj.values())
|
||||||
|
seen_ids[obj_id] = any([has_tensor(v) for v in values])
|
||||||
return seen_ids[obj_id]
|
return seen_ids[obj_id]
|
||||||
elif istype(obj, (str, int, float, type(None), bool)):
|
elif istype(obj, (str, int, float, type(None), bool)):
|
||||||
seen_ids[obj_id] = False
|
seen_ids[obj_id] = False
|
||||||
@ -164,8 +168,13 @@ def has_tensor_in_frame(frame):
|
|||||||
elif is_namedtuple(obj):
|
elif is_namedtuple(obj):
|
||||||
seen_ids[obj_id] = any([has_tensor(getattr(obj, v)) for v in obj._fields])
|
seen_ids[obj_id] = any([has_tensor(getattr(obj, v)) for v in obj._fields])
|
||||||
return seen_ids[obj_id]
|
return seen_ids[obj_id]
|
||||||
elif not is_allowed(obj) and hasattr(obj, "__dict__") and len(obj.__dict__):
|
elif (
|
||||||
seen_ids[obj_id] = any([has_tensor(v) for v in obj.__dict__.values()])
|
not is_allowed(obj)
|
||||||
|
and not hasattr(obj, "__get__") # overridden get can mutate the object
|
||||||
|
and hasattr(obj, "__dict__")
|
||||||
|
and istype(obj.__dict__, dict)
|
||||||
|
):
|
||||||
|
seen_ids[obj_id] = has_tensor(obj.__dict__)
|
||||||
return seen_ids[obj_id]
|
return seen_ids[obj_id]
|
||||||
else:
|
else:
|
||||||
# if config.debug:
|
# if config.debug:
|
||||||
|
Reference in New Issue
Block a user