From d14d62b7aa71bdfaed1dcc013c2fed0ac9f1f299 Mon Sep 17 00:00:00 2001 From: William Wen Date: Wed, 6 Mar 2024 10:17:42 -0800 Subject: [PATCH] [dynamo] add more refleak tests (#120657) Pull Request resolved: https://github.com/pytorch/pytorch/pull/120657 Approved by: https://github.com/jansel --- test/dynamo/test_dynamic_shapes.py | 3 +++ test/dynamo/test_misc.py | 21 +++++++++++++++++++++ torch/_dynamo/testing.py | 6 ++++++ 3 files changed, 30 insertions(+) diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index b49efacec929..cff9149bdd75 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -89,6 +89,9 @@ if TEST_Z3: unittest.expectedFailure( DynamicShapesMiscTests.test_custom_module_free_dynamic_shapes # noqa: F821 ) + unittest.expectedFailure( + DynamicShapesMiscTests.test_sequential_module_free_dynamic_shapes # noqa: F821 + ) unittest.expectedFailure( # Test is only valid without dynamic shapes diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 84b0512de263..77748f7f68e5 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -48,6 +48,7 @@ from torch._dynamo.testing import ( same, skipIfNotPy311, unsupported, + xfailIfPy311, ) from torch._dynamo.utils import CompileProfiler, counters, ifdynstaticdefault from torch._inductor.utils import run_and_get_code @@ -9725,6 +9726,26 @@ fn lambda mod: mod.fc, ) + @xfailIfPy311 + def test_sequential_module_free(self): + self._test_compile_model_free( + lambda: ( + torch.nn.Sequential( + torch.nn.Linear(100, 100), + torch.nn.ReLU(), + ), + torch.randn(100, 100), + ), + lambda mod: mod[0], + ) + + @unittest.expectedFailure + def test_linear_module_free(self): + self._test_compile_model_free( + lambda: (torch.nn.Linear(100, 100), torch.randn(100, 100)), + lambda mod: mod, + ) + def test_dynamo_cache_move_to_front(self): class Mod(torch.nn.Module): def __init__(self): diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index b8577b52215a..fac20cf55508 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -342,6 +342,12 @@ def skipIfNotPy311(fn): return unittest.skip(fn) +def xfailIfPy311(fn): + if sys.version_info >= (3, 11): + return unittest.expectedFailure(fn) + return fn + + # Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py # and test/dynamo/test_dynamic_shapes.py def expectedFailureDynamic(fn):