[dynamo] fixes dict changed during runtime error (#87526)

Fixes https://github.com/pytorch/torchdynamo/issues/1744

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87526
Approved by: https://github.com/ezyang
This commit is contained in:
Animesh Jain
2022-11-10 01:57:17 +00:00
committed by PyTorch MergeBot
parent 0b8889c724
commit cf04b36ce8
2 changed files with 12 additions and 6 deletions

View File

@ -156,7 +156,11 @@ def has_tensor_in_frame(frame):
seen_ids[obj_id] = any([has_tensor(v) for v in obj])
return seen_ids[obj_id]
elif istype(obj, dict):
seen_ids[obj_id] = any([has_tensor(v) for v in obj.values()])
# Some packages like pytest can be updated during runtime. So, make a
# copy of values to avoid issues like "RuntimeError: dictionary
# changed size during iteration"
values = list(obj.values())
seen_ids[obj_id] = any([has_tensor(v) for v in values])
return seen_ids[obj_id]
elif istype(obj, (str, int, float, type(None), bool)):
seen_ids[obj_id] = False
@ -164,8 +168,13 @@ def has_tensor_in_frame(frame):
elif is_namedtuple(obj):
seen_ids[obj_id] = any([has_tensor(getattr(obj, v)) for v in obj._fields])
return seen_ids[obj_id]
elif not is_allowed(obj) and hasattr(obj, "__dict__") and len(obj.__dict__):
seen_ids[obj_id] = any([has_tensor(v) for v in obj.__dict__.values()])
elif (
not is_allowed(obj)
and not hasattr(obj, "__get__") # overridden get can mutate the object
and hasattr(obj, "__dict__")
and istype(obj.__dict__, dict)
):
seen_ids[obj_id] = has_tensor(obj.__dict__)
return seen_ids[obj_id]
else:
# if config.debug: