mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add infra to run CPython tests under Dynamo (#150787)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150787 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
13fbf21a76
commit
ae1e51b6ad
@ -18,6 +18,8 @@ exclude_patterns = [
|
||||
'torch/_inductor/autoheuristic/artifacts/**',
|
||||
'scripts/**',
|
||||
'test/generated_type_hints_smoketest.py',
|
||||
# CPython tests
|
||||
'test/dynamo/cpython/**',
|
||||
# Tests from the NumPy test suite
|
||||
'test/torch_np/numpy_test/**/*.py',
|
||||
'third_party/**',
|
||||
@ -398,6 +400,7 @@ exclude_patterns=[
|
||||
'tools/clang_format_hash/**',
|
||||
'test/cpp/jit/upgrader_models/*.ptl',
|
||||
'test/cpp/jit/upgrader_models/*.ptl.ff',
|
||||
'test/dynamo/cpython/**',
|
||||
'**/*.png',
|
||||
'**/*.gz',
|
||||
'**/*.patch',
|
||||
@ -936,6 +939,7 @@ include_patterns = [
|
||||
exclude_patterns = [
|
||||
'test/run_test.py',
|
||||
'**/fb/**',
|
||||
'test/dynamo/cpython/3.13/**',
|
||||
'test/quantization/**', # should be run through test/test_quantization.py
|
||||
'test/jit/**', # should be run through test/test_jit.py
|
||||
'test/ao/sparsity/**', # should be run through test/test_ao_sparsity.py
|
||||
@ -1131,6 +1135,7 @@ exclude_patterns = [
|
||||
'caffe2/**/*.pyi',
|
||||
'fb/**',
|
||||
'**/fb/**',
|
||||
'test/dynamo/cpython/**',
|
||||
'third_party/**/*.py',
|
||||
'third_party/**/*.pyi',
|
||||
'torch/_vendor/**',
|
||||
@ -1536,6 +1541,7 @@ exclude_patterns = [
|
||||
'functorch/notebooks/**',
|
||||
'torch/_inductor/fx_passes/serialized_patterns/**',
|
||||
'torch/_inductor/autoheuristic/artifacts/**',
|
||||
'test/dynamo/cpython/**',
|
||||
'scripts/**',
|
||||
'third_party/**',
|
||||
'fb/**',
|
||||
|
9
test/dynamo/cpython/3_13/CHANGES.txt
Normal file
9
test/dynamo/cpython/3_13/CHANGES.txt
Normal file
@ -0,0 +1,9 @@
|
||||
This subdirectory contains a selection of tests from the CPython repository (branch: v3.13.0):\
|
||||
https://github.com/python/cpython/releases/tag/v3.13.0
|
||||
|
||||
Modifications were made to ensure compatibility with the Dynamo infrastructure:
|
||||
+ Monkey-patched `unittest.TestCase` to `torch._dynamo.test_case.CPythonTestCase`.
|
||||
+ Replaced `unittest.main()` with `torch._dynamo.test_case.run_tests()`.
|
||||
+ Assigned test "owners."
|
||||
+ Annotated CPU-intensive tests with the `@slowTest` decorator.
|
||||
+ Adjusted imports to use `import module` instead of `from test import module`.
|
46
test/dynamo/cpython/3_13/LICENSE
Normal file
46
test/dynamo/cpython/3_13/LICENSE
Normal file
@ -0,0 +1,46 @@
|
||||
PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
|
||||
--------------------------------------------
|
||||
|
||||
1. This LICENSE AGREEMENT is between the Python Software Foundation
|
||||
("PSF"), and the Individual or Organization ("Licensee") accessing and
|
||||
otherwise using this software ("Python") in source or binary form and
|
||||
its associated documentation.
|
||||
|
||||
2. Subject to the terms and conditions of this License Agreement, PSF hereby
|
||||
grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce,
|
||||
analyze, test, perform and/or display publicly, prepare derivative works,
|
||||
distribute, and otherwise use Python alone or in any derivative version,
|
||||
provided, however, that PSF's License Agreement and PSF's notice of copyright,
|
||||
i.e., "Copyright (c) 2001 Python Software Foundation; All Rights Reserved"
|
||||
are retained in Python alone or in any derivative version prepared by Licensee.
|
||||
|
||||
3. In the event Licensee prepares a derivative work that is based on
|
||||
or incorporates Python or any part thereof, and wants to make
|
||||
the derivative work available to others as provided herein, then
|
||||
Licensee hereby agrees to include in any such work a brief summary of
|
||||
the changes made to Python.
|
||||
|
||||
4. PSF is making Python available to Licensee on an "AS IS"
|
||||
basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
|
||||
IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
|
||||
DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
|
||||
FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT
|
||||
INFRINGE ANY THIRD PARTY RIGHTS.
|
||||
|
||||
5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
|
||||
FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
|
||||
A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON,
|
||||
OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
|
||||
|
||||
6. This License Agreement will automatically terminate upon a material
|
||||
breach of its terms and conditions.
|
||||
|
||||
7. Nothing in this License Agreement shall be deemed to create any
|
||||
relationship of agency, partnership, or joint venture between PSF and
|
||||
Licensee. This License Agreement does not grant permission to use PSF
|
||||
trademarks or trade name in a trademark sense to endorse or promote
|
||||
products or services of Licensee, or any third party.
|
||||
|
||||
8. By copying, installing or otherwise using Python, Licensee
|
||||
agrees to be bound by the terms and conditions of this License
|
||||
Agreement.
|
@ -1,7 +1,6 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import contextlib
|
||||
import sys
|
||||
import traceback
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
|
||||
@ -9,18 +8,12 @@ import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.exc import InternalTorchDynamoError
|
||||
from torch._dynamo.testing import (
|
||||
EagerAndRecordGraphs,
|
||||
normalize_gm,
|
||||
same,
|
||||
skipIfNotPy311,
|
||||
)
|
||||
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm, same
|
||||
from torch._dynamo.utils import counters
|
||||
from torch.nn import functional as F
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
make_dynamo_test,
|
||||
parametrize,
|
||||
)
|
||||
|
||||
@ -37,6 +30,16 @@ z_glb = 0
|
||||
k_glb = 0
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_default_dtype(dtype):
|
||||
old_dtype = torch.get_default_dtype()
|
||||
try:
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
finally:
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
|
||||
class CustomizedCtxManager:
|
||||
def __init__(self, mode):
|
||||
self.prev = torch.is_grad_enabled()
|
||||
@ -2700,319 +2703,6 @@ class GraphModule(torch.nn.Module):
|
||||
self.assertEqual(y, t.sin())
|
||||
|
||||
|
||||
class CPythonContextManagerTestCase(torch._dynamo.test_case.CPythonTestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_contextlib.py
|
||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_contextlib.py
|
||||
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_plain(self):
|
||||
state = []
|
||||
|
||||
@contextmanager
|
||||
def woohoo():
|
||||
state.append(1)
|
||||
yield 42
|
||||
state.append(999)
|
||||
|
||||
with woohoo() as x:
|
||||
self.assertEqual(state, [1])
|
||||
self.assertEqual(x, 42)
|
||||
state.append(x)
|
||||
self.assertEqual(state, [1, 42, 999])
|
||||
|
||||
@skipIfNotPy311
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_finally(self):
|
||||
state = []
|
||||
|
||||
@contextmanager
|
||||
def woohoo():
|
||||
state.append(1)
|
||||
try:
|
||||
yield 42
|
||||
finally:
|
||||
state.append(999)
|
||||
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
with woohoo() as x:
|
||||
self.assertEqual(state, [1])
|
||||
self.assertEqual(x, 42)
|
||||
state.append(x)
|
||||
raise ZeroDivisionError
|
||||
self.assertEqual(state, [1, 42, 999])
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_traceback(self):
|
||||
@contextmanager
|
||||
def f():
|
||||
yield
|
||||
|
||||
try:
|
||||
with f():
|
||||
1 / 0
|
||||
except ZeroDivisionError as e:
|
||||
frames = traceback.extract_tb(e.__traceback__)
|
||||
|
||||
self.assertEqual(len(frames), 1)
|
||||
self.assertEqual(frames[0].name, "test_contextmanager_traceback")
|
||||
self.assertEqual(frames[0].line, "1/0")
|
||||
|
||||
# Repeat with RuntimeError (which goes through a different code path)
|
||||
try:
|
||||
with f():
|
||||
raise NotImplementedError(42)
|
||||
except NotImplementedError as e:
|
||||
frames = traceback.extract_tb(e.__traceback__)
|
||||
|
||||
self.assertEqual(len(frames), 1)
|
||||
self.assertEqual(frames[0].name, "test_contextmanager_traceback")
|
||||
self.assertEqual(frames[0].line, "raise NotImplementedError(42)")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_no_reraise(self):
|
||||
@contextmanager
|
||||
def whee():
|
||||
yield
|
||||
|
||||
ctx = whee()
|
||||
ctx.__enter__()
|
||||
# Calling __exit__ should not result in an exception
|
||||
self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
|
||||
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_trap_yield_after_throw(self):
|
||||
@contextmanager
|
||||
def whoo():
|
||||
try:
|
||||
yield
|
||||
except Exception: # noqa: E722
|
||||
yield
|
||||
|
||||
ctx = whoo()
|
||||
ctx.__enter__()
|
||||
with self.assertRaises(RuntimeError):
|
||||
ctx.__exit__(TypeError, TypeError("foo"), None)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
||||
def test_contextmanager_except(self):
|
||||
state = []
|
||||
|
||||
@contextmanager
|
||||
def woohoo():
|
||||
state.append(1)
|
||||
try:
|
||||
yield 42
|
||||
except ZeroDivisionError as e:
|
||||
state.append(e.args[0])
|
||||
self.assertEqual(state, [1, 42, 999])
|
||||
|
||||
with woohoo() as x:
|
||||
self.assertEqual(state, [1])
|
||||
self.assertEqual(x, 42)
|
||||
state.append(x)
|
||||
raise ZeroDivisionError(999)
|
||||
self.assertEqual(state, [1, 42, 999])
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_except_stopiter(self):
|
||||
@contextmanager
|
||||
def woohoo():
|
||||
yield
|
||||
|
||||
class StopIterationSubclass(StopIteration):
|
||||
pass
|
||||
|
||||
for stop_exc in (StopIteration("spam"), StopIterationSubclass("spam")):
|
||||
with self.subTest(type=type(stop_exc)):
|
||||
try:
|
||||
with woohoo():
|
||||
raise stop_exc
|
||||
except Exception as ex:
|
||||
self.assertIs(ex, stop_exc)
|
||||
else:
|
||||
self.fail(f"{stop_exc} was suppressed")
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_except_pep479(self):
|
||||
code = """\
|
||||
from __future__ import generator_stop
|
||||
from contextlib import contextmanager
|
||||
@contextmanager
|
||||
def woohoo():
|
||||
yield
|
||||
"""
|
||||
locals = {}
|
||||
exec(code, locals, locals)
|
||||
woohoo = locals["woohoo"]
|
||||
|
||||
stop_exc = StopIteration("spam")
|
||||
try:
|
||||
with woohoo():
|
||||
raise stop_exc
|
||||
except Exception as ex:
|
||||
self.assertIs(ex, stop_exc)
|
||||
else:
|
||||
self.fail("StopIteration was suppressed")
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_do_not_unchain_non_stopiteration_exceptions(self):
|
||||
@contextmanager
|
||||
def test_issue29692():
|
||||
try:
|
||||
yield
|
||||
except Exception as exc:
|
||||
raise RuntimeError("issue29692:Chained") from exc
|
||||
|
||||
try:
|
||||
with test_issue29692():
|
||||
raise ZeroDivisionError
|
||||
except Exception as ex:
|
||||
self.assertIs(type(ex), RuntimeError)
|
||||
self.assertEqual(ex.args[0], "issue29692:Chained")
|
||||
self.assertIsInstance(ex.__cause__, ZeroDivisionError)
|
||||
|
||||
try:
|
||||
with test_issue29692():
|
||||
raise StopIteration("issue29692:Unchained")
|
||||
except Exception as ex:
|
||||
self.assertIs(type(ex), StopIteration)
|
||||
self.assertEqual(ex.args[0], "issue29692:Unchained")
|
||||
self.assertIsNone(ex.__cause__)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def _create_contextmanager_attribs(self):
|
||||
def attribs(**kw):
|
||||
def decorate(func):
|
||||
for k, v in kw.items():
|
||||
setattr(func, k, v)
|
||||
return func
|
||||
|
||||
return decorate
|
||||
|
||||
@contextmanager
|
||||
@attribs(foo="bar")
|
||||
def baz(spam):
|
||||
"""Whee!"""
|
||||
|
||||
return baz
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_attribs(self):
|
||||
baz = self._create_contextmanager_attribs()
|
||||
self.assertEqual(baz.__name__, "baz")
|
||||
self.assertEqual(baz.foo, "bar")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_keywords(self):
|
||||
# Ensure no keyword arguments are inhibited
|
||||
@contextmanager
|
||||
def woohoo(self, func, args, kwds):
|
||||
yield (self, func, args, kwds)
|
||||
|
||||
with woohoo(self=11, func=22, args=33, kwds=44) as target:
|
||||
self.assertEqual(target, (11, 22, 33, 44))
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_param_errors(self):
|
||||
@contextmanager
|
||||
def woohoo(a, *, b):
|
||||
yield
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
woohoo()
|
||||
with self.assertRaises(TypeError):
|
||||
woohoo(3, 5)
|
||||
with self.assertRaises(TypeError):
|
||||
woohoo(b=3)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_recursive(self):
|
||||
depth = 0
|
||||
|
||||
@contextmanager
|
||||
def woohoo():
|
||||
nonlocal depth
|
||||
before = depth
|
||||
depth += 1
|
||||
yield
|
||||
depth -= 1
|
||||
self.assertEqual(depth, before)
|
||||
|
||||
@woohoo()
|
||||
def recursive():
|
||||
if depth < 10:
|
||||
recursive()
|
||||
|
||||
recursive()
|
||||
self.assertEqual(depth, 0)
|
||||
|
||||
@skipIfNotPy311
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_trap_no_yield(self):
|
||||
@contextmanager
|
||||
def whoo():
|
||||
if False:
|
||||
yield
|
||||
|
||||
ctx = whoo()
|
||||
with self.assertRaises(RuntimeError):
|
||||
ctx.__enter__()
|
||||
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_trap_second_yield(self):
|
||||
@contextmanager
|
||||
def whoo():
|
||||
yield
|
||||
yield
|
||||
|
||||
ctx = whoo()
|
||||
ctx.__enter__()
|
||||
with self.assertRaises(RuntimeError):
|
||||
ctx.__exit__(None, None, None)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_wrap_runtimeerror(self):
|
||||
@contextmanager
|
||||
def woohoo():
|
||||
try:
|
||||
yield
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"caught {exc}") from exc
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
with woohoo():
|
||||
1 / 0
|
||||
|
||||
# If the context manager wrapped StopIteration in a RuntimeError,
|
||||
# we also unwrap it, because we can't tell whether the wrapping was
|
||||
# done by the generator machinery or by the generator itself.
|
||||
with self.assertRaises(StopIteration):
|
||||
with woohoo():
|
||||
raise StopIteration
|
||||
|
||||
@make_dynamo_test
|
||||
def test_contextmanager_non_normalised(self):
|
||||
@contextmanager
|
||||
def whoo():
|
||||
try:
|
||||
yield
|
||||
except RuntimeError:
|
||||
raise SyntaxError # noqa: B904
|
||||
|
||||
ctx = whoo()
|
||||
ctx.__enter__()
|
||||
with self.assertRaises(SyntaxError):
|
||||
ctx.__exit__(RuntimeError, None, None)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(CtxManagerTests)
|
||||
instantiate_parametrized_tests(ContextlibContextManagerTests)
|
||||
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
import contextlib
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo.config
|
||||
@ -905,238 +904,6 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
||||
assert exc2.__context__ is None
|
||||
|
||||
|
||||
class CPythonExceptionTests(torch._dynamo.test_case.TestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_exceptions.py
|
||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_exceptions.py
|
||||
def setUp(self):
|
||||
self._u_prev = torch._dynamo.config.enable_trace_unittest
|
||||
torch._dynamo.config.enable_trace_unittest = True
|
||||
|
||||
def tearDown(self):
|
||||
torch._dynamo.config.enable_trace_unittest = self._u_prev
|
||||
|
||||
@make_dynamo_test
|
||||
def testChainingAttrs(self):
|
||||
e = Exception()
|
||||
assert e.__context__ is None
|
||||
assert e.__cause__ is None
|
||||
|
||||
e = TypeError()
|
||||
assert e.__context__ is None
|
||||
assert e.__cause__ is None
|
||||
|
||||
e = MyException()
|
||||
assert e.__context__ is None
|
||||
assert e.__cause__ is None
|
||||
|
||||
@make_dynamo_test
|
||||
def testChainingDescriptors(self):
|
||||
try:
|
||||
raise Exception # noqa: TRY002
|
||||
except Exception as exc:
|
||||
e = exc
|
||||
|
||||
assert e.__context__ is None
|
||||
assert e.__cause__ is None
|
||||
assert e.__suppress_context__ is False
|
||||
|
||||
e.__context__ = NameError()
|
||||
e.__cause__ = None
|
||||
assert isinstance(e.__context__, NameError)
|
||||
assert e.__cause__ is None
|
||||
assert e.__suppress_context__ is True
|
||||
e.__suppress_context__ = False
|
||||
assert e.__suppress_context__ is False
|
||||
|
||||
@make_dynamo_test
|
||||
def test_context_of_exception_in_try_and_finally(self):
|
||||
try:
|
||||
try:
|
||||
te = TypeError(1)
|
||||
raise te
|
||||
finally:
|
||||
ve = ValueError(2)
|
||||
raise ve
|
||||
except Exception as e:
|
||||
exc = e
|
||||
|
||||
assert exc is ve
|
||||
assert exc.__context__ is te
|
||||
|
||||
@make_dynamo_test
|
||||
def test_context_of_exception_in_except_and_finally(self):
|
||||
try:
|
||||
try:
|
||||
te = TypeError(1)
|
||||
raise te
|
||||
except Exception: # noqa: E722
|
||||
ve = ValueError(2)
|
||||
raise ve # noqa: B904
|
||||
finally:
|
||||
oe = OSError(3)
|
||||
raise oe
|
||||
except Exception as e:
|
||||
exc = e
|
||||
|
||||
assert exc is oe
|
||||
assert exc.__context__ is ve
|
||||
assert exc.__context__.__context__ is te
|
||||
|
||||
@make_dynamo_test
|
||||
def test_context_of_exception_in_else_and_finally(self):
|
||||
try:
|
||||
try:
|
||||
pass
|
||||
except Exception: # noqa: E722
|
||||
pass
|
||||
else:
|
||||
ve = ValueError(1)
|
||||
raise ve
|
||||
finally:
|
||||
oe = OSError(2)
|
||||
raise oe
|
||||
except Exception as e:
|
||||
exc = e
|
||||
|
||||
assert exc is oe
|
||||
assert exc.__context__ is ve
|
||||
|
||||
@make_dynamo_test
|
||||
def test_raise_does_not_create_context_chain_cycle(self):
|
||||
A = AssertionError
|
||||
B = BytesWarning
|
||||
C = ConnectionError
|
||||
|
||||
# Create a context chain:
|
||||
# C -> B -> A
|
||||
# Then raise A in context of C.
|
||||
try:
|
||||
try:
|
||||
raise A
|
||||
except A as a_:
|
||||
a = a_
|
||||
try:
|
||||
raise B
|
||||
except B as b_:
|
||||
b = b_
|
||||
try:
|
||||
raise C
|
||||
except C as c_:
|
||||
c = c_
|
||||
self.assertIsInstance(a, A)
|
||||
self.assertIsInstance(b, B)
|
||||
self.assertIsInstance(c, C)
|
||||
self.assertIsNone(a.__context__)
|
||||
self.assertIs(b.__context__, a)
|
||||
self.assertIs(c.__context__, b)
|
||||
raise a # noqa: B904
|
||||
except A as e:
|
||||
exc = e
|
||||
|
||||
# Expect A -> C -> B, without cycle
|
||||
self.assertIs(exc, a)
|
||||
self.assertIs(a.__context__, c)
|
||||
self.assertIs(c.__context__, b)
|
||||
self.assertIsNone(b.__context__)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_no_hang_on_context_chain_cycle1(self):
|
||||
# See issue 25782. Cycle in context chain.
|
||||
|
||||
def cycle():
|
||||
try:
|
||||
raise ValueError(1)
|
||||
except ValueError as ex:
|
||||
ex.__context__ = ex
|
||||
raise TypeError(2) # noqa: B904
|
||||
|
||||
try:
|
||||
cycle()
|
||||
except Exception as e:
|
||||
exc = e
|
||||
|
||||
self.assertIsInstance(exc, TypeError)
|
||||
self.assertIsInstance(exc.__context__, ValueError)
|
||||
self.assertIs(exc.__context__.__context__, exc.__context__)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def test_no_hang_on_context_chain_cycle2(self):
|
||||
# See issue 25782. Cycle at head of context chain.
|
||||
|
||||
A = AssertionError
|
||||
B = BytesWarning
|
||||
C = ConnectionError
|
||||
|
||||
# Context cycle:
|
||||
# +-----------+
|
||||
# V |
|
||||
# C --> B --> A
|
||||
with self.assertRaises(C) as cm:
|
||||
try:
|
||||
raise A() # noqa: RSE102
|
||||
except A as _a:
|
||||
a = _a
|
||||
try:
|
||||
raise B() # noqa: RSE102
|
||||
except B as _b:
|
||||
b = _b
|
||||
try:
|
||||
raise C() # noqa: RSE102
|
||||
except C as _c:
|
||||
c = _c
|
||||
a.__context__ = c
|
||||
raise c # noqa: B904
|
||||
|
||||
self.assertIs(cm.exception, c)
|
||||
# Verify the expected context chain cycle
|
||||
self.assertIs(c.__context__, b)
|
||||
self.assertIs(b.__context__, a)
|
||||
self.assertIs(a.__context__, c)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_no_hang_on_context_chain_cycle3(self):
|
||||
# See issue 25782. Longer context chain with cycle.
|
||||
A = AssertionError
|
||||
B = BytesWarning
|
||||
C = ConnectionError
|
||||
D = DeprecationWarning
|
||||
E = Exception
|
||||
|
||||
# Context cycle:
|
||||
# +-----------+
|
||||
# V |
|
||||
# E --> D --> C --> B --> A
|
||||
with self.assertRaises(E) as cm:
|
||||
try:
|
||||
raise A
|
||||
except A as _a:
|
||||
a = _a
|
||||
try:
|
||||
raise B
|
||||
except B as _b:
|
||||
b = _b
|
||||
try:
|
||||
raise C
|
||||
except C as _c:
|
||||
c = _c
|
||||
a.__context__ = c
|
||||
try:
|
||||
raise D
|
||||
except D as _d:
|
||||
d = _d
|
||||
e = E()
|
||||
raise e # noqa: B904
|
||||
|
||||
self.assertIs(cm.exception, e)
|
||||
# Verify the expected context chain cycle
|
||||
self.assertIs(e.__context__, d)
|
||||
self.assertIs(d.__context__, c)
|
||||
self.assertIs(c.__context__, b)
|
||||
self.assertIs(b.__context__, a)
|
||||
self.assertIs(a.__context__, c)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ExceptionTests)
|
||||
|
||||
|
||||
|
@ -1481,331 +1481,6 @@ class TestGeneratorThrow(GeneratorTestsBase):
|
||||
self._compile_check(fn)
|
||||
|
||||
|
||||
class GeneratorCloseCPythonTests(GeneratorTestsBase):
|
||||
# Taken from commit
|
||||
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
|
||||
# changed the tests a little bit to run them inside dynamo
|
||||
# + replaced all self.assert* calls to plain assert statements
|
||||
|
||||
def test_close_no_return_value(self):
|
||||
def f():
|
||||
yield
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
gen = f()
|
||||
gen.send(None)
|
||||
assert gen.close() is None
|
||||
return t.sin()
|
||||
|
||||
t = torch.randn(2)
|
||||
fn(t)
|
||||
|
||||
def test_close_return_value(self):
|
||||
def f():
|
||||
try:
|
||||
yield
|
||||
# close() raises GeneratorExit here, which is caught
|
||||
except GeneratorExit:
|
||||
return 0 # noqa: B901
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
gen = f()
|
||||
gen.send(None)
|
||||
assert gen.close() == 0
|
||||
return t.sin()
|
||||
|
||||
t = torch.randn(2)
|
||||
fn(t)
|
||||
|
||||
def test_close_not_catching_exit(self):
|
||||
def f():
|
||||
yield
|
||||
# close() raises GeneratorExit here, which isn't caught and
|
||||
# therefore propagates -- no return value
|
||||
return 0 # noqa: B901
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
gen = f()
|
||||
gen.send(None)
|
||||
assert gen.close() is None
|
||||
return t.sin()
|
||||
|
||||
t = torch.randn(2)
|
||||
fn(t)
|
||||
|
||||
def test_close_not_started(self):
|
||||
def f():
|
||||
try:
|
||||
yield
|
||||
except GeneratorExit:
|
||||
return 0 # noqa: B901
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
gen = f()
|
||||
assert gen.close() is None
|
||||
return t.sin()
|
||||
|
||||
t = torch.randn(2)
|
||||
fn(t)
|
||||
|
||||
def test_close_exhausted(self):
|
||||
def f():
|
||||
try:
|
||||
yield
|
||||
except GeneratorExit:
|
||||
return 0 # noqa: B901
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
gen = f()
|
||||
next(gen)
|
||||
z = 0
|
||||
try:
|
||||
next(gen) # -> StopIteration
|
||||
except StopIteration:
|
||||
z = 1
|
||||
except Exception as e:
|
||||
# anything other than StopIteration should fail
|
||||
raise AssertionError from e
|
||||
assert z == 1
|
||||
assert gen.close() is None
|
||||
return t.sin()
|
||||
|
||||
t = torch.randn(2)
|
||||
fn(t)
|
||||
|
||||
def test_close_closed(self):
|
||||
def f():
|
||||
try:
|
||||
yield
|
||||
except GeneratorExit:
|
||||
return 0 # noqa: B901
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
gen = f()
|
||||
gen.send(None)
|
||||
assert gen.close() == 0
|
||||
assert gen.close() is None
|
||||
return t.sin()
|
||||
|
||||
t = torch.randn(2)
|
||||
fn(t)
|
||||
|
||||
def test_close_raises(self):
|
||||
def f():
|
||||
try:
|
||||
yield
|
||||
except GeneratorExit:
|
||||
pass
|
||||
raise RuntimeError
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
gen = f()
|
||||
gen.send(None)
|
||||
z = 0
|
||||
try:
|
||||
gen.close() # -> RuntimeError
|
||||
except RuntimeError:
|
||||
z = 1
|
||||
except Exception as e:
|
||||
raise AssertionError from e
|
||||
assert z == 1
|
||||
return t.sin()
|
||||
|
||||
t = torch.randn(2)
|
||||
fn(t)
|
||||
|
||||
|
||||
class GeneratorThrowCpythonTests(GeneratorTestsBase):
|
||||
# Taken from commit
|
||||
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
|
||||
# changed the tests a little bit to run them inside dynamo
|
||||
# + replaced all self.assert* calls to plain assert statements
|
||||
|
||||
def test_exception_context_with_yield(self):
|
||||
def f():
|
||||
try:
|
||||
raise KeyError("a")
|
||||
except Exception:
|
||||
yield
|
||||
|
||||
def fn(t):
|
||||
gen = f()
|
||||
gen.send(None)
|
||||
try:
|
||||
gen.throw(ValueError)
|
||||
except ValueError as e:
|
||||
context = e.__context__
|
||||
assert (type(context), context.args) == (KeyError, ("a",))
|
||||
except Exception as e:
|
||||
raise AssertionError from e
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
def test_exception_context_with_yield_inside_generator(self):
|
||||
# Check that the context is also available from inside the generator
|
||||
# with yield, as opposed to outside.
|
||||
def f():
|
||||
z = 0
|
||||
try:
|
||||
raise KeyError("a")
|
||||
except Exception:
|
||||
try:
|
||||
yield
|
||||
except Exception as exc:
|
||||
z = 1
|
||||
assert type(exc) == ValueError
|
||||
context = exc.__context__
|
||||
assert (type(context), context.args) == (KeyError, ("a",))
|
||||
yield "b"
|
||||
finally:
|
||||
assert z == 1
|
||||
|
||||
def fn(t):
|
||||
gen = f()
|
||||
gen.send(None)
|
||||
actual = gen.throw(ValueError)
|
||||
# This ensures that the assertions inside were executed.
|
||||
assert actual == "b"
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
def test_exception_context_with_yield_from(self):
|
||||
def f():
|
||||
yield
|
||||
|
||||
def g():
|
||||
try:
|
||||
raise KeyError("a")
|
||||
except Exception:
|
||||
yield from f()
|
||||
|
||||
def fn(t):
|
||||
gen = g()
|
||||
gen.send(None)
|
||||
try:
|
||||
gen.throw(ValueError)
|
||||
except ValueError as e:
|
||||
context = e.__context__
|
||||
assert (type(context), context.args) == (KeyError, ("a",))
|
||||
except Exception as e:
|
||||
raise AssertionError from e
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
def test_exception_context_with_yield_from_with_context_cycle(self):
|
||||
# Check trying to create an exception context cycle:
|
||||
# https://bugs.python.org/issue40696
|
||||
has_cycle = None
|
||||
|
||||
def f():
|
||||
yield
|
||||
|
||||
def g(exc):
|
||||
nonlocal has_cycle
|
||||
try:
|
||||
raise exc
|
||||
except Exception:
|
||||
try:
|
||||
yield from f()
|
||||
except Exception as exc:
|
||||
has_cycle = exc is exc.__context__
|
||||
yield
|
||||
|
||||
def fn(t):
|
||||
exc = KeyError("a")
|
||||
gen = g(exc)
|
||||
gen.send(None)
|
||||
gen.throw(exc)
|
||||
# This also distinguishes from the initial has_cycle=None.
|
||||
assert has_cycle is False
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
def test_throw_after_none_exc_type(self):
|
||||
def g():
|
||||
try:
|
||||
raise KeyError
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
yield
|
||||
except Exception:
|
||||
raise RuntimeError # noqa: B904
|
||||
|
||||
def fn(t):
|
||||
gen = g()
|
||||
gen.send(None)
|
||||
z = 0
|
||||
try:
|
||||
gen.throw(ValueError)
|
||||
except RuntimeError:
|
||||
z += 1
|
||||
except Exception:
|
||||
raise AssertionError # noqa: B904
|
||||
assert z == 1
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
|
||||
class GeneratorCPythonTests(GeneratorTestsBase):
|
||||
# Taken from commit
|
||||
# https://github.com/python/cpython/blob/d51a4ca1123e3e49e5cae4273355bdfd9e419a10
|
||||
# changed the tests a little bit to run them inside dynamo
|
||||
# + replaced all self.assert* calls to plain assert statements
|
||||
|
||||
def test_send_non_none_to_new_gen(self):
|
||||
def f():
|
||||
yield 1
|
||||
|
||||
def fn(t):
|
||||
g = f()
|
||||
z = 0
|
||||
try:
|
||||
g.send(0)
|
||||
except TypeError:
|
||||
z += 1
|
||||
except Exception as e:
|
||||
raise AssertionError from e
|
||||
assert z == 1
|
||||
assert next(g) == 1
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
def test_issue103488(self):
|
||||
def gen_raises():
|
||||
yield 1
|
||||
raise ValueError
|
||||
|
||||
def loop():
|
||||
try:
|
||||
for _ in gen_raises():
|
||||
if True is False: # noqa: PLR0133
|
||||
return
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def fn(t):
|
||||
# This should not raise
|
||||
loop()
|
||||
return t.sin()
|
||||
|
||||
self._compile_check(fn)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(GeneratorTests)
|
||||
instantiate_parametrized_tests(TestGeneratorSend)
|
||||
instantiate_parametrized_tests(TestGeneratorClose)
|
||||
|
@ -1,52 +0,0 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
from torch.testing._internal.common_utils import make_dynamo_test
|
||||
|
||||
|
||||
class TestPEP479(torch._dynamo.test_case.CPythonTestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_generator_stop.py
|
||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_generator_stop.py
|
||||
@unittest.skipIf(sys.version_info < (3, 12), "Test does not work in Python < 3.12")
|
||||
@make_dynamo_test
|
||||
def test_stopiteration_wrapping(self):
|
||||
def f():
|
||||
raise StopIteration
|
||||
|
||||
def g():
|
||||
yield f()
|
||||
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
next(g())
|
||||
self.assertEqual("generator raised StopIteration", str(cm.exception))
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 12), "Test does not work in Python < 3.12")
|
||||
@make_dynamo_test
|
||||
def test_stopiteration_wrapping_context(self):
|
||||
def f():
|
||||
raise StopIteration
|
||||
|
||||
def g():
|
||||
yield f()
|
||||
|
||||
try:
|
||||
next(g())
|
||||
except RuntimeError as exc:
|
||||
self.assertIs(type(exc.__cause__), StopIteration)
|
||||
self.assertIs(type(exc.__context__), StopIteration)
|
||||
self.assertTrue(exc.__suppress_context__)
|
||||
else:
|
||||
self.fail(
|
||||
"__cause__, __context__, or __suppress_context__ "
|
||||
"were not properly set"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
@ -1,563 +0,0 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
# ruff: noqa
|
||||
# flake8: noqa
|
||||
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo.config
|
||||
import torch._dynamo.test_case
|
||||
import torch._functorch.config
|
||||
import torch.nn
|
||||
import torch.utils.checkpoint
|
||||
from torch.testing._internal.common_utils import make_dynamo_test
|
||||
|
||||
|
||||
def get_tb():
|
||||
try:
|
||||
raise OSError()
|
||||
except:
|
||||
return sys.exc_info()[2]
|
||||
|
||||
|
||||
class Context:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_tb):
|
||||
return True
|
||||
|
||||
|
||||
class MyException(Exception):
|
||||
def __init__(self):
|
||||
raise RuntimeError()
|
||||
|
||||
|
||||
class ContextManager:
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, t, v, tb):
|
||||
raise NameError
|
||||
|
||||
|
||||
class TestRaise(torch._dynamo.test_case.CPythonTestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
|
||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
|
||||
@make_dynamo_test
|
||||
def test_invalid_reraise(self):
|
||||
try:
|
||||
raise
|
||||
except RuntimeError as e:
|
||||
self.assertIn("No active exception", str(e))
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_reraise(self):
|
||||
try:
|
||||
try:
|
||||
raise IndexError
|
||||
except IndexError as e:
|
||||
exc1 = e
|
||||
raise
|
||||
except IndexError as exc2:
|
||||
self.assertIs(exc1, exc2)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_except_reraise(self):
|
||||
def reraise():
|
||||
try:
|
||||
raise TypeError("foo")
|
||||
except:
|
||||
try:
|
||||
raise KeyError("caught")
|
||||
except KeyError:
|
||||
pass
|
||||
raise
|
||||
|
||||
self.assertRaises(TypeError, reraise)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_finally_reraise(self):
|
||||
def reraise():
|
||||
try:
|
||||
raise TypeError("foo")
|
||||
except:
|
||||
try:
|
||||
raise KeyError("caught")
|
||||
finally:
|
||||
raise
|
||||
|
||||
self.assertRaises(KeyError, reraise)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_nested_reraise(self):
|
||||
def nested_reraise():
|
||||
raise
|
||||
|
||||
def reraise():
|
||||
try:
|
||||
raise TypeError("foo")
|
||||
except:
|
||||
nested_reraise()
|
||||
|
||||
self.assertRaises(TypeError, reraise)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_raise_from_None(self):
|
||||
try:
|
||||
try:
|
||||
raise TypeError("foo")
|
||||
except:
|
||||
raise ValueError() from None
|
||||
except ValueError as e:
|
||||
self.assertIsInstance(e.__context__, TypeError)
|
||||
self.assertIsNone(e.__cause__)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_with_reraise1(self):
|
||||
def reraise():
|
||||
try:
|
||||
raise TypeError("foo")
|
||||
except:
|
||||
with Context():
|
||||
pass
|
||||
raise
|
||||
|
||||
self.assertRaises(TypeError, reraise)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_with_reraise2(self):
|
||||
def reraise():
|
||||
try:
|
||||
raise TypeError("foo")
|
||||
except:
|
||||
with Context():
|
||||
raise KeyError("caught")
|
||||
raise
|
||||
|
||||
self.assertRaises(TypeError, reraise)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_yield_reraise(self):
|
||||
def reraise():
|
||||
try:
|
||||
raise TypeError("foo")
|
||||
except:
|
||||
yield 1
|
||||
raise
|
||||
|
||||
g = reraise()
|
||||
next(g)
|
||||
self.assertRaises(TypeError, lambda: next(g))
|
||||
self.assertRaises(StopIteration, lambda: next(g))
|
||||
|
||||
@make_dynamo_test
|
||||
def test_erroneous_exception(self):
|
||||
try:
|
||||
raise MyException
|
||||
except RuntimeError:
|
||||
pass
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@unittest.expectedFailure # object
|
||||
@make_dynamo_test
|
||||
def test_new_returns_invalid_instance(self):
|
||||
# See issue #11627.
|
||||
class MyException2(Exception):
|
||||
def __new__(cls, *args):
|
||||
return object()
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
raise MyException2
|
||||
|
||||
@unittest.expectedFailure # Assertion with non-string message
|
||||
@make_dynamo_test
|
||||
def test_assert_with_tuple_arg(self):
|
||||
try:
|
||||
assert False, (3,)
|
||||
except AssertionError as e:
|
||||
self.assertEqual(str(e), "(3,)")
|
||||
|
||||
|
||||
class TestCause(torch._dynamo.test_case.TestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
|
||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
|
||||
def setUp(self):
|
||||
self._prev = torch._dynamo.config.enable_trace_unittest
|
||||
torch._dynamo.config.enable_trace_unittest = True
|
||||
|
||||
def tearDown(self):
|
||||
torch._dynamo.config.enable_trace_unittest = self._prev
|
||||
|
||||
@make_dynamo_test
|
||||
def testCauseSyntax(self):
|
||||
try:
|
||||
try:
|
||||
try:
|
||||
raise TypeError
|
||||
except Exception:
|
||||
raise ValueError from None
|
||||
except ValueError as exc:
|
||||
self.assertIsNone(exc.__cause__)
|
||||
self.assertTrue(exc.__suppress_context__)
|
||||
exc.__suppress_context__ = False
|
||||
raise exc
|
||||
except ValueError as exc:
|
||||
e = exc
|
||||
|
||||
self.assertIsNone(e.__cause__)
|
||||
self.assertFalse(e.__suppress_context__)
|
||||
self.assertIsInstance(e.__context__, TypeError)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_invalid_cause(self):
|
||||
try:
|
||||
raise IndexError from 5
|
||||
except TypeError as e:
|
||||
self.assertIn("exception cause", str(e))
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_class_cause(self):
|
||||
try:
|
||||
raise IndexError from KeyError
|
||||
except IndexError as e:
|
||||
self.assertIsInstance(e.__cause__, KeyError)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_instance_cause(self):
|
||||
cause = KeyError()
|
||||
try:
|
||||
raise IndexError from cause
|
||||
except IndexError as e:
|
||||
self.assertIs(e.__cause__, cause)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_erroneous_cause(self):
|
||||
try:
|
||||
raise IndexError from MyException
|
||||
except RuntimeError:
|
||||
pass
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
|
||||
class TestTraceback(torch._dynamo.test_case.TestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
|
||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
|
||||
def setUp(self):
|
||||
self._prev = torch._dynamo.config.enable_trace_unittest
|
||||
torch._dynamo.config.enable_trace_unittest = True
|
||||
|
||||
def tearDown(self):
|
||||
torch._dynamo.config.enable_trace_unittest = self._prev
|
||||
|
||||
@unittest.expectedFailure # Dynamo doesn't track traceback
|
||||
@make_dynamo_test
|
||||
def test_sets_traceback(self):
|
||||
try:
|
||||
raise IndexError()
|
||||
except IndexError as e:
|
||||
self.assertIsInstance(e.__traceback__, types.TracebackType)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@unittest.expectedFailure # Dynamo doesn't track traceback
|
||||
@make_dynamo_test
|
||||
def test_accepts_traceback(self):
|
||||
tb = get_tb()
|
||||
try:
|
||||
raise IndexError().with_traceback(tb)
|
||||
except IndexError as e:
|
||||
self.assertNotEqual(e.__traceback__, tb)
|
||||
self.assertEqual(e.__traceback__.tb_next, tb)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
|
||||
class TestTracebackType(torch._dynamo.test_case.TestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
|
||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
|
||||
def setUp(self):
|
||||
self._prev = torch._dynamo.config.enable_trace_unittest
|
||||
torch._dynamo.config.enable_trace_unittest = True
|
||||
|
||||
def tearDown(self):
|
||||
torch._dynamo.config.enable_trace_unittest = self._prev
|
||||
|
||||
def raiser(self):
|
||||
raise ValueError
|
||||
|
||||
@unittest.expectedFailure # Dynamo doesn't track traceback
|
||||
@make_dynamo_test
|
||||
def test_attrs(self):
|
||||
try:
|
||||
self.raiser()
|
||||
except Exception as exc:
|
||||
tb = exc.__traceback__
|
||||
|
||||
self.assertIsInstance(tb.tb_next, types.TracebackType)
|
||||
self.assertIs(tb.tb_frame, sys._getframe())
|
||||
self.assertIsInstance(tb.tb_lasti, int)
|
||||
self.assertIsInstance(tb.tb_lineno, int)
|
||||
|
||||
self.assertIs(tb.tb_next.tb_next, None)
|
||||
|
||||
# Invalid assignments
|
||||
with self.assertRaises(TypeError):
|
||||
del tb.tb_next
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
tb.tb_next = "asdf"
|
||||
|
||||
# Loops
|
||||
with self.assertRaises(ValueError):
|
||||
tb.tb_next = tb
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tb.tb_next.tb_next = tb
|
||||
|
||||
# Valid assignments
|
||||
tb.tb_next = None
|
||||
self.assertIs(tb.tb_next, None)
|
||||
|
||||
new_tb = get_tb()
|
||||
tb.tb_next = new_tb
|
||||
self.assertIs(tb.tb_next, new_tb)
|
||||
|
||||
@unittest.expectedFailure # Dynamo doesn't track traceback
|
||||
@make_dynamo_test
|
||||
def test_constructor(self):
|
||||
other_tb = get_tb()
|
||||
frame = sys._getframe()
|
||||
|
||||
tb = types.TracebackType(other_tb, frame, 1, 2)
|
||||
self.assertEqual(tb.tb_next, other_tb)
|
||||
self.assertEqual(tb.tb_frame, frame)
|
||||
self.assertEqual(tb.tb_lasti, 1)
|
||||
self.assertEqual(tb.tb_lineno, 2)
|
||||
|
||||
tb = types.TracebackType(None, frame, 1, 2)
|
||||
self.assertEqual(tb.tb_next, None)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
types.TracebackType("no", frame, 1, 2)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
types.TracebackType(other_tb, "no", 1, 2)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
types.TracebackType(other_tb, frame, "no", 2)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
types.TracebackType(other_tb, frame, 1, "nuh-uh")
|
||||
|
||||
|
||||
class TestContext(torch._dynamo.test_case.TestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_raise.py
|
||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py
|
||||
def setUp(self):
|
||||
self._prev = torch._dynamo.config.enable_trace_unittest
|
||||
torch._dynamo.config.enable_trace_unittest = True
|
||||
|
||||
def tearDown(self):
|
||||
torch._dynamo.config.enable_trace_unittest = self._prev
|
||||
|
||||
@unittest.expectedFailure # missing Exception.__eq__
|
||||
@make_dynamo_test
|
||||
def test_instance_context_instance_raise(self):
|
||||
context = IndexError()
|
||||
try:
|
||||
try:
|
||||
raise context
|
||||
except:
|
||||
raise OSError()
|
||||
except OSError as e:
|
||||
self.assertEqual(e.__context__, context)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@unittest.expectedFailure # missing Exception.__eq__ and Exception.__repr__
|
||||
@make_dynamo_test
|
||||
def test_class_context_instance_raise(self):
|
||||
context = IndexError
|
||||
try:
|
||||
try:
|
||||
raise context
|
||||
except:
|
||||
raise OSError()
|
||||
except OSError as e:
|
||||
self.assertNotEqual(e.__context__, context)
|
||||
self.assertIsInstance(e.__context__, context)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@unittest.expectedFailure # missing Exception.__eq__ and Exception.__repr__
|
||||
@make_dynamo_test
|
||||
def test_class_context_class_raise(self):
|
||||
context = IndexError
|
||||
try:
|
||||
try:
|
||||
raise context
|
||||
except:
|
||||
raise OSError
|
||||
except OSError as e:
|
||||
self.assertNotEqual(e.__context__, context)
|
||||
self.assertIsInstance(e.__context__, context)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_c_exception_context(self):
|
||||
try:
|
||||
try:
|
||||
raise ZeroDivisionError
|
||||
except:
|
||||
raise OSError
|
||||
except OSError as e:
|
||||
self.assertIsInstance(e.__context__, ZeroDivisionError)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_c_exception_raise(self):
|
||||
try:
|
||||
try:
|
||||
raise ZeroDivisionError
|
||||
except:
|
||||
raise NameError
|
||||
except NameError as e:
|
||||
self.assertIsInstance(e.__context__, ZeroDivisionError)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_noraise_finally(self):
|
||||
try:
|
||||
try:
|
||||
pass
|
||||
finally:
|
||||
raise OSError
|
||||
except OSError as e:
|
||||
self.assertIsNone(e.__context__)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_raise_finally(self):
|
||||
try:
|
||||
try:
|
||||
raise ZeroDivisionError
|
||||
finally:
|
||||
raise OSError
|
||||
except OSError as e:
|
||||
self.assertIsInstance(e.__context__, ZeroDivisionError)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_context_manager(self):
|
||||
try:
|
||||
with ContextManager():
|
||||
raise ZeroDivisionError
|
||||
except NameError as e:
|
||||
self.assertIsInstance(e.__context__, ZeroDivisionError)
|
||||
else:
|
||||
self.fail("No exception raised")
|
||||
|
||||
@make_dynamo_test
|
||||
def test_cycle_broken(self):
|
||||
# Self-cycles (when re-raising a caught exception) are broken
|
||||
try:
|
||||
try:
|
||||
raise ZeroDivisionError
|
||||
except ZeroDivisionError as e:
|
||||
raise e
|
||||
except ZeroDivisionError as e:
|
||||
self.assertIsNone(e.__context__)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_reraise_cycle_broken(self):
|
||||
# Non-trivial context cycles (through re-raising a previous exception)
|
||||
# are broken too.
|
||||
try:
|
||||
try:
|
||||
raise NameError
|
||||
except NameError as a:
|
||||
try:
|
||||
raise ZeroDivisionError
|
||||
except ZeroDivisionError:
|
||||
raise a
|
||||
except NameError as e:
|
||||
self.assertIsNone(e.__context__.__context__)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_3118(self):
|
||||
# deleting the generator caused the __context__ to be cleared
|
||||
def gen():
|
||||
try:
|
||||
yield 1
|
||||
finally:
|
||||
pass
|
||||
|
||||
def f():
|
||||
g = gen()
|
||||
next(g)
|
||||
try:
|
||||
try:
|
||||
raise ValueError
|
||||
except:
|
||||
del g
|
||||
raise KeyError
|
||||
except Exception as e:
|
||||
self.assertIsInstance(e.__context__, ValueError)
|
||||
|
||||
f()
|
||||
|
||||
@unittest.expectedFailure # too CPython specific(?)
|
||||
@make_dynamo_test
|
||||
def test_3611(self):
|
||||
# A re-raised exception in a __del__ caused the __context__
|
||||
# to be cleared
|
||||
class C:
|
||||
def __del__(self):
|
||||
try:
|
||||
raise ZeroDivisionError
|
||||
except:
|
||||
raise
|
||||
|
||||
def f():
|
||||
x = C()
|
||||
try:
|
||||
try:
|
||||
x.x
|
||||
except AttributeError:
|
||||
del x
|
||||
raise TypeError
|
||||
except Exception as e:
|
||||
self.assertNotEqual(e.__context__, None)
|
||||
self.assertIsInstance(e.__context__, AttributeError)
|
||||
|
||||
with support.catch_unraisable_exception() as cm:
|
||||
f()
|
||||
|
||||
self.assertEqual(ZeroDivisionError, cm.unraisable.exc_type)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
@ -1,107 +0,0 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
from torch.testing._internal.common_utils import make_dynamo_test
|
||||
|
||||
|
||||
class SysTests(torch._dynamo.test_case.TestCase):
|
||||
def test_exc_info(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
try:
|
||||
raise ValueError
|
||||
except Exception:
|
||||
typ, _, _ = sys.exc_info()
|
||||
if typ is ValueError:
|
||||
return t.sin()
|
||||
else:
|
||||
return t.cos()
|
||||
|
||||
t = torch.randn(2)
|
||||
y = fn(t)
|
||||
self.assertEqual(y, t.sin())
|
||||
|
||||
|
||||
class CPythonActiveExceptionTests(torch._dynamo.test_case.CPythonTestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_sys.py
|
||||
# https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_sys.py
|
||||
|
||||
@make_dynamo_test
|
||||
def test_exc_info_no_exception(self):
|
||||
self.assertEqual(sys.exc_info(), (None, None, None))
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
||||
@make_dynamo_test
|
||||
def test_sys_exception_no_exception(self):
|
||||
self.assertEqual(sys.exception(), None)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_exc_info_with_exception_instance(self):
|
||||
def f():
|
||||
raise ValueError(42)
|
||||
|
||||
try:
|
||||
f()
|
||||
except Exception as e_:
|
||||
e = e_
|
||||
exc_info = sys.exc_info()
|
||||
|
||||
self.assertIsInstance(e, ValueError)
|
||||
self.assertIs(exc_info[0], ValueError)
|
||||
self.assertIs(exc_info[1], e)
|
||||
self.assertIs(exc_info[2], e.__traceback__)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_exc_info_with_exception_type(self):
|
||||
def f():
|
||||
raise ValueError
|
||||
|
||||
try:
|
||||
f()
|
||||
except Exception as e_:
|
||||
e = e_
|
||||
exc_info = sys.exc_info()
|
||||
|
||||
self.assertIsInstance(e, ValueError)
|
||||
self.assertIs(exc_info[0], ValueError)
|
||||
self.assertIs(exc_info[1], e)
|
||||
self.assertIs(exc_info[2], e.__traceback__)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
||||
@make_dynamo_test
|
||||
def test_sys_exception_with_exception_instance(self):
|
||||
def f():
|
||||
raise ValueError(42)
|
||||
|
||||
try:
|
||||
f()
|
||||
except Exception as e_:
|
||||
e = e_
|
||||
exc = sys.exception()
|
||||
|
||||
self.assertIsInstance(e, ValueError)
|
||||
self.assertIs(exc, e)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+")
|
||||
@make_dynamo_test
|
||||
def test_sys_exception_with_exception_type(self):
|
||||
def f():
|
||||
raise ValueError
|
||||
|
||||
try:
|
||||
f()
|
||||
except Exception as e_:
|
||||
e = e_
|
||||
exc = sys.exception()
|
||||
|
||||
self.assertIsInstance(e, ValueError)
|
||||
self.assertIs(exc, e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
@ -1,8 +1,5 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import sys
|
||||
import unittest
|
||||
import warnings
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
@ -28,591 +25,6 @@ class TestUnittest(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(z, 1)
|
||||
|
||||
|
||||
class CPythonTest_Assertions(torch._dynamo.test_case.CPythonTestCase):
|
||||
# Tests taken from CPython source code in cpython/Lib/test/test_unittest/test_assertions.py
|
||||
# https://github.com/python/cpython/blob/3.13/Lib/test/test_unittest/test_assertions.py
|
||||
|
||||
@make_dynamo_test
|
||||
def test_AlmostEqual(self):
|
||||
self.assertAlmostEqual(1.00000001, 1.0)
|
||||
self.assertNotAlmostEqual(1.0000001, 1.0)
|
||||
self.assertRaises(self.failureException, self.assertAlmostEqual, 1.0000001, 1.0)
|
||||
self.assertRaises(
|
||||
self.failureException, self.assertNotAlmostEqual, 1.00000001, 1.0
|
||||
)
|
||||
|
||||
self.assertAlmostEqual(1.1, 1.0, places=0)
|
||||
self.assertRaises(
|
||||
self.failureException, self.assertAlmostEqual, 1.1, 1.0, places=1
|
||||
)
|
||||
|
||||
self.assertAlmostEqual(0, 0.1 + 0.1j, places=0)
|
||||
self.assertNotAlmostEqual(0, 0.1 + 0.1j, places=1)
|
||||
self.assertRaises(
|
||||
self.failureException, self.assertAlmostEqual, 0, 0.1 + 0.1j, places=1
|
||||
)
|
||||
self.assertRaises(
|
||||
self.failureException, self.assertNotAlmostEqual, 0, 0.1 + 0.1j, places=0
|
||||
)
|
||||
|
||||
self.assertAlmostEqual(float("inf"), float("inf"))
|
||||
self.assertRaises(
|
||||
self.failureException, self.assertNotAlmostEqual, float("inf"), float("inf")
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_AmostEqualWithDelta(self):
|
||||
self.assertAlmostEqual(1.1, 1.0, delta=0.5)
|
||||
self.assertAlmostEqual(1.0, 1.1, delta=0.5)
|
||||
self.assertNotAlmostEqual(1.1, 1.0, delta=0.05)
|
||||
self.assertNotAlmostEqual(1.0, 1.1, delta=0.05)
|
||||
|
||||
self.assertAlmostEqual(1.0, 1.0, delta=0.5)
|
||||
self.assertRaises(
|
||||
self.failureException, self.assertNotAlmostEqual, 1.0, 1.0, delta=0.5
|
||||
)
|
||||
|
||||
self.assertRaises(
|
||||
self.failureException, self.assertAlmostEqual, 1.1, 1.0, delta=0.05
|
||||
)
|
||||
self.assertRaises(
|
||||
self.failureException, self.assertNotAlmostEqual, 1.1, 1.0, delta=0.5
|
||||
)
|
||||
|
||||
self.assertRaises(
|
||||
TypeError, self.assertAlmostEqual, 1.1, 1.0, places=2, delta=2
|
||||
)
|
||||
self.assertRaises(
|
||||
TypeError, self.assertNotAlmostEqual, 1.1, 1.0, places=2, delta=2
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_assertRaises(self):
|
||||
def _raise(e):
|
||||
raise e
|
||||
|
||||
self.assertRaises(KeyError, _raise, KeyError)
|
||||
self.assertRaises(KeyError, _raise, KeyError("key"))
|
||||
try:
|
||||
self.assertRaises(KeyError, lambda: None)
|
||||
except self.failureException as e:
|
||||
self.assertIn("KeyError not raised", str(e))
|
||||
else:
|
||||
self.fail("assertRaises() didn't fail")
|
||||
try:
|
||||
self.assertRaises(KeyError, _raise, ValueError)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
self.fail("assertRaises() didn't let exception pass through")
|
||||
with self.assertRaises(KeyError) as cm:
|
||||
try:
|
||||
raise KeyError
|
||||
except Exception as e:
|
||||
exc = e
|
||||
raise
|
||||
self.assertIs(cm.exception, exc)
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
raise KeyError("key")
|
||||
try:
|
||||
with self.assertRaises(KeyError):
|
||||
pass
|
||||
except self.failureException as e:
|
||||
self.assertIn("KeyError not raised", str(e))
|
||||
else:
|
||||
self.fail("assertRaises() didn't fail")
|
||||
try:
|
||||
with self.assertRaises(KeyError):
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
self.fail("assertRaises() didn't let exception pass through")
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertNotRegex(self):
|
||||
self.assertNotRegex("Ala ma kota", r"r+")
|
||||
try:
|
||||
self.assertNotRegex("Ala ma kota", r"k.t", "Message")
|
||||
except self.failureException as e:
|
||||
self.assertIn("Message", e.args[0])
|
||||
else:
|
||||
self.fail("assertNotRegex should have failed.")
|
||||
|
||||
|
||||
class CPythonTestLongMessage(torch._dynamo.test_case.CPythonTestCase):
|
||||
"""Test that the individual asserts honour longMessage.
|
||||
This actually tests all the message behaviour for
|
||||
asserts that use longMessage."""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
class TestableTestFalse(unittest.TestCase):
|
||||
longMessage = False
|
||||
failureException = self.failureException
|
||||
|
||||
def testTest(self):
|
||||
pass
|
||||
|
||||
class TestableTestTrue(unittest.TestCase):
|
||||
longMessage = True
|
||||
failureException = self.failureException
|
||||
|
||||
def testTest(self):
|
||||
pass
|
||||
|
||||
self.testableTrue = TestableTestTrue("testTest")
|
||||
self.testableFalse = TestableTestFalse("testTest")
|
||||
|
||||
def testDefault(self):
|
||||
self.assertTrue(unittest.TestCase.longMessage)
|
||||
|
||||
def test_formatMsg(self):
|
||||
self.assertEqual(self.testableFalse._formatMessage(None, "foo"), "foo")
|
||||
self.assertEqual(self.testableFalse._formatMessage("foo", "bar"), "foo")
|
||||
|
||||
self.assertEqual(self.testableTrue._formatMessage(None, "foo"), "foo")
|
||||
self.assertEqual(self.testableTrue._formatMessage("foo", "bar"), "bar : foo")
|
||||
|
||||
# This blows up if _formatMessage uses string concatenation
|
||||
self.testableTrue._formatMessage(object(), "foo")
|
||||
|
||||
def test_formatMessage_unicode_error(self):
|
||||
one = "".join(chr(i) for i in range(255))
|
||||
# this used to cause a UnicodeDecodeError constructing msg
|
||||
self.testableTrue._formatMessage(one, "\uFFFD")
|
||||
|
||||
def assertMessages(self, methodName, args, errors):
|
||||
"""
|
||||
Check that methodName(*args) raises the correct error messages.
|
||||
errors should be a list of 4 regex that match the error when:
|
||||
1) longMessage = False and no msg passed;
|
||||
2) longMessage = False and msg passed;
|
||||
3) longMessage = True and no msg passed;
|
||||
4) longMessage = True and msg passed;
|
||||
"""
|
||||
|
||||
def getMethod(i):
|
||||
useTestableFalse = i < 2
|
||||
if useTestableFalse:
|
||||
test = self.testableFalse
|
||||
else:
|
||||
test = self.testableTrue
|
||||
return getattr(test, methodName)
|
||||
|
||||
for i, expected_regex in enumerate(errors):
|
||||
testMethod = getMethod(i)
|
||||
kwargs = {}
|
||||
withMsg = i % 2
|
||||
if withMsg:
|
||||
kwargs = {"msg": "oops"}
|
||||
|
||||
# with self.assertRaisesRegex(
|
||||
# self.failureException, expected_regex=expected_regex
|
||||
# ):
|
||||
# testMethod(*args, **kwargs)
|
||||
with self.assertRaises(self.failureException) as cm:
|
||||
testMethod(*args, **kwargs)
|
||||
self.assertRegex(str(cm.exception), expected_regex)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertTrue(self):
|
||||
self.assertMessages(
|
||||
"assertTrue",
|
||||
(False,),
|
||||
[
|
||||
"False is not true",
|
||||
"oops",
|
||||
"False is not true",
|
||||
"False is not true : oops",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertFalse(self):
|
||||
self.assertMessages(
|
||||
"assertFalse",
|
||||
(True,),
|
||||
[
|
||||
"True is not false",
|
||||
"oops",
|
||||
"True is not false",
|
||||
"True is not false : oops",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testNotEqual(self):
|
||||
self.assertMessages(
|
||||
"assertNotEqual", (1, 1), ["1 == 1", "oops", "1 == 1", "1 == 1 : oops"]
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAlmostEqual(self):
|
||||
self.assertMessages(
|
||||
"assertAlmostEqual",
|
||||
(1, 2),
|
||||
[
|
||||
r"^1 != 2 within 7 places \(1 difference\)$",
|
||||
"^oops$",
|
||||
r"^1 != 2 within 7 places \(1 difference\)$",
|
||||
r"^1 != 2 within 7 places \(1 difference\) : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testNotAlmostEqual(self):
|
||||
self.assertMessages(
|
||||
"assertNotAlmostEqual",
|
||||
(1, 1),
|
||||
[
|
||||
"^1 == 1 within 7 places$",
|
||||
"^oops$",
|
||||
"^1 == 1 within 7 places$",
|
||||
"^1 == 1 within 7 places : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def test_baseAssertEqual(self):
|
||||
self.assertMessages(
|
||||
"_baseAssertEqual",
|
||||
(1, 2),
|
||||
["^1 != 2$", "^oops$", "^1 != 2$", "^1 != 2 : oops$"],
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def testAssertSequenceEqual(self):
|
||||
# Error messages are multiline so not testing on full message
|
||||
# assertTupleEqual and assertListEqual delegate to this method
|
||||
self.assertMessages(
|
||||
"assertSequenceEqual",
|
||||
([], [None]),
|
||||
[r"\+ \[None\]$", "^oops$", r"\+ \[None\]$", r"\+ \[None\] : oops$"],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertSetEqual(self):
|
||||
self.assertMessages(
|
||||
"assertSetEqual",
|
||||
(set(), set([None])), # noqa: C405
|
||||
["None$", "^oops$", "None$", "None : oops$"],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertIn(self):
|
||||
self.assertMessages(
|
||||
"assertIn",
|
||||
(None, []),
|
||||
[
|
||||
r"^None not found in \[\]$",
|
||||
"^oops$",
|
||||
r"^None not found in \[\]$",
|
||||
r"^None not found in \[\] : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertNotIn(self):
|
||||
self.assertMessages(
|
||||
"assertNotIn",
|
||||
(None, [None]),
|
||||
[
|
||||
r"^None unexpectedly found in \[None\]$",
|
||||
"^oops$",
|
||||
r"^None unexpectedly found in \[None\]$",
|
||||
r"^None unexpectedly found in \[None\] : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def testAssertDictEqual(self):
|
||||
self.assertMessages(
|
||||
"assertDictEqual",
|
||||
({}, {"key": "value"}),
|
||||
[
|
||||
r"\+ \{'key': 'value'\}$",
|
||||
"^oops$",
|
||||
r"\+ \{'key': 'value'\}$",
|
||||
r"\+ \{'key': 'value'\} : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def testAssertMultiLineEqual(self):
|
||||
self.assertMessages(
|
||||
"assertMultiLineEqual",
|
||||
("", "foo"),
|
||||
[r"\+ foo\n$", "^oops$", r"\+ foo\n$", r"\+ foo\n : oops$"],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertLess(self):
|
||||
self.assertMessages(
|
||||
"assertLess",
|
||||
(2, 1),
|
||||
[
|
||||
"^2 not less than 1$",
|
||||
"^oops$",
|
||||
"^2 not less than 1$",
|
||||
"^2 not less than 1 : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertLessEqual(self):
|
||||
self.assertMessages(
|
||||
"assertLessEqual",
|
||||
(2, 1),
|
||||
[
|
||||
"^2 not less than or equal to 1$",
|
||||
"^oops$",
|
||||
"^2 not less than or equal to 1$",
|
||||
"^2 not less than or equal to 1 : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertGreater(self):
|
||||
self.assertMessages(
|
||||
"assertGreater",
|
||||
(1, 2),
|
||||
[
|
||||
"^1 not greater than 2$",
|
||||
"^oops$",
|
||||
"^1 not greater than 2$",
|
||||
"^1 not greater than 2 : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertGreaterEqual(self):
|
||||
self.assertMessages(
|
||||
"assertGreaterEqual",
|
||||
(1, 2),
|
||||
[
|
||||
"^1 not greater than or equal to 2$",
|
||||
"^oops$",
|
||||
"^1 not greater than or equal to 2$",
|
||||
"^1 not greater than or equal to 2 : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertIsNone(self):
|
||||
self.assertMessages(
|
||||
"assertIsNone",
|
||||
("not None",),
|
||||
[
|
||||
"^'not None' is not None$",
|
||||
"^oops$",
|
||||
"^'not None' is not None$",
|
||||
"^'not None' is not None : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertIsNotNone(self):
|
||||
self.assertMessages(
|
||||
"assertIsNotNone",
|
||||
(None,),
|
||||
[
|
||||
"^unexpectedly None$",
|
||||
"^oops$",
|
||||
"^unexpectedly None$",
|
||||
"^unexpectedly None : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertIs(self):
|
||||
self.assertMessages(
|
||||
"assertIs",
|
||||
(None, "foo"),
|
||||
[
|
||||
"^None is not 'foo'$",
|
||||
"^oops$",
|
||||
"^None is not 'foo'$",
|
||||
"^None is not 'foo' : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertIsNot(self):
|
||||
self.assertMessages(
|
||||
"assertIsNot",
|
||||
(None, None),
|
||||
[
|
||||
"^unexpectedly identical: None$",
|
||||
"^oops$",
|
||||
"^unexpectedly identical: None$",
|
||||
"^unexpectedly identical: None : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertRegex(self):
|
||||
self.assertMessages(
|
||||
"assertRegex",
|
||||
("foo", "bar"),
|
||||
[
|
||||
"^Regex didn't match:",
|
||||
"^oops$",
|
||||
"^Regex didn't match:",
|
||||
"^Regex didn't match: (.*) : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertNotRegex(self):
|
||||
self.assertMessages(
|
||||
"assertNotRegex",
|
||||
("foo", "foo"),
|
||||
[
|
||||
"^Regex matched:",
|
||||
"^oops$",
|
||||
"^Regex matched:",
|
||||
"^Regex matched: (.*) : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
def assertMessagesCM(self, methodName, args, func, errors):
|
||||
"""
|
||||
Check that the correct error messages are raised while executing:
|
||||
with method(*args):
|
||||
func()
|
||||
*errors* should be a list of 4 regex that match the error when:
|
||||
1) longMessage = False and no msg passed;
|
||||
2) longMessage = False and msg passed;
|
||||
3) longMessage = True and no msg passed;
|
||||
4) longMessage = True and msg passed;
|
||||
"""
|
||||
p = product((self.testableFalse, self.testableTrue), ({}, {"msg": "oops"}))
|
||||
for (cls, kwargs), err in zip(p, errors):
|
||||
method = getattr(cls, methodName)
|
||||
# with self.assertRaisesRegex(cls.failureException, err):
|
||||
with self.assertRaises(cls.failureException) as c:
|
||||
with method(*args, **kwargs) as cm: # noqa: F841
|
||||
func()
|
||||
self.assertRegex(str(c.exception), err)
|
||||
|
||||
@make_dynamo_test
|
||||
def testAssertRaises(self):
|
||||
self.assertMessagesCM(
|
||||
"assertRaises",
|
||||
(TypeError,),
|
||||
lambda: None,
|
||||
[
|
||||
"^TypeError not raised$",
|
||||
"^oops$",
|
||||
"^TypeError not raised$",
|
||||
"^TypeError not raised : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def testAssertRaisesRegex(self):
|
||||
self.assertMessagesCM(
|
||||
"assertRaisesRegex",
|
||||
(TypeError, "unused regex"),
|
||||
lambda: None,
|
||||
[
|
||||
"^TypeError not raised$",
|
||||
"^oops$",
|
||||
"^TypeError not raised$",
|
||||
"^TypeError not raised : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
# test error raised but with wrong message
|
||||
def raise_wrong_message():
|
||||
raise TypeError("foo")
|
||||
|
||||
self.assertMessagesCM(
|
||||
"assertRaisesRegex",
|
||||
(TypeError, "regex"),
|
||||
raise_wrong_message,
|
||||
[
|
||||
'^"regex" does not match "foo"$',
|
||||
"^oops$",
|
||||
'^"regex" does not match "foo"$',
|
||||
'^"regex" does not match "foo" : oops$',
|
||||
],
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def testAssertWarns(self):
|
||||
self.assertMessagesCM(
|
||||
"assertWarns",
|
||||
(UserWarning,),
|
||||
lambda: None,
|
||||
[
|
||||
"^UserWarning not triggered$",
|
||||
"^oops$",
|
||||
"^UserWarning not triggered$",
|
||||
"^UserWarning not triggered : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipIf(sys.version_info < (3, 13), "feature landed in 3.13")
|
||||
@make_dynamo_test
|
||||
def test_assertNotWarns(self):
|
||||
def warn_future():
|
||||
warnings.warn("xyz", FutureWarning, stacklevel=2)
|
||||
|
||||
self.assertMessagesCM(
|
||||
"_assertNotWarns",
|
||||
(FutureWarning,),
|
||||
warn_future,
|
||||
[
|
||||
"^FutureWarning triggered$",
|
||||
"^oops$",
|
||||
"^FutureWarning triggered$",
|
||||
"^FutureWarning triggered : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@make_dynamo_test
|
||||
def testAssertWarnsRegex(self):
|
||||
# test error not raised
|
||||
self.assertMessagesCM(
|
||||
"assertWarnsRegex",
|
||||
(UserWarning, "unused regex"),
|
||||
lambda: None,
|
||||
[
|
||||
"^UserWarning not triggered$",
|
||||
"^oops$",
|
||||
"^UserWarning not triggered$",
|
||||
"^UserWarning not triggered : oops$",
|
||||
],
|
||||
)
|
||||
|
||||
# test warning raised but with wrong message
|
||||
def raise_wrong_message():
|
||||
warnings.warn("foo")
|
||||
|
||||
self.assertMessagesCM(
|
||||
"assertWarnsRegex",
|
||||
(UserWarning, "regex"),
|
||||
raise_wrong_message,
|
||||
[
|
||||
'^"regex" does not match "foo"$',
|
||||
"^oops$",
|
||||
'^"regex" does not match "foo"$',
|
||||
'^"regex" does not match "foo" : oops$',
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
|
@ -1593,6 +1593,13 @@ def get_selected_tests(options) -> list[str]:
|
||||
]
|
||||
)
|
||||
|
||||
if sys.version_info[:2] < (3, 13):
|
||||
# Skip tests for older Python versions as they may use syntax or features
|
||||
# not supported in those versions
|
||||
options.exclude.extend(
|
||||
[test for test in selected_tests if test.startswith("dynamo/cpython/3_13/")]
|
||||
)
|
||||
|
||||
selected_tests = exclude_tests(options.exclude, selected_tests)
|
||||
|
||||
if sys.platform == "win32" and not options.ignore_win_blocklist:
|
||||
|
@ -1,3 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
"""Testing utilities for Dynamo, providing a specialized TestCase class and test running functionality.
|
||||
|
||||
This module extends PyTorch's testing framework with Dynamo-specific testing capabilities.
|
||||
@ -10,8 +12,13 @@ It includes:
|
||||
|
||||
import contextlib
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import sys
|
||||
import unittest
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
@ -98,7 +105,70 @@ class TestCase(TorchTestCase):
|
||||
|
||||
|
||||
class CPythonTestCase(TestCase):
|
||||
"""
|
||||
Test class for CPython tests located in "test/dynamo/CPython/Py_version/*".
|
||||
|
||||
This class enables specific features that are disabled by default, such as
|
||||
tracing through unittest methods.
|
||||
"""
|
||||
|
||||
_stack: contextlib.ExitStack
|
||||
dynamo_strict_nopython = True
|
||||
|
||||
# Restore original unittest methods to simplify tracing CPython test cases.
|
||||
assertEqual = unittest.TestCase.assertEqual # type: ignore[assignment]
|
||||
assertNotEqual = unittest.TestCase.assertNotEqual # type: ignore[assignment]
|
||||
assertTrue = unittest.TestCase.assertTrue
|
||||
assertFalse = unittest.TestCase.assertFalse
|
||||
assertIs = unittest.TestCase.assertIs
|
||||
assertIsNot = unittest.TestCase.assertIsNot
|
||||
assertIsNone = unittest.TestCase.assertIsNone
|
||||
assertIsNotNone = unittest.TestCase.assertIsNotNone
|
||||
assertIn = unittest.TestCase.assertIn
|
||||
assertNotIn = unittest.TestCase.assertNotIn
|
||||
assertIsInstance = unittest.TestCase.assertIsInstance
|
||||
assertNotIsInstance = unittest.TestCase.assertNotIsInstance
|
||||
assertAlmostEqual = unittest.TestCase.assertAlmostEqual
|
||||
assertNotAlmostEqual = unittest.TestCase.assertNotAlmostEqual
|
||||
assertGreater = unittest.TestCase.assertGreater
|
||||
assertGreaterEqual = unittest.TestCase.assertGreaterEqual
|
||||
assertLess = unittest.TestCase.assertLess
|
||||
assertLessEqual = unittest.TestCase.assertLessEqual
|
||||
assertRegex = unittest.TestCase.assertRegex
|
||||
assertNotRegex = unittest.TestCase.assertNotRegex
|
||||
assertCountEqual = unittest.TestCase.assertCountEqual
|
||||
assertMultiLineEqual = unittest.TestCase.assertMultiLineEqual
|
||||
assertSequenceEqual = unittest.TestCase.assertSequenceEqual
|
||||
assertListEqual = unittest.TestCase.assertListEqual
|
||||
assertTupleEqual = unittest.TestCase.assertTupleEqual
|
||||
assertSetEqual = unittest.TestCase.assertSetEqual
|
||||
assertDictEqual = unittest.TestCase.assertDictEqual
|
||||
assertRaises = unittest.TestCase.assertRaises
|
||||
assertRaisesRegex = unittest.TestCase.assertRaisesRegex
|
||||
assertWarns = unittest.TestCase.assertWarns
|
||||
assertWarnsRegex = unittest.TestCase.assertWarnsRegex
|
||||
assertLogs = unittest.TestCase.assertLogs
|
||||
fail = unittest.TestCase.fail
|
||||
failureException = unittest.TestCase.failureException
|
||||
|
||||
def compile_fn(self, fn, backend, nopython):
|
||||
# We want to compile only the test function, excluding any setup code
|
||||
# from unittest
|
||||
method = getattr(self, self._testMethodName)
|
||||
method = torch._dynamo.optimize(backend, nopython=nopython)(method)
|
||||
setattr(self, self._testMethodName, method)
|
||||
return fn
|
||||
|
||||
def _dynamo_test_key(self):
|
||||
suffix = super()._dynamo_test_key()
|
||||
test_cls = self.__class__
|
||||
test_file = inspect.getfile(test_cls).split(os.sep)[-1].split(".")[0]
|
||||
py_ver = re.search(r"/([\d_]+)/", inspect.getfile(test_cls))
|
||||
if py_ver:
|
||||
py_ver = py_ver.group().strip(os.sep).replace("_", "") # type: ignore[assignment]
|
||||
else:
|
||||
return suffix
|
||||
return f"CPython{py_ver}-{test_file}-{suffix}"
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
@ -107,6 +177,24 @@ class CPythonTestCase(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
# Skip test if python versions doesn't match
|
||||
normalized_path = pathlib.PurePath("dynamo/cpython").as_posix()
|
||||
regex = re.escape(normalized_path) + r"\b\d+_\d{2}\b"
|
||||
m = re.search(regex, inspect.getfile(cls))
|
||||
if m:
|
||||
test_py_ver = tuple(map(int, m.group().split("_")))
|
||||
py_ver = sys.version_info[:2]
|
||||
if py_ver != test_py_ver:
|
||||
expected = ".".join(map(str, test_py_ver))
|
||||
got = ".".join(map(str, py_ver))
|
||||
raise unittest.SkipTest(
|
||||
f"Test requires Python {expected} but got Python {got}"
|
||||
)
|
||||
else:
|
||||
raise unittest.SkipTest(
|
||||
f"Test requires a specific Python version but not found in path {inspect.getfile(cls)}"
|
||||
)
|
||||
|
||||
super().setUpClass()
|
||||
cls._stack = contextlib.ExitStack() # type: ignore[attr-defined]
|
||||
cls._stack.enter_context( # type: ignore[attr-defined]
|
||||
|
@ -1989,11 +1989,11 @@ class BuiltinVariable(VariableTracker):
|
||||
)
|
||||
):
|
||||
unimplemented_v2(
|
||||
gb_type="Failed to trace builtin operator",
|
||||
gb_type="Failed to trace unittest method",
|
||||
context=f"function: unittest.TestCase.{name}",
|
||||
explanation=f"Dynamo does not know how to trace builtin operator `{name}` ",
|
||||
explanation=f"Dynamo does not know how to trace unittest method `{name}` ",
|
||||
hints=[
|
||||
f"Avoid calling builtin `{name}`. "
|
||||
f"Avoid calling `TestCase.{name}`. "
|
||||
"Please report an issue to PyTorch.",
|
||||
],
|
||||
)
|
||||
|
@ -3157,6 +3157,13 @@ class TestCase(expecttest.TestCase):
|
||||
def wrap_with_cuda_memory_check(self, method):
|
||||
return self.wrap_method_with_policy(method, self.assertLeaksNoCudaTensors)
|
||||
|
||||
def _dynamo_test_key(self):
|
||||
return f"{self.__class__.__name__}.{self._testMethodName}"
|
||||
|
||||
def compile_fn(self, fn, backend, nopython):
|
||||
# Allows subclasses to control compilation
|
||||
return torch._dynamo.optimize(backend, nopython=nopython)(fn)
|
||||
|
||||
def _run_custom(self, result=None):
|
||||
using_unittest = isinstance(result, unittest.TestResult)
|
||||
|
||||
@ -3232,16 +3239,16 @@ class TestCase(expecttest.TestCase):
|
||||
|
||||
with unittest.mock.patch("torch._dynamo.config.suppress_errors", suppress_errors), maybe_disable_size_asserts:
|
||||
if TEST_WITH_AOT_EAGER:
|
||||
super_run = torch._dynamo.optimize("aot_eager_decomp_partition")(super_run)
|
||||
super_run = self.compile_fn(super_run, "aot_eager_decomp_partition", nopython)
|
||||
elif TEST_WITH_TORCHDYNAMO or TEST_WITH_TORCHINDUCTOR:
|
||||
if TEST_WITH_TORCHINDUCTOR:
|
||||
super_run = torch._dynamo.optimize("inductor")(super_run)
|
||||
super_run = self.compile_fn(super_run, "inductor", nopython)
|
||||
else:
|
||||
# Assume eager-generated GraphModules will not error out.
|
||||
# If we do, this is probably a Dynamo bug!
|
||||
super_run = torch._dynamo.optimize("eager_noexcept", nopython=nopython)(super_run)
|
||||
super_run = self.compile_fn(super_run, "eager_noexcept", nopython)
|
||||
|
||||
key = f"{self.__class__.__name__}.{self._testMethodName}"
|
||||
key = self._dynamo_test_key()
|
||||
|
||||
def expect_failure(f, file_name):
|
||||
@wraps(f)
|
||||
|
Reference in New Issue
Block a user