[ROCm] enable complex128 in test_addmm_sizes_all_sparse_csr for rocm for trivial (k,n,m) cases (#120504)

This PR enables `test_addmm_sizes_all_sparse_csr_k_*_n_*_m_*_cuda_complex128` for ROCm for trivial cases  (m or n or k = 0)

CUSPARSE_SPMM_COMPLEX128_SUPPORTED also used for `test_addmm_all_sparse_csr` and ` test_sparse_matmul` and both of them are skipped for ROCm by `@skipIfRocm` or `@skipCUDAIf(not _check_cusparse_spgemm_available())`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120504
Approved by: https://github.com/jithunnair-amd, https://github.com/ezyang
This commit is contained in:
Dmitry Nikolaev
2024-03-12 07:29:57 +00:00
committed by PyTorch MergeBot
parent 86a2d67bb9
commit 656134c38f
2 changed files with 6 additions and 2 deletions

View File

@ -66,6 +66,8 @@ CUSPARSE_SPMM_COMPLEX128_SUPPORTED = (
IS_WINDOWS and torch.version.cuda and version.parse(torch.version.cuda) > version.parse("11.2")
) or (not IS_WINDOWS and not TEST_WITH_ROCM)
HIPSPARSE_SPMM_COMPLEX128_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("6.0")
def all_sparse_layouts(test_name='layout', include_strided=False):
return parametrize(test_name, [
subtest(torch.strided, name='Strided'),

View File

@ -21,7 +21,7 @@ from torch.testing._internal.common_dtype import (
floating_types, all_types_and_complex_and, floating_and_complex_types, floating_types_and,
all_types_and_complex, floating_and_complex_types_and)
from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse
from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED
from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED, HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
import operator
if TEST_SCIPY:
@ -2024,7 +2024,9 @@ class TestSparseCSR(TestCase):
@dtypesIfCUDA(*floating_types_and(torch.complex64,
*[torch.bfloat16] if SM80OrLater else [],
*[torch.half] if SM53OrLater else [],
*[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else []))
*[torch.complex128]
if CUSPARSE_SPMM_COMPLEX128_SUPPORTED or HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
else []))
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
def test_addmm_sizes_all_sparse_csr(self, device, dtype, m, n, k):