mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Use Variable Builder to build the property fget object (#165683)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165683 Approved by: https://github.com/ezyang, https://github.com/williamwen42
This commit is contained in:
committed by
PyTorch MergeBot
parent
9e94ec76b8
commit
24879f0de9
@ -5173,10 +5173,9 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_property_class_transmute(self):
|
||||
class PropertyGetter:
|
||||
def __call__(self):
|
||||
def __call__(self, obj):
|
||||
return True
|
||||
|
||||
p = property(PropertyGetter())
|
||||
@ -5195,6 +5194,31 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
|
||||
x = torch.randn(1)
|
||||
self.assertEqual(opt_mod(x), x + 1)
|
||||
|
||||
def test_property_functools_partial(self):
|
||||
def p_getter(obj, *, delta: int):
|
||||
# Use instance state + a bound constant
|
||||
return (getattr(obj, "flag", 0) + delta) > 0
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self, flag: int):
|
||||
super().__init__()
|
||||
self.flag = flag
|
||||
|
||||
# fget is a functools.partial object
|
||||
p = property(functools.partial(p_getter, delta=1))
|
||||
|
||||
def forward(self, x):
|
||||
if self.p: # calls p_getter(self, delta=1)
|
||||
return x + 1
|
||||
else:
|
||||
raise RuntimeError("whoops")
|
||||
|
||||
mod = Mod(flag=1)
|
||||
|
||||
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
|
||||
x = torch.randn(1)
|
||||
self.assertEqual(opt_mod(x), x + 1)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(FunctionTests)
|
||||
instantiate_parametrized_tests(DefaultsTests)
|
||||
|
||||
@ -1458,12 +1458,8 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
# Get the getter function
|
||||
source = AttrSource(source, "fget")
|
||||
|
||||
# Avoid using UserMethodVariable here because there is no way to
|
||||
# access the method object here. Direct inline by creating the
|
||||
# UserFunctionVariable.
|
||||
return variables.UserFunctionVariable(
|
||||
subobj.fget, source=source
|
||||
).call_function(tx, [self], {})
|
||||
fget_vt = VariableTracker.build(tx, subobj.fget, source=source)
|
||||
return fget_vt.call_function(tx, [self], {})
|
||||
elif isinstance(subobj, _collections._tuplegetter):
|
||||
# namedtuple fields are represented by _tuplegetter, and here we
|
||||
# emulate its `__get__`, which is implemented in C.
|
||||
|
||||
Reference in New Issue
Block a user