mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Correctly track mutation class source for MutableMappingVariable (#161568)
Fixes https://github.com/pytorch/pytorch/issues/161505 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161568 Approved by: https://github.com/Lucaskabela, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
b9c6aa1e17
commit
68fa882dad
@ -7,7 +7,7 @@ import operator
|
||||
import types
|
||||
import unittest
|
||||
import weakref
|
||||
from collections import defaultdict, namedtuple, OrderedDict
|
||||
from collections import defaultdict, namedtuple, OrderedDict, UserDict
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@ -31,6 +31,10 @@ class SimpleDict(dict):
|
||||
pass
|
||||
|
||||
|
||||
class DummyUserDict(UserDict):
|
||||
pass
|
||||
|
||||
|
||||
class DictTests(torch._dynamo.test_case.TestCase):
|
||||
def test_dict_subclass_instantiation(self):
|
||||
def fn(x):
|
||||
@ -788,6 +792,17 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
||||
x = torch.randn(4)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
def test_construct_user_dict_and_return(self):
|
||||
def fn(x):
|
||||
return DummyUserDict({"a": x + 1})
|
||||
|
||||
x = torch.randn(4)
|
||||
res = fn(x)
|
||||
self.assertEqual(res["a"], x + 1)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
self.assertEqual(res["a"], opt_fn(x)["a"])
|
||||
|
||||
def test_fn_id(self):
|
||||
def fn(x, f):
|
||||
d = {id(f): 3}
|
||||
|
@ -1521,7 +1521,8 @@ class VariableBuilder:
|
||||
return self.tx.output.side_effects.track_object_existing(value, result)
|
||||
elif issubclass(type(value), MutableMapping):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
return MutableMappingVariable(value, source=self.source)
|
||||
result = MutableMappingVariable(value, source=self.source)
|
||||
return self.tx.output.side_effects.track_object_existing(value, result)
|
||||
elif is_frozen_dataclass(value):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
result = FrozenDataClassVariable.create(self.tx, value, source=self.source)
|
||||
|
@ -92,12 +92,7 @@ from ..utils import (
|
||||
tuple_methods,
|
||||
unpatched_nn_module_getattr,
|
||||
)
|
||||
from .base import (
|
||||
AttributeMutationExisting,
|
||||
AttributeMutationNew,
|
||||
ValueMutationNew,
|
||||
VariableTracker,
|
||||
)
|
||||
from .base import ValueMutationNew, VariableTracker
|
||||
from .dicts import DefaultDictVariable
|
||||
from .lists import SizeVariable
|
||||
|
||||
@ -2157,9 +2152,6 @@ class MutableMappingVariable(UserDefinedObjectVariable):
|
||||
def __init__(self, value, **kwargs):
|
||||
super().__init__(value, **kwargs)
|
||||
self.generic_dict_vt = variables.ConstDictVariable({})
|
||||
self.mutation_type = (
|
||||
AttributeMutationExisting() if self.source else AttributeMutationNew()
|
||||
)
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
||||
# A common pattern in the init code of MutableMapping objects is to
|
||||
|
Reference in New Issue
Block a user