mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] enable complex128 in test_addmm_sizes_all_sparse_csr for rocm for trivial (k,n,m) cases (#120504)
This PR enables `test_addmm_sizes_all_sparse_csr_k_*_n_*_m_*_cuda_complex128` for ROCm for trivial cases (m or n or k = 0) CUSPARSE_SPMM_COMPLEX128_SUPPORTED also used for `test_addmm_all_sparse_csr` and ` test_sparse_matmul` and both of them are skipped for ROCm by `@skipIfRocm` or `@skipCUDAIf(not _check_cusparse_spgemm_available())` Pull Request resolved: https://github.com/pytorch/pytorch/pull/120504 Approved by: https://github.com/jithunnair-amd, https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
86a2d67bb9
commit
656134c38f
@ -21,7 +21,7 @@ from torch.testing._internal.common_dtype import (
|
||||
floating_types, all_types_and_complex_and, floating_and_complex_types, floating_types_and,
|
||||
all_types_and_complex, floating_and_complex_types_and)
|
||||
from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse
|
||||
from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED
|
||||
from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED, HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
|
||||
import operator
|
||||
|
||||
if TEST_SCIPY:
|
||||
@ -2024,7 +2024,9 @@ class TestSparseCSR(TestCase):
|
||||
@dtypesIfCUDA(*floating_types_and(torch.complex64,
|
||||
*[torch.bfloat16] if SM80OrLater else [],
|
||||
*[torch.half] if SM53OrLater else [],
|
||||
*[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else []))
|
||||
*[torch.complex128]
|
||||
if CUSPARSE_SPMM_COMPLEX128_SUPPORTED or HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
|
||||
else []))
|
||||
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
|
||||
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
|
||||
def test_addmm_sizes_all_sparse_csr(self, device, dtype, m, n, k):
|
||||
|
Reference in New Issue
Block a user