mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo 3.11] enable dynamo unittests in 3.11 (#98104)
Enable most dynamo unittests for 3.11. There are a few tests that are skipped due to failures that will be addressed in upcoming PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/98104 Approved by: https://github.com/yanboliang, https://github.com/voznesenskym, https://github.com/albanD, https://github.com/jansel, https://github.com/jerryzh168, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
dbfc4df075
commit
0066f3405f
@ -7,7 +7,7 @@ import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
import torch.onnx.operators
|
||||
from torch._dynamo.testing import same
|
||||
from torch._dynamo.testing import same, skipIfPy311
|
||||
|
||||
from torch.nn import functional as F
|
||||
from torch.testing._internal.common_cuda import (
|
||||
@ -77,6 +77,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||
opt_fn(a, b)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
@skipIfPy311
|
||||
def test_nested_grad_mode_graph_break(self):
|
||||
def fn(x):
|
||||
before = torch.is_grad_enabled()
|
||||
@ -99,6 +100,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||
opt_fn(a)
|
||||
self.assertEqual(cnts.frame_count, 3)
|
||||
|
||||
@skipIfPy311
|
||||
def test_torch_profiler(self):
|
||||
# wrap torch.profiler.* as NullContextVariable and do nothing
|
||||
def fn(x):
|
||||
@ -119,6 +121,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertTrue(same(ref, res))
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
@skipIfPy311
|
||||
def test_autograd_profiler(self):
|
||||
# wrap torch.autograd.profiler.* as NullContextVariable and do nothing
|
||||
def fn(x):
|
||||
@ -344,6 +347,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(exported.device.type, "cpu")
|
||||
self.assertEqual(exported.dtype, torch.bfloat16)
|
||||
|
||||
@skipIfPy311
|
||||
def test_autocast_cpu_graph_break(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -371,6 +375,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(res.device.type, "cpu")
|
||||
self.assertEqual(res.dtype, torch.bfloat16)
|
||||
|
||||
@skipIfPy311
|
||||
def test_autocast_cpu_graph_break_2(self):
|
||||
# Regression for: https://github.com/pytorch/pytorch/issues/93890
|
||||
def fn(x):
|
||||
@ -389,6 +394,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(res.dtype, torch.bfloat16)
|
||||
self.assertEqual(opt_res.dtype, torch.bfloat16)
|
||||
|
||||
@skipIfPy311
|
||||
def test_autocast_cpu_graph_break_inner_fn(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
@staticmethod
|
||||
@ -436,6 +442,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(out_32.device.type, "cpu")
|
||||
self.assertEqual(out_32.dtype, torch.float32)
|
||||
|
||||
@skipIfPy311
|
||||
def test_autocast_graph_break_method(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self, bias):
|
||||
|
@ -4308,10 +4308,8 @@ def fn():
|
||||
z *= 3
|
||||
return z
|
||||
|
||||
# TODO remove condition once 3.11 is fully supported
|
||||
if sys.version_info < (3, 11):
|
||||
opt_f = torch._dynamo.optimize("eager", nopython=True)(f)
|
||||
self.assertEqual(opt_f(None, torch.ones(2)), 6)
|
||||
opt_f = torch._dynamo.optimize("eager", nopython=True)(f)
|
||||
self.assertEqual(opt_f(None, torch.ones(2)), 6)
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
insts = bytecode_transformation.cleaned_instructions(f.__code__)
|
||||
|
@ -110,6 +110,11 @@ class End2EndTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# most optimizer tests are broken on 3.11
|
||||
# TODO remove when 3.11 is fully supported
|
||||
import sys
|
||||
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
if sys.version_info < (3, 11):
|
||||
run_tests()
|
||||
|
@ -29,7 +29,12 @@ import torch.library
|
||||
|
||||
from torch import nn
|
||||
from torch._dynamo.debug_utils import same_two_models
|
||||
from torch._dynamo.testing import rand_strided, requires_static_shapes, same
|
||||
from torch._dynamo.testing import (
|
||||
rand_strided,
|
||||
requires_static_shapes,
|
||||
same,
|
||||
skipIfPy311,
|
||||
)
|
||||
from torch._dynamo.utils import ifdyn, ifunspec
|
||||
from torch.nn import functional as F
|
||||
|
||||
@ -995,6 +1000,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertIn(cnt.op_count, (36, 35, 34, 29, 28, 27))
|
||||
|
||||
# see: https://github.com/pytorch/pytorch/issues/80067
|
||||
@skipIfPy311
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=False, dynamic_shapes=True)
|
||||
def test_maml_no_item_capture(self):
|
||||
a = torch.randn(5, 1, 28, 28)
|
||||
@ -1282,6 +1288,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
res = opt_fn3()
|
||||
self.assertTrue(same(ref, res))
|
||||
|
||||
@skipIfPy311
|
||||
def test_with_on_graph_break_inst(self):
|
||||
def reversible(x):
|
||||
print("Hello world") # Cause graph break so inline fails
|
||||
@ -1306,6 +1313,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
res = opt_fn(x)
|
||||
self.assertTrue(same(ref, res))
|
||||
|
||||
@skipIfPy311
|
||||
def test_with_on_graph_break_nested(self):
|
||||
def reversible(x):
|
||||
torch._dynamo.graph_break() # Cause graph break so inline fails
|
||||
@ -1333,6 +1341,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertTrue(same(ref, res))
|
||||
|
||||
# https://github.com/pytorch/torchdynamo/issues/1446
|
||||
@skipIfPy311
|
||||
def test_grad_mode_carrying_correct_state_after_graph_break(self):
|
||||
def fn(x):
|
||||
with torch.no_grad():
|
||||
@ -2311,6 +2320,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
self.assertEqual(cnt.op_count, 2)
|
||||
|
||||
@skipIfPy311
|
||||
def test_exception_in_dynamo_handling(self):
|
||||
hit_handler = False
|
||||
|
||||
|
@ -7,7 +7,7 @@ import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo import config
|
||||
from torch._dynamo.testing import unsupported
|
||||
from torch._dynamo.testing import skipIfPy311, unsupported
|
||||
from torch._dynamo.utils import disable_cache_limit, ifunspec
|
||||
|
||||
globalmod = torch.nn.ReLU()
|
||||
@ -517,6 +517,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(cnt.frame_count, 7)
|
||||
self.assertEqual(cnt.op_count, 10)
|
||||
|
||||
@skipIfPy311
|
||||
def test_resume_with_no_grad1(self):
|
||||
def fn(a, b):
|
||||
x = a + b
|
||||
@ -532,6 +533,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
|
||||
with torch.no_grad():
|
||||
self._common(fn, 2, 9)
|
||||
|
||||
@skipIfPy311
|
||||
def test_resume_with_no_grad2(self):
|
||||
def fn(a, b):
|
||||
x = a + b
|
||||
@ -546,6 +548,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
self._common(fn, 3, 13)
|
||||
|
||||
@skipIfPy311
|
||||
def test_resume_with_no_grad3(self):
|
||||
def fn(a, b):
|
||||
x = a + b
|
||||
|
@ -2,7 +2,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch._dynamo as torchdynamo
|
||||
from torch.testing._internal.common_utils import xfailIfPython311
|
||||
from torch.testing._internal.common_quantization import (
|
||||
QuantizationTestCase,
|
||||
skip_if_no_torchvision,
|
||||
@ -38,7 +37,6 @@ from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
@skipIfNoQNNPACK
|
||||
class TestQuantizePT2E(QuantizationTestCase):
|
||||
@xfailIfPython311
|
||||
def test_qconfig_none(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -87,7 +85,6 @@ class TestQuantizePT2E(QuantizationTestCase):
|
||||
self.checkGraphModuleNodes(
|
||||
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence)
|
||||
|
||||
@xfailIfPython311
|
||||
def test_qconfig_module_type(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -135,7 +132,6 @@ class TestQuantizePT2E(QuantizationTestCase):
|
||||
]
|
||||
self.checkGraphModuleNodes(m, expected_node_list=node_list)
|
||||
|
||||
@xfailIfPython311
|
||||
def test_simple_quantizer(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -194,7 +190,6 @@ class TestQuantizePT2E(QuantizationTestCase):
|
||||
self.checkGraphModuleNodes(
|
||||
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence)
|
||||
|
||||
@xfailIfPython311
|
||||
def test_qnnpack_quantizer_conv(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -238,7 +233,6 @@ class TestQuantizePT2E(QuantizationTestCase):
|
||||
self.checkGraphModuleNodes(
|
||||
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence)
|
||||
|
||||
@xfailIfPython311
|
||||
def test_rearrange_weight_observer_for_decomposed_linear(self):
|
||||
"""
|
||||
Check whether weight observer is correctly rearranged for decomposed linear.
|
||||
@ -300,7 +294,6 @@ class TestQuantizePT2E(QuantizationTestCase):
|
||||
code_after_recompile = m.code
|
||||
self.assertTrue(code_before_recompile == code_after_recompile, error_msg)
|
||||
|
||||
@xfailIfPython311
|
||||
def test_transposed_conv_bn_fusion(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -346,7 +339,6 @@ class TestQuantizePT2E(QuantizationTestCase):
|
||||
@skipIfNoQNNPACK
|
||||
class TestQuantizePT2EX86Inductor(QuantizationTestCase):
|
||||
@skipIfNoX86
|
||||
@xfailIfPython311
|
||||
def test_inductor_backend_config_conv(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, use_relu: bool = False, inplace_relu: bool = False):
|
||||
@ -432,7 +424,6 @@ class TestQuantizePT2EX86Inductor(QuantizationTestCase):
|
||||
class TestQuantizePT2EModels(QuantizationTestCase):
|
||||
@skip_if_no_torchvision
|
||||
@skipIfNoQNNPACK
|
||||
@xfailIfPython311
|
||||
def test_resnet18(self):
|
||||
import torchvision
|
||||
with override_quantized_engine("qnnpack"):
|
||||
|
@ -215,9 +215,8 @@ def main():
|
||||
f"ROCM version: {rocm_ver}\n"
|
||||
)
|
||||
for args in _SANITY_CHECK_ARGS:
|
||||
# TODO remove check when 3.11 is supported
|
||||
if sys.version_info >= (3, 11):
|
||||
warnings.warn("Dynamo not yet supported in Python 3.11. Skipping check.")
|
||||
if sys.version_info >= (3, 12):
|
||||
warnings.warn("Dynamo not yet supported in Python 3.12. Skipping check.")
|
||||
continue
|
||||
check_dynamo(*args)
|
||||
print("All required checks passed")
|
||||
|
@ -431,8 +431,13 @@ class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg]
|
||||
def check_if_dynamo_supported():
|
||||
if sys.platform == "win32":
|
||||
raise RuntimeError("Windows not yet supported for torch.compile")
|
||||
if sys.version_info >= (3, 11):
|
||||
raise RuntimeError("Python 3.11+ not yet supported for torch.compile")
|
||||
if sys.version_info >= (3, 12):
|
||||
raise RuntimeError("Python 3.12+ not yet supported for torch.compile")
|
||||
elif sys.version_info >= (3, 11):
|
||||
warnings.warn(
|
||||
"torch.compile support of Python 3.11 is experimental. "
|
||||
"Program may generate incorrect results or segfault."
|
||||
)
|
||||
|
||||
|
||||
def is_dynamo_supported():
|
||||
|
@ -23,7 +23,7 @@ def run_tests(needs=()):
|
||||
or IS_WINDOWS
|
||||
or TEST_WITH_CROSSREF
|
||||
or TEST_WITH_ROCM
|
||||
or sys.version_info >= (3, 11)
|
||||
or sys.version_info >= (3, 12)
|
||||
):
|
||||
return # skip testing
|
||||
|
||||
|
@ -3,6 +3,7 @@ import dis
|
||||
import functools
|
||||
import logging
|
||||
import os.path
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
@ -297,3 +298,11 @@ def make_test_cls_with_patches(cls, cls_prefix, fn_suffix, *patches):
|
||||
setattr(DummyTestClass, new_name, fn)
|
||||
|
||||
return DummyTestClass
|
||||
|
||||
|
||||
# temporary decorator to skip failing 3.11 dynamo tests
|
||||
def skipIfPy311(fn):
|
||||
if sys.version_info < (3, 11):
|
||||
return fn
|
||||
else:
|
||||
return unittest.skip(fn)
|
||||
|
@ -1148,13 +1148,6 @@ def skipIfRocmVersionLessThan(version=None):
|
||||
return wrap_fn
|
||||
return dec_fn
|
||||
|
||||
# Temporary function to simplify adding support to 3.11
|
||||
def xfailIfPython311(fn):
|
||||
if sys.version_info < (3, 11):
|
||||
return fn
|
||||
else:
|
||||
return unittest.expectedFailure(fn)
|
||||
|
||||
def skipIfNotMiopenSuggestNHWC(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
|
Reference in New Issue
Block a user