Add infra to run CPython tests under Dynamo (#150787)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150787
Approved by: https://github.com/zou3519
This commit is contained in:
Guilherme Leobas
2025-05-06 11:54:30 -03:00
committed by PyTorch MergeBot
parent 13fbf21a76
commit ae1e51b6ad
14 changed files with 181 additions and 2196 deletions

View File

@ -18,6 +18,8 @@ exclude_patterns = [
'torch/_inductor/autoheuristic/artifacts/**',
'scripts/**',
'test/generated_type_hints_smoketest.py',
# CPython tests
'test/dynamo/cpython/**',
# Tests from the NumPy test suite
'test/torch_np/numpy_test/**/*.py',
'third_party/**',
@ -398,6 +400,7 @@ exclude_patterns=[
'tools/clang_format_hash/**',
'test/cpp/jit/upgrader_models/*.ptl',
'test/cpp/jit/upgrader_models/*.ptl.ff',
'test/dynamo/cpython/**',
'**/*.png',
'**/*.gz',
'**/*.patch',
@ -936,6 +939,7 @@ include_patterns = [
exclude_patterns = [
'test/run_test.py',
'**/fb/**',
'test/dynamo/cpython/3.13/**',
'test/quantization/**', # should be run through test/test_quantization.py
'test/jit/**', # should be run through test/test_jit.py
'test/ao/sparsity/**', # should be run through test/test_ao_sparsity.py
@ -1131,6 +1135,7 @@ exclude_patterns = [
'caffe2/**/*.pyi',
'fb/**',
'**/fb/**',
'test/dynamo/cpython/**',
'third_party/**/*.py',
'third_party/**/*.pyi',
'torch/_vendor/**',
@ -1536,6 +1541,7 @@ exclude_patterns = [
'functorch/notebooks/**',
'torch/_inductor/fx_passes/serialized_patterns/**',
'torch/_inductor/autoheuristic/artifacts/**',
'test/dynamo/cpython/**',
'scripts/**',
'third_party/**',
'fb/**',

View File

@ -0,0 +1,9 @@
This subdirectory contains a selection of tests from the CPython repository (branch: v3.13.0):\
https://github.com/python/cpython/releases/tag/v3.13.0
Modifications were made to ensure compatibility with the Dynamo infrastructure:
+ Monkey-patched `unittest.TestCase` to `torch._dynamo.test_case.CPythonTestCase`.
+ Replaced `unittest.main()` with `torch._dynamo.test_case.run_tests()`.
+ Assigned test "owners."
+ Annotated CPU-intensive tests with the `@slowTest` decorator.
+ Adjusted imports to use `import module` instead of `from test import module`.

View File

@ -0,0 +1,46 @@
PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
--------------------------------------------
1. This LICENSE AGREEMENT is between the Python Software Foundation
("PSF"), and the Individual or Organization ("Licensee") accessing and
otherwise using this software ("Python") in source or binary form and
its associated documentation.
2. Subject to the terms and conditions of this License Agreement, PSF hereby
grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce,
analyze, test, perform and/or display publicly, prepare derivative works,
distribute, and otherwise use Python alone or in any derivative version,
provided, however, that PSF's License Agreement and PSF's notice of copyright,
i.e., "Copyright (c) 2001 Python Software Foundation; All Rights Reserved"
are retained in Python alone or in any derivative version prepared by Licensee.
3. In the event Licensee prepares a derivative work that is based on
or incorporates Python or any part thereof, and wants to make
the derivative work available to others as provided herein, then
Licensee hereby agrees to include in any such work a brief summary of
the changes made to Python.
4. PSF is making Python available to Licensee on an "AS IS"
basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT
INFRINGE ANY THIRD PARTY RIGHTS.
5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON,
OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
6. This License Agreement will automatically terminate upon a material
breach of its terms and conditions.
7. Nothing in this License Agreement shall be deemed to create any
relationship of agency, partnership, or joint venture between PSF and
Licensee. This License Agreement does not grant permission to use PSF
trademarks or trade name in a trademark sense to endorse or promote
products or services of Licensee, or any third party.
8. By copying, installing or otherwise using Python, Licensee
agrees to be bound by the terms and conditions of this License
Agreement.

View File

@ -1,7 +1,6 @@
# Owner(s): ["module: dynamo"]
import contextlib
import sys
import traceback
import unittest
from contextlib import contextmanager
@ -9,18 +8,12 @@ import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.exc import InternalTorchDynamoError
from torch._dynamo.testing import (
EagerAndRecordGraphs,
normalize_gm,
same,
skipIfNotPy311,
)
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm, same
from torch._dynamo.utils import counters
from torch.nn import functional as F
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
make_dynamo_test,
parametrize,
)
@ -37,6 +30,16 @@ z_glb = 0
k_glb = 0
@contextlib.contextmanager
def set_default_dtype(dtype):
old_dtype = torch.get_default_dtype()
try:
torch.set_default_dtype(dtype)
yield
finally:
torch.set_default_dtype(old_dtype)
class CustomizedCtxManager:
def __init__(self, mode):
self.prev = torch.is_grad_enabled()
@ -2700,319 +2703,6 @@ class GraphModule(torch.nn.Module):
self.assertEqual(y, t.sin())
class CPythonContextManagerTestCase(torch._dynamo.test_case.CPythonTestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_contextlib.py
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_contextlib.py
@make_dynamo_test
def test_contextmanager_plain(self):
state = []
@contextmanager
def woohoo():
state.append(1)
yield 42
state.append(999)
with woohoo() as x:
self.assertEqual(state, [1])
self.assertEqual(x, 42)
state.append(x)
self.assertEqual(state, [1, 42, 999])
@skipIfNotPy311
@make_dynamo_test
def test_contextmanager_finally(self):
state = []
@contextmanager
def woohoo():
state.append(1)
try:
yield 42
finally:
state.append(999)
with self.assertRaises(ZeroDivisionError):
with woohoo() as x:
self.assertEqual(state, [1])
self.assertEqual(x, 42)
state.append(x)
raise ZeroDivisionError
self.assertEqual(state, [1, 42, 999])
@unittest.expectedFailure
@make_dynamo_test
def test_contextmanager_traceback(self):
@contextmanager
def f():
yield
try:
with f():
1 / 0
except ZeroDivisionError as e:
frames = traceback.extract_tb(e.__traceback__)
self.assertEqual(len(frames), 1)
self.assertEqual(frames[0].name, "test_contextmanager_traceback")
self.assertEqual(frames[0].line, "1/0")
# Repeat with RuntimeError (which goes through a different code path)
try:
with f():
raise NotImplementedError(42)
except NotImplementedError as e:
frames = traceback.extract_tb(e.__traceback__)
self.assertEqual(len(frames), 1)
self.assertEqual(frames[0].name, "test_contextmanager_traceback")
self.assertEqual(frames[0].line, "raise NotImplementedError(42)")
@make_dynamo_test
def test_contextmanager_no_reraise(self):
@contextmanager
def whee():
yield
ctx = whee()
ctx.__enter__()
# Calling __exit__ should not result in an exception
self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
@make_dynamo_test
def test_contextmanager_trap_yield_after_throw(self):
@contextmanager
def whoo():
try:
yield
except Exception: # noqa: E722
yield
ctx = whoo()
ctx.__enter__()
with self.assertRaises(RuntimeError):
ctx.__exit__(TypeError, TypeError("foo"), None)
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
def test_contextmanager_except(self):
state = []
@contextmanager
def woohoo():
state.append(1)
try:
yield 42
except ZeroDivisionError as e:
state.append(e.args[0])
self.assertEqual(state, [1, 42, 999])
with woohoo() as x:
self.assertEqual(state, [1])
self.assertEqual(x, 42)
state.append(x)
raise ZeroDivisionError(999)
self.assertEqual(state, [1, 42, 999])
@unittest.expectedFailure
@make_dynamo_test
def test_contextmanager_except_stopiter(self):
@contextmanager
def woohoo():
yield
class StopIterationSubclass(StopIteration):
pass
for stop_exc in (StopIteration("spam"), StopIterationSubclass("spam")):
with self.subTest(type=type(stop_exc)):
try:
with woohoo():
raise stop_exc
except Exception as ex:
self.assertIs(ex, stop_exc)
else:
self.fail(f"{stop_exc} was suppressed")
@unittest.expectedFailure
@make_dynamo_test
def test_contextmanager_except_pep479(self):
code = """\
from __future__ import generator_stop
from contextlib import contextmanager
@contextmanager
def woohoo():
yield
"""
locals = {}
exec(code, locals, locals)
woohoo = locals["woohoo"]
stop_exc = StopIteration("spam")
try:
with woohoo():
raise stop_exc
except Exception as ex:
self.assertIs(ex, stop_exc)
else:
self.fail("StopIteration was suppressed")
@unittest.expectedFailure
@make_dynamo_test
def test_contextmanager_do_not_unchain_non_stopiteration_exceptions(self):
@contextmanager
def test_issue29692():
try:
yield
except Exception as exc:
raise RuntimeError("issue29692:Chained") from exc
try:
with test_issue29692():
raise ZeroDivisionError
except Exception as ex:
self.assertIs(type(ex), RuntimeError)
self.assertEqual(ex.args[0], "issue29692:Chained")
self.assertIsInstance(ex.__cause__, ZeroDivisionError)
try:
with test_issue29692():
raise StopIteration("issue29692:Unchained")
except Exception as ex:
self.assertIs(type(ex), StopIteration)
self.assertEqual(ex.args[0], "issue29692:Unchained")
self.assertIsNone(ex.__cause__)
@unittest.expectedFailure
@make_dynamo_test
def _create_contextmanager_attribs(self):
def attribs(**kw):
def decorate(func):
for k, v in kw.items():
setattr(func, k, v)
return func
return decorate
@contextmanager
@attribs(foo="bar")
def baz(spam):
"""Whee!"""
return baz
@unittest.expectedFailure
@make_dynamo_test
def test_contextmanager_attribs(self):
baz = self._create_contextmanager_attribs()
self.assertEqual(baz.__name__, "baz")
self.assertEqual(baz.foo, "bar")
@make_dynamo_test
def test_keywords(self):
# Ensure no keyword arguments are inhibited
@contextmanager
def woohoo(self, func, args, kwds):
yield (self, func, args, kwds)
with woohoo(self=11, func=22, args=33, kwds=44) as target:
self.assertEqual(target, (11, 22, 33, 44))
@unittest.expectedFailure
@make_dynamo_test
def test_param_errors(self):
@contextmanager
def woohoo(a, *, b):
yield
with self.assertRaises(TypeError):
woohoo()
with self.assertRaises(TypeError):
woohoo(3, 5)
with self.assertRaises(TypeError):
woohoo(b=3)
@make_dynamo_test
def test_recursive(self):
depth = 0
@contextmanager
def woohoo():
nonlocal depth
before = depth
depth += 1
yield
depth -= 1
self.assertEqual(depth, before)
@woohoo()
def recursive():
if depth < 10:
recursive()
recursive()
self.assertEqual(depth, 0)
@skipIfNotPy311
@make_dynamo_test
def test_contextmanager_trap_no_yield(self):
@contextmanager
def whoo():
if False:
yield
ctx = whoo()
with self.assertRaises(RuntimeError):
ctx.__enter__()
@make_dynamo_test
def test_contextmanager_trap_second_yield(self):
@contextmanager
def whoo():
yield
yield
ctx = whoo()
ctx.__enter__()
with self.assertRaises(RuntimeError):
ctx.__exit__(None, None, None)
@unittest.expectedFailure
@make_dynamo_test
def test_contextmanager_wrap_runtimeerror(self):
@contextmanager
def woohoo():
try:
yield
except Exception as exc:
raise RuntimeError(f"caught {exc}") from exc
with self.assertRaises(RuntimeError):
with woohoo():
1 / 0
# If the context manager wrapped StopIteration in a RuntimeError,
# we also unwrap it, because we can't tell whether the wrapping was
# done by the generator machinery or by the generator itself.
with self.assertRaises(StopIteration):
with woohoo():
raise StopIteration
@make_dynamo_test
def test_contextmanager_non_normalised(self):
@contextmanager
def whoo():
try:
yield
except RuntimeError:
raise SyntaxError # noqa: B904
ctx = whoo()
ctx.__enter__()
with self.assertRaises(SyntaxError):
ctx.__exit__(RuntimeError, None, None)
instantiate_parametrized_tests(CtxManagerTests)
instantiate_parametrized_tests(ContextlibContextManagerTests)

View File

@ -2,7 +2,6 @@
import contextlib
import sys
import unittest
import torch
import torch._dynamo.config
@ -905,238 +904,6 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
assert exc2.__context__ is None
class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_exceptions.py
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_exceptions.py
def setUp(self):
self._u_prev = torch._dynamo.config.enable_trace_unittest
torch._dynamo.config.enable_trace_unittest = True
def tearDown(self):
torch._dynamo.config.enable_trace_unittest = self._u_prev
@make_dynamo_test
def testChainingAttrs(self):
e = Exception()
assert e.__context__ is None
assert e.__cause__ is None
e = TypeError()
assert e.__context__ is None
assert e.__cause__ is None
e = MyException()
assert e.__context__ is None
assert e.__cause__ is None
@make_dynamo_test
def testChainingDescriptors(self):
try:
raise Exception # noqa: TRY002
except Exception as exc:
e = exc
assert e.__context__ is None
assert e.__cause__ is None
assert e.__suppress_context__ is False
e.__context__ = NameError()
e.__cause__ = None
assert isinstance(e.__context__, NameError)
assert e.__cause__ is None
assert e.__suppress_context__ is True
e.__suppress_context__ = False
assert e.__suppress_context__ is False
@make_dynamo_test
def test_context_of_exception_in_try_and_finally(self):
try:
try:
te = TypeError(1)
raise te
finally:
ve = ValueError(2)
raise ve
except Exception as e:
exc = e
assert exc is ve
assert exc.__context__ is te
@make_dynamo_test
def test_context_of_exception_in_except_and_finally(self):
try:
try:
te = TypeError(1)
raise te
except Exception: # noqa: E722
ve = ValueError(2)
raise ve # noqa: B904
finally:
oe = OSError(3)
raise oe
except Exception as e:
exc = e
assert exc is oe
assert exc.__context__ is ve
assert exc.__context__.__context__ is te
@make_dynamo_test
def test_context_of_exception_in_else_and_finally(self):
try:
try:
pass
except Exception: # noqa: E722
pass
else:
ve = ValueError(1)
raise ve
finally:
oe = OSError(2)
raise oe
except Exception as e:
exc = e
assert exc is oe
assert exc.__context__ is ve
@make_dynamo_test
def test_raise_does_not_create_context_chain_cycle(self):
A = AssertionError
B = BytesWarning
C = ConnectionError
# Create a context chain:
# C -> B -> A
# Then raise A in context of C.
try:
try:
raise A
except A as a_:
a = a_
try:
raise B
except B as b_:
b = b_
try:
raise C
except C as c_:
c = c_
self.assertIsInstance(a, A)
self.assertIsInstance(b, B)
self.assertIsInstance(c, C)
self.assertIsNone(a.__context__)
self.assertIs(b.__context__, a)
self.assertIs(c.__context__, b)
raise a # noqa: B904
except A as e:
exc = e
# Expect A -> C -> B, without cycle
self.assertIs(exc, a)
self.assertIs(a.__context__, c)
self.assertIs(c.__context__, b)
self.assertIsNone(b.__context__)
@make_dynamo_test
def test_no_hang_on_context_chain_cycle1(self):
# See issue 25782. Cycle in context chain.
def cycle():
try:
raise ValueError(1)
except ValueError as ex:
ex.__context__ = ex
raise TypeError(2) # noqa: B904
try:
cycle()
except Exception as e:
exc = e
self.assertIsInstance(exc, TypeError)
self.assertIsInstance(exc.__context__, ValueError)
self.assertIs(exc.__context__.__context__, exc.__context__)
@unittest.expectedFailure
@make_dynamo_test
def test_no_hang_on_context_chain_cycle2(self):
# See issue 25782. Cycle at head of context chain.
A = AssertionError
B = BytesWarning
C = ConnectionError
# Context cycle:
# +-----------+
# V |
# C --> B --> A
with self.assertRaises(C) as cm:
try:
raise A() # noqa: RSE102
except A as _a:
a = _a
try:
raise B() # noqa: RSE102
except B as _b:
b = _b
try:
raise C() # noqa: RSE102
except C as _c:
c = _c
a.__context__ = c
raise c # noqa: B904
self.assertIs(cm.exception, c)
# Verify the expected context chain cycle
self.assertIs(c.__context__, b)
self.assertIs(b.__context__, a)
self.assertIs(a.__context__, c)
@make_dynamo_test
def test_no_hang_on_context_chain_cycle3(self):
# See issue 25782. Longer context chain with cycle.
A = AssertionError
B = BytesWarning
C = ConnectionError
D = DeprecationWarning
E = Exception
# Context cycle:
# +-----------+
# V |
# E --> D --> C --> B --> A
with self.assertRaises(E) as cm:
try:
raise A
except A as _a:
a = _a
try:
raise B
except B as _b:
b = _b
try:
raise C
except C as _c:
c = _c
a.__context__ = c
try:
raise D
except D as _d:
d = _d
e = E()
raise e # noqa: B904
self.assertIs(cm.exception, e)
# Verify the expected context chain cycle
self.assertIs(e.__context__, d)
self.assertIs(d.__context__, c)
self.assertIs(c.__context__, b)
self.assertIs(b.__context__, a)
self.assertIs(a.__context__, c)
instantiate_parametrized_tests(ExceptionTests)

View File

@ -1481,331 +1481,6 @@ class TestGeneratorThrow(GeneratorTestsBase):
self._compile_check(fn)
class GeneratorCloseCPythonTests(GeneratorTestsBase):
# Taken from commit
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
# changed the tests a little bit to run them inside dynamo
# + replaced all self.assert* calls to plain assert statements
def test_close_no_return_value(self):
def f():
yield
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
gen = f()
gen.send(None)
assert gen.close() is None
return t.sin()
t = torch.randn(2)
fn(t)
def test_close_return_value(self):
def f():
try:
yield
# close() raises GeneratorExit here, which is caught
except GeneratorExit:
return 0 # noqa: B901
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
gen = f()
gen.send(None)
assert gen.close() == 0
return t.sin()
t = torch.randn(2)
fn(t)
def test_close_not_catching_exit(self):
def f():
yield
# close() raises GeneratorExit here, which isn't caught and
# therefore propagates -- no return value
return 0 # noqa: B901
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
gen = f()
gen.send(None)
assert gen.close() is None
return t.sin()
t = torch.randn(2)
fn(t)
def test_close_not_started(self):
def f():
try:
yield
except GeneratorExit:
return 0 # noqa: B901
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
gen = f()
assert gen.close() is None
return t.sin()
t = torch.randn(2)
fn(t)
def test_close_exhausted(self):
def f():
try:
yield
except GeneratorExit:
return 0 # noqa: B901
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
gen = f()
next(gen)
z = 0
try:
next(gen) # -> StopIteration
except StopIteration:
z = 1
except Exception as e:
# anything other than StopIteration should fail
raise AssertionError from e
assert z == 1
assert gen.close() is None
return t.sin()
t = torch.randn(2)
fn(t)
def test_close_closed(self):
def f():
try:
yield
except GeneratorExit:
return 0 # noqa: B901
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
gen = f()
gen.send(None)
assert gen.close() == 0
assert gen.close() is None
return t.sin()
t = torch.randn(2)
fn(t)
def test_close_raises(self):
def f():
try:
yield
except GeneratorExit:
pass
raise RuntimeError
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
gen = f()
gen.send(None)
z = 0
try:
gen.close() # -> RuntimeError
except RuntimeError:
z = 1
except Exception as e:
raise AssertionError from e
assert z == 1
return t.sin()
t = torch.randn(2)
fn(t)
class GeneratorThrowCpythonTests(GeneratorTestsBase):
# Taken from commit
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
# changed the tests a little bit to run them inside dynamo
# + replaced all self.assert* calls to plain assert statements
def test_exception_context_with_yield(self):
def f():
try:
raise KeyError("a")
except Exception:
yield
def fn(t):
gen = f()
gen.send(None)
try:
gen.throw(ValueError)
except ValueError as e:
context = e.__context__
assert (type(context), context.args) == (KeyError, ("a",))
except Exception as e:
raise AssertionError from e
return t.sin()
self._compile_check(fn)
def test_exception_context_with_yield_inside_generator(self):
# Check that the context is also available from inside the generator
# with yield, as opposed to outside.
def f():
z = 0
try:
raise KeyError("a")
except Exception:
try:
yield
except Exception as exc:
z = 1
assert type(exc) == ValueError
context = exc.__context__
assert (type(context), context.args) == (KeyError, ("a",))
yield "b"
finally:
assert z == 1
def fn(t):
gen = f()
gen.send(None)
actual = gen.throw(ValueError)
# This ensures that the assertions inside were executed.
assert actual == "b"
return t.sin()
self._compile_check(fn)
def test_exception_context_with_yield_from(self):
def f():
yield
def g():
try:
raise KeyError("a")
except Exception:
yield from f()
def fn(t):
gen = g()
gen.send(None)
try:
gen.throw(ValueError)
except ValueError as e:
context = e.__context__
assert (type(context), context.args) == (KeyError, ("a",))
except Exception as e:
raise AssertionError from e
return t.sin()
self._compile_check(fn)
def test_exception_context_with_yield_from_with_context_cycle(self):
# Check trying to create an exception context cycle:
# https://bugs.python.org/issue40696
has_cycle = None
def f():
yield
def g(exc):
nonlocal has_cycle
try:
raise exc
except Exception:
try:
yield from f()
except Exception as exc:
has_cycle = exc is exc.__context__
yield
def fn(t):
exc = KeyError("a")
gen = g(exc)
gen.send(None)
gen.throw(exc)
# This also distinguishes from the initial has_cycle=None.
assert has_cycle is False
return t.sin()
self._compile_check(fn)
def test_throw_after_none_exc_type(self):
def g():
try:
raise KeyError
except KeyError:
pass
try:
yield
except Exception:
raise RuntimeError # noqa: B904
def fn(t):
gen = g()
gen.send(None)
z = 0
try:
gen.throw(ValueError)
except RuntimeError:
z += 1
except Exception:
raise AssertionError # noqa: B904
assert z == 1
return t.sin()
self._compile_check(fn)
class GeneratorCPythonTests(GeneratorTestsBase):
# Taken from commit
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
# changed the tests a little bit to run them inside dynamo
# + replaced all self.assert* calls to plain assert statements
def test_send_non_none_to_new_gen(self):
def f():
yield 1
def fn(t):
g = f()
z = 0
try:
g.send(0)
except TypeError:
z += 1
except Exception as e:
raise AssertionError from e
assert z == 1
assert next(g) == 1
return t.sin()
self._compile_check(fn)
def test_issue103488(self):
def gen_raises():
yield 1
raise ValueError
def loop():
try:
for _ in gen_raises():
if True is False: # noqa: PLR0133
return
except ValueError:
pass
def fn(t):
# This should not raise
loop()
return t.sin()
self._compile_check(fn)
instantiate_parametrized_tests(GeneratorTests)
instantiate_parametrized_tests(TestGeneratorSend)
instantiate_parametrized_tests(TestGeneratorClose)

View File

@ -1,52 +0,0 @@
# Owner(s): ["module: dynamo"]
import sys
import unittest
import torch
import torch._dynamo.test_case
from torch.testing._internal.common_utils import make_dynamo_test
class TestPEP479(torch._dynamo.test_case.CPythonTestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_generator_stop.py
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_generator_stop.py
@unittest.skipIf(sys.version_info < (3, 12), "Test does not work in Python < 3.12")
@make_dynamo_test
def test_stopiteration_wrapping(self):
def f():
raise StopIteration
def g():
yield f()
with self.assertRaises(RuntimeError) as cm:
next(g())
self.assertEqual("generator raised StopIteration", str(cm.exception))
@unittest.skipIf(sys.version_info < (3, 12), "Test does not work in Python < 3.12")
@make_dynamo_test
def test_stopiteration_wrapping_context(self):
def f():
raise StopIteration
def g():
yield f()
try:
next(g())
except RuntimeError as exc:
self.assertIs(type(exc.__cause__), StopIteration)
self.assertIs(type(exc.__context__), StopIteration)
self.assertTrue(exc.__suppress_context__)
else:
self.fail(
"__cause__, __context__, or __suppress_context__ "
"were not properly set"
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -1,563 +0,0 @@
# Owner(s): ["module: dynamo"]
# ruff: noqa
# flake8: noqa
import sys
import types
import unittest
import torch
import torch._dynamo.config
import torch._dynamo.test_case
import torch._functorch.config
import torch.nn
import torch.utils.checkpoint
from torch.testing._internal.common_utils import make_dynamo_test
def get_tb():
try:
raise OSError()
except:
return sys.exc_info()[2]
class Context:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_tb):
return True
class MyException(Exception):
def __init__(self):
raise RuntimeError()
class ContextManager:
def __enter__(self):
pass
def __exit__(self, t, v, tb):
raise NameError
class TestRaise(torch._dynamo.test_case.CPythonTestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
@make_dynamo_test
def test_invalid_reraise(self):
try:
raise
except RuntimeError as e:
self.assertIn("No active exception", str(e))
else:
self.fail("No exception raised")
@make_dynamo_test
def test_reraise(self):
try:
try:
raise IndexError
except IndexError as e:
exc1 = e
raise
except IndexError as exc2:
self.assertIs(exc1, exc2)
else:
self.fail("No exception raised")
@make_dynamo_test
def test_except_reraise(self):
def reraise():
try:
raise TypeError("foo")
except:
try:
raise KeyError("caught")
except KeyError:
pass
raise
self.assertRaises(TypeError, reraise)
@make_dynamo_test
def test_finally_reraise(self):
def reraise():
try:
raise TypeError("foo")
except:
try:
raise KeyError("caught")
finally:
raise
self.assertRaises(KeyError, reraise)
@make_dynamo_test
def test_nested_reraise(self):
def nested_reraise():
raise
def reraise():
try:
raise TypeError("foo")
except:
nested_reraise()
self.assertRaises(TypeError, reraise)
@make_dynamo_test
def test_raise_from_None(self):
try:
try:
raise TypeError("foo")
except:
raise ValueError() from None
except ValueError as e:
self.assertIsInstance(e.__context__, TypeError)
self.assertIsNone(e.__cause__)
@make_dynamo_test
def test_with_reraise1(self):
def reraise():
try:
raise TypeError("foo")
except:
with Context():
pass
raise
self.assertRaises(TypeError, reraise)
@make_dynamo_test
def test_with_reraise2(self):
def reraise():
try:
raise TypeError("foo")
except:
with Context():
raise KeyError("caught")
raise
self.assertRaises(TypeError, reraise)
@make_dynamo_test
def test_yield_reraise(self):
def reraise():
try:
raise TypeError("foo")
except:
yield 1
raise
g = reraise()
next(g)
self.assertRaises(TypeError, lambda: next(g))
self.assertRaises(StopIteration, lambda: next(g))
@make_dynamo_test
def test_erroneous_exception(self):
try:
raise MyException
except RuntimeError:
pass
else:
self.fail("No exception raised")
@unittest.expectedFailure # object
@make_dynamo_test
def test_new_returns_invalid_instance(self):
# See issue #11627.
class MyException2(Exception):
def __new__(cls, *args):
return object()
with self.assertRaises(TypeError):
raise MyException2
@unittest.expectedFailure # Assertion with non-string message
@make_dynamo_test
def test_assert_with_tuple_arg(self):
try:
assert False, (3,)
except AssertionError as e:
self.assertEqual(str(e), "(3,)")
class TestCause(torch._dynamo.test_case.TestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
def setUp(self):
self._prev = torch._dynamo.config.enable_trace_unittest
torch._dynamo.config.enable_trace_unittest = True
def tearDown(self):
torch._dynamo.config.enable_trace_unittest = self._prev
@make_dynamo_test
def testCauseSyntax(self):
try:
try:
try:
raise TypeError
except Exception:
raise ValueError from None
except ValueError as exc:
self.assertIsNone(exc.__cause__)
self.assertTrue(exc.__suppress_context__)
exc.__suppress_context__ = False
raise exc
except ValueError as exc:
e = exc
self.assertIsNone(e.__cause__)
self.assertFalse(e.__suppress_context__)
self.assertIsInstance(e.__context__, TypeError)
@make_dynamo_test
def test_invalid_cause(self):
try:
raise IndexError from 5
except TypeError as e:
self.assertIn("exception cause", str(e))
else:
self.fail("No exception raised")
@make_dynamo_test
def test_class_cause(self):
try:
raise IndexError from KeyError
except IndexError as e:
self.assertIsInstance(e.__cause__, KeyError)
else:
self.fail("No exception raised")
@make_dynamo_test
def test_instance_cause(self):
cause = KeyError()
try:
raise IndexError from cause
except IndexError as e:
self.assertIs(e.__cause__, cause)
else:
self.fail("No exception raised")
@make_dynamo_test
def test_erroneous_cause(self):
try:
raise IndexError from MyException
except RuntimeError:
pass
else:
self.fail("No exception raised")
class TestTraceback(torch._dynamo.test_case.TestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
def setUp(self):
self._prev = torch._dynamo.config.enable_trace_unittest
torch._dynamo.config.enable_trace_unittest = True
def tearDown(self):
torch._dynamo.config.enable_trace_unittest = self._prev
@unittest.expectedFailure # Dynamo doesn't track traceback
@make_dynamo_test
def test_sets_traceback(self):
try:
raise IndexError()
except IndexError as e:
self.assertIsInstance(e.__traceback__, types.TracebackType)
else:
self.fail("No exception raised")
@unittest.expectedFailure # Dynamo doesn't track traceback
@make_dynamo_test
def test_accepts_traceback(self):
tb = get_tb()
try:
raise IndexError().with_traceback(tb)
except IndexError as e:
self.assertNotEqual(e.__traceback__, tb)
self.assertEqual(e.__traceback__.tb_next, tb)
else:
self.fail("No exception raised")
class TestTracebackType(torch._dynamo.test_case.TestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
def setUp(self):
self._prev = torch._dynamo.config.enable_trace_unittest
torch._dynamo.config.enable_trace_unittest = True
def tearDown(self):
torch._dynamo.config.enable_trace_unittest = self._prev
def raiser(self):
raise ValueError
@unittest.expectedFailure # Dynamo doesn't track traceback
@make_dynamo_test
def test_attrs(self):
try:
self.raiser()
except Exception as exc:
tb = exc.__traceback__
self.assertIsInstance(tb.tb_next, types.TracebackType)
self.assertIs(tb.tb_frame, sys._getframe())
self.assertIsInstance(tb.tb_lasti, int)
self.assertIsInstance(tb.tb_lineno, int)
self.assertIs(tb.tb_next.tb_next, None)
# Invalid assignments
with self.assertRaises(TypeError):
del tb.tb_next
with self.assertRaises(TypeError):
tb.tb_next = "asdf"
# Loops
with self.assertRaises(ValueError):
tb.tb_next = tb
with self.assertRaises(ValueError):
tb.tb_next.tb_next = tb
# Valid assignments
tb.tb_next = None
self.assertIs(tb.tb_next, None)
new_tb = get_tb()
tb.tb_next = new_tb
self.assertIs(tb.tb_next, new_tb)
@unittest.expectedFailure # Dynamo doesn't track traceback
@make_dynamo_test
def test_constructor(self):
other_tb = get_tb()
frame = sys._getframe()
tb = types.TracebackType(other_tb, frame, 1, 2)
self.assertEqual(tb.tb_next, other_tb)
self.assertEqual(tb.tb_frame, frame)
self.assertEqual(tb.tb_lasti, 1)
self.assertEqual(tb.tb_lineno, 2)
tb = types.TracebackType(None, frame, 1, 2)
self.assertEqual(tb.tb_next, None)
with self.assertRaises(TypeError):
types.TracebackType("no", frame, 1, 2)
with self.assertRaises(TypeError):
types.TracebackType(other_tb, "no", 1, 2)
with self.assertRaises(TypeError):
types.TracebackType(other_tb, frame, "no", 2)
with self.assertRaises(TypeError):
types.TracebackType(other_tb, frame, 1, "nuh-uh")
class TestContext(torch._dynamo.test_case.TestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
def setUp(self):
self._prev = torch._dynamo.config.enable_trace_unittest
torch._dynamo.config.enable_trace_unittest = True
def tearDown(self):
torch._dynamo.config.enable_trace_unittest = self._prev
@unittest.expectedFailure # missing Exception.__eq__
@make_dynamo_test
def test_instance_context_instance_raise(self):
context = IndexError()
try:
try:
raise context
except:
raise OSError()
except OSError as e:
self.assertEqual(e.__context__, context)
else:
self.fail("No exception raised")
@unittest.expectedFailure # missing Exception.__eq__ and Exception.__repr__
@make_dynamo_test
def test_class_context_instance_raise(self):
context = IndexError
try:
try:
raise context
except:
raise OSError()
except OSError as e:
self.assertNotEqual(e.__context__, context)
self.assertIsInstance(e.__context__, context)
else:
self.fail("No exception raised")
@unittest.expectedFailure # missing Exception.__eq__ and Exception.__repr__
@make_dynamo_test
def test_class_context_class_raise(self):
context = IndexError
try:
try:
raise context
except:
raise OSError
except OSError as e:
self.assertNotEqual(e.__context__, context)
self.assertIsInstance(e.__context__, context)
else:
self.fail("No exception raised")
@make_dynamo_test
def test_c_exception_context(self):
try:
try:
raise ZeroDivisionError
except:
raise OSError
except OSError as e:
self.assertIsInstance(e.__context__, ZeroDivisionError)
else:
self.fail("No exception raised")
@make_dynamo_test
def test_c_exception_raise(self):
try:
try:
raise ZeroDivisionError
except:
raise NameError
except NameError as e:
self.assertIsInstance(e.__context__, ZeroDivisionError)
else:
self.fail("No exception raised")
@make_dynamo_test
def test_noraise_finally(self):
try:
try:
pass
finally:
raise OSError
except OSError as e:
self.assertIsNone(e.__context__)
else:
self.fail("No exception raised")
@make_dynamo_test
def test_raise_finally(self):
try:
try:
raise ZeroDivisionError
finally:
raise OSError
except OSError as e:
self.assertIsInstance(e.__context__, ZeroDivisionError)
else:
self.fail("No exception raised")
@make_dynamo_test
def test_context_manager(self):
try:
with ContextManager():
raise ZeroDivisionError
except NameError as e:
self.assertIsInstance(e.__context__, ZeroDivisionError)
else:
self.fail("No exception raised")
@make_dynamo_test
def test_cycle_broken(self):
# Self-cycles (when re-raising a caught exception) are broken
try:
try:
raise ZeroDivisionError
except ZeroDivisionError as e:
raise e
except ZeroDivisionError as e:
self.assertIsNone(e.__context__)
@make_dynamo_test
def test_reraise_cycle_broken(self):
# Non-trivial context cycles (through re-raising a previous exception)
# are broken too.
try:
try:
raise NameError
except NameError as a:
try:
raise ZeroDivisionError
except ZeroDivisionError:
raise a
except NameError as e:
self.assertIsNone(e.__context__.__context__)
@make_dynamo_test
def test_3118(self):
# deleting the generator caused the __context__ to be cleared
def gen():
try:
yield 1
finally:
pass
def f():
g = gen()
next(g)
try:
try:
raise ValueError
except:
del g
raise KeyError
except Exception as e:
self.assertIsInstance(e.__context__, ValueError)
f()
@unittest.expectedFailure # too CPython specific(?)
@make_dynamo_test
def test_3611(self):
# A re-raised exception in a __del__ caused the __context__
# to be cleared
class C:
def __del__(self):
try:
raise ZeroDivisionError
except:
raise
def f():
x = C()
try:
try:
x.x
except AttributeError:
del x
raise TypeError
except Exception as e:
self.assertNotEqual(e.__context__, None)
self.assertIsInstance(e.__context__, AttributeError)
with support.catch_unraisable_exception() as cm:
f()
self.assertEqual(ZeroDivisionError, cm.unraisable.exc_type)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -1,107 +0,0 @@
# Owner(s): ["module: dynamo"]
import sys
import unittest
import torch
import torch._dynamo.test_case
from torch.testing._internal.common_utils import make_dynamo_test
class SysTests(torch._dynamo.test_case.TestCase):
def test_exc_info(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
try:
raise ValueError
except Exception:
typ, _, _ = sys.exc_info()
if typ is ValueError:
return t.sin()
else:
return t.cos()
t = torch.randn(2)
y = fn(t)
self.assertEqual(y, t.sin())
class CPythonActiveExceptionTests(torch._dynamo.test_case.CPythonTestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_sys.py
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_sys.py
@make_dynamo_test
def test_exc_info_no_exception(self):
self.assertEqual(sys.exc_info(), (None, None, None))
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
@make_dynamo_test
def test_sys_exception_no_exception(self):
self.assertEqual(sys.exception(), None)
@make_dynamo_test
def test_exc_info_with_exception_instance(self):
def f():
raise ValueError(42)
try:
f()
except Exception as e_:
e = e_
exc_info = sys.exc_info()
self.assertIsInstance(e, ValueError)
self.assertIs(exc_info[0], ValueError)
self.assertIs(exc_info[1], e)
self.assertIs(exc_info[2], e.__traceback__)
@make_dynamo_test
def test_exc_info_with_exception_type(self):
def f():
raise ValueError
try:
f()
except Exception as e_:
e = e_
exc_info = sys.exc_info()
self.assertIsInstance(e, ValueError)
self.assertIs(exc_info[0], ValueError)
self.assertIs(exc_info[1], e)
self.assertIs(exc_info[2], e.__traceback__)
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
@make_dynamo_test
def test_sys_exception_with_exception_instance(self):
def f():
raise ValueError(42)
try:
f()
except Exception as e_:
e = e_
exc = sys.exception()
self.assertIsInstance(e, ValueError)
self.assertIs(exc, e)
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
@make_dynamo_test
def test_sys_exception_with_exception_type(self):
def f():
raise ValueError
try:
f()
except Exception as e_:
e = e_
exc = sys.exception()
self.assertIsInstance(e, ValueError)
self.assertIs(exc, e)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -1,8 +1,5 @@
# Owner(s): ["module: dynamo"]
import sys
import unittest
import warnings
from itertools import product
import torch
import torch._dynamo.test_case
@ -28,591 +25,6 @@ class TestUnittest(torch._dynamo.test_case.TestCase):
self.assertEqual(z, 1)
class CPythonTest_Assertions(torch._dynamo.test_case.CPythonTestCase):
# Tests taken from CPython source code in cpython/Lib/test/test_unittest/test_assertions.py
# https://github.com/python/cpython/blob/3.13/Lib/test/test_unittest/test_assertions.py
@make_dynamo_test
def test_AlmostEqual(self):
self.assertAlmostEqual(1.00000001, 1.0)
self.assertNotAlmostEqual(1.0000001, 1.0)
self.assertRaises(self.failureException, self.assertAlmostEqual, 1.0000001, 1.0)
self.assertRaises(
self.failureException, self.assertNotAlmostEqual, 1.00000001, 1.0
)
self.assertAlmostEqual(1.1, 1.0, places=0)
self.assertRaises(
self.failureException, self.assertAlmostEqual, 1.1, 1.0, places=1
)
self.assertAlmostEqual(0, 0.1 + 0.1j, places=0)
self.assertNotAlmostEqual(0, 0.1 + 0.1j, places=1)
self.assertRaises(
self.failureException, self.assertAlmostEqual, 0, 0.1 + 0.1j, places=1
)
self.assertRaises(
self.failureException, self.assertNotAlmostEqual, 0, 0.1 + 0.1j, places=0
)
self.assertAlmostEqual(float("inf"), float("inf"))
self.assertRaises(
self.failureException, self.assertNotAlmostEqual, float("inf"), float("inf")
)
@make_dynamo_test
def test_AmostEqualWithDelta(self):
self.assertAlmostEqual(1.1, 1.0, delta=0.5)
self.assertAlmostEqual(1.0, 1.1, delta=0.5)
self.assertNotAlmostEqual(1.1, 1.0, delta=0.05)
self.assertNotAlmostEqual(1.0, 1.1, delta=0.05)
self.assertAlmostEqual(1.0, 1.0, delta=0.5)
self.assertRaises(
self.failureException, self.assertNotAlmostEqual, 1.0, 1.0, delta=0.5
)
self.assertRaises(
self.failureException, self.assertAlmostEqual, 1.1, 1.0, delta=0.05
)
self.assertRaises(
self.failureException, self.assertNotAlmostEqual, 1.1, 1.0, delta=0.5
)
self.assertRaises(
TypeError, self.assertAlmostEqual, 1.1, 1.0, places=2, delta=2
)
self.assertRaises(
TypeError, self.assertNotAlmostEqual, 1.1, 1.0, places=2, delta=2
)
@make_dynamo_test
def test_assertRaises(self):
def _raise(e):
raise e
self.assertRaises(KeyError, _raise, KeyError)
self.assertRaises(KeyError, _raise, KeyError("key"))
try:
self.assertRaises(KeyError, lambda: None)
except self.failureException as e:
self.assertIn("KeyError not raised", str(e))
else:
self.fail("assertRaises() didn't fail")
try:
self.assertRaises(KeyError, _raise, ValueError)
except ValueError:
pass
else:
self.fail("assertRaises() didn't let exception pass through")
with self.assertRaises(KeyError) as cm:
try:
raise KeyError
except Exception as e:
exc = e
raise
self.assertIs(cm.exception, exc)
with self.assertRaises(KeyError):
raise KeyError("key")
try:
with self.assertRaises(KeyError):
pass
except self.failureException as e:
self.assertIn("KeyError not raised", str(e))
else:
self.fail("assertRaises() didn't fail")
try:
with self.assertRaises(KeyError):
raise ValueError
except ValueError:
pass
else:
self.fail("assertRaises() didn't let exception pass through")
@make_dynamo_test
def testAssertNotRegex(self):
self.assertNotRegex("Ala ma kota", r"r+")
try:
self.assertNotRegex("Ala ma kota", r"k.t", "Message")
except self.failureException as e:
self.assertIn("Message", e.args[0])
else:
self.fail("assertNotRegex should have failed.")
class CPythonTestLongMessage(torch._dynamo.test_case.CPythonTestCase):
"""Test that the individual asserts honour longMessage.
This actually tests all the message behaviour for
asserts that use longMessage."""
def setUp(self):
super().setUp()
class TestableTestFalse(unittest.TestCase):
longMessage = False
failureException = self.failureException
def testTest(self):
pass
class TestableTestTrue(unittest.TestCase):
longMessage = True
failureException = self.failureException
def testTest(self):
pass
self.testableTrue = TestableTestTrue("testTest")
self.testableFalse = TestableTestFalse("testTest")
def testDefault(self):
self.assertTrue(unittest.TestCase.longMessage)
def test_formatMsg(self):
self.assertEqual(self.testableFalse._formatMessage(None, "foo"), "foo")
self.assertEqual(self.testableFalse._formatMessage("foo", "bar"), "foo")
self.assertEqual(self.testableTrue._formatMessage(None, "foo"), "foo")
self.assertEqual(self.testableTrue._formatMessage("foo", "bar"), "bar : foo")
# This blows up if _formatMessage uses string concatenation
self.testableTrue._formatMessage(object(), "foo")
def test_formatMessage_unicode_error(self):
one = "".join(chr(i) for i in range(255))
# this used to cause a UnicodeDecodeError constructing msg
self.testableTrue._formatMessage(one, "\uFFFD")
def assertMessages(self, methodName, args, errors):
"""
Check that methodName(*args) raises the correct error messages.
errors should be a list of 4 regex that match the error when:
1) longMessage = False and no msg passed;
2) longMessage = False and msg passed;
3) longMessage = True and no msg passed;
4) longMessage = True and msg passed;
"""
def getMethod(i):
useTestableFalse = i < 2
if useTestableFalse:
test = self.testableFalse
else:
test = self.testableTrue
return getattr(test, methodName)
for i, expected_regex in enumerate(errors):
testMethod = getMethod(i)
kwargs = {}
withMsg = i % 2
if withMsg:
kwargs = {"msg": "oops"}
# with self.assertRaisesRegex(
# self.failureException, expected_regex=expected_regex
# ):
# testMethod(*args, **kwargs)
with self.assertRaises(self.failureException) as cm:
testMethod(*args, **kwargs)
self.assertRegex(str(cm.exception), expected_regex)
@make_dynamo_test
def testAssertTrue(self):
self.assertMessages(
"assertTrue",
(False,),
[
"False is not true",
"oops",
"False is not true",
"False is not true : oops",
],
)
@make_dynamo_test
def testAssertFalse(self):
self.assertMessages(
"assertFalse",
(True,),
[
"True is not false",
"oops",
"True is not false",
"True is not false : oops",
],
)
@make_dynamo_test
def testNotEqual(self):
self.assertMessages(
"assertNotEqual", (1, 1), ["1 == 1", "oops", "1 == 1", "1 == 1 : oops"]
)
@make_dynamo_test
def testAlmostEqual(self):
self.assertMessages(
"assertAlmostEqual",
(1, 2),
[
r"^1 != 2 within 7 places \(1 difference\)$",
"^oops$",
r"^1 != 2 within 7 places \(1 difference\)$",
r"^1 != 2 within 7 places \(1 difference\) : oops$",
],
)
@make_dynamo_test
def testNotAlmostEqual(self):
self.assertMessages(
"assertNotAlmostEqual",
(1, 1),
[
"^1 == 1 within 7 places$",
"^oops$",
"^1 == 1 within 7 places$",
"^1 == 1 within 7 places : oops$",
],
)
@make_dynamo_test
def test_baseAssertEqual(self):
self.assertMessages(
"_baseAssertEqual",
(1, 2),
["^1 != 2$", "^oops$", "^1 != 2$", "^1 != 2 : oops$"],
)
@unittest.expectedFailure
@make_dynamo_test
def testAssertSequenceEqual(self):
# Error messages are multiline so not testing on full message
# assertTupleEqual and assertListEqual delegate to this method
self.assertMessages(
"assertSequenceEqual",
([], [None]),
[r"\+ \[None\]$", "^oops$", r"\+ \[None\]$", r"\+ \[None\] : oops$"],
)
@make_dynamo_test
def testAssertSetEqual(self):
self.assertMessages(
"assertSetEqual",
(set(), set([None])), # noqa: C405
["None$", "^oops$", "None$", "None : oops$"],
)
@make_dynamo_test
def testAssertIn(self):
self.assertMessages(
"assertIn",
(None, []),
[
r"^None not found in \[\]$",
"^oops$",
r"^None not found in \[\]$",
r"^None not found in \[\] : oops$",
],
)
@make_dynamo_test
def testAssertNotIn(self):
self.assertMessages(
"assertNotIn",
(None, [None]),
[
r"^None unexpectedly found in \[None\]$",
"^oops$",
r"^None unexpectedly found in \[None\]$",
r"^None unexpectedly found in \[None\] : oops$",
],
)
@unittest.expectedFailure
@make_dynamo_test
def testAssertDictEqual(self):
self.assertMessages(
"assertDictEqual",
({}, {"key": "value"}),
[
r"\+ \{'key': 'value'\}$",
"^oops$",
r"\+ \{'key': 'value'\}$",
r"\+ \{'key': 'value'\} : oops$",
],
)
@unittest.expectedFailure
@make_dynamo_test
def testAssertMultiLineEqual(self):
self.assertMessages(
"assertMultiLineEqual",
("", "foo"),
[r"\+ foo\n$", "^oops$", r"\+ foo\n$", r"\+ foo\n : oops$"],
)
@make_dynamo_test
def testAssertLess(self):
self.assertMessages(
"assertLess",
(2, 1),
[
"^2 not less than 1$",
"^oops$",
"^2 not less than 1$",
"^2 not less than 1 : oops$",
],
)
@make_dynamo_test
def testAssertLessEqual(self):
self.assertMessages(
"assertLessEqual",
(2, 1),
[
"^2 not less than or equal to 1$",
"^oops$",
"^2 not less than or equal to 1$",
"^2 not less than or equal to 1 : oops$",
],
)
@make_dynamo_test
def testAssertGreater(self):
self.assertMessages(
"assertGreater",
(1, 2),
[
"^1 not greater than 2$",
"^oops$",
"^1 not greater than 2$",
"^1 not greater than 2 : oops$",
],
)
@make_dynamo_test
def testAssertGreaterEqual(self):
self.assertMessages(
"assertGreaterEqual",
(1, 2),
[
"^1 not greater than or equal to 2$",
"^oops$",
"^1 not greater than or equal to 2$",
"^1 not greater than or equal to 2 : oops$",
],
)
@make_dynamo_test
def testAssertIsNone(self):
self.assertMessages(
"assertIsNone",
("not None",),
[
"^'not None' is not None$",
"^oops$",
"^'not None' is not None$",
"^'not None' is not None : oops$",
],
)
@make_dynamo_test
def testAssertIsNotNone(self):
self.assertMessages(
"assertIsNotNone",
(None,),
[
"^unexpectedly None$",
"^oops$",
"^unexpectedly None$",
"^unexpectedly None : oops$",
],
)
@make_dynamo_test
def testAssertIs(self):
self.assertMessages(
"assertIs",
(None, "foo"),
[
"^None is not 'foo'$",
"^oops$",
"^None is not 'foo'$",
"^None is not 'foo' : oops$",
],
)
@make_dynamo_test
def testAssertIsNot(self):
self.assertMessages(
"assertIsNot",
(None, None),
[
"^unexpectedly identical: None$",
"^oops$",
"^unexpectedly identical: None$",
"^unexpectedly identical: None : oops$",
],
)
@make_dynamo_test
def testAssertRegex(self):
self.assertMessages(
"assertRegex",
("foo", "bar"),
[
"^Regex didn't match:",
"^oops$",
"^Regex didn't match:",
"^Regex didn't match: (.*) : oops$",
],
)
@make_dynamo_test
def testAssertNotRegex(self):
self.assertMessages(
"assertNotRegex",
("foo", "foo"),
[
"^Regex matched:",
"^oops$",
"^Regex matched:",
"^Regex matched: (.*) : oops$",
],
)
def assertMessagesCM(self, methodName, args, func, errors):
"""
Check that the correct error messages are raised while executing:
with method(*args):
func()
*errors* should be a list of 4 regex that match the error when:
1) longMessage = False and no msg passed;
2) longMessage = False and msg passed;
3) longMessage = True and no msg passed;
4) longMessage = True and msg passed;
"""
p = product((self.testableFalse, self.testableTrue), ({}, {"msg": "oops"}))
for (cls, kwargs), err in zip(p, errors):
method = getattr(cls, methodName)
# with self.assertRaisesRegex(cls.failureException, err):
with self.assertRaises(cls.failureException) as c:
with method(*args, **kwargs) as cm: # noqa: F841
func()
self.assertRegex(str(c.exception), err)
@make_dynamo_test
def testAssertRaises(self):
self.assertMessagesCM(
"assertRaises",
(TypeError,),
lambda: None,
[
"^TypeError not raised$",
"^oops$",
"^TypeError not raised$",
"^TypeError not raised : oops$",
],
)
@unittest.expectedFailure
@make_dynamo_test
def testAssertRaisesRegex(self):
self.assertMessagesCM(
"assertRaisesRegex",
(TypeError, "unused regex"),
lambda: None,
[
"^TypeError not raised$",
"^oops$",
"^TypeError not raised$",
"^TypeError not raised : oops$",
],
)
# test error raised but with wrong message
def raise_wrong_message():
raise TypeError("foo")
self.assertMessagesCM(
"assertRaisesRegex",
(TypeError, "regex"),
raise_wrong_message,
[
'^"regex" does not match "foo"$',
"^oops$",
'^"regex" does not match "foo"$',
'^"regex" does not match "foo" : oops$',
],
)
@unittest.expectedFailure
@make_dynamo_test
def testAssertWarns(self):
self.assertMessagesCM(
"assertWarns",
(UserWarning,),
lambda: None,
[
"^UserWarning not triggered$",
"^oops$",
"^UserWarning not triggered$",
"^UserWarning not triggered : oops$",
],
)
@unittest.expectedFailure
@unittest.skipIf(sys.version_info < (3, 13), "feature landed in 3.13")
@make_dynamo_test
def test_assertNotWarns(self):
def warn_future():
warnings.warn("xyz", FutureWarning, stacklevel=2)
self.assertMessagesCM(
"_assertNotWarns",
(FutureWarning,),
warn_future,
[
"^FutureWarning triggered$",
"^oops$",
"^FutureWarning triggered$",
"^FutureWarning triggered : oops$",
],
)
@unittest.expectedFailure
@make_dynamo_test
def testAssertWarnsRegex(self):
# test error not raised
self.assertMessagesCM(
"assertWarnsRegex",
(UserWarning, "unused regex"),
lambda: None,
[
"^UserWarning not triggered$",
"^oops$",
"^UserWarning not triggered$",
"^UserWarning not triggered : oops$",
],
)
# test warning raised but with wrong message
def raise_wrong_message():
warnings.warn("foo")
self.assertMessagesCM(
"assertWarnsRegex",
(UserWarning, "regex"),
raise_wrong_message,
[
'^"regex" does not match "foo"$',
"^oops$",
'^"regex" does not match "foo"$',
'^"regex" does not match "foo" : oops$',
],
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1593,6 +1593,13 @@ def get_selected_tests(options) -> list[str]:
]
)
if sys.version_info[:2] < (3, 13):
# Skip tests for older Python versions as they may use syntax or features
# not supported in those versions
options.exclude.extend(
[test for test in selected_tests if test.startswith("dynamo/cpython/3_13/")]
)
selected_tests = exclude_tests(options.exclude, selected_tests)
if sys.platform == "win32" and not options.ignore_win_blocklist:

View File

@ -1,3 +1,5 @@
# mypy: allow-untyped-defs
"""Testing utilities for Dynamo, providing a specialized TestCase class and test running functionality.
This module extends PyTorch's testing framework with Dynamo-specific testing capabilities.
@ -10,8 +12,13 @@ It includes:
import contextlib
import importlib
import inspect
import logging
import os
import pathlib
import re
import sys
import unittest
from typing import Union
import torch
@ -98,7 +105,70 @@ class TestCase(TorchTestCase):
class CPythonTestCase(TestCase):
"""
Test class for CPython tests located in "test/dynamo/CPython/Py_version/*".
This class enables specific features that are disabled by default, such as
tracing through unittest methods.
"""
_stack: contextlib.ExitStack
dynamo_strict_nopython = True
# Restore original unittest methods to simplify tracing CPython test cases.
assertEqual = unittest.TestCase.assertEqual # type: ignore[assignment]
assertNotEqual = unittest.TestCase.assertNotEqual # type: ignore[assignment]
assertTrue = unittest.TestCase.assertTrue
assertFalse = unittest.TestCase.assertFalse
assertIs = unittest.TestCase.assertIs
assertIsNot = unittest.TestCase.assertIsNot
assertIsNone = unittest.TestCase.assertIsNone
assertIsNotNone = unittest.TestCase.assertIsNotNone
assertIn = unittest.TestCase.assertIn
assertNotIn = unittest.TestCase.assertNotIn
assertIsInstance = unittest.TestCase.assertIsInstance
assertNotIsInstance = unittest.TestCase.assertNotIsInstance
assertAlmostEqual = unittest.TestCase.assertAlmostEqual
assertNotAlmostEqual = unittest.TestCase.assertNotAlmostEqual
assertGreater = unittest.TestCase.assertGreater
assertGreaterEqual = unittest.TestCase.assertGreaterEqual
assertLess = unittest.TestCase.assertLess
assertLessEqual = unittest.TestCase.assertLessEqual
assertRegex = unittest.TestCase.assertRegex
assertNotRegex = unittest.TestCase.assertNotRegex
assertCountEqual = unittest.TestCase.assertCountEqual
assertMultiLineEqual = unittest.TestCase.assertMultiLineEqual
assertSequenceEqual = unittest.TestCase.assertSequenceEqual
assertListEqual = unittest.TestCase.assertListEqual
assertTupleEqual = unittest.TestCase.assertTupleEqual
assertSetEqual = unittest.TestCase.assertSetEqual
assertDictEqual = unittest.TestCase.assertDictEqual
assertRaises = unittest.TestCase.assertRaises
assertRaisesRegex = unittest.TestCase.assertRaisesRegex
assertWarns = unittest.TestCase.assertWarns
assertWarnsRegex = unittest.TestCase.assertWarnsRegex
assertLogs = unittest.TestCase.assertLogs
fail = unittest.TestCase.fail
failureException = unittest.TestCase.failureException
def compile_fn(self, fn, backend, nopython):
# We want to compile only the test function, excluding any setup code
# from unittest
method = getattr(self, self._testMethodName)
method = torch._dynamo.optimize(backend, nopython=nopython)(method)
setattr(self, self._testMethodName, method)
return fn
def _dynamo_test_key(self):
suffix = super()._dynamo_test_key()
test_cls = self.__class__
test_file = inspect.getfile(test_cls).split(os.sep)[-1].split(".")[0]
py_ver = re.search(r"/([\d_]+)/", inspect.getfile(test_cls))
if py_ver:
py_ver = py_ver.group().strip(os.sep).replace("_", "") # type: ignore[assignment]
else:
return suffix
return f"CPython{py_ver}-{test_file}-{suffix}"
@classmethod
def tearDownClass(cls) -> None:
@ -107,6 +177,24 @@ class CPythonTestCase(TestCase):
@classmethod
def setUpClass(cls) -> None:
# Skip test if python versions doesn't match
normalized_path = pathlib.PurePath("dynamo/cpython").as_posix()
regex = re.escape(normalized_path) + r"\b\d+_\d{2}\b"
m = re.search(regex, inspect.getfile(cls))
if m:
test_py_ver = tuple(map(int, m.group().split("_")))
py_ver = sys.version_info[:2]
if py_ver != test_py_ver:
expected = ".".join(map(str, test_py_ver))
got = ".".join(map(str, py_ver))
raise unittest.SkipTest(
f"Test requires Python {expected} but got Python {got}"
)
else:
raise unittest.SkipTest(
f"Test requires a specific Python version but not found in path {inspect.getfile(cls)}"
)
super().setUpClass()
cls._stack = contextlib.ExitStack() # type: ignore[attr-defined]
cls._stack.enter_context( # type: ignore[attr-defined]

View File

@ -1989,11 +1989,11 @@ class BuiltinVariable(VariableTracker):
)
):
unimplemented_v2(
gb_type="Failed to trace builtin operator",
gb_type="Failed to trace unittest method",
context=f"function: unittest.TestCase.{name}",
explanation=f"Dynamo does not know how to trace builtin operator `{name}` ",
explanation=f"Dynamo does not know how to trace unittest method `{name}` ",
hints=[
f"Avoid calling builtin `{name}`. "
f"Avoid calling `TestCase.{name}`. "
"Please report an issue to PyTorch.",
],
)

View File

@ -3157,6 +3157,13 @@ class TestCase(expecttest.TestCase):
def wrap_with_cuda_memory_check(self, method):
return self.wrap_method_with_policy(method, self.assertLeaksNoCudaTensors)
def _dynamo_test_key(self):
return f"{self.__class__.__name__}.{self._testMethodName}"
def compile_fn(self, fn, backend, nopython):
# Allows subclasses to control compilation
return torch._dynamo.optimize(backend, nopython=nopython)(fn)
def _run_custom(self, result=None):
using_unittest = isinstance(result, unittest.TestResult)
@ -3232,16 +3239,16 @@ class TestCase(expecttest.TestCase):
with unittest.mock.patch("torch._dynamo.config.suppress_errors", suppress_errors), maybe_disable_size_asserts:
if TEST_WITH_AOT_EAGER:
super_run = torch._dynamo.optimize("aot_eager_decomp_partition")(super_run)
super_run = self.compile_fn(super_run, "aot_eager_decomp_partition", nopython)
elif TEST_WITH_TORCHDYNAMO or TEST_WITH_TORCHINDUCTOR:
if TEST_WITH_TORCHINDUCTOR:
super_run = torch._dynamo.optimize("inductor")(super_run)
super_run = self.compile_fn(super_run, "inductor", nopython)
else:
# Assume eager-generated GraphModules will not error out.
# If we do, this is probably a Dynamo bug!
super_run = torch._dynamo.optimize("eager_noexcept", nopython=nopython)(super_run)
super_run = self.compile_fn(super_run, "eager_noexcept", nopython)
key = f"{self.__class__.__name__}.{self._testMethodName}"
key = self._dynamo_test_key()
def expect_failure(f, file_name):
@wraps(f)