Files
pytorch/test/dynamo/cpython/3_13/test_math.diff
William Wen 8678d831c4 [dynamo] rename set_fullgraph to error_on_graph_break (#161739)
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph.

I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet).

 cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739
Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
2025-09-04 01:15:06 +00:00

414 lines
14 KiB
Diff

diff --git a/test/dynamo/cpython/3_13/test_math.py b/test/dynamo/cpython/3_13/test_math.py
index 5ee3055c871..5402cdc4a6c 100644
--- a/test/dynamo/cpython/3_13/test_math.py
+++ b/test/dynamo/cpython/3_13/test_math.py
@@ -1,3 +1,61 @@
+# ======= BEGIN Dynamo patch =======
+# Owner(s): ["module: dynamo"]
+
+# ruff: noqa
+# flake8: noqa
+
+# Test copied from
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_math.py
+
+import sys
+import torch
+import torch._dynamo.test_case
+import unittest
+from torch._dynamo.test_case import CPythonTestCase
+from torch.testing._internal.common_utils import (
+ slowTest,
+ run_tests,
+ skipIfTorchDynamo,
+)
+
+__TestCase = CPythonTestCase
+
+
+# redirect import statements
+import sys
+import importlib.abc
+
+redirect_imports = (
+ "test.mapping_tests",
+ "test.typinganndata",
+ "test.test_grammar",
+ "test.test_math",
+ "test.test_iter",
+ "test.typinganndata.ann_module",
+)
+
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
+ def find_spec(self, fullname, path, target=None):
+ # Check if the import is the problematic one
+ if fullname in redirect_imports:
+ try:
+ # Attempt to import the standalone module
+ name = fullname.removeprefix("test.")
+ r = importlib.import_module(name)
+ # Redirect the module in sys.modules
+ sys.modules[fullname] = r
+ # Return a module spec from the found module
+ return importlib.util.find_spec(name)
+ except ImportError:
+ return None
+ return None
+
+# Add the custom finder to sys.meta_path
+sys.meta_path.insert(0, RedirectImportFinder())
+
+
+# ======= END DYNAMO PATCH =======
+
# Python test set -- math module
# XXXX Should not do tests around zero only
@@ -242,7 +300,7 @@ class BadDescr:
def __get__(self, obj, objtype=None):
raise ValueError
-class MathTests(unittest.TestCase):
+class MathTests(__TestCase):
def ftest(self, name, got, expected, ulp_tol=5, abs_tol=0.0):
"""Compare arguments expected and got, as floats, if either
@@ -417,16 +475,17 @@ class MathTests(unittest.TestCase):
#self.assertEqual(math.ceil(NINF), NINF)
#self.assertTrue(math.isnan(math.ceil(NAN)))
- class TestCeil:
- def __ceil__(self):
- return 42
- class FloatCeil(float):
- def __ceil__(self):
- return 42
- class TestNoCeil:
- pass
- class TestBadCeil:
- __ceil__ = BadDescr()
+ with torch._dynamo.error_on_graph_break(False):
+ class TestCeil:
+ def __ceil__(self):
+ return 42
+ class FloatCeil(float):
+ def __ceil__(self):
+ return 42
+ class TestNoCeil:
+ pass
+ class TestBadCeil:
+ __ceil__ = BadDescr()
self.assertEqual(math.ceil(TestCeil()), 42)
self.assertEqual(math.ceil(FloatCeil()), 42)
self.assertEqual(math.ceil(FloatLike(42.5)), 43)
@@ -533,6 +592,7 @@ class MathTests(unittest.TestCase):
self.ftest('fabs(0)', math.fabs(0), 0)
self.ftest('fabs(1)', math.fabs(1), 1)
+ @skipIfTorchDynamo("infinite loop")
def testFactorial(self):
self.assertEqual(math.factorial(0), 1)
total = 1
@@ -573,16 +633,17 @@ class MathTests(unittest.TestCase):
#self.assertEqual(math.ceil(NINF), NINF)
#self.assertTrue(math.isnan(math.floor(NAN)))
- class TestFloor:
- def __floor__(self):
- return 42
- class FloatFloor(float):
- def __floor__(self):
- return 42
- class TestNoFloor:
- pass
- class TestBadFloor:
- __floor__ = BadDescr()
+ with torch._dynamo.error_on_graph_break(False):
+ class TestFloor:
+ def __floor__(self):
+ return 42
+ class FloatFloor(float):
+ def __floor__(self):
+ return 42
+ class TestNoFloor:
+ pass
+ class TestBadFloor:
+ __floor__ = BadDescr()
self.assertEqual(math.floor(TestFloor()), 42)
self.assertEqual(math.floor(FloatFloor()), 42)
self.assertEqual(math.floor(FloatLike(41.9)), 41)
@@ -995,8 +1056,9 @@ class MathTests(unittest.TestCase):
)
# Verify tuple subclasses are allowed
- class T(tuple):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class T(tuple):
+ pass
self.assertEqual(dist(T((1, 2, 3)), ((4, 2, -1))), 5.0)
# Test handling of bad arguments
@@ -1028,8 +1090,9 @@ class MathTests(unittest.TestCase):
with self.assertRaises(TypeError):
dist([1], 2)
- class BadFloat:
- __float__ = BadDescr()
+ with torch._dynamo.error_on_graph_break(False):
+ class BadFloat:
+ __float__ = BadDescr()
with self.assertRaises(ValueError):
dist([1], [BadFloat()])
@@ -1072,6 +1135,7 @@ class MathTests(unittest.TestCase):
with self.assertRaises(ValueError):
math.dist([1, 2], [3, 4, 5])
+ @slowTest
def testIsqrt(self):
# Test a variety of inputs, large and small.
test_values = (
@@ -1101,12 +1165,13 @@ class MathTests(unittest.TestCase):
self.assertIs(type(s), int)
self.assertEqual(s, 0)
- class IntegerLike(object):
- def __init__(self, value):
- self.value = value
+ with torch._dynamo.error_on_graph_break(False):
+ class IntegerLike(object):
+ def __init__(self, value):
+ self.value = value
- def __index__(self):
- return self.value
+ def __index__(self):
+ return self.value
s = math.isqrt(IntegerLike(1729))
self.assertIs(type(s), int)
@@ -1202,12 +1267,6 @@ class MathTests(unittest.TestCase):
self.assertEqual(math.ldexp(NINF, n), NINF)
self.assertTrue(math.isnan(math.ldexp(NAN, n)))
- @requires_IEEE_754
- def testLdexp_denormal(self):
- # Denormal output incorrectly rounded (truncated)
- # on some Windows.
- self.assertEqual(math.ldexp(6993274598585239, -1126), 1e-323)
-
def testLog(self):
self.assertRaises(TypeError, math.log)
self.assertRaises(TypeError, math.log, 1, 2, 3)
@@ -1233,6 +1292,7 @@ class MathTests(unittest.TestCase):
self.assertRaises(ValueError, math.log1p, -1)
self.assertEqual(math.log1p(INF), INF)
+ @skipIfTorchDynamo("Infinite loop")
@requires_IEEE_754
def testLog2(self):
self.assertRaises(TypeError, math.log2)
@@ -1251,6 +1311,7 @@ class MathTests(unittest.TestCase):
self.assertRaises(ValueError, math.log2, NINF)
self.assertTrue(math.isnan(math.log2(NAN)))
+ @skipIfTorchDynamo("Infinite loop")
@requires_IEEE_754
# log2() is not accurate enough on Mac OS X Tiger (10.4)
@support.requires_mac_ver(10, 5)
@@ -1332,17 +1393,18 @@ class MathTests(unittest.TestCase):
with self.assertRaises(RuntimeError):
sumprod(raise_after(5), range(10))
- from test.test_iter import BasicIterClass
+ from test_iter import BasicIterClass
self.assertEqual(sumprod(BasicIterClass(1), [1]), 0)
self.assertEqual(sumprod([1], BasicIterClass(1)), 0)
# Error in multiplication
- class BadMultiply:
- def __mul__(self, other):
- raise RuntimeError
- def __rmul__(self, other):
- raise RuntimeError
+ with torch._dynamo.error_on_graph_break(False):
+ class BadMultiply:
+ def __mul__(self, other):
+ raise RuntimeError
+ def __rmul__(self, other):
+ raise RuntimeError
with self.assertRaises(RuntimeError):
sumprod([10, BadMultiply(), 30], [1, 2, 3])
with self.assertRaises(RuntimeError):
@@ -1387,25 +1449,26 @@ class MathTests(unittest.TestCase):
Decimal = decimal.Decimal
Fraction = fractions.Fraction
- class Int(int):
- def __add__(self, other):
- return Int(int(self) + int(other))
- def __mul__(self, other):
- return Int(int(self) * int(other))
- __radd__ = __add__
- __rmul__ = __mul__
- def __repr__(self):
- return f'Int({int(self)})'
-
- class Flt(float):
- def __add__(self, other):
- return Int(int(self) + int(other))
- def __mul__(self, other):
- return Int(int(self) * int(other))
- __radd__ = __add__
- __rmul__ = __mul__
- def __repr__(self):
- return f'Flt({int(self)})'
+ with torch._dynamo.error_on_graph_break(False):
+ class Int(int):
+ def __add__(self, other):
+ return Int(int(self) + int(other))
+ def __mul__(self, other):
+ return Int(int(self) * int(other))
+ __radd__ = __add__
+ __rmul__ = __mul__
+ def __repr__(self):
+ return f'Int({int(self)})'
+
+ class Flt(float):
+ def __add__(self, other):
+ return Int(int(self) + int(other))
+ def __mul__(self, other):
+ return Int(int(self) * int(other))
+ __radd__ = __add__
+ __rmul__ = __mul__
+ def __repr__(self):
+ return f'Flt({int(self)})'
def baseline_sumprod(p, q):
"""This defines the target behavior including exceptions and special values.
@@ -1925,16 +1988,17 @@ class MathTests(unittest.TestCase):
self.assertEqual(math.trunc(-0.999999), -0)
self.assertEqual(math.trunc(-100.999), -100)
- class TestTrunc:
- def __trunc__(self):
- return 23
- class FloatTrunc(float):
- def __trunc__(self):
- return 23
- class TestNoTrunc:
- pass
- class TestBadTrunc:
- __trunc__ = BadDescr()
+ with torch._dynamo.error_on_graph_break(False):
+ class TestTrunc:
+ def __trunc__(self):
+ return 23
+ class FloatTrunc(float):
+ def __trunc__(self):
+ return 23
+ class TestNoTrunc:
+ pass
+ class TestBadTrunc:
+ __trunc__ = BadDescr()
self.assertEqual(math.trunc(TestTrunc()), 23)
self.assertEqual(math.trunc(FloatTrunc()), 23)
@@ -2167,9 +2231,10 @@ class MathTests(unittest.TestCase):
self.assertEqual(prod([1., F(3, 2)]), 1.5)
# Error in multiplication
- class BadMultiply:
- def __rmul__(self, other):
- raise RuntimeError
+ with torch._dynamo.error_on_graph_break(False):
+ class BadMultiply:
+ def __rmul__(self, other):
+ raise RuntimeError
with self.assertRaises(RuntimeError):
prod([10., BadMultiply()])
@@ -2252,6 +2317,7 @@ class MathTests(unittest.TestCase):
self.assertEqual(type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])),
decimal.Decimal)
+ @skipIfTorchDynamo("Infinite loop")
def testPerm(self):
perm = math.perm
factorial = math.factorial
@@ -2316,6 +2382,7 @@ class MathTests(unittest.TestCase):
self.assertIs(type(perm(IntSubclass(5), IntSubclass(k))), int)
self.assertIs(type(perm(MyIndexable(5), MyIndexable(k))), int)
+ @skipIfTorchDynamo("infinite loop")
def testComb(self):
comb = math.comb
factorial = math.factorial
@@ -2446,6 +2513,7 @@ class MathTests(unittest.TestCase):
math.nextafter(1.0, INF, steps=-1)
+ @unittest.skip("flaky test under torch dynamo") # works on pytest and crashes on unittest
@requires_IEEE_754
def test_ulp(self):
self.assertEqual(math.ulp(1.0), sys.float_info.epsilon)
@@ -2472,10 +2540,11 @@ class MathTests(unittest.TestCase):
def test_issue39871(self):
# A SystemError should not be raised if the first arg to atan2(),
# copysign(), or remainder() cannot be converted to a float.
- class F:
- def __float__(self):
- self.converted = True
- 1/0
+ with torch._dynamo.error_on_graph_break(False):
+ class F:
+ def __float__(self):
+ self.converted = True
+ 1/0
for func in math.atan2, math.copysign, math.remainder:
y = F()
with self.assertRaises(TypeError):
@@ -2508,7 +2577,7 @@ class MathTests(unittest.TestCase):
self.assertEqual(math.copysign(1.0, x), math.copysign(1.0, y))
-class IsCloseTests(unittest.TestCase):
+class IsCloseTests(__TestCase):
isclose = math.isclose # subclasses should override this
def assertIsClose(self, a, b, *args, **kwargs):
@@ -2631,7 +2700,7 @@ class IsCloseTests(unittest.TestCase):
self.assertAllNotClose(fraction_examples, rel_tol=1e-9)
-class FMATests(unittest.TestCase):
+class FMATests(__TestCase):
""" Tests for math.fma. """
def test_fma_nan_results(self):
@@ -2719,8 +2788,7 @@ class FMATests(unittest.TestCase):
# properly: it doesn't use the right sign when the result is zero.
@unittest.skipIf(
sys.platform.startswith(("freebsd", "wasi", "netbsd", "emscripten"))
- or (sys.platform == "android" and platform.machine() == "x86_64")
- or support.linked_to_musl(), # gh-131032
+ or (sys.platform == "android" and platform.machine() == "x86_64"),
f"this platform doesn't implement IEE 754-2008 properly")
def test_fma_zero_result(self):
nonnegative_finites = [0.0, 1e-300, 2.3, 1e300]
@@ -2879,10 +2947,5 @@ class FMATests(unittest.TestCase):
)
-def load_tests(loader, tests, pattern):
- from doctest import DocFileSuite
- tests.addTest(DocFileSuite(os.path.join("mathdata", "ieee754.txt")))
- return tests
-
-if __name__ == '__main__':
- unittest.main()
+if __name__ == "__main__":
+ run_tests()