mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add hasattr for tensor variable (#131008)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131008 Approved by: https://github.com/anijain2305 ghstack dependencies: #131007
This commit is contained in:
committed by
PyTorch MergeBot
parent
1f961ad495
commit
1b72cf0b09
@ -1364,6 +1364,20 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||
r2 = opt_fn(i)
|
||||
self.assertEqual(r1, r2)
|
||||
|
||||
def test_tensor_hasattr(self):
|
||||
@torch.compile(fullgraph=True)
|
||||
def fn(x):
|
||||
if hasattr(x, "test"):
|
||||
return x + 2
|
||||
else:
|
||||
return x + 1
|
||||
|
||||
self.assertEqual(torch.ones(2, 2) + 1, fn(torch.ones(2, 2)))
|
||||
|
||||
inp = torch.ones(2, 2)
|
||||
inp.test = None
|
||||
self.assertEqual(torch.ones(2, 2) + 2, fn(inp))
|
||||
|
||||
def test_shape_unpack(self):
|
||||
def fn(x):
|
||||
a, b = x.size()
|
||||
|
@ -333,6 +333,27 @@ class TensorVariable(VariableTracker):
|
||||
tx, [self], {}
|
||||
)
|
||||
|
||||
def call_hasattr(self, tx, name):
|
||||
from . import GetAttrVariable
|
||||
from .builtin import BuiltinVariable
|
||||
|
||||
try:
|
||||
var = BuiltinVariable(getattr).call_function(
|
||||
tx, [self, ConstantVariable(name)], {}
|
||||
)
|
||||
# in the event that TensorVariable returns NotImplemented
|
||||
# BuiltinVariable.call_getattr returns GetAttrVariable
|
||||
ret_val = not isinstance(var, GetAttrVariable)
|
||||
except AttributeError:
|
||||
ret_val = False
|
||||
|
||||
if self.source:
|
||||
install_guard(
|
||||
AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
|
||||
)
|
||||
|
||||
return ConstantVariable(ret_val)
|
||||
|
||||
def var_getattr(self, tx, name):
|
||||
from . import UserDefinedClassVariable
|
||||
|
||||
|
Reference in New Issue
Block a user