mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Renaming `set_fullgraph` to `error_on_graph_break` for now. There are no semantic differences yet. In a followup PR, we will introduce a new `torch.compile` option `error_on_graph_break` that has lower priority than `fullgraph` so that `fullgraph` really returns 1 graph. I could keep `set_fullgraph` as a deprecated alias for `error_on_graph_break` for now, but I'm hoping that won't be necessary since it's still private API (there are no internal callsites yet, and there are no significant OSS callsites yet). cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos @guilhermeleobas @xmfan as primary users for `set_fullgraph` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161739 Approved by: https://github.com/xmfan, https://github.com/Lucaskabela, https://github.com/anijain2305, https://github.com/mlazos
468 lines
15 KiB
Diff
468 lines
15 KiB
Diff
diff --git a/test/dynamo/cpython/3_13/test_int.py b/test/dynamo/cpython/3_13/test_int.py
|
|
index 48825f46911..731680d82a0 100644
|
|
--- a/test/dynamo/cpython/3_13/test_int.py
|
|
+++ b/test/dynamo/cpython/3_13/test_int.py
|
|
@@ -1,13 +1,140 @@
|
|
+# ======= BEGIN Dynamo patch =======
|
|
+# Owner(s): ["module: dynamo"]
|
|
+
|
|
+# ruff: noqa
|
|
+# flake8: noqa
|
|
+
|
|
+# Test copied from
|
|
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_int.py
|
|
+
|
|
+import sys
|
|
+import torch
|
|
+import torch._dynamo.test_case
|
|
+import unittest
|
|
+from torch._dynamo.test_case import CPythonTestCase
|
|
+from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
|
|
+
|
|
+__TestCase = CPythonTestCase
|
|
+
|
|
+
|
|
+# redirect import statements
|
|
+import sys
|
|
+import importlib.abc
|
|
+
|
|
+redirect_imports = (
|
|
+ "test.mapping_tests",
|
|
+ "test.typinganndata",
|
|
+ "test.test_grammar",
|
|
+ "test.test_math",
|
|
+ "test.test_iter",
|
|
+ "test.typinganndata.ann_module",
|
|
+)
|
|
+
|
|
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
|
+ def find_spec(self, fullname, path, target=None):
|
|
+ # Check if the import is the problematic one
|
|
+ if fullname in redirect_imports:
|
|
+ try:
|
|
+ # Attempt to import the standalone module
|
|
+ name = fullname.removeprefix("test.")
|
|
+ r = importlib.import_module(name)
|
|
+ # Redirect the module in sys.modules
|
|
+ sys.modules[fullname] = r
|
|
+ # Return a module spec from the found module
|
|
+ return importlib.util.find_spec(name)
|
|
+ except ImportError:
|
|
+ return None
|
|
+ return None
|
|
+
|
|
+# Add the custom finder to sys.meta_path
|
|
+sys.meta_path.insert(0, RedirectImportFinder())
|
|
+
|
|
+
|
|
+# ======= END DYNAMO PATCH =======
|
|
+
|
|
import sys
|
|
import time
|
|
|
|
import unittest
|
|
from unittest import mock
|
|
from test import support
|
|
-from test.support.numbers import (
|
|
- VALID_UNDERSCORE_LITERALS,
|
|
- INVALID_UNDERSCORE_LITERALS,
|
|
-)
|
|
+
|
|
+VALID_UNDERSCORE_LITERALS = [
|
|
+ '0_0_0',
|
|
+ '4_2',
|
|
+ '1_0000_0000',
|
|
+ '0b1001_0100',
|
|
+ '0xffff_ffff',
|
|
+ '0o5_7_7',
|
|
+ '1_00_00.5',
|
|
+ '1_00_00.5e5',
|
|
+ '1_00_00e5_1',
|
|
+ '1e1_0',
|
|
+ '.1_4',
|
|
+ '.1_4e1',
|
|
+ '0b_0',
|
|
+ '0x_f',
|
|
+ '0o_5',
|
|
+ '1_00_00j',
|
|
+ '1_00_00.5j',
|
|
+ '1_00_00e5_1j',
|
|
+ '.1_4j',
|
|
+ '(1_2.5+3_3j)',
|
|
+ '(.5_6j)',
|
|
+]
|
|
+INVALID_UNDERSCORE_LITERALS = [
|
|
+ # Trailing underscores:
|
|
+ '0_',
|
|
+ '42_',
|
|
+ '1.4j_',
|
|
+ '0x_',
|
|
+ '0b1_',
|
|
+ '0xf_',
|
|
+ '0o5_',
|
|
+ '0 if 1_Else 1',
|
|
+ # Underscores in the base selector:
|
|
+ '0_b0',
|
|
+ '0_xf',
|
|
+ '0_o5',
|
|
+ # Old-style octal, still disallowed:
|
|
+ '0_7',
|
|
+ '09_99',
|
|
+ # Multiple consecutive underscores:
|
|
+ '4_______2',
|
|
+ '0.1__4',
|
|
+ '0.1__4j',
|
|
+ '0b1001__0100',
|
|
+ '0xffff__ffff',
|
|
+ '0x___',
|
|
+ '0o5__77',
|
|
+ '1e1__0',
|
|
+ '1e1__0j',
|
|
+ # Underscore right before a dot:
|
|
+ '1_.4',
|
|
+ '1_.4j',
|
|
+ # Underscore right after a dot:
|
|
+ '1._4',
|
|
+ '1._4j',
|
|
+ '._5',
|
|
+ '._5j',
|
|
+ # Underscore right after a sign:
|
|
+ '1.0e+_1',
|
|
+ '1.0e+_1j',
|
|
+ # Underscore right before j:
|
|
+ '1.4_j',
|
|
+ '1.4e5_j',
|
|
+ # Underscore right before e:
|
|
+ '1_e1',
|
|
+ '1.4_e1',
|
|
+ '1.4_e1j',
|
|
+ # Underscore right after e:
|
|
+ '1e_1',
|
|
+ '1.4e_1',
|
|
+ '1.4e_1j',
|
|
+ # Complex cases with parens:
|
|
+ '(1+1.5_j_)',
|
|
+ '(1+1.5_j)',
|
|
+]
|
|
|
|
try:
|
|
import _pylong
|
|
@@ -38,7 +165,7 @@ L = [
|
|
class IntSubclass(int):
|
|
pass
|
|
|
|
-class IntTestCases(unittest.TestCase):
|
|
+class IntTestCases(__TestCase):
|
|
|
|
def test_basic(self):
|
|
self.assertEqual(int(314), 314)
|
|
@@ -309,11 +436,13 @@ class IntTestCases(unittest.TestCase):
|
|
int('0', 5.0)
|
|
|
|
def test_int_base_indexable(self):
|
|
- class MyIndexable(object):
|
|
- def __init__(self, value):
|
|
- self.value = value
|
|
- def __index__(self):
|
|
- return self.value
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class MyIndexable(object):
|
|
+ def __init__(self, value):
|
|
+ self.value = value
|
|
+ def __index__(self):
|
|
+ return self.value
|
|
|
|
# Check out of range bases.
|
|
for base in 2**100, -2**100, 1, 37:
|
|
@@ -328,9 +457,11 @@ class IntTestCases(unittest.TestCase):
|
|
def test_non_numeric_input_types(self):
|
|
# Test possible non-numeric types for the argument x, including
|
|
# subclasses of the explicitly documented accepted types.
|
|
- class CustomStr(str): pass
|
|
- class CustomBytes(bytes): pass
|
|
- class CustomByteArray(bytearray): pass
|
|
+
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class CustomStr(str): pass
|
|
+ class CustomBytes(bytes): pass
|
|
+ class CustomByteArray(bytearray): pass
|
|
|
|
factories = [
|
|
bytes,
|
|
@@ -372,72 +503,82 @@ class IntTestCases(unittest.TestCase):
|
|
|
|
def test_intconversion(self):
|
|
# Test __int__()
|
|
- class ClassicMissingMethods:
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class ClassicMissingMethods:
|
|
+ pass
|
|
self.assertRaises(TypeError, int, ClassicMissingMethods())
|
|
|
|
- class MissingMethods(object):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class MissingMethods(object):
|
|
+ pass
|
|
self.assertRaises(TypeError, int, MissingMethods())
|
|
|
|
- class Foo0:
|
|
- def __int__(self):
|
|
- return 42
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class Foo0:
|
|
+ def __int__(self):
|
|
+ return 42
|
|
|
|
self.assertEqual(int(Foo0()), 42)
|
|
|
|
- class Classic:
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class Classic:
|
|
+ pass
|
|
for base in (object, Classic):
|
|
- class IntOverridesTrunc(base):
|
|
- def __int__(self):
|
|
- return 42
|
|
- def __trunc__(self):
|
|
- return -12
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class IntOverridesTrunc(base):
|
|
+ def __int__(self):
|
|
+ return 42
|
|
+ def __trunc__(self):
|
|
+ return -12
|
|
self.assertEqual(int(IntOverridesTrunc()), 42)
|
|
|
|
- class JustTrunc(base):
|
|
- def __trunc__(self):
|
|
- return 42
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class JustTrunc(base):
|
|
+ def __trunc__(self):
|
|
+ return 42
|
|
with self.assertWarns(DeprecationWarning):
|
|
self.assertEqual(int(JustTrunc()), 42)
|
|
|
|
- class ExceptionalTrunc(base):
|
|
- def __trunc__(self):
|
|
- 1 / 0
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class ExceptionalTrunc(base):
|
|
+ def __trunc__(self):
|
|
+ 1 / 0
|
|
with self.assertRaises(ZeroDivisionError), \
|
|
self.assertWarns(DeprecationWarning):
|
|
int(ExceptionalTrunc())
|
|
|
|
for trunc_result_base in (object, Classic):
|
|
- class Index(trunc_result_base):
|
|
- def __index__(self):
|
|
- return 42
|
|
-
|
|
- class TruncReturnsNonInt(base):
|
|
- def __trunc__(self):
|
|
- return Index()
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class Index(trunc_result_base):
|
|
+ def __index__(self):
|
|
+ return 42
|
|
+
|
|
+ class TruncReturnsNonInt(base):
|
|
+ def __trunc__(self):
|
|
+ return Index()
|
|
with self.assertWarns(DeprecationWarning):
|
|
self.assertEqual(int(TruncReturnsNonInt()), 42)
|
|
|
|
- class Intable(trunc_result_base):
|
|
- def __int__(self):
|
|
- return 42
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class Intable(trunc_result_base):
|
|
+ def __int__(self):
|
|
+ return 42
|
|
|
|
- class TruncReturnsNonIndex(base):
|
|
- def __trunc__(self):
|
|
- return Intable()
|
|
+ class TruncReturnsNonIndex(base):
|
|
+ def __trunc__(self):
|
|
+ return Intable()
|
|
with self.assertWarns(DeprecationWarning):
|
|
self.assertEqual(int(TruncReturnsNonInt()), 42)
|
|
|
|
- class NonIntegral(trunc_result_base):
|
|
- def __trunc__(self):
|
|
- # Check that we avoid infinite recursion.
|
|
- return NonIntegral()
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class NonIntegral(trunc_result_base):
|
|
+ def __trunc__(self):
|
|
+ # Check that we avoid infinite recursion.
|
|
+ return NonIntegral()
|
|
|
|
- class TruncReturnsNonIntegral(base):
|
|
- def __trunc__(self):
|
|
- return NonIntegral()
|
|
+ class TruncReturnsNonIntegral(base):
|
|
+ def __trunc__(self):
|
|
+ return NonIntegral()
|
|
try:
|
|
with self.assertWarns(DeprecationWarning):
|
|
int(TruncReturnsNonIntegral())
|
|
@@ -449,27 +590,29 @@ class IntTestCases(unittest.TestCase):
|
|
self.fail("Failed to raise TypeError with %s" %
|
|
((base, trunc_result_base),))
|
|
|
|
- # Regression test for bugs.python.org/issue16060.
|
|
- class BadInt(trunc_result_base):
|
|
- def __int__(self):
|
|
- return 42.0
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ # Regression test for bugs.python.org/issue16060.
|
|
+ class BadInt(trunc_result_base):
|
|
+ def __int__(self):
|
|
+ return 42.0
|
|
|
|
- class TruncReturnsBadInt(base):
|
|
- def __trunc__(self):
|
|
- return BadInt()
|
|
+ class TruncReturnsBadInt(base):
|
|
+ def __trunc__(self):
|
|
+ return BadInt()
|
|
|
|
with self.assertRaises(TypeError), \
|
|
self.assertWarns(DeprecationWarning):
|
|
int(TruncReturnsBadInt())
|
|
|
|
def test_int_subclass_with_index(self):
|
|
- class MyIndex(int):
|
|
- def __index__(self):
|
|
- return 42
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class MyIndex(int):
|
|
+ def __index__(self):
|
|
+ return 42
|
|
|
|
- class BadIndex(int):
|
|
- def __index__(self):
|
|
- return 42.0
|
|
+ class BadIndex(int):
|
|
+ def __index__(self):
|
|
+ return 42.0
|
|
|
|
my_int = MyIndex(7)
|
|
self.assertEqual(my_int, 7)
|
|
@@ -478,13 +621,14 @@ class IntTestCases(unittest.TestCase):
|
|
self.assertEqual(int(BadIndex()), 0)
|
|
|
|
def test_int_subclass_with_int(self):
|
|
- class MyInt(int):
|
|
- def __int__(self):
|
|
- return 42
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class MyInt(int):
|
|
+ def __int__(self):
|
|
+ return 42
|
|
|
|
- class BadInt(int):
|
|
- def __int__(self):
|
|
- return 42.0
|
|
+ class BadInt(int):
|
|
+ def __int__(self):
|
|
+ return 42.0
|
|
|
|
my_int = MyInt(7)
|
|
self.assertEqual(my_int, 7)
|
|
@@ -495,33 +639,34 @@ class IntTestCases(unittest.TestCase):
|
|
self.assertRaises(TypeError, int, my_int)
|
|
|
|
def test_int_returns_int_subclass(self):
|
|
- class BadIndex:
|
|
- def __index__(self):
|
|
- return True
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class BadIndex:
|
|
+ def __index__(self):
|
|
+ return True
|
|
|
|
- class BadIndex2(int):
|
|
- def __index__(self):
|
|
- return True
|
|
+ class BadIndex2(int):
|
|
+ def __index__(self):
|
|
+ return True
|
|
|
|
- class BadInt:
|
|
- def __int__(self):
|
|
- return True
|
|
+ class BadInt:
|
|
+ def __int__(self):
|
|
+ return True
|
|
|
|
- class BadInt2(int):
|
|
- def __int__(self):
|
|
- return True
|
|
+ class BadInt2(int):
|
|
+ def __int__(self):
|
|
+ return True
|
|
|
|
- class TruncReturnsBadIndex:
|
|
- def __trunc__(self):
|
|
- return BadIndex()
|
|
+ class TruncReturnsBadIndex:
|
|
+ def __trunc__(self):
|
|
+ return BadIndex()
|
|
|
|
- class TruncReturnsBadInt:
|
|
- def __trunc__(self):
|
|
- return BadInt()
|
|
+ class TruncReturnsBadInt:
|
|
+ def __trunc__(self):
|
|
+ return BadInt()
|
|
|
|
- class TruncReturnsIntSubclass:
|
|
- def __trunc__(self):
|
|
- return True
|
|
+ class TruncReturnsIntSubclass:
|
|
+ def __trunc__(self):
|
|
+ return True
|
|
|
|
bad_int = BadIndex()
|
|
with self.assertWarns(DeprecationWarning):
|
|
@@ -566,6 +711,7 @@ class IntTestCases(unittest.TestCase):
|
|
self.assertEqual(n, 1)
|
|
self.assertIs(type(n), IntSubclass)
|
|
|
|
+ @skipIfTorchDynamo("flaky under dynamo")
|
|
def test_error_message(self):
|
|
def check(s, base=None):
|
|
with self.assertRaises(ValueError,
|
|
@@ -607,7 +753,7 @@ class IntTestCases(unittest.TestCase):
|
|
self.assertEqual(int('1_2_3_4_5_6_7', 32), 1144132807)
|
|
|
|
|
|
-class IntStrDigitLimitsTests(unittest.TestCase):
|
|
+class IntStrDigitLimitsTests(__TestCase):
|
|
|
|
int_class = int # Override this in subclasses to reuse the suite.
|
|
|
|
@@ -818,7 +964,7 @@ class IntSubclassStrDigitLimitsTests(IntStrDigitLimitsTests):
|
|
int_class = IntSubclass
|
|
|
|
|
|
-class PyLongModuleTests(unittest.TestCase):
|
|
+class PyLongModuleTests(__TestCase):
|
|
# Tests of the functions in _pylong.py. Those get used when the
|
|
# number of digits in the input values are large enough.
|
|
|
|
@@ -922,4 +1068,4 @@ class PyLongModuleTests(unittest.TestCase):
|
|
bits <<= 1
|
|
|
|
if __name__ == "__main__":
|
|
- unittest.main()
|
|
+ run_tests()
|