mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] fixes dict changed during runtime error (#87526)
Fixes https://github.com/pytorch/torchdynamo/issues/1744 Pull Request resolved: https://github.com/pytorch/pytorch/pull/87526 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
0b8889c724
commit
cf04b36ce8
@ -156,7 +156,11 @@ def has_tensor_in_frame(frame):
|
||||
seen_ids[obj_id] = any([has_tensor(v) for v in obj])
|
||||
return seen_ids[obj_id]
|
||||
elif istype(obj, dict):
|
||||
seen_ids[obj_id] = any([has_tensor(v) for v in obj.values()])
|
||||
# Some packages like pytest can be updated during runtime. So, make a
|
||||
# copy of values to avoid issues like "RuntimeError: dictionary
|
||||
# changed size during iteration"
|
||||
values = list(obj.values())
|
||||
seen_ids[obj_id] = any([has_tensor(v) for v in values])
|
||||
return seen_ids[obj_id]
|
||||
elif istype(obj, (str, int, float, type(None), bool)):
|
||||
seen_ids[obj_id] = False
|
||||
@ -164,8 +168,13 @@ def has_tensor_in_frame(frame):
|
||||
elif is_namedtuple(obj):
|
||||
seen_ids[obj_id] = any([has_tensor(getattr(obj, v)) for v in obj._fields])
|
||||
return seen_ids[obj_id]
|
||||
elif not is_allowed(obj) and hasattr(obj, "__dict__") and len(obj.__dict__):
|
||||
seen_ids[obj_id] = any([has_tensor(v) for v in obj.__dict__.values()])
|
||||
elif (
|
||||
not is_allowed(obj)
|
||||
and not hasattr(obj, "__get__") # overridden get can mutate the object
|
||||
and hasattr(obj, "__dict__")
|
||||
and istype(obj.__dict__, dict)
|
||||
):
|
||||
seen_ids[obj_id] = has_tensor(obj.__dict__)
|
||||
return seen_ids[obj_id]
|
||||
else:
|
||||
# if config.debug:
|
||||
|
Reference in New Issue
Block a user