From 4c36c8a99463c898190a462300ba7f05b5b3384e Mon Sep 17 00:00:00 2001 From: Rob Timpe Date: Fri, 22 Aug 2025 17:46:34 +0000 Subject: [PATCH] [dynamo] Support method calls on complex ConstantVariables (#161122) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161122 Approved by: https://github.com/mlazos, https://github.com/guilhermeleobas --- test/dynamo/cpython/3_13/test_complex.diff | 35 +++++++++++++++---- test/dynamo/cpython/3_13/test_complex.py | 7 +++- test/dynamo/test_misc.py | 1 + ...-test_complex-ComplexTest.test_boolcontext | 0 ...13-test_complex-ComplexTest.test_conjugate | 0 ...3-test_complex-ComplexTest.test_getnewargs | 0 torch/_dynamo/variables/builtin.py | 16 +-------- torch/_dynamo/variables/constant.py | 6 ++++ 8 files changed, 42 insertions(+), 23 deletions(-) delete mode 100644 test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_boolcontext delete mode 100644 test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_conjugate delete mode 100644 test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_getnewargs diff --git a/test/dynamo/cpython/3_13/test_complex.diff b/test/dynamo/cpython/3_13/test_complex.diff index feca8fcc9b04..063b9131056e 100644 --- a/test/dynamo/cpython/3_13/test_complex.diff +++ b/test/dynamo/cpython/3_13/test_complex.diff @@ -1,8 +1,8 @@ diff --git a/test/dynamo/cpython/3_13/test_complex.py b/test/dynamo/cpython/3_13/test_complex.py -index 6ff1a8ab29d..01295e03efc 100644 +index 6ff1a8ab29d..1572433c5ae 100644 --- a/test/dynamo/cpython/3_13/test_complex.py +++ b/test/dynamo/cpython/3_13/test_complex.py -@@ -1,16 +1,146 @@ +@@ -1,16 +1,147 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + @@ -19,6 +19,7 @@ index 6ff1a8ab29d..01295e03efc 100644 +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import ( + run_tests, ++ slowTest, + xfailIfTorchDynamo, +) + @@ -154,7 +155,7 @@ index 6ff1a8ab29d..01295e03efc 100644 INF = float("inf") NAN = float("nan") DBL_MAX = sys.float_info.max -@@ -45,7 +175,40 @@ class WithComplex: +@@ -45,7 +176,40 @@ class WithComplex: def __complex__(self): return self.value @@ -196,7 +197,7 @@ index 6ff1a8ab29d..01295e03efc 100644 def assertAlmostEqual(self, a, b): if isinstance(a, complex): -@@ -74,6 +237,29 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): +@@ -74,6 +238,29 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): # check that relative difference < eps self.assertTrue(abs((x-y)/y) < eps) @@ -226,7 +227,27 @@ index 6ff1a8ab29d..01295e03efc 100644 def assertClose(self, x, y, eps=1e-9): """Return true iff complexes x and y "are close".""" self.assertCloseAbs(x.real, y.real, eps) -@@ -431,12 +617,13 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): +@@ -93,6 +280,7 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): + q = z.__truediv__(y) + self.assertClose(q, x) + ++ @slowTest + def test_truediv(self): + simple_real = [float(i) for i in range(-5, 6)] + simple_complex = [complex(x, y) for x in simple_real for y in simple_real] +@@ -338,7 +526,10 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): + + def test_boolcontext(self): + for i in range(100): +- self.assertTrue(complex(random() + 1e-6, random() + 1e-6)) ++ with torch._dynamo.set_fullgraph(False): ++ r1 = random() ++ r2 = random() ++ self.assertTrue(complex(r1 + 1e-6, r2 + 1e-6)) + self.assertTrue(not complex(0.0, 0.0)) + self.assertTrue(1j) + +@@ -431,12 +622,13 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): self.assertRaises(TypeError, complex, WithComplex(1), object()) self.assertRaises(TypeError, complex, WithComplex(None), object()) @@ -245,7 +266,7 @@ index 6ff1a8ab29d..01295e03efc 100644 self.assertRaises(EvilExc, complex, evilcomplex()) -@@ -460,31 +647,33 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): +@@ -460,31 +652,33 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): self.assertRaises(TypeError, complex, WithIndex(None), 1.5) self.assertRaises(TypeError, complex, 1.5, WithIndex(None)) @@ -299,7 +320,7 @@ index 6ff1a8ab29d..01295e03efc 100644 check(complex(complex0(1j)), 0.0, 42.0) with self.assertWarns(DeprecationWarning): -@@ -855,4 +1044,4 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): +@@ -855,4 +1049,4 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): if __name__ == "__main__": diff --git a/test/dynamo/cpython/3_13/test_complex.py b/test/dynamo/cpython/3_13/test_complex.py index 01295e03efc0..1572433c5aef 100644 --- a/test/dynamo/cpython/3_13/test_complex.py +++ b/test/dynamo/cpython/3_13/test_complex.py @@ -14,6 +14,7 @@ import unittest from torch._dynamo.test_case import CPythonTestCase from torch.testing._internal.common_utils import ( run_tests, + slowTest, xfailIfTorchDynamo, ) @@ -279,6 +280,7 @@ class ComplexTest(__TestCase): q = z.__truediv__(y) self.assertClose(q, x) + @slowTest def test_truediv(self): simple_real = [float(i) for i in range(-5, 6)] simple_complex = [complex(x, y) for x in simple_real for y in simple_real] @@ -524,7 +526,10 @@ class ComplexTest(__TestCase): def test_boolcontext(self): for i in range(100): - self.assertTrue(complex(random() + 1e-6, random() + 1e-6)) + with torch._dynamo.set_fullgraph(False): + r1 = random() + r2 = random() + self.assertTrue(complex(r1 + 1e-6, r2 + 1e-6)) self.assertTrue(not complex(0.0, 0.0)) self.assertTrue(1j) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 57983cea8e02..e86947aa2c10 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -12858,6 +12858,7 @@ fn complex(real=1), complex(imag=1, real=2), complex("1+2j"), + complex(1, 2).conjugate(), ) return [x + z for z in c] diff --git a/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_boolcontext b/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_boolcontext deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_conjugate b/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_conjugate deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_getnewargs b/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_getnewargs deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 51c9f2941ceb..74f8864479d4 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -308,6 +308,7 @@ class BuiltinVariable(VariableTracker): bool, callable, chr, + complex, divmod, float, getattr, @@ -1478,21 +1479,6 @@ class BuiltinVariable(VariableTracker): call_int = _call_int_float call_float = _call_int_float - def call_complex(self, tx: "InstructionTranslator", *args, **kwargs): - if self.constant_args(*args, **kwargs): - try: - c = complex( - *(arg.as_python_constant() for arg in args), - **{k: kwargs[k].as_python_constant() for k in kwargs}, - ) - except (TypeError, ValueError) as exc: - raise_observed_exception( - type(exc), - tx, - args=list(map(ConstantVariable.create, exc.args)), - ) - return ConstantVariable(c) - def call_bool(self, tx: "InstructionTranslator", arg): # Emulate `PyBool_Type.tp_vectorcall` which boils down to `PyObject_IsTrue`. # https://github.com/python/cpython/blob/3.12/Objects/object.c#L1674-L1697 diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 998bef52da4c..90cbb08f5fc8 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -206,6 +206,12 @@ its type to `common_constant_types`. elif isinstance(self.value, bytes) and name == "decode": method = getattr(self.value, name) return ConstantVariable.create(method(*const_args, **const_kwargs)) + elif type(self.value) is complex and name in complex.__dict__.keys(): + method = getattr(self.value, name) + try: + return ConstantVariable.create(method(*const_args, **const_kwargs)) + except Exception as e: + raise_observed_exception(type(e), tx) if name == "__len__" and not (args or kwargs): return ConstantVariable.create(len(self.value))