Add hasattr for tensor variable (#131008)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131008
Approved by: https://github.com/anijain2305
ghstack dependencies: #131007
This commit is contained in:
Michael Lazos
2024-07-18 14:42:52 -07:00
committed by PyTorch MergeBot
parent 1f961ad495
commit 1b72cf0b09
5 changed files with 35 additions and 0 deletions

View File

@ -1364,6 +1364,20 @@ utils_device.CURRENT_DEVICE == None""".split(
r2 = opt_fn(i)
self.assertEqual(r1, r2)
def test_tensor_hasattr(self):
@torch.compile(fullgraph=True)
def fn(x):
if hasattr(x, "test"):
return x + 2
else:
return x + 1
self.assertEqual(torch.ones(2, 2) + 1, fn(torch.ones(2, 2)))
inp = torch.ones(2, 2)
inp.test = None
self.assertEqual(torch.ones(2, 2) + 2, fn(inp))
def test_shape_unpack(self):
def fn(x):
a, b = x.size()

View File

@ -333,6 +333,27 @@ class TensorVariable(VariableTracker):
tx, [self], {}
)
def call_hasattr(self, tx, name):
from . import GetAttrVariable
from .builtin import BuiltinVariable
try:
var = BuiltinVariable(getattr).call_function(
tx, [self, ConstantVariable(name)], {}
)
# in the event that TensorVariable returns NotImplemented
# BuiltinVariable.call_getattr returns GetAttrVariable
ret_val = not isinstance(var, GetAttrVariable)
except AttributeError:
ret_val = False
if self.source:
install_guard(
AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
)
return ConstantVariable(ret_val)
def var_getattr(self, tx, name):
from . import UserDefinedClassVariable