mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Support builtin complex with constant args (#160799)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160799 Approved by: https://github.com/guilhermeleobas, https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
66166cf1e7
commit
35e4d97e04
@ -12848,6 +12848,36 @@ fn
|
|||||||
res = opt_f(x)
|
res = opt_f(x)
|
||||||
self.assertEqual(ref, res)
|
self.assertEqual(ref, res)
|
||||||
|
|
||||||
|
def test_builtin_complex(self):
|
||||||
|
def f(x):
|
||||||
|
c = (
|
||||||
|
complex(),
|
||||||
|
complex(1),
|
||||||
|
complex(2, 3),
|
||||||
|
complex(imag=2),
|
||||||
|
complex(real=1),
|
||||||
|
complex(imag=1, real=2),
|
||||||
|
complex("1+2j"),
|
||||||
|
)
|
||||||
|
return [x + z for z in c]
|
||||||
|
|
||||||
|
x = torch.randn(1)
|
||||||
|
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
||||||
|
res = opt_f(x)
|
||||||
|
ref = f(x)
|
||||||
|
self.assertEqual(res, ref)
|
||||||
|
|
||||||
|
def test_builtin_complex_args(self):
|
||||||
|
@torch.compile(backend="eager", fullgraph=True)
|
||||||
|
def f(*args, **kwargs):
|
||||||
|
return torch.tensor(complex(*args, **kwargs))
|
||||||
|
|
||||||
|
self.assertRaises(Unsupported, f, 1, 1, 1)
|
||||||
|
self.assertRaises(Unsupported, f, 1, 1, fake_arg=1)
|
||||||
|
self.assertRaises(Unsupported, f, fake_arg=1)
|
||||||
|
self.assertRaises(Unsupported, f, [])
|
||||||
|
self.assertRaises(Unsupported, f, "1 + j")
|
||||||
|
|
||||||
|
|
||||||
class TestTracer(JitTestCase):
|
class TestTracer(JitTestCase):
|
||||||
def test_jit_save(self):
|
def test_jit_save(self):
|
||||||
|
|||||||
@ -1478,6 +1478,21 @@ class BuiltinVariable(VariableTracker):
|
|||||||
call_int = _call_int_float
|
call_int = _call_int_float
|
||||||
call_float = _call_int_float
|
call_float = _call_int_float
|
||||||
|
|
||||||
|
def call_complex(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||||
|
if self.constant_args(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
c = complex(
|
||||||
|
*(arg.as_python_constant() for arg in args),
|
||||||
|
**{k: kwargs[k].as_python_constant() for k in kwargs},
|
||||||
|
)
|
||||||
|
except (TypeError, ValueError) as exc:
|
||||||
|
raise_observed_exception(
|
||||||
|
type(exc),
|
||||||
|
tx,
|
||||||
|
args=list(map(ConstantVariable.create, exc.args)),
|
||||||
|
)
|
||||||
|
return ConstantVariable(c)
|
||||||
|
|
||||||
def call_bool(self, tx: "InstructionTranslator", arg):
|
def call_bool(self, tx: "InstructionTranslator", arg):
|
||||||
# Emulate `PyBool_Type.tp_vectorcall` which boils down to `PyObject_IsTrue`.
|
# Emulate `PyBool_Type.tp_vectorcall` which boils down to `PyObject_IsTrue`.
|
||||||
# https://github.com/python/cpython/blob/3.12/Objects/object.c#L1674-L1697
|
# https://github.com/python/cpython/blob/3.12/Objects/object.c#L1674-L1697
|
||||||
|
|||||||
Reference in New Issue
Block a user