[dynamo 3.11] enable dynamo unittests in 3.11 (#98104)

Enable most dynamo unittests for 3.11. There are a few tests that are skipped due to failures that will be addressed in upcoming PRs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98104
Approved by: https://github.com/yanboliang, https://github.com/voznesenskym, https://github.com/albanD, https://github.com/jansel, https://github.com/jerryzh168, https://github.com/malfet
This commit is contained in:
William Wen
2023-04-06 17:56:50 +00:00
committed by PyTorch MergeBot
parent dbfc4df075
commit 0066f3405f
11 changed files with 50 additions and 30 deletions

View File

@ -7,7 +7,7 @@ import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch.onnx.operators
from torch._dynamo.testing import same
from torch._dynamo.testing import same, skipIfPy311
from torch.nn import functional as F
from torch.testing._internal.common_cuda import (
@ -77,6 +77,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
opt_fn(a, b)
self.assertEqual(cnts.frame_count, 2)
@skipIfPy311
def test_nested_grad_mode_graph_break(self):
def fn(x):
before = torch.is_grad_enabled()
@ -99,6 +100,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
opt_fn(a)
self.assertEqual(cnts.frame_count, 3)
@skipIfPy311
def test_torch_profiler(self):
# wrap torch.profiler.* as NullContextVariable and do nothing
def fn(x):
@ -119,6 +121,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 2)
@skipIfPy311
def test_autograd_profiler(self):
# wrap torch.autograd.profiler.* as NullContextVariable and do nothing
def fn(x):
@ -344,6 +347,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
self.assertEqual(exported.device.type, "cpu")
self.assertEqual(exported.dtype, torch.bfloat16)
@skipIfPy311
def test_autocast_cpu_graph_break(self):
class MyModule(torch.nn.Module):
def forward(self, x):
@ -371,6 +375,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
self.assertEqual(res.device.type, "cpu")
self.assertEqual(res.dtype, torch.bfloat16)
@skipIfPy311
def test_autocast_cpu_graph_break_2(self):
# Regression for: https://github.com/pytorch/pytorch/issues/93890
def fn(x):
@ -389,6 +394,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
self.assertEqual(res.dtype, torch.bfloat16)
self.assertEqual(opt_res.dtype, torch.bfloat16)
@skipIfPy311
def test_autocast_cpu_graph_break_inner_fn(self):
class MyModule(torch.nn.Module):
@staticmethod
@ -436,6 +442,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
self.assertEqual(out_32.device.type, "cpu")
self.assertEqual(out_32.dtype, torch.float32)
@skipIfPy311
def test_autocast_graph_break_method(self):
class MyModule(torch.nn.Module):
def __init__(self, bias):

View File

@ -4308,10 +4308,8 @@ def fn():
z *= 3
return z
# TODO remove condition once 3.11 is fully supported
if sys.version_info < (3, 11):
opt_f = torch._dynamo.optimize("eager", nopython=True)(f)
self.assertEqual(opt_f(None, torch.ones(2)), 6)
opt_f = torch._dynamo.optimize("eager", nopython=True)(f)
self.assertEqual(opt_f(None, torch.ones(2)), 6)
if sys.version_info >= (3, 11):
insts = bytecode_transformation.cleaned_instructions(f.__code__)

View File

@ -110,6 +110,11 @@ class End2EndTests(torch._dynamo.test_case.TestCase):
if __name__ == "__main__":
# most optimizer tests are broken on 3.11
# TODO remove when 3.11 is fully supported
import sys
from torch._dynamo.test_case import run_tests
run_tests()
if sys.version_info < (3, 11):
run_tests()

View File

@ -29,7 +29,12 @@ import torch.library
from torch import nn
from torch._dynamo.debug_utils import same_two_models
from torch._dynamo.testing import rand_strided, requires_static_shapes, same
from torch._dynamo.testing import (
rand_strided,
requires_static_shapes,
same,
skipIfPy311,
)
from torch._dynamo.utils import ifdyn, ifunspec
from torch.nn import functional as F
@ -995,6 +1000,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertIn(cnt.op_count, (36, 35, 34, 29, 28, 27))
# see: https://github.com/pytorch/pytorch/issues/80067
@skipIfPy311
@torch._dynamo.config.patch(capture_scalar_outputs=False, dynamic_shapes=True)
def test_maml_no_item_capture(self):
a = torch.randn(5, 1, 28, 28)
@ -1282,6 +1288,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
res = opt_fn3()
self.assertTrue(same(ref, res))
@skipIfPy311
def test_with_on_graph_break_inst(self):
def reversible(x):
print("Hello world") # Cause graph break so inline fails
@ -1306,6 +1313,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
res = opt_fn(x)
self.assertTrue(same(ref, res))
@skipIfPy311
def test_with_on_graph_break_nested(self):
def reversible(x):
torch._dynamo.graph_break() # Cause graph break so inline fails
@ -1333,6 +1341,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(ref, res))
# https://github.com/pytorch/torchdynamo/issues/1446
@skipIfPy311
def test_grad_mode_carrying_correct_state_after_graph_break(self):
def fn(x):
with torch.no_grad():
@ -2311,6 +2320,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 2)
@skipIfPy311
def test_exception_in_dynamo_handling(self):
hit_handler = False

View File

@ -7,7 +7,7 @@ import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo import config
from torch._dynamo.testing import unsupported
from torch._dynamo.testing import skipIfPy311, unsupported
from torch._dynamo.utils import disable_cache_limit, ifunspec
globalmod = torch.nn.ReLU()
@ -517,6 +517,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
self.assertEqual(cnt.frame_count, 7)
self.assertEqual(cnt.op_count, 10)
@skipIfPy311
def test_resume_with_no_grad1(self):
def fn(a, b):
x = a + b
@ -532,6 +533,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
with torch.no_grad():
self._common(fn, 2, 9)
@skipIfPy311
def test_resume_with_no_grad2(self):
def fn(a, b):
x = a + b
@ -546,6 +548,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
self._common(fn, 3, 13)
@skipIfPy311
def test_resume_with_no_grad3(self):
def fn(a, b):
x = a + b

View File

@ -2,7 +2,6 @@
import torch
import torch.nn as nn
import torch._dynamo as torchdynamo
from torch.testing._internal.common_utils import xfailIfPython311
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
skip_if_no_torchvision,
@ -38,7 +37,6 @@ from torch._inductor.compile_fx import compile_fx
@skipIfNoQNNPACK
class TestQuantizePT2E(QuantizationTestCase):
@xfailIfPython311
def test_qconfig_none(self):
class M(torch.nn.Module):
def __init__(self):
@ -87,7 +85,6 @@ class TestQuantizePT2E(QuantizationTestCase):
self.checkGraphModuleNodes(
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence)
@xfailIfPython311
def test_qconfig_module_type(self):
class M(torch.nn.Module):
def __init__(self):
@ -135,7 +132,6 @@ class TestQuantizePT2E(QuantizationTestCase):
]
self.checkGraphModuleNodes(m, expected_node_list=node_list)
@xfailIfPython311
def test_simple_quantizer(self):
class M(torch.nn.Module):
def __init__(self):
@ -194,7 +190,6 @@ class TestQuantizePT2E(QuantizationTestCase):
self.checkGraphModuleNodes(
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence)
@xfailIfPython311
def test_qnnpack_quantizer_conv(self):
class M(torch.nn.Module):
def __init__(self):
@ -238,7 +233,6 @@ class TestQuantizePT2E(QuantizationTestCase):
self.checkGraphModuleNodes(
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence)
@xfailIfPython311
def test_rearrange_weight_observer_for_decomposed_linear(self):
"""
Check whether weight observer is correctly rearranged for decomposed linear.
@ -300,7 +294,6 @@ class TestQuantizePT2E(QuantizationTestCase):
code_after_recompile = m.code
self.assertTrue(code_before_recompile == code_after_recompile, error_msg)
@xfailIfPython311
def test_transposed_conv_bn_fusion(self):
class M(torch.nn.Module):
def __init__(self):
@ -346,7 +339,6 @@ class TestQuantizePT2E(QuantizationTestCase):
@skipIfNoQNNPACK
class TestQuantizePT2EX86Inductor(QuantizationTestCase):
@skipIfNoX86
@xfailIfPython311
def test_inductor_backend_config_conv(self):
class M(torch.nn.Module):
def __init__(self, use_relu: bool = False, inplace_relu: bool = False):
@ -432,7 +424,6 @@ class TestQuantizePT2EX86Inductor(QuantizationTestCase):
class TestQuantizePT2EModels(QuantizationTestCase):
@skip_if_no_torchvision
@skipIfNoQNNPACK
@xfailIfPython311
def test_resnet18(self):
import torchvision
with override_quantized_engine("qnnpack"):

View File

@ -215,9 +215,8 @@ def main():
f"ROCM version: {rocm_ver}\n"
)
for args in _SANITY_CHECK_ARGS:
# TODO remove check when 3.11 is supported
if sys.version_info >= (3, 11):
warnings.warn("Dynamo not yet supported in Python 3.11. Skipping check.")
if sys.version_info >= (3, 12):
warnings.warn("Dynamo not yet supported in Python 3.12. Skipping check.")
continue
check_dynamo(*args)
print("All required checks passed")

View File

@ -431,8 +431,13 @@ class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg]
def check_if_dynamo_supported():
if sys.platform == "win32":
raise RuntimeError("Windows not yet supported for torch.compile")
if sys.version_info >= (3, 11):
raise RuntimeError("Python 3.11+ not yet supported for torch.compile")
if sys.version_info >= (3, 12):
raise RuntimeError("Python 3.12+ not yet supported for torch.compile")
elif sys.version_info >= (3, 11):
warnings.warn(
"torch.compile support of Python 3.11 is experimental. "
"Program may generate incorrect results or segfault."
)
def is_dynamo_supported():

View File

@ -23,7 +23,7 @@ def run_tests(needs=()):
or IS_WINDOWS
or TEST_WITH_CROSSREF
or TEST_WITH_ROCM
or sys.version_info >= (3, 11)
or sys.version_info >= (3, 12)
):
return # skip testing

View File

@ -3,6 +3,7 @@ import dis
import functools
import logging
import os.path
import sys
import types
import unittest
from unittest.mock import patch
@ -297,3 +298,11 @@ def make_test_cls_with_patches(cls, cls_prefix, fn_suffix, *patches):
setattr(DummyTestClass, new_name, fn)
return DummyTestClass
# temporary decorator to skip failing 3.11 dynamo tests
def skipIfPy311(fn):
if sys.version_info < (3, 11):
return fn
else:
return unittest.skip(fn)

View File

@ -1148,13 +1148,6 @@ def skipIfRocmVersionLessThan(version=None):
return wrap_fn
return dec_fn
# Temporary function to simplify adding support to 3.11
def xfailIfPython311(fn):
if sys.version_info < (3, 11):
return fn
else:
return unittest.expectedFailure(fn)
def skipIfNotMiopenSuggestNHWC(fn):
@wraps(fn)
def wrapper(*args, **kwargs):