mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] Enable few test_prim UTs for ROCm (#88983)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88983 Approved by: https://github.com/IvanYashchuk, https://github.com/jeffdaily, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
26d1dbc4f8
commit
15949fc248
@ -13,7 +13,6 @@ from torch.testing._internal.common_utils import (parametrize, run_tests, TestCa
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
onlyCUDA,
|
||||
skipCUDAIfRocm,
|
||||
dtypes,
|
||||
OpDTypes,
|
||||
)
|
||||
@ -39,7 +38,6 @@ GET_ISOLATED_GRAPHMODULE_ERROR = "get_isolated_graphmodule failed on decompositi
|
||||
|
||||
class TestPrims(TestCase):
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32)
|
||||
def test_broadcast_in_dim(self, device, dtype):
|
||||
def _wrapper(a, b, broadcast_dimensions):
|
||||
@ -103,7 +101,6 @@ class TestPrims(TestCase):
|
||||
"""
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32)
|
||||
def test_broadcast_in_dim_sum(self, device, dtype):
|
||||
def _wrapper(a):
|
||||
@ -144,7 +141,6 @@ class TestPrims(TestCase):
|
||||
self.assertEqual(y, y_np, exact_device=False)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
def test_nvfuser_impl_is_used(self, device):
|
||||
# This test is to ensure that when the nvfuser implementation exists it is used
|
||||
# Assuming one-to-one mapping between prims and nvfuser implementations
|
||||
@ -235,7 +231,6 @@ class TestPrims(TestCase):
|
||||
self.assertEqual(len(partitions), 1)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32)
|
||||
def test_full(self, device, dtype):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
@ -275,7 +270,6 @@ class TestPrims(TestCase):
|
||||
self.assertEqual(out, func(size, value, b))
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
def test_nvfuser_empty_fusion(self, device):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.executor import execute
|
||||
@ -327,7 +321,6 @@ class TestPrims(TestCase):
|
||||
self.assertEqual(includes_nvprim_convert_element_type, nvprim_support_flag)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
def test_nvfuser_rand_like_fusion(self, device):
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
@ -346,7 +339,6 @@ class TestPrims(TestCase):
|
||||
|
||||
@skipCUDAMemoryLeakCheckIf(True) # https://github.com/pytorch/pytorch/issues/84529
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
def test_nvfuser_no_args(self, device):
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
@ -377,7 +369,6 @@ class TestPrims(TestCase):
|
||||
self.assertEqual(out, func())
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
def test_nvfuser_constant_tensors(self, device):
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
@ -400,7 +391,6 @@ class TestPrims(TestCase):
|
||||
self.assertEqual(out, gm(b))
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
def test_nvfuser_executor_cached_noncontiguous(self, device):
|
||||
# This test is to ensure that nvfuser computes correct results for noncontiguous tensors
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
@ -503,7 +493,6 @@ class TestPrims(TestCase):
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
def test_nvfuser_executor_parameters(self, device):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.executor import execute
|
||||
@ -536,7 +525,6 @@ class TestPrims(TestCase):
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
def test_nvfuser_executor_partitioned(self, device):
|
||||
# This test is to ensure that nvfuser partitioned executor works correctly
|
||||
# It's assumed that digamma is not supported by nvfuser
|
||||
@ -565,7 +553,6 @@ class TestPrims(TestCase):
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
def test_nvfuser_executor_partitioned_no_partitions_error(self, device):
|
||||
# This test is to ensure that nvfuser partitioned executor works correctly
|
||||
# It's assumed that digamma is not supported by nvfuser
|
||||
@ -610,7 +597,6 @@ class TestPrims(TestCase):
|
||||
self.assertFalse(node.target == torch.ops.aten.add.default)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32, torch.float64)
|
||||
def test_native_batch_norm_nvprims(self, device, dtype):
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
@ -673,7 +659,6 @@ class TestPrims(TestCase):
|
||||
self.assertEqual(out, gm(sample.input, *sample.args))
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32, torch.float64)
|
||||
def test_cudnn_batch_norm_nvprims(self, device, dtype):
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
@ -778,7 +763,6 @@ class TestPrims(TestCase):
|
||||
self.assertTrue(all_nvprims)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32)
|
||||
def test_silu_backward_no_filled_tensor(self, device, dtype):
|
||||
# This test verifies a workaround for
|
||||
@ -827,7 +811,6 @@ class TestPrims(TestCase):
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32)
|
||||
@parametrize("correction", [0, 1])
|
||||
def test_var(self, device, dtype, correction):
|
||||
@ -848,7 +831,6 @@ class TestPrims(TestCase):
|
||||
self.assertEqual(_wrapper(a), result)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float16, torch.float32)
|
||||
@parametrize("correction", [0, 1])
|
||||
@parametrize("keepdim", [True, False])
|
||||
@ -873,7 +855,6 @@ class TestPrims(TestCase):
|
||||
self.assertTrue(includes_nvprims_var_mean)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float16, torch.float32)
|
||||
def test_nvprims_view(self, device, dtype):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
@ -920,7 +901,6 @@ class TestPrims(TestCase):
|
||||
self.assertEqual(out, func(a))
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float16, torch.float32)
|
||||
def test_nvprims_view_partitioner(self, device, dtype):
|
||||
# This test verifies that views that are not fused with other ops are
|
||||
@ -946,7 +926,6 @@ class TestPrims(TestCase):
|
||||
self.assertEqual(out, func(a, b))
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32, torch.float16)
|
||||
def test_cpu_tensor(self, device, dtype):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
@ -987,7 +966,6 @@ class TestPrims(TestCase):
|
||||
self.assertEqual(expected, nvprim_aten_fallback)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32)
|
||||
def test_pytree_input_output(self, device, dtype):
|
||||
@make_traced
|
||||
@ -1144,7 +1122,6 @@ instantiate_device_type_tests(TestRefs, globals())
|
||||
|
||||
class TestDecomp(TestCase):
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float16, torch.float32)
|
||||
def test_decomposition_type_promotion_nvprim_amp(self, device, dtype):
|
||||
x = torch.rand(5, device=device).to(dtype)
|
||||
@ -1185,7 +1162,6 @@ class TestDecomp(TestCase):
|
||||
self.assertFalse(includes_aten_to_copy)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float16, torch.float32)
|
||||
def test_masked_fill_decomposition_under_nvprim_context(self, device, dtype):
|
||||
# masked_fill decomposition extracts cpu scalar tensor value when
|
||||
|
Reference in New Issue
Block a user