[dynamo] Support method calls on complex ConstantVariables (#161122)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161122
Approved by: https://github.com/mlazos, https://github.com/guilhermeleobas
This commit is contained in:
Rob Timpe
2025-08-22 17:46:34 +00:00
committed by PyTorch MergeBot
parent 9d882fd9ff
commit 4c36c8a994
8 changed files with 42 additions and 23 deletions

View File

@ -1,8 +1,8 @@
diff --git a/test/dynamo/cpython/3_13/test_complex.py b/test/dynamo/cpython/3_13/test_complex.py diff --git a/test/dynamo/cpython/3_13/test_complex.py b/test/dynamo/cpython/3_13/test_complex.py
index 6ff1a8ab29d..01295e03efc 100644 index 6ff1a8ab29d..1572433c5ae 100644
--- a/test/dynamo/cpython/3_13/test_complex.py --- a/test/dynamo/cpython/3_13/test_complex.py
+++ b/test/dynamo/cpython/3_13/test_complex.py +++ b/test/dynamo/cpython/3_13/test_complex.py
@@ -1,16 +1,146 @@ @@ -1,16 +1,147 @@
+# ======= BEGIN Dynamo patch ======= +# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"] +# Owner(s): ["module: dynamo"]
+ +
@ -19,6 +19,7 @@ index 6ff1a8ab29d..01295e03efc 100644
+from torch._dynamo.test_case import CPythonTestCase +from torch._dynamo.test_case import CPythonTestCase
+from torch.testing._internal.common_utils import ( +from torch.testing._internal.common_utils import (
+ run_tests, + run_tests,
+ slowTest,
+ xfailIfTorchDynamo, + xfailIfTorchDynamo,
+) +)
+ +
@ -154,7 +155,7 @@ index 6ff1a8ab29d..01295e03efc 100644
INF = float("inf") INF = float("inf")
NAN = float("nan") NAN = float("nan")
DBL_MAX = sys.float_info.max DBL_MAX = sys.float_info.max
@@ -45,7 +175,40 @@ class WithComplex: @@ -45,7 +176,40 @@ class WithComplex:
def __complex__(self): def __complex__(self):
return self.value return self.value
@ -196,7 +197,7 @@ index 6ff1a8ab29d..01295e03efc 100644
def assertAlmostEqual(self, a, b): def assertAlmostEqual(self, a, b):
if isinstance(a, complex): if isinstance(a, complex):
@@ -74,6 +237,29 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): @@ -74,6 +238,29 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
# check that relative difference < eps # check that relative difference < eps
self.assertTrue(abs((x-y)/y) < eps) self.assertTrue(abs((x-y)/y) < eps)
@ -226,7 +227,27 @@ index 6ff1a8ab29d..01295e03efc 100644
def assertClose(self, x, y, eps=1e-9): def assertClose(self, x, y, eps=1e-9):
"""Return true iff complexes x and y "are close".""" """Return true iff complexes x and y "are close"."""
self.assertCloseAbs(x.real, y.real, eps) self.assertCloseAbs(x.real, y.real, eps)
@@ -431,12 +617,13 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): @@ -93,6 +280,7 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
q = z.__truediv__(y)
self.assertClose(q, x)
+ @slowTest
def test_truediv(self):
simple_real = [float(i) for i in range(-5, 6)]
simple_complex = [complex(x, y) for x in simple_real for y in simple_real]
@@ -338,7 +526,10 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
def test_boolcontext(self):
for i in range(100):
- self.assertTrue(complex(random() + 1e-6, random() + 1e-6))
+ with torch._dynamo.set_fullgraph(False):
+ r1 = random()
+ r2 = random()
+ self.assertTrue(complex(r1 + 1e-6, r2 + 1e-6))
self.assertTrue(not complex(0.0, 0.0))
self.assertTrue(1j)
@@ -431,12 +622,13 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
self.assertRaises(TypeError, complex, WithComplex(1), object()) self.assertRaises(TypeError, complex, WithComplex(1), object())
self.assertRaises(TypeError, complex, WithComplex(None), object()) self.assertRaises(TypeError, complex, WithComplex(None), object())
@ -245,7 +266,7 @@ index 6ff1a8ab29d..01295e03efc 100644
self.assertRaises(EvilExc, complex, evilcomplex()) self.assertRaises(EvilExc, complex, evilcomplex())
@@ -460,31 +647,33 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): @@ -460,31 +652,33 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
self.assertRaises(TypeError, complex, WithIndex(None), 1.5) self.assertRaises(TypeError, complex, WithIndex(None), 1.5)
self.assertRaises(TypeError, complex, 1.5, WithIndex(None)) self.assertRaises(TypeError, complex, 1.5, WithIndex(None))
@ -299,7 +320,7 @@ index 6ff1a8ab29d..01295e03efc 100644
check(complex(complex0(1j)), 0.0, 42.0) check(complex(complex0(1j)), 0.0, 42.0)
with self.assertWarns(DeprecationWarning): with self.assertWarns(DeprecationWarning):
@@ -855,4 +1044,4 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): @@ -855,4 +1049,4 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -14,6 +14,7 @@ import unittest
from torch._dynamo.test_case import CPythonTestCase from torch._dynamo.test_case import CPythonTestCase
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
run_tests, run_tests,
slowTest,
xfailIfTorchDynamo, xfailIfTorchDynamo,
) )
@ -279,6 +280,7 @@ class ComplexTest(__TestCase):
q = z.__truediv__(y) q = z.__truediv__(y)
self.assertClose(q, x) self.assertClose(q, x)
@slowTest
def test_truediv(self): def test_truediv(self):
simple_real = [float(i) for i in range(-5, 6)] simple_real = [float(i) for i in range(-5, 6)]
simple_complex = [complex(x, y) for x in simple_real for y in simple_real] simple_complex = [complex(x, y) for x in simple_real for y in simple_real]
@ -524,7 +526,10 @@ class ComplexTest(__TestCase):
def test_boolcontext(self): def test_boolcontext(self):
for i in range(100): for i in range(100):
self.assertTrue(complex(random() + 1e-6, random() + 1e-6)) with torch._dynamo.set_fullgraph(False):
r1 = random()
r2 = random()
self.assertTrue(complex(r1 + 1e-6, r2 + 1e-6))
self.assertTrue(not complex(0.0, 0.0)) self.assertTrue(not complex(0.0, 0.0))
self.assertTrue(1j) self.assertTrue(1j)

View File

@ -12858,6 +12858,7 @@ fn
complex(real=1), complex(real=1),
complex(imag=1, real=2), complex(imag=1, real=2),
complex("1+2j"), complex("1+2j"),
complex(1, 2).conjugate(),
) )
return [x + z for z in c] return [x + z for z in c]

View File

@ -308,6 +308,7 @@ class BuiltinVariable(VariableTracker):
bool, bool,
callable, callable,
chr, chr,
complex,
divmod, divmod,
float, float,
getattr, getattr,
@ -1478,21 +1479,6 @@ class BuiltinVariable(VariableTracker):
call_int = _call_int_float call_int = _call_int_float
call_float = _call_int_float call_float = _call_int_float
def call_complex(self, tx: "InstructionTranslator", *args, **kwargs):
if self.constant_args(*args, **kwargs):
try:
c = complex(
*(arg.as_python_constant() for arg in args),
**{k: kwargs[k].as_python_constant() for k in kwargs},
)
except (TypeError, ValueError) as exc:
raise_observed_exception(
type(exc),
tx,
args=list(map(ConstantVariable.create, exc.args)),
)
return ConstantVariable(c)
def call_bool(self, tx: "InstructionTranslator", arg): def call_bool(self, tx: "InstructionTranslator", arg):
# Emulate `PyBool_Type.tp_vectorcall` which boils down to `PyObject_IsTrue`. # Emulate `PyBool_Type.tp_vectorcall` which boils down to `PyObject_IsTrue`.
# https://github.com/python/cpython/blob/3.12/Objects/object.c#L1674-L1697 # https://github.com/python/cpython/blob/3.12/Objects/object.c#L1674-L1697

View File

@ -206,6 +206,12 @@ its type to `common_constant_types`.
elif isinstance(self.value, bytes) and name == "decode": elif isinstance(self.value, bytes) and name == "decode":
method = getattr(self.value, name) method = getattr(self.value, name)
return ConstantVariable.create(method(*const_args, **const_kwargs)) return ConstantVariable.create(method(*const_args, **const_kwargs))
elif type(self.value) is complex and name in complex.__dict__.keys():
method = getattr(self.value, name)
try:
return ConstantVariable.create(method(*const_args, **const_kwargs))
except Exception as e:
raise_observed_exception(type(e), tx)
if name == "__len__" and not (args or kwargs): if name == "__len__" and not (args or kwargs):
return ConstantVariable.create(len(self.value)) return ConstantVariable.create(len(self.value))