[dynamo] add more refleak tests (#120657)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120657
Approved by: https://github.com/jansel
This commit is contained in:
William Wen
2024-03-06 10:17:42 -08:00
committed by PyTorch MergeBot
parent 6490441d8f
commit d14d62b7aa
3 changed files with 30 additions and 0 deletions

View File

@ -89,6 +89,9 @@ if TEST_Z3:
unittest.expectedFailure( unittest.expectedFailure(
DynamicShapesMiscTests.test_custom_module_free_dynamic_shapes # noqa: F821 DynamicShapesMiscTests.test_custom_module_free_dynamic_shapes # noqa: F821
) )
unittest.expectedFailure(
DynamicShapesMiscTests.test_sequential_module_free_dynamic_shapes # noqa: F821
)
unittest.expectedFailure( unittest.expectedFailure(
# Test is only valid without dynamic shapes # Test is only valid without dynamic shapes

View File

@ -48,6 +48,7 @@ from torch._dynamo.testing import (
same, same,
skipIfNotPy311, skipIfNotPy311,
unsupported, unsupported,
xfailIfPy311,
) )
from torch._dynamo.utils import CompileProfiler, counters, ifdynstaticdefault from torch._dynamo.utils import CompileProfiler, counters, ifdynstaticdefault
from torch._inductor.utils import run_and_get_code from torch._inductor.utils import run_and_get_code
@ -9725,6 +9726,26 @@ fn
lambda mod: mod.fc, lambda mod: mod.fc,
) )
@xfailIfPy311
def test_sequential_module_free(self):
self._test_compile_model_free(
lambda: (
torch.nn.Sequential(
torch.nn.Linear(100, 100),
torch.nn.ReLU(),
),
torch.randn(100, 100),
),
lambda mod: mod[0],
)
@unittest.expectedFailure
def test_linear_module_free(self):
self._test_compile_model_free(
lambda: (torch.nn.Linear(100, 100), torch.randn(100, 100)),
lambda mod: mod,
)
def test_dynamo_cache_move_to_front(self): def test_dynamo_cache_move_to_front(self):
class Mod(torch.nn.Module): class Mod(torch.nn.Module):
def __init__(self): def __init__(self):

View File

@ -342,6 +342,12 @@ def skipIfNotPy311(fn):
return unittest.skip(fn) return unittest.skip(fn)
def xfailIfPy311(fn):
if sys.version_info >= (3, 11):
return unittest.expectedFailure(fn)
return fn
# Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py # Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py
# and test/dynamo/test_dynamic_shapes.py # and test/dynamo/test_dynamic_shapes.py
def expectedFailureDynamic(fn): def expectedFailureDynamic(fn):