mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] enable tests test_sampled_addmm_autograd_cuda_*, test_sample… (#117501)
These tests PASS on ROCM 5.6+ now: - test_sampled_addmm_autograd_cuda_complex128 - test_sampled_addmm_autograd_cuda_complex64 - test_sampled_addmm_autograd_cuda_float32 - test_sampled_addmm_autograd_cuda_float64 - test_sampled_addmm_cuda_complex128 - test_sampled_addmm_cuda_complex64 - test_sampled_addmm_cuda_float32 - test_sampled_addmm_cuda_float64 - test_autograd_dense_output_addmm_cuda_float64 - test_autograd_dense_output_addmv_cuda_float64 - test_autograd_dense_output_mv_cuda_float64 @pruthvistony @jithunnair-amd Pull Request resolved: https://github.com/pytorch/pytorch/pull/117501 Approved by: https://github.com/pruthvistony, https://github.com/jeffdaily, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
1c1028ac49
commit
c7328602ed
@ -46,6 +46,8 @@ def _check_cusparse_spgemm_available():
|
||||
return not TEST_WITH_ROCM
|
||||
|
||||
def _check_cusparse_sddmm_available():
|
||||
if TEST_WITH_ROCM:
|
||||
return True
|
||||
version = _get_torch_cuda_version()
|
||||
# cusparseSDDMM was added in 11.2.1 but we don't have access to patch version
|
||||
min_supported_version = (11, 3)
|
||||
@ -2381,7 +2383,6 @@ class TestSparseCSR(TestCase):
|
||||
itertools.product([True, False], repeat=4)):
|
||||
run_test(n, k, upper, unitriangular, transpose, zero)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@skipCUDAIf(
|
||||
not _check_cusparse_sddmm_available(),
|
||||
"cuSparse Generic API SDDMM is not available"
|
||||
@ -2436,7 +2437,6 @@ class TestSparseCSR(TestCase):
|
||||
for op_a, op_b in itertools.product([True, False], repeat=2):
|
||||
run_test(c, a, b, op_a, op_b)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@skipCUDAIf(
|
||||
not _check_cusparse_sddmm_available(),
|
||||
"cuSparse Generic API SDDMM is not available"
|
||||
@ -2490,7 +2490,7 @@ class TestSparseCSR(TestCase):
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIf(
|
||||
not (TEST_WITH_ROCM or _check_cusparse_sddmm_available()),
|
||||
not _check_cusparse_sddmm_available(),
|
||||
"cuSparse Generic API SDDMM is not available"
|
||||
)
|
||||
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
||||
@ -2772,7 +2772,6 @@ class TestSparseCSR(TestCase):
|
||||
dense_output.backward(dense_covector)
|
||||
self.assertEqual(sparse_input.grad, dense_input.grad)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@skipCUDAIf(
|
||||
not _check_cusparse_sddmm_available(),
|
||||
"cuSparse Generic API SDDMM is not available"
|
||||
@ -2848,7 +2847,6 @@ class TestSparseCSR(TestCase):
|
||||
else:
|
||||
self.assertEqual(a.grad, dense_a.grad)
|
||||
|
||||
@skipCUDAIfRocm
|
||||
@skipCPUIfNoMklSparse
|
||||
@dtypes(torch.float64)
|
||||
def test_autograd_dense_output_addmv(self, device, dtype):
|
||||
@ -2883,9 +2881,6 @@ class TestSparseCSR(TestCase):
|
||||
def test_autograd_dense_output(self, device, dtype, op):
|
||||
if op.name == "mv" and no_mkl_sparse and self.device_type == 'cpu':
|
||||
self.skipTest("MKL Sparse is not available")
|
||||
if op.name == "mv" and TEST_WITH_ROCM and self.device_type == 'cuda':
|
||||
# mv currently work only on CUDA
|
||||
self.skipTest("ROCm is not supported")
|
||||
|
||||
samples = list(op.sample_inputs(device, dtype, requires_grad=True))
|
||||
|
||||
|
Reference in New Issue
Block a user