mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Implement GetAttrVariable.as_python_constant() (#134216)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134216 Approved by: https://github.com/amjames, https://github.com/williamwen42
This commit is contained in:
committed by
PyTorch MergeBot
parent
d9aca9914b
commit
e3ea5429f2
@ -11333,6 +11333,29 @@ fn
|
||||
self.assertEqual(r.y, torch.ones(2, 2) + 1)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
def test_getattrvariable_as_python_constant(self):
|
||||
from torch._dynamo.variables.misc import GetAttrVariable
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x, rand1):
|
||||
random.Random().setstate(rand1.getstate())
|
||||
return x + rand1.random()
|
||||
|
||||
def get_rng():
|
||||
rand1 = random.Random(1)
|
||||
orig_random = rand1.random
|
||||
rand1.random = lambda: orig_random()
|
||||
return rand1
|
||||
|
||||
x = torch.randn(3, 3)
|
||||
expected = fn.__wrapped__(x, get_rng())
|
||||
|
||||
with patch.object(GetAttrVariable, "as_python_constant", autospec=True) as po:
|
||||
actual = fn(x, get_rng())
|
||||
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertGreater(po.call_count, 0)
|
||||
|
||||
|
||||
class TestTracer(JitTestCase):
|
||||
def test_jit_save(self):
|
||||
|
@ -1003,6 +1003,13 @@ class GetAttrVariable(VariableTracker):
|
||||
def as_proxy(self):
|
||||
return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name)
|
||||
|
||||
def as_python_constant(self):
|
||||
constant = self.obj.as_python_constant()
|
||||
try:
|
||||
return getattr(constant, self.name)
|
||||
except AttributeError:
|
||||
raise NotImplementedError(f"{self} is not a constant") from None
|
||||
|
||||
def const_getattr(self, tx: "InstructionTranslator", name):
|
||||
if not isinstance(self.obj, variables.NNModuleVariable):
|
||||
raise NotImplementedError
|
||||
|
Reference in New Issue
Block a user