[dynamo] Correctly track mutation class source for MutableMappingVariable (#161568)

Fixes https://github.com/pytorch/pytorch/issues/161505

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161568
Approved by: https://github.com/Lucaskabela, https://github.com/malfet
This commit is contained in:
Animesh Jain
2025-08-26 21:57:02 -07:00
committed by PyTorch MergeBot
parent b9c6aa1e17
commit 68fa882dad
3 changed files with 19 additions and 11 deletions

View File

@ -7,7 +7,7 @@ import operator
import types
import unittest
import weakref
from collections import defaultdict, namedtuple, OrderedDict
from collections import defaultdict, namedtuple, OrderedDict, UserDict
from typing import Any
import torch
@ -31,6 +31,10 @@ class SimpleDict(dict):
pass
class DummyUserDict(UserDict):
pass
class DictTests(torch._dynamo.test_case.TestCase):
def test_dict_subclass_instantiation(self):
def fn(x):
@ -788,6 +792,17 @@ class DictTests(torch._dynamo.test_case.TestCase):
x = torch.randn(4)
self.assertEqual(fn(x), opt_fn(x))
def test_construct_user_dict_and_return(self):
def fn(x):
return DummyUserDict({"a": x + 1})
x = torch.randn(4)
res = fn(x)
self.assertEqual(res["a"], x + 1)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(res["a"], opt_fn(x)["a"])
def test_fn_id(self):
def fn(x, f):
d = {id(f): 3}

View File

@ -1521,7 +1521,8 @@ class VariableBuilder:
return self.tx.output.side_effects.track_object_existing(value, result)
elif issubclass(type(value), MutableMapping):
self.install_guards(GuardBuilder.TYPE_MATCH)
return MutableMappingVariable(value, source=self.source)
result = MutableMappingVariable(value, source=self.source)
return self.tx.output.side_effects.track_object_existing(value, result)
elif is_frozen_dataclass(value):
self.install_guards(GuardBuilder.TYPE_MATCH)
result = FrozenDataClassVariable.create(self.tx, value, source=self.source)

View File

@ -92,12 +92,7 @@ from ..utils import (
tuple_methods,
unpatched_nn_module_getattr,
)
from .base import (
AttributeMutationExisting,
AttributeMutationNew,
ValueMutationNew,
VariableTracker,
)
from .base import ValueMutationNew, VariableTracker
from .dicts import DefaultDictVariable
from .lists import SizeVariable
@ -2157,9 +2152,6 @@ class MutableMappingVariable(UserDefinedObjectVariable):
def __init__(self, value, **kwargs):
super().__init__(value, **kwargs)
self.generic_dict_vt = variables.ConstDictVariable({})
self.mutation_type = (
AttributeMutationExisting() if self.source else AttributeMutationNew()
)
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
# A common pattern in the init code of MutableMapping objects is to