[ROCm] Enable few test_prim UTs for ROCm (#88983)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88983
Approved by: https://github.com/IvanYashchuk, https://github.com/jeffdaily, https://github.com/malfet
This commit is contained in:
Pruthvi Madugundu
2022-12-07 06:21:31 +00:00
committed by PyTorch MergeBot
parent 26d1dbc4f8
commit 15949fc248

View File

@ -13,7 +13,6 @@ from torch.testing._internal.common_utils import (parametrize, run_tests, TestCa
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCUDA,
skipCUDAIfRocm,
dtypes,
OpDTypes,
)
@ -39,7 +38,6 @@ GET_ISOLATED_GRAPHMODULE_ERROR = "get_isolated_graphmodule failed on decompositi
class TestPrims(TestCase):
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32)
def test_broadcast_in_dim(self, device, dtype):
def _wrapper(a, b, broadcast_dimensions):
@ -103,7 +101,6 @@ class TestPrims(TestCase):
"""
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32)
def test_broadcast_in_dim_sum(self, device, dtype):
def _wrapper(a):
@ -144,7 +141,6 @@ class TestPrims(TestCase):
self.assertEqual(y, y_np, exact_device=False)
@onlyCUDA
@skipCUDAIfRocm
def test_nvfuser_impl_is_used(self, device):
# This test is to ensure that when the nvfuser implementation exists it is used
# Assuming one-to-one mapping between prims and nvfuser implementations
@ -235,7 +231,6 @@ class TestPrims(TestCase):
self.assertEqual(len(partitions), 1)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32)
def test_full(self, device, dtype):
from torch.fx.experimental.proxy_tensor import make_fx
@ -275,7 +270,6 @@ class TestPrims(TestCase):
self.assertEqual(out, func(size, value, b))
@onlyCUDA
@skipCUDAIfRocm
def test_nvfuser_empty_fusion(self, device):
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute
@ -327,7 +321,6 @@ class TestPrims(TestCase):
self.assertEqual(includes_nvprim_convert_element_type, nvprim_support_flag)
@onlyCUDA
@skipCUDAIfRocm
def test_nvfuser_rand_like_fusion(self, device):
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
@ -346,7 +339,6 @@ class TestPrims(TestCase):
@skipCUDAMemoryLeakCheckIf(True) # https://github.com/pytorch/pytorch/issues/84529
@onlyCUDA
@skipCUDAIfRocm
def test_nvfuser_no_args(self, device):
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
@ -377,7 +369,6 @@ class TestPrims(TestCase):
self.assertEqual(out, func())
@onlyCUDA
@skipCUDAIfRocm
def test_nvfuser_constant_tensors(self, device):
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
@ -400,7 +391,6 @@ class TestPrims(TestCase):
self.assertEqual(out, gm(b))
@onlyCUDA
@skipCUDAIfRocm
def test_nvfuser_executor_cached_noncontiguous(self, device):
# This test is to ensure that nvfuser computes correct results for noncontiguous tensors
from torch.fx.experimental.proxy_tensor import make_fx
@ -503,7 +493,6 @@ class TestPrims(TestCase):
@onlyCUDA
@skipCUDAIfRocm
def test_nvfuser_executor_parameters(self, device):
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute
@ -536,7 +525,6 @@ class TestPrims(TestCase):
@onlyCUDA
@skipCUDAIfRocm
def test_nvfuser_executor_partitioned(self, device):
# This test is to ensure that nvfuser partitioned executor works correctly
# It's assumed that digamma is not supported by nvfuser
@ -565,7 +553,6 @@ class TestPrims(TestCase):
self.assertEqual(expected, actual)
@onlyCUDA
@skipCUDAIfRocm
def test_nvfuser_executor_partitioned_no_partitions_error(self, device):
# This test is to ensure that nvfuser partitioned executor works correctly
# It's assumed that digamma is not supported by nvfuser
@ -610,7 +597,6 @@ class TestPrims(TestCase):
self.assertFalse(node.target == torch.ops.aten.add.default)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32, torch.float64)
def test_native_batch_norm_nvprims(self, device, dtype):
from torch._prims.context import TorchRefsNvfuserCapabilityMode
@ -673,7 +659,6 @@ class TestPrims(TestCase):
self.assertEqual(out, gm(sample.input, *sample.args))
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32, torch.float64)
def test_cudnn_batch_norm_nvprims(self, device, dtype):
from torch._prims.context import TorchRefsNvfuserCapabilityMode
@ -778,7 +763,6 @@ class TestPrims(TestCase):
self.assertTrue(all_nvprims)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32)
def test_silu_backward_no_filled_tensor(self, device, dtype):
# This test verifies a workaround for
@ -827,7 +811,6 @@ class TestPrims(TestCase):
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32)
@parametrize("correction", [0, 1])
def test_var(self, device, dtype, correction):
@ -848,7 +831,6 @@ class TestPrims(TestCase):
self.assertEqual(_wrapper(a), result)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float16, torch.float32)
@parametrize("correction", [0, 1])
@parametrize("keepdim", [True, False])
@ -873,7 +855,6 @@ class TestPrims(TestCase):
self.assertTrue(includes_nvprims_var_mean)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float16, torch.float32)
def test_nvprims_view(self, device, dtype):
from torch.fx.experimental.proxy_tensor import make_fx
@ -920,7 +901,6 @@ class TestPrims(TestCase):
self.assertEqual(out, func(a))
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float16, torch.float32)
def test_nvprims_view_partitioner(self, device, dtype):
# This test verifies that views that are not fused with other ops are
@ -946,7 +926,6 @@ class TestPrims(TestCase):
self.assertEqual(out, func(a, b))
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32, torch.float16)
def test_cpu_tensor(self, device, dtype):
from torch.fx.experimental.proxy_tensor import make_fx
@ -987,7 +966,6 @@ class TestPrims(TestCase):
self.assertEqual(expected, nvprim_aten_fallback)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32)
def test_pytree_input_output(self, device, dtype):
@make_traced
@ -1144,7 +1122,6 @@ instantiate_device_type_tests(TestRefs, globals())
class TestDecomp(TestCase):
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float16, torch.float32)
def test_decomposition_type_promotion_nvprim_amp(self, device, dtype):
x = torch.rand(5, device=device).to(dtype)
@ -1185,7 +1162,6 @@ class TestDecomp(TestCase):
self.assertFalse(includes_aten_to_copy)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float16, torch.float32)
def test_masked_fill_decomposition_under_nvprim_context(self, device, dtype):
# masked_fill decomposition extracts cpu scalar tensor value when