Compare commits

...

2 Commits

Author SHA1 Message Date
f3825bb983 Add hasattr for tensor variable
Enables tests that are now running sucessfully under dynamo

ghstack-source-id: 3fcdb33d957c30b3e740e1d83c148c1572279360
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131008
2024-07-18 14:42:52 -07:00
ebfe1dcc80 Graph break on tostring for numpy remapping
ghstack-source-id: ce1ca5fdb92e5e07d5fe52aaa44117ef9df77d31
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131007
2024-07-17 19:16:14 -07:00
5 changed files with 37 additions and 2 deletions

View File

@ -1364,6 +1364,20 @@ utils_device.CURRENT_DEVICE == None""".split(
r2 = opt_fn(i)
self.assertEqual(r1, r2)
def test_tensor_hasattr(self):
@torch.compile(fullgraph=True)
def fn(x):
if hasattr(x, "test"):
return x + 2
else:
return x + 1
self.assertEqual(torch.ones(2, 2) + 1, fn(torch.ones(2, 2)))
inp = torch.ones(2, 2)
inp.test = None
self.assertEqual(torch.ones(2, 2) + 2, fn(inp))
def test_shape_unpack(self):
def fn(x):
a, b = x.size()

View File

@ -333,6 +333,27 @@ class TensorVariable(VariableTracker):
tx, [self], {}
)
def call_hasattr(self, tx, name):
from . import GetAttrVariable
from .builtin import BuiltinVariable
try:
var = BuiltinVariable(getattr).call_function(
tx, [self, ConstantVariable(name)], {}
)
# in the event that TensorVariable returns NotImplemented
# BuiltinVariable.call_getattr returns GetAttrVariable
ret_val = not isinstance(var, GetAttrVariable)
except AttributeError:
ret_val = False
if self.source:
install_guard(
AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
)
return ConstantVariable(ret_val)
def var_getattr(self, tx, name):
from . import UserDefinedClassVariable
@ -1171,8 +1192,8 @@ class NumpyNdarrayVariable(TensorVariable):
if name in ["__len__", "size", "tolist"]:
# delegate back to TensorVariable
return super().call_method(tx, name, args, kwargs)
if name == "tobytes":
unimplemented("tobytes is not modelled in torch._numpy")
if name in ("tostring", "tobytes"):
unimplemented(f"{name} is not modelled in torch._numpy")
proxy = tx.output.create_proxy(
"call_function",
numpy_method_wrapper(name),