Remove unnecessary skipIfTorchDynamo from test_jit_fuser_te (#118728)

And add some expected failures.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118728
Approved by: https://github.com/bdhirsh
This commit is contained in:
rzou
2024-02-12 11:25:34 -08:00
committed by PyTorch MergeBot
parent 28c30f29be
commit 7eecbf8a30
3 changed files with 53 additions and 10 deletions

View File

@ -82,7 +82,6 @@ def inline_fusion_groups():
torch._C._debug_set_fusion_group_inlining(old_inlining)
@skipIfTorchDynamo()
class TestTEFuser(JitTestCase):
def setUp(self):
super().setUp()
@ -1924,6 +1923,7 @@ class TestTEFuser(JitTestCase):
t = torch.rand(8, dtype=torch.float, device=device)
scripted = self.checkScript(eager, (t, t, t, t, 0.1))
@skipIfTorchDynamo("too slow")
def test_chunk_mul_one(self):
if self.dynamic_shapes:
self.skipTest("TODO: chunk dynamic shapes")
@ -2200,6 +2200,7 @@ class TestTEFuser(JitTestCase):
x = torch.ones((8, 1))
torch.testing.assert_close(eager(x), script(x))
@skipIfTorchDynamo("too slow")
@unittest.skipIf(TEST_WITH_ASAN, "takes 10+ minutes on asan")
def test_batch_norm(self):
def test(fn, args):
@ -2626,7 +2627,6 @@ def get_name(op):
# super() [with no arguments] fails, presumably because of how instantiate_device_type_tests works.
# super(TestNNCOpInfo, self) fails because TestNNCOpInfo gets deleted from global scope.
# super(JitCommonTestCase, self).fn() would skip JitCommonTestCase.fn() implementation
@skipIfTorchDynamo()
class TestNNCOpInfoParent(JitCommonTestCase):
pass
@ -2681,7 +2681,6 @@ def f({', '.join(param_names)}):
self.assertEqual(kernel.fallback(tuple(param_values)), correct_val)
@onlyCPU
@skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
@unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
@ops([op for op in op_db if get_name(op) in works_list], allowed_dtypes=(torch.float,))
def test_working(self, device, dtype, op):
@ -2745,7 +2744,6 @@ only_for = ("cuda") if IS_FBCODE else ("cpu", "cuda")
instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for)
# Purpose of this class is to allow super() calls. (See TestNNCOpInfoParent)
@skipIfTorchDynamo()
class TestLoopnestRandomizationParent(JitTestCase):
pass