[dynamo] Use Variable Builder to build the property fget object (#165683)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165683
Approved by: https://github.com/ezyang, https://github.com/williamwen42
This commit is contained in:
Animesh Jain
2025-10-16 16:50:46 -07:00
committed by PyTorch MergeBot
parent 9e94ec76b8
commit 24879f0de9
2 changed files with 28 additions and 8 deletions

View File

@ -5173,10 +5173,9 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
res = opt_fn(x)
self.assertEqual(ref, res)
@unittest.expectedFailure
def test_property_class_transmute(self):
class PropertyGetter:
def __call__(self):
def __call__(self, obj):
return True
p = property(PropertyGetter())
@ -5195,6 +5194,31 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
x = torch.randn(1)
self.assertEqual(opt_mod(x), x + 1)
def test_property_functools_partial(self):
def p_getter(obj, *, delta: int):
# Use instance state + a bound constant
return (getattr(obj, "flag", 0) + delta) > 0
class Mod(torch.nn.Module):
def __init__(self, flag: int):
super().__init__()
self.flag = flag
# fget is a functools.partial object
p = property(functools.partial(p_getter, delta=1))
def forward(self, x):
if self.p: # calls p_getter(self, delta=1)
return x + 1
else:
raise RuntimeError("whoops")
mod = Mod(flag=1)
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
x = torch.randn(1)
self.assertEqual(opt_mod(x), x + 1)
instantiate_parametrized_tests(FunctionTests)
instantiate_parametrized_tests(DefaultsTests)

View File

@ -1458,12 +1458,8 @@ class UserDefinedObjectVariable(UserDefinedVariable):
# Get the getter function
source = AttrSource(source, "fget")
# Avoid using UserMethodVariable here because there is no way to
# access the method object here. Direct inline by creating the
# UserFunctionVariable.
return variables.UserFunctionVariable(
subobj.fget, source=source
).call_function(tx, [self], {})
fget_vt = VariableTracker.build(tx, subobj.fget, source=source)
return fget_vt.call_function(tx, [self], {})
elif isinstance(subobj, _collections._tuplegetter):
# namedtuple fields are represented by _tuplegetter, and here we
# emulate its `__get__`, which is implemented in C.