[dynamo] Support builtin complex with constant args (#160799)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160799
Approved by: https://github.com/guilhermeleobas, https://github.com/mlazos
This commit is contained in:
Rob Timpe
2025-08-18 21:30:42 +00:00
committed by PyTorch MergeBot
parent 66166cf1e7
commit 35e4d97e04
16 changed files with 45 additions and 0 deletions

View File

@ -12848,6 +12848,36 @@ fn
res = opt_f(x) res = opt_f(x)
self.assertEqual(ref, res) self.assertEqual(ref, res)
def test_builtin_complex(self):
def f(x):
c = (
complex(),
complex(1),
complex(2, 3),
complex(imag=2),
complex(real=1),
complex(imag=1, real=2),
complex("1+2j"),
)
return [x + z for z in c]
x = torch.randn(1)
opt_f = torch.compile(f, backend="eager", fullgraph=True)
res = opt_f(x)
ref = f(x)
self.assertEqual(res, ref)
def test_builtin_complex_args(self):
@torch.compile(backend="eager", fullgraph=True)
def f(*args, **kwargs):
return torch.tensor(complex(*args, **kwargs))
self.assertRaises(Unsupported, f, 1, 1, 1)
self.assertRaises(Unsupported, f, 1, 1, fake_arg=1)
self.assertRaises(Unsupported, f, fake_arg=1)
self.assertRaises(Unsupported, f, [])
self.assertRaises(Unsupported, f, "1 + j")
class TestTracer(JitTestCase): class TestTracer(JitTestCase):
def test_jit_save(self): def test_jit_save(self):

View File

@ -1478,6 +1478,21 @@ class BuiltinVariable(VariableTracker):
call_int = _call_int_float call_int = _call_int_float
call_float = _call_int_float call_float = _call_int_float
def call_complex(self, tx: "InstructionTranslator", *args, **kwargs):
if self.constant_args(*args, **kwargs):
try:
c = complex(
*(arg.as_python_constant() for arg in args),
**{k: kwargs[k].as_python_constant() for k in kwargs},
)
except (TypeError, ValueError) as exc:
raise_observed_exception(
type(exc),
tx,
args=list(map(ConstantVariable.create, exc.args)),
)
return ConstantVariable(c)
def call_bool(self, tx: "InstructionTranslator", arg): def call_bool(self, tx: "InstructionTranslator", arg):
# Emulate `PyBool_Type.tp_vectorcall` which boils down to `PyObject_IsTrue`. # Emulate `PyBool_Type.tp_vectorcall` which boils down to `PyObject_IsTrue`.
# https://github.com/python/cpython/blob/3.12/Objects/object.c#L1674-L1697 # https://github.com/python/cpython/blob/3.12/Objects/object.c#L1674-L1697