Fix zeros_like for sparse tensors with batch dimensions. Add opinfo-based tests to like-functions. (#101215)

Fixes #101078

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101215
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Pearu Peterson
2023-06-13 12:28:29 +03:00
committed by PyTorch MergeBot
parent 597e2a11a3
commit cbe270d233
8 changed files with 278 additions and 16 deletions

View File

@ -21,7 +21,7 @@ from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride,
deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes)
from torch.testing._internal.common_methods_invocations import \
(reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs)
(op_db, reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs)
from torch.testing._internal.common_dtype import (
all_types, all_types_and_complex, all_types_and_complex_and, floating_and_complex_types,
floating_and_complex_types_and, integral_types, floating_types_and,
@ -39,6 +39,8 @@ reduction_ops_with_sparse_support = [op for op in reduction_ops if 'masked.' not
binary_ufuncs_with_sparse_support = [op for op in binary_ufuncs if _op_supports_any_sparse(op)]
like_fns_with_sparse_support = [op for op in op_db if _op_supports_any_sparse(op) and '_like' in op.name]
if TEST_SCIPY:
import scipy.sparse
@ -4858,6 +4860,62 @@ class TestSparseAny(TestCase):
run_test(m, n, k, device, dtype)
@onlyNativeDeviceTypes
@suppress_warnings
@ops(like_fns_with_sparse_support)
@all_sparse_layouts('layout', include_strided=False)
def test_like_fns(self, layout, device, dtype, op):
for sample in op.sample_inputs_sparse(layout, device, dtype):
t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs
batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim()
if t_inp.layout in {torch.sparse_bsr, torch.sparse_bsc}:
expected_blocksize = t_inp.values().shape[batch_dim + 1:batch_dim + 3]
else:
expected_blocksize = None
expected_dtype = t_kwargs.get('dtype', dtype)
expected_device = torch.device(t_kwargs.get('device', device))
expected_layout = t_kwargs.get('layout', layout)
result = op.op(t_inp, *t_args, **t_kwargs)
self.assertEqual(result.dtype, expected_dtype)
self.assertEqual(result.device.type, expected_device.type)
self.assertEqual(result.layout, expected_layout)
if result.layout in {torch.sparse_bsr, torch.sparse_bsc}:
result_batch_dim = result.dim() - result.dense_dim() - result.sparse_dim()
blocksize = result.values().shape[result_batch_dim + 1:result_batch_dim + 3]
self.assertEqual(blocksize, expected_blocksize)
# Check op(inp).shape == inp.shape
self.assertEqual(result.shape, t_inp.shape)
if expected_layout is torch.strided:
self.assertEqual(result.sparse_dim(), 0)
# Check op(inp, layout=torch.strided).dense_dim() == inp.dim()
self.assertEqual(result.dense_dim(), t_inp.dim())
elif expected_layout is torch.sparse_coo:
# Check op(inp, layout=torch.sparse_coo).sparse_dim() == batch_dim + inp.sparse_dim()
self.assertEqual(result.sparse_dim(), batch_dim + t_inp.sparse_dim())
# Check op(inp, layout=torch.sparse_coo).dense_dim() == inp.dense_dim()
self.assertEqual(result.dense_dim(), t_inp.dense_dim())
torch._validate_sparse_coo_tensor_args(result._indices(), result._values(), result.shape)
else:
# Check op(inp).sparse_dim() == inp.sparse_dim()
self.assertEqual(result.sparse_dim(), t_inp.sparse_dim())
# Check op(inp).dense_dim() == inp.dense_dim()
self.assertEqual(result.dense_dim(), t_inp.dense_dim())
if result.layout in {torch.sparse_csr, torch.sparse_bsr}:
compressed_indices, plain_indices = result.crow_indices(), result.col_indices()
else:
compressed_indices, plain_indices = result.ccol_indices(), result.row_indices()
torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, result.values(),
result.shape, result.layout)
# e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA
instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta')