dynamo: Handle objects in graph that do not support weakref (#163168)

We are seeing crashes of the form
```
Traceback (most recent call last):
  File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/symbolic_convert.py", line 1487, in run
    while self.step():
  File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/symbolic_convert.py", line 1348, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/symbolic_convert.py", line 2437, in LOAD_ATTR
    self._load_attr(inst)
  File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/symbolic_convert.py", line 2425, in _load_attr
    result = BuiltinVariable(getattr).call_function(
  File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/variables/builtin.py", line 1347, in call_function
    return handler(tx, args, kwargs)
  File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/variables/builtin.py", line 967, in <lambda>
    tx, [v.realize() for v in args], kwargs
  File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/variables/builtin.py", line 967, in <listcomp>
    tx, [v.realize() for v in args], kwargs
  File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/variables/lazy.py", line 72, in realize
    self._cache.realize()
  File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/variables/lazy.py", line 33, in realize
    self.vt = builder.VariableBuilder(tx, self.source)(self.value)
  File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/variables/builder.py", line 445, in __call__
    vt = self._wrap(value)
  File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/variables/builder.py", line 1043, in _wrap
    torch._dynamo.utils.store_user_object_weakref(value)
  File "/packages/aps_ads_vm/launcher_multiapp-inplace#link-tree/torch/_dynamo/utils.py", line 4694, in store_user_object_weakref
    user_obj_id_to_weakref[obj_id] = weakref.ref(obj)
torch._dynamo.exc.InternalTorchDynamoError: TypeError: cannot create weak reference to 'torch.Event' object
```

This pull request makes us gracefully graph break, vs explicitly crashing.

I've added a test which reproduces the issue. There is a side discussion re:
how did torch.Event support ever work here, since it appears you cannot take a
weakref to a torch.Event

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163168
Approved by: https://github.com/Lucaskabela, https://github.com/jansel
This commit is contained in:
clr
2025-09-17 10:00:10 -07:00
committed by PyTorch MergeBot
parent 60c2bdedcd
commit 33daaad7d0
3 changed files with 41 additions and 1 deletions

View File

@ -4696,7 +4696,18 @@ def get_user_object_from_id(obj_id: int) -> Any:
def store_user_object_weakref(obj: object) -> None:
obj_id = id(obj)
user_obj_id_to_weakref[obj_id] = weakref.ref(obj)
try:
user_obj_id_to_weakref[obj_id] = weakref.ref(obj)
except TypeError as e:
from .exc import unimplemented_v2
unimplemented_v2(
gb_type="Failed to make weakref to User Object",
context=f"user_objected: {obj}",
explanation="Object does not allow us to make a weakref to it",
hints=[],
from_exc=e,
)
class CompileTimeInstructionCounter: