Fix zeros_like for sparse tensors with batch dimensions. Add opinfo-based tests to like-functions. (#101215)

Fixes #101078

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101215
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Pearu Peterson
2023-06-13 12:28:29 +03:00
committed by PyTorch MergeBot
parent 597e2a11a3
commit cbe270d233
8 changed files with 278 additions and 16 deletions

View File

@ -101,20 +101,16 @@ void SparseCsrTensorImpl::resize_(int64_t nnz, IntArrayRef size) {
refresh_numel();
}
void SparseCsrTensorImpl::resize_and_clear_(int64_t sparse_dim, IntArrayRef size) {
void SparseCsrTensorImpl::resize_and_clear_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) {
TORCH_CHECK(
!has_symbolic_sizes_strides_,
"resize_and_clear_ called on tensor with symbolic shape");
TORCH_CHECK(sparse_dim >= 2, "resize_and_clear_ sparse dimensionality must be at least 2, got ", sparse_dim);
TORCH_CHECK(static_cast<int64_t>(size.size()) >= sparse_dim, "resize_and_clear_ size length must be at least sparse dimensionality (=",
sparse_dim, "), got ", size.size());
auto batch_dim = sparse_dim - 2;
TORCH_CHECK(sparse_dim == 2, "resize_and_clear_ sparse dimensionality must be 2, got ", sparse_dim);
TORCH_CHECK(static_cast<int64_t>(size.size()) >= sparse_dim + dense_dim, "resize_and_clear_ size length must be at least sparse dimensionality (=",
sparse_dim, ") plus dense dimensionality (=", dense_dim, "), got ", size.size());
auto batch_dim = size.size() - sparse_dim - dense_dim;
auto batchsize = size.slice(0, batch_dim);
auto densesize = size.slice(batch_dim + 2, size.size() - batch_dim - 2);
auto values_size = DimVector(batchsize);
values_size.push_back(0); // nse
values_size.append(densesize.begin(), densesize.end());
auto densesize = size.slice(batch_dim + sparse_dim, dense_dim);
auto col_indices_size = DimVector(batchsize);
col_indices_size.push_back(0); // nse
@ -123,14 +119,26 @@ void SparseCsrTensorImpl::resize_and_clear_(int64_t sparse_dim, IntArrayRef size
[&] () -> int64_t { return size[batch_dim]; },
[&] () -> int64_t { return size[batch_dim + 1]; }
);
auto values_size = DimVector(batchsize);
values_size.push_back(0); // nse
// WARNING: in the case of block tensors, the block size is defined
// by the existing values shape.
int64_t block_factor = 1;
AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout_,
"resize_and_clear_",
[] () {},
[&] () {
auto blocksize = this->values_.sizes().slice(this->batch_dim() + 1, 2);
values_size.append(blocksize.begin(), blocksize.end());
n_compressed_indices /= blocksize[(the_layout == kSparseBsr ? 0 : 1)];
block_factor = blocksize[(the_layout == kSparseBsr ? 0 : 1)];
});
TORCH_CHECK(n_compressed_indices % block_factor == 0,
"The size of the compressed dimension (=", n_compressed_indices,
") must be divisible with the corresponding block size (=", block_factor,")");
n_compressed_indices /= block_factor;
values_size.append(densesize.begin(), densesize.end());
auto crow_indices_size = DimVector(batchsize);
crow_indices_size.push_back(n_compressed_indices + 1);

View File

@ -37,7 +37,10 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
const caffe2::TypeMeta);
void resize_(int64_t nnz, IntArrayRef size);
void resize_and_clear_(int64_t sparse_dim, IntArrayRef size);
void resize_and_clear_(
int64_t sparse_dim,
int64_t dense_dim,
IntArrayRef size);
void resize_as_sparse_compressed_tensor_(const Tensor& src);
void set_member_tensors(
const Tensor& crow_indices,

View File

@ -408,7 +408,7 @@ Tensor& zero_sparse_csr_(Tensor& self) {
`result = csr.clone(); result.values.zero_();`
*/
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "zero_sparse_csr_", [](){});
get_sparse_csr_impl(self)->resize_and_clear_(self.sparse_dim(), self.sizes());
get_sparse_csr_impl(self)->resize_and_clear_(self.sparse_dim(), self.dense_dim(), self.sizes());
return self;
}

View File

@ -21,7 +21,7 @@ from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride,
deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes)
from torch.testing._internal.common_methods_invocations import \
(reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs)
(op_db, reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs)
from torch.testing._internal.common_dtype import (
all_types, all_types_and_complex, all_types_and_complex_and, floating_and_complex_types,
floating_and_complex_types_and, integral_types, floating_types_and,
@ -39,6 +39,8 @@ reduction_ops_with_sparse_support = [op for op in reduction_ops if 'masked.' not
binary_ufuncs_with_sparse_support = [op for op in binary_ufuncs if _op_supports_any_sparse(op)]
like_fns_with_sparse_support = [op for op in op_db if _op_supports_any_sparse(op) and '_like' in op.name]
if TEST_SCIPY:
import scipy.sparse
@ -4858,6 +4860,62 @@ class TestSparseAny(TestCase):
run_test(m, n, k, device, dtype)
@onlyNativeDeviceTypes
@suppress_warnings
@ops(like_fns_with_sparse_support)
@all_sparse_layouts('layout', include_strided=False)
def test_like_fns(self, layout, device, dtype, op):
for sample in op.sample_inputs_sparse(layout, device, dtype):
t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs
batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim()
if t_inp.layout in {torch.sparse_bsr, torch.sparse_bsc}:
expected_blocksize = t_inp.values().shape[batch_dim + 1:batch_dim + 3]
else:
expected_blocksize = None
expected_dtype = t_kwargs.get('dtype', dtype)
expected_device = torch.device(t_kwargs.get('device', device))
expected_layout = t_kwargs.get('layout', layout)
result = op.op(t_inp, *t_args, **t_kwargs)
self.assertEqual(result.dtype, expected_dtype)
self.assertEqual(result.device.type, expected_device.type)
self.assertEqual(result.layout, expected_layout)
if result.layout in {torch.sparse_bsr, torch.sparse_bsc}:
result_batch_dim = result.dim() - result.dense_dim() - result.sparse_dim()
blocksize = result.values().shape[result_batch_dim + 1:result_batch_dim + 3]
self.assertEqual(blocksize, expected_blocksize)
# Check op(inp).shape == inp.shape
self.assertEqual(result.shape, t_inp.shape)
if expected_layout is torch.strided:
self.assertEqual(result.sparse_dim(), 0)
# Check op(inp, layout=torch.strided).dense_dim() == inp.dim()
self.assertEqual(result.dense_dim(), t_inp.dim())
elif expected_layout is torch.sparse_coo:
# Check op(inp, layout=torch.sparse_coo).sparse_dim() == batch_dim + inp.sparse_dim()
self.assertEqual(result.sparse_dim(), batch_dim + t_inp.sparse_dim())
# Check op(inp, layout=torch.sparse_coo).dense_dim() == inp.dense_dim()
self.assertEqual(result.dense_dim(), t_inp.dense_dim())
torch._validate_sparse_coo_tensor_args(result._indices(), result._values(), result.shape)
else:
# Check op(inp).sparse_dim() == inp.sparse_dim()
self.assertEqual(result.sparse_dim(), t_inp.sparse_dim())
# Check op(inp).dense_dim() == inp.dense_dim()
self.assertEqual(result.dense_dim(), t_inp.dense_dim())
if result.layout in {torch.sparse_csr, torch.sparse_bsr}:
compressed_indices, plain_indices = result.crow_indices(), result.col_indices()
else:
compressed_indices, plain_indices = result.ccol_indices(), result.row_indices()
torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, result.values(),
result.shape, result.layout)
# e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA
instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta')

View File

@ -515,6 +515,14 @@ class TestSparseCompressed(TestCase):
if len(samples) == 0:
raise ValueError("Expected at least one 2 or higher D tensor in samples.")
# Re-define atol and rtol for operations that result values
# are random (and hence, non-comparable) be we still want to
# check the shape, dtype, etc attributes of the results:
atol = rtol = None
if op.name == 'randn_like':
atol = 1e300
rtol = 1
for sample, sparse_sample in samples:
expected = op(sample.input, *sample.args, **sample.kwargs)
assert torch.is_tensor(expected)
@ -524,7 +532,7 @@ class TestSparseCompressed(TestCase):
if require_mask and sample.kwargs.get('mask') is not None:
output_mask = torch.masked._output_mask(op.op, sample.input, *sample.args, **sample.kwargs)
expected.masked_fill_(~output_mask, 0)
self.assertEqual(strided_output, expected)
self.assertEqual(strided_output, expected, atol=atol, rtol=rtol)
@skipMeta
@all_sparse_compressed_layouts()

View File

@ -134,6 +134,8 @@ from torch.testing._internal.opinfo.definitions._masked import (
sample_inputs_softmax_variant,
)
from torch.testing._internal.opinfo.definitions.sparse import (
error_inputs_sparse_like_fns,
sample_inputs_sparse_like_fns,
error_inputs_sparse_mul,
sample_inputs_sparse_mul,
error_inputs_sparse_reduction_sum,
@ -15690,6 +15692,12 @@ op_db: List[OpInfo] = [
supports_out=False,
sample_inputs_func=sample_inputs_like_fns,
supports_autograd=False,
error_inputs_sparse_func=error_inputs_sparse_like_fns,
sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_coo),
sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csr),
sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csc),
sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsr),
sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsc),
skips=(
)),
OpInfo('ones_like',
@ -15732,7 +15740,12 @@ op_db: List[OpInfo] = [
supports_out=False,
sample_inputs_func=sample_inputs_like_fns,
supports_autograd=False,
supports_sparse_csr=True,
error_inputs_sparse_func=error_inputs_sparse_like_fns,
sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_coo),
sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csr),
sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csc),
sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsr),
sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsc),
skips=(
DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
# AssertionError: JIT Test does not execute any logic

View File

@ -973,18 +973,28 @@ class OpInfo:
# corresponding layout support implies the layout support:
if self.supports_sparse is None:
self.supports_sparse = self.sample_inputs_sparse_coo_func is not None
if self.sample_inputs_sparse_coo_func is None:
self.sample_inputs_sparse_coo_func = self._sample_inputs_unspecified
if self.supports_sparse_csr is None:
self.supports_sparse_csr = self.sample_inputs_sparse_csr_func is not None
if self.sample_inputs_sparse_csr_func is None:
self.sample_inputs_sparse_csr_func = self._sample_inputs_unspecified
if self.supports_sparse_csc is None:
self.supports_sparse_csc = self.sample_inputs_sparse_csc_func is not None
if self.sample_inputs_sparse_csc_func is None:
self.sample_inputs_sparse_csc_func = self._sample_inputs_unspecified
if self.supports_sparse_bsr is None:
self.supports_sparse_bsr = self.sample_inputs_sparse_bsr_func is not None
if self.sample_inputs_sparse_bsr_func is None:
self.sample_inputs_sparse_bsr_func = self._sample_inputs_unspecified
if self.supports_sparse_bsc is None:
self.supports_sparse_bsc = self.sample_inputs_sparse_bsc_func is not None
if self.sample_inputs_sparse_bsc_func is None:
self.sample_inputs_sparse_bsc_func = self._sample_inputs_unspecified
# We run the sampling functions without tracking the gradiends of the creation of inputs
self.sample_inputs_func = torch.no_grad()(self.sample_inputs_func)
@ -1228,6 +1238,21 @@ class OpInfo:
sample_inputs_mth(device, dtype, requires_grad=requires_grad, **kwargs),
)
def _sample_inputs_unspecified(self, *args, **kwargs):
"""Raises an NotImplemented exception in a OpInfo instance creation
that specifies supports_sparse(|_csr|_csc|_bsr|_bsc)=True
without specifying the corresponding sample function as
sample_inputs_sparse_(coo|csr|csc|bsr|bsc)_func.
To avoid this, either define the corresponding sample function,
or re-map unsupported samples to error inputs in an appropiate
opinfo/definitions/sparse.py:_validate_sample_input_sparse_<op>
function.
"""
raise NotImplementedError("no sample function specified")
def sample_inputs_sparse_coo(self, device, dtype, requires_grad=False, **kwargs):
"""Returns an iterable of SampleInputs that contain inputs with sparse
coo layout.

View File

@ -598,6 +598,7 @@ def _validate_sample_input_elementwise_binary_sparse_mul(sample):
layout is torch.sparse_csr
and dtype is torch.complex32
and t_inp.numel() > 0
and t_inp._nnz() > 0
and t_args[0].numel() > 0
and t_args[0].ndim > 0
):
@ -619,6 +620,7 @@ def _validate_sample_input_elementwise_binary_sparse_mul(sample):
elif (
layout in {torch.sparse_coo, torch.sparse_csr}
and dtype is torch.bool
and t_inp._nnz() > 0
and t_args[0].ndim > 0
and t_inp.is_cpu
and t_inp.numel() > 0
@ -649,6 +651,7 @@ def _validate_sample_input_elementwise_binary_sparse_mul(sample):
elif (
layout is torch.sparse_csr
and t_inp.dense_dim() > 0
and t_inp._nnz() > 0
and t_inp.is_cpu
and dtype is torch.float16
and t_args[0].ndim > 0
@ -758,6 +761,150 @@ def error_inputs_sparse_mul(op_info, device, layout, **kwargs):
)
def _sample_inputs_sparse_like_fns(
op_info, device, dtype, requires_grad, layout, **kwargs
):
from torch.testing._internal.common_utils import TestCase
for tensor in TestCase().generate_simple_inputs(
layout,
device=device,
dtype=dtype,
enable_batch=True,
enable_hybrid=True,
enable_zero_sized=True,
enable_non_contiguous_indices=False,
enable_non_contiguous_values=False,
):
yield SampleInput(tensor, args=(), kwargs={})
yield SampleInput(
tensor, args=(), kwargs=dict(device=device, dtype=dtype, layout=layout)
)
if dtype is not torch.float64:
yield SampleInput(tensor, args=(), kwargs=dict(dtype=torch.float64))
if torch.cuda.is_available():
other_device = "cuda" if tensor.device.type == "cpu" else "cpu"
yield SampleInput(tensor, args=(), kwargs=dict(device=other_device))
if layout is torch.sparse_csr:
other_layout = torch.sparse_csc
elif layout is torch.sparse_csc:
other_layout = torch.sparse_csr
elif layout is torch.sparse_bsr:
other_layout = torch.sparse_bsc
elif layout is torch.sparse_bsc:
other_layout = torch.sparse_bsr
else:
other_layout = torch.strided
yield SampleInput(tensor, args=(), kwargs=dict(layout=other_layout))
if layout is not torch.sparse_coo:
yield SampleInput(tensor, args=(), kwargs=dict(layout=torch.sparse_coo))
def _validate_sample_input_sparse_like_fns(op_info, sample, check_validate=False):
if sample.input.layout in {
torch.sparse_csr,
torch.sparse_csc,
torch.sparse_bsr,
torch.sparse_bsc,
}:
if sample.kwargs.get("device", sample.input.device) != sample.input.device:
return ErrorInput(
sample,
error_regex=(
"device of (ccol|crow)_indices \\(=(cpu|cuda.*)\\) must"
" match device of values \\(=(cuda.*|cpu)\\)"
),
)
if sample.kwargs.get("layout", sample.input.layout) != sample.input.layout:
return ErrorInput(
sample,
error_regex=(
"empty_like with different sparse layout is not supported"
" \\(self is Sparse(Csc|Csr|Bsc|Bsr) but you requested Sparse(Csr|Csc|Bsr|Bsc)\\)"
),
)
if sample.input.layout is torch.sparse_coo:
return ErrorInput(
sample,
error_regex=(
"Could not run 'aten::normal_' with arguments from the 'Sparse(CPU|CUDA)' backend."
),
)
if check_validate:
_check_validate(op_info, sample)
return sample
def _maybe_failing_sample_inputs_sparse_like_fns(
op_info, device, dtype, requires_grad, layout, **kwargs
):
if torch.cuda.is_available() and layout is not torch.sparse_coo:
other_device = "cuda" if torch.device(device).type == "cpu" else "cpu"
if layout is torch.sparse_csr:
other_layout = torch.sparse_csc
elif layout is torch.sparse_csc:
other_layout = torch.sparse_csr
elif layout is torch.sparse_bsr:
other_layout = torch.sparse_bsc
elif layout is torch.sparse_bsc:
other_layout = torch.sparse_bsr
else:
other_layout = torch.strided
blocksize = (1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None
yield SampleInput(
torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse(
layout=layout, blocksize=blocksize
),
kwargs=dict(device=other_device),
)
yield SampleInput(
torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse(
layout=layout, blocksize=blocksize
),
kwargs=dict(layout=other_layout),
)
def sample_inputs_sparse_like_fns(
op_info, device, dtype, requires_grad, layout, **kwargs
):
"""Sample inputs for like-functions on sparse tensors."""
yield from _sample_inputs_sparse(
_sample_inputs_sparse_like_fns,
_maybe_failing_sample_inputs_sparse_like_fns,
_validate_sample_input_sparse_like_fns,
op_info,
device,
dtype,
requires_grad,
layout,
**kwargs,
)
def error_inputs_sparse_like_fns(op_info, device, layout, **kwargs):
"""Error inputs for like-functions on sparse tensors."""
dtype = torch.float64
requires_grad = False
yield from _error_inputs_sparse(
_maybe_failing_sample_inputs_sparse_like_fns,
_validate_sample_input_sparse_like_fns,
op_info,
device,
dtype,
requires_grad,
layout,
**kwargs,
)
def _validate_sample_input_sparse_default(op_info, sample, check_validate=False):
if op_info.name == "to_sparse":
if (