mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Support method calls on complex ConstantVariables (#161122)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161122 Approved by: https://github.com/mlazos, https://github.com/guilhermeleobas
This commit is contained in:
committed by
PyTorch MergeBot
parent
9d882fd9ff
commit
4c36c8a994
@ -1,8 +1,8 @@
|
|||||||
diff --git a/test/dynamo/cpython/3_13/test_complex.py b/test/dynamo/cpython/3_13/test_complex.py
|
diff --git a/test/dynamo/cpython/3_13/test_complex.py b/test/dynamo/cpython/3_13/test_complex.py
|
||||||
index 6ff1a8ab29d..01295e03efc 100644
|
index 6ff1a8ab29d..1572433c5ae 100644
|
||||||
--- a/test/dynamo/cpython/3_13/test_complex.py
|
--- a/test/dynamo/cpython/3_13/test_complex.py
|
||||||
+++ b/test/dynamo/cpython/3_13/test_complex.py
|
+++ b/test/dynamo/cpython/3_13/test_complex.py
|
||||||
@@ -1,16 +1,146 @@
|
@@ -1,16 +1,147 @@
|
||||||
+# ======= BEGIN Dynamo patch =======
|
+# ======= BEGIN Dynamo patch =======
|
||||||
+# Owner(s): ["module: dynamo"]
|
+# Owner(s): ["module: dynamo"]
|
||||||
+
|
+
|
||||||
@ -19,6 +19,7 @@ index 6ff1a8ab29d..01295e03efc 100644
|
|||||||
+from torch._dynamo.test_case import CPythonTestCase
|
+from torch._dynamo.test_case import CPythonTestCase
|
||||||
+from torch.testing._internal.common_utils import (
|
+from torch.testing._internal.common_utils import (
|
||||||
+ run_tests,
|
+ run_tests,
|
||||||
|
+ slowTest,
|
||||||
+ xfailIfTorchDynamo,
|
+ xfailIfTorchDynamo,
|
||||||
+)
|
+)
|
||||||
+
|
+
|
||||||
@ -154,7 +155,7 @@ index 6ff1a8ab29d..01295e03efc 100644
|
|||||||
INF = float("inf")
|
INF = float("inf")
|
||||||
NAN = float("nan")
|
NAN = float("nan")
|
||||||
DBL_MAX = sys.float_info.max
|
DBL_MAX = sys.float_info.max
|
||||||
@@ -45,7 +175,40 @@ class WithComplex:
|
@@ -45,7 +176,40 @@ class WithComplex:
|
||||||
def __complex__(self):
|
def __complex__(self):
|
||||||
return self.value
|
return self.value
|
||||||
|
|
||||||
@ -196,7 +197,7 @@ index 6ff1a8ab29d..01295e03efc 100644
|
|||||||
|
|
||||||
def assertAlmostEqual(self, a, b):
|
def assertAlmostEqual(self, a, b):
|
||||||
if isinstance(a, complex):
|
if isinstance(a, complex):
|
||||||
@@ -74,6 +237,29 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
|
@@ -74,6 +238,29 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
|
||||||
# check that relative difference < eps
|
# check that relative difference < eps
|
||||||
self.assertTrue(abs((x-y)/y) < eps)
|
self.assertTrue(abs((x-y)/y) < eps)
|
||||||
|
|
||||||
@ -226,7 +227,27 @@ index 6ff1a8ab29d..01295e03efc 100644
|
|||||||
def assertClose(self, x, y, eps=1e-9):
|
def assertClose(self, x, y, eps=1e-9):
|
||||||
"""Return true iff complexes x and y "are close"."""
|
"""Return true iff complexes x and y "are close"."""
|
||||||
self.assertCloseAbs(x.real, y.real, eps)
|
self.assertCloseAbs(x.real, y.real, eps)
|
||||||
@@ -431,12 +617,13 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
|
@@ -93,6 +280,7 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
|
||||||
|
q = z.__truediv__(y)
|
||||||
|
self.assertClose(q, x)
|
||||||
|
|
||||||
|
+ @slowTest
|
||||||
|
def test_truediv(self):
|
||||||
|
simple_real = [float(i) for i in range(-5, 6)]
|
||||||
|
simple_complex = [complex(x, y) for x in simple_real for y in simple_real]
|
||||||
|
@@ -338,7 +526,10 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
|
||||||
|
|
||||||
|
def test_boolcontext(self):
|
||||||
|
for i in range(100):
|
||||||
|
- self.assertTrue(complex(random() + 1e-6, random() + 1e-6))
|
||||||
|
+ with torch._dynamo.set_fullgraph(False):
|
||||||
|
+ r1 = random()
|
||||||
|
+ r2 = random()
|
||||||
|
+ self.assertTrue(complex(r1 + 1e-6, r2 + 1e-6))
|
||||||
|
self.assertTrue(not complex(0.0, 0.0))
|
||||||
|
self.assertTrue(1j)
|
||||||
|
|
||||||
|
@@ -431,12 +622,13 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
|
||||||
self.assertRaises(TypeError, complex, WithComplex(1), object())
|
self.assertRaises(TypeError, complex, WithComplex(1), object())
|
||||||
self.assertRaises(TypeError, complex, WithComplex(None), object())
|
self.assertRaises(TypeError, complex, WithComplex(None), object())
|
||||||
|
|
||||||
@ -245,7 +266,7 @@ index 6ff1a8ab29d..01295e03efc 100644
|
|||||||
|
|
||||||
self.assertRaises(EvilExc, complex, evilcomplex())
|
self.assertRaises(EvilExc, complex, evilcomplex())
|
||||||
|
|
||||||
@@ -460,31 +647,33 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
|
@@ -460,31 +652,33 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
|
||||||
self.assertRaises(TypeError, complex, WithIndex(None), 1.5)
|
self.assertRaises(TypeError, complex, WithIndex(None), 1.5)
|
||||||
self.assertRaises(TypeError, complex, 1.5, WithIndex(None))
|
self.assertRaises(TypeError, complex, 1.5, WithIndex(None))
|
||||||
|
|
||||||
@ -299,7 +320,7 @@ index 6ff1a8ab29d..01295e03efc 100644
|
|||||||
|
|
||||||
check(complex(complex0(1j)), 0.0, 42.0)
|
check(complex(complex0(1j)), 0.0, 42.0)
|
||||||
with self.assertWarns(DeprecationWarning):
|
with self.assertWarns(DeprecationWarning):
|
||||||
@@ -855,4 +1044,4 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
|
@@ -855,4 +1049,4 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -14,6 +14,7 @@ import unittest
|
|||||||
from torch._dynamo.test_case import CPythonTestCase
|
from torch._dynamo.test_case import CPythonTestCase
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
run_tests,
|
run_tests,
|
||||||
|
slowTest,
|
||||||
xfailIfTorchDynamo,
|
xfailIfTorchDynamo,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -279,6 +280,7 @@ class ComplexTest(__TestCase):
|
|||||||
q = z.__truediv__(y)
|
q = z.__truediv__(y)
|
||||||
self.assertClose(q, x)
|
self.assertClose(q, x)
|
||||||
|
|
||||||
|
@slowTest
|
||||||
def test_truediv(self):
|
def test_truediv(self):
|
||||||
simple_real = [float(i) for i in range(-5, 6)]
|
simple_real = [float(i) for i in range(-5, 6)]
|
||||||
simple_complex = [complex(x, y) for x in simple_real for y in simple_real]
|
simple_complex = [complex(x, y) for x in simple_real for y in simple_real]
|
||||||
@ -524,7 +526,10 @@ class ComplexTest(__TestCase):
|
|||||||
|
|
||||||
def test_boolcontext(self):
|
def test_boolcontext(self):
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
self.assertTrue(complex(random() + 1e-6, random() + 1e-6))
|
with torch._dynamo.set_fullgraph(False):
|
||||||
|
r1 = random()
|
||||||
|
r2 = random()
|
||||||
|
self.assertTrue(complex(r1 + 1e-6, r2 + 1e-6))
|
||||||
self.assertTrue(not complex(0.0, 0.0))
|
self.assertTrue(not complex(0.0, 0.0))
|
||||||
self.assertTrue(1j)
|
self.assertTrue(1j)
|
||||||
|
|
||||||
|
@ -12858,6 +12858,7 @@ fn
|
|||||||
complex(real=1),
|
complex(real=1),
|
||||||
complex(imag=1, real=2),
|
complex(imag=1, real=2),
|
||||||
complex("1+2j"),
|
complex("1+2j"),
|
||||||
|
complex(1, 2).conjugate(),
|
||||||
)
|
)
|
||||||
return [x + z for z in c]
|
return [x + z for z in c]
|
||||||
|
|
||||||
|
@ -308,6 +308,7 @@ class BuiltinVariable(VariableTracker):
|
|||||||
bool,
|
bool,
|
||||||
callable,
|
callable,
|
||||||
chr,
|
chr,
|
||||||
|
complex,
|
||||||
divmod,
|
divmod,
|
||||||
float,
|
float,
|
||||||
getattr,
|
getattr,
|
||||||
@ -1478,21 +1479,6 @@ class BuiltinVariable(VariableTracker):
|
|||||||
call_int = _call_int_float
|
call_int = _call_int_float
|
||||||
call_float = _call_int_float
|
call_float = _call_int_float
|
||||||
|
|
||||||
def call_complex(self, tx: "InstructionTranslator", *args, **kwargs):
|
|
||||||
if self.constant_args(*args, **kwargs):
|
|
||||||
try:
|
|
||||||
c = complex(
|
|
||||||
*(arg.as_python_constant() for arg in args),
|
|
||||||
**{k: kwargs[k].as_python_constant() for k in kwargs},
|
|
||||||
)
|
|
||||||
except (TypeError, ValueError) as exc:
|
|
||||||
raise_observed_exception(
|
|
||||||
type(exc),
|
|
||||||
tx,
|
|
||||||
args=list(map(ConstantVariable.create, exc.args)),
|
|
||||||
)
|
|
||||||
return ConstantVariable(c)
|
|
||||||
|
|
||||||
def call_bool(self, tx: "InstructionTranslator", arg):
|
def call_bool(self, tx: "InstructionTranslator", arg):
|
||||||
# Emulate `PyBool_Type.tp_vectorcall` which boils down to `PyObject_IsTrue`.
|
# Emulate `PyBool_Type.tp_vectorcall` which boils down to `PyObject_IsTrue`.
|
||||||
# https://github.com/python/cpython/blob/3.12/Objects/object.c#L1674-L1697
|
# https://github.com/python/cpython/blob/3.12/Objects/object.c#L1674-L1697
|
||||||
|
@ -206,6 +206,12 @@ its type to `common_constant_types`.
|
|||||||
elif isinstance(self.value, bytes) and name == "decode":
|
elif isinstance(self.value, bytes) and name == "decode":
|
||||||
method = getattr(self.value, name)
|
method = getattr(self.value, name)
|
||||||
return ConstantVariable.create(method(*const_args, **const_kwargs))
|
return ConstantVariable.create(method(*const_args, **const_kwargs))
|
||||||
|
elif type(self.value) is complex and name in complex.__dict__.keys():
|
||||||
|
method = getattr(self.value, name)
|
||||||
|
try:
|
||||||
|
return ConstantVariable.create(method(*const_args, **const_kwargs))
|
||||||
|
except Exception as e:
|
||||||
|
raise_observed_exception(type(e), tx)
|
||||||
|
|
||||||
if name == "__len__" and not (args or kwargs):
|
if name == "__len__" and not (args or kwargs):
|
||||||
return ConstantVariable.create(len(self.value))
|
return ConstantVariable.create(len(self.value))
|
||||||
|
Reference in New Issue
Block a user