mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 06:11:27 +08:00
Fix zeros_like for sparse tensors with batch dimensions. Add opinfo-based tests to like-functions. (#101215)
Fixes #101078 Pull Request resolved: https://github.com/pytorch/pytorch/pull/101215 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
597e2a11a3
commit
cbe270d233
@ -21,7 +21,7 @@ from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride,
|
||||
deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes)
|
||||
from torch.testing._internal.common_methods_invocations import \
|
||||
(reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs)
|
||||
(op_db, reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs)
|
||||
from torch.testing._internal.common_dtype import (
|
||||
all_types, all_types_and_complex, all_types_and_complex_and, floating_and_complex_types,
|
||||
floating_and_complex_types_and, integral_types, floating_types_and,
|
||||
@ -39,6 +39,8 @@ reduction_ops_with_sparse_support = [op for op in reduction_ops if 'masked.' not
|
||||
|
||||
binary_ufuncs_with_sparse_support = [op for op in binary_ufuncs if _op_supports_any_sparse(op)]
|
||||
|
||||
like_fns_with_sparse_support = [op for op in op_db if _op_supports_any_sparse(op) and '_like' in op.name]
|
||||
|
||||
if TEST_SCIPY:
|
||||
import scipy.sparse
|
||||
|
||||
@ -4858,6 +4860,62 @@ class TestSparseAny(TestCase):
|
||||
run_test(m, n, k, device, dtype)
|
||||
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@suppress_warnings
|
||||
@ops(like_fns_with_sparse_support)
|
||||
@all_sparse_layouts('layout', include_strided=False)
|
||||
def test_like_fns(self, layout, device, dtype, op):
|
||||
|
||||
for sample in op.sample_inputs_sparse(layout, device, dtype):
|
||||
t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs
|
||||
batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim()
|
||||
if t_inp.layout in {torch.sparse_bsr, torch.sparse_bsc}:
|
||||
expected_blocksize = t_inp.values().shape[batch_dim + 1:batch_dim + 3]
|
||||
else:
|
||||
expected_blocksize = None
|
||||
expected_dtype = t_kwargs.get('dtype', dtype)
|
||||
expected_device = torch.device(t_kwargs.get('device', device))
|
||||
expected_layout = t_kwargs.get('layout', layout)
|
||||
|
||||
result = op.op(t_inp, *t_args, **t_kwargs)
|
||||
|
||||
self.assertEqual(result.dtype, expected_dtype)
|
||||
self.assertEqual(result.device.type, expected_device.type)
|
||||
self.assertEqual(result.layout, expected_layout)
|
||||
|
||||
if result.layout in {torch.sparse_bsr, torch.sparse_bsc}:
|
||||
result_batch_dim = result.dim() - result.dense_dim() - result.sparse_dim()
|
||||
blocksize = result.values().shape[result_batch_dim + 1:result_batch_dim + 3]
|
||||
self.assertEqual(blocksize, expected_blocksize)
|
||||
|
||||
# Check op(inp).shape == inp.shape
|
||||
self.assertEqual(result.shape, t_inp.shape)
|
||||
|
||||
if expected_layout is torch.strided:
|
||||
self.assertEqual(result.sparse_dim(), 0)
|
||||
# Check op(inp, layout=torch.strided).dense_dim() == inp.dim()
|
||||
self.assertEqual(result.dense_dim(), t_inp.dim())
|
||||
elif expected_layout is torch.sparse_coo:
|
||||
# Check op(inp, layout=torch.sparse_coo).sparse_dim() == batch_dim + inp.sparse_dim()
|
||||
self.assertEqual(result.sparse_dim(), batch_dim + t_inp.sparse_dim())
|
||||
# Check op(inp, layout=torch.sparse_coo).dense_dim() == inp.dense_dim()
|
||||
self.assertEqual(result.dense_dim(), t_inp.dense_dim())
|
||||
|
||||
torch._validate_sparse_coo_tensor_args(result._indices(), result._values(), result.shape)
|
||||
else:
|
||||
# Check op(inp).sparse_dim() == inp.sparse_dim()
|
||||
self.assertEqual(result.sparse_dim(), t_inp.sparse_dim())
|
||||
# Check op(inp).dense_dim() == inp.dense_dim()
|
||||
self.assertEqual(result.dense_dim(), t_inp.dense_dim())
|
||||
|
||||
if result.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
||||
compressed_indices, plain_indices = result.crow_indices(), result.col_indices()
|
||||
else:
|
||||
compressed_indices, plain_indices = result.ccol_indices(), result.row_indices()
|
||||
|
||||
torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, result.values(),
|
||||
result.shape, result.layout)
|
||||
|
||||
# e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA
|
||||
instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta')
|
||||
|
||||
|
Reference in New Issue
Block a user