Implement GetAttrVariable.as_python_constant() (#134216)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134216
Approved by: https://github.com/amjames, https://github.com/williamwen42
This commit is contained in:
Tom Ritchford
2024-09-18 15:02:35 +00:00
committed by PyTorch MergeBot
parent d9aca9914b
commit e3ea5429f2
12 changed files with 30 additions and 0 deletions

View File

@ -11333,6 +11333,29 @@ fn
self.assertEqual(r.y, torch.ones(2, 2) + 1)
self.assertEqual(cnts.frame_count, 1)
def test_getattrvariable_as_python_constant(self):
from torch._dynamo.variables.misc import GetAttrVariable
@torch.compile(backend="eager")
def fn(x, rand1):
random.Random().setstate(rand1.getstate())
return x + rand1.random()
def get_rng():
rand1 = random.Random(1)
orig_random = rand1.random
rand1.random = lambda: orig_random()
return rand1
x = torch.randn(3, 3)
expected = fn.__wrapped__(x, get_rng())
with patch.object(GetAttrVariable, "as_python_constant", autospec=True) as po:
actual = fn(x, get_rng())
self.assertEqual(expected, actual)
self.assertGreater(po.call_count, 0)
class TestTracer(JitTestCase):
def test_jit_save(self):

View File

@ -1003,6 +1003,13 @@ class GetAttrVariable(VariableTracker):
def as_proxy(self):
return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name)
def as_python_constant(self):
constant = self.obj.as_python_constant()
try:
return getattr(constant, self.name)
except AttributeError:
raise NotImplementedError(f"{self} is not a constant") from None
def const_getattr(self, tx: "InstructionTranslator", name):
if not isinstance(self.obj, variables.NNModuleVariable):
raise NotImplementedError