mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 23:15:01 +08:00
Fix zeros_like for sparse tensors with batch dimensions. Add opinfo-based tests to like-functions. (#101215)
Fixes #101078 Pull Request resolved: https://github.com/pytorch/pytorch/pull/101215 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
597e2a11a3
commit
cbe270d233
@ -101,20 +101,16 @@ void SparseCsrTensorImpl::resize_(int64_t nnz, IntArrayRef size) {
|
||||
refresh_numel();
|
||||
}
|
||||
|
||||
void SparseCsrTensorImpl::resize_and_clear_(int64_t sparse_dim, IntArrayRef size) {
|
||||
void SparseCsrTensorImpl::resize_and_clear_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) {
|
||||
TORCH_CHECK(
|
||||
!has_symbolic_sizes_strides_,
|
||||
"resize_and_clear_ called on tensor with symbolic shape");
|
||||
TORCH_CHECK(sparse_dim >= 2, "resize_and_clear_ sparse dimensionality must be at least 2, got ", sparse_dim);
|
||||
TORCH_CHECK(static_cast<int64_t>(size.size()) >= sparse_dim, "resize_and_clear_ size length must be at least sparse dimensionality (=",
|
||||
sparse_dim, "), got ", size.size());
|
||||
auto batch_dim = sparse_dim - 2;
|
||||
TORCH_CHECK(sparse_dim == 2, "resize_and_clear_ sparse dimensionality must be 2, got ", sparse_dim);
|
||||
TORCH_CHECK(static_cast<int64_t>(size.size()) >= sparse_dim + dense_dim, "resize_and_clear_ size length must be at least sparse dimensionality (=",
|
||||
sparse_dim, ") plus dense dimensionality (=", dense_dim, "), got ", size.size());
|
||||
auto batch_dim = size.size() - sparse_dim - dense_dim;
|
||||
auto batchsize = size.slice(0, batch_dim);
|
||||
auto densesize = size.slice(batch_dim + 2, size.size() - batch_dim - 2);
|
||||
|
||||
auto values_size = DimVector(batchsize);
|
||||
values_size.push_back(0); // nse
|
||||
values_size.append(densesize.begin(), densesize.end());
|
||||
auto densesize = size.slice(batch_dim + sparse_dim, dense_dim);
|
||||
|
||||
auto col_indices_size = DimVector(batchsize);
|
||||
col_indices_size.push_back(0); // nse
|
||||
@ -123,14 +119,26 @@ void SparseCsrTensorImpl::resize_and_clear_(int64_t sparse_dim, IntArrayRef size
|
||||
[&] () -> int64_t { return size[batch_dim]; },
|
||||
[&] () -> int64_t { return size[batch_dim + 1]; }
|
||||
);
|
||||
auto values_size = DimVector(batchsize);
|
||||
values_size.push_back(0); // nse
|
||||
// WARNING: in the case of block tensors, the block size is defined
|
||||
// by the existing values shape.
|
||||
int64_t block_factor = 1;
|
||||
AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout_,
|
||||
"resize_and_clear_",
|
||||
[] () {},
|
||||
[&] () {
|
||||
auto blocksize = this->values_.sizes().slice(this->batch_dim() + 1, 2);
|
||||
values_size.append(blocksize.begin(), blocksize.end());
|
||||
n_compressed_indices /= blocksize[(the_layout == kSparseBsr ? 0 : 1)];
|
||||
block_factor = blocksize[(the_layout == kSparseBsr ? 0 : 1)];
|
||||
|
||||
});
|
||||
TORCH_CHECK(n_compressed_indices % block_factor == 0,
|
||||
"The size of the compressed dimension (=", n_compressed_indices,
|
||||
") must be divisible with the corresponding block size (=", block_factor,")");
|
||||
n_compressed_indices /= block_factor;
|
||||
values_size.append(densesize.begin(), densesize.end());
|
||||
|
||||
auto crow_indices_size = DimVector(batchsize);
|
||||
crow_indices_size.push_back(n_compressed_indices + 1);
|
||||
|
||||
|
||||
@ -37,7 +37,10 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
|
||||
const caffe2::TypeMeta);
|
||||
|
||||
void resize_(int64_t nnz, IntArrayRef size);
|
||||
void resize_and_clear_(int64_t sparse_dim, IntArrayRef size);
|
||||
void resize_and_clear_(
|
||||
int64_t sparse_dim,
|
||||
int64_t dense_dim,
|
||||
IntArrayRef size);
|
||||
void resize_as_sparse_compressed_tensor_(const Tensor& src);
|
||||
void set_member_tensors(
|
||||
const Tensor& crow_indices,
|
||||
|
||||
@ -408,7 +408,7 @@ Tensor& zero_sparse_csr_(Tensor& self) {
|
||||
`result = csr.clone(); result.values.zero_();`
|
||||
*/
|
||||
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "zero_sparse_csr_", [](){});
|
||||
get_sparse_csr_impl(self)->resize_and_clear_(self.sparse_dim(), self.sizes());
|
||||
get_sparse_csr_impl(self)->resize_and_clear_(self.sparse_dim(), self.dense_dim(), self.sizes());
|
||||
return self;
|
||||
}
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride,
|
||||
deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes)
|
||||
from torch.testing._internal.common_methods_invocations import \
|
||||
(reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs)
|
||||
(op_db, reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs)
|
||||
from torch.testing._internal.common_dtype import (
|
||||
all_types, all_types_and_complex, all_types_and_complex_and, floating_and_complex_types,
|
||||
floating_and_complex_types_and, integral_types, floating_types_and,
|
||||
@ -39,6 +39,8 @@ reduction_ops_with_sparse_support = [op for op in reduction_ops if 'masked.' not
|
||||
|
||||
binary_ufuncs_with_sparse_support = [op for op in binary_ufuncs if _op_supports_any_sparse(op)]
|
||||
|
||||
like_fns_with_sparse_support = [op for op in op_db if _op_supports_any_sparse(op) and '_like' in op.name]
|
||||
|
||||
if TEST_SCIPY:
|
||||
import scipy.sparse
|
||||
|
||||
@ -4858,6 +4860,62 @@ class TestSparseAny(TestCase):
|
||||
run_test(m, n, k, device, dtype)
|
||||
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@suppress_warnings
|
||||
@ops(like_fns_with_sparse_support)
|
||||
@all_sparse_layouts('layout', include_strided=False)
|
||||
def test_like_fns(self, layout, device, dtype, op):
|
||||
|
||||
for sample in op.sample_inputs_sparse(layout, device, dtype):
|
||||
t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs
|
||||
batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim()
|
||||
if t_inp.layout in {torch.sparse_bsr, torch.sparse_bsc}:
|
||||
expected_blocksize = t_inp.values().shape[batch_dim + 1:batch_dim + 3]
|
||||
else:
|
||||
expected_blocksize = None
|
||||
expected_dtype = t_kwargs.get('dtype', dtype)
|
||||
expected_device = torch.device(t_kwargs.get('device', device))
|
||||
expected_layout = t_kwargs.get('layout', layout)
|
||||
|
||||
result = op.op(t_inp, *t_args, **t_kwargs)
|
||||
|
||||
self.assertEqual(result.dtype, expected_dtype)
|
||||
self.assertEqual(result.device.type, expected_device.type)
|
||||
self.assertEqual(result.layout, expected_layout)
|
||||
|
||||
if result.layout in {torch.sparse_bsr, torch.sparse_bsc}:
|
||||
result_batch_dim = result.dim() - result.dense_dim() - result.sparse_dim()
|
||||
blocksize = result.values().shape[result_batch_dim + 1:result_batch_dim + 3]
|
||||
self.assertEqual(blocksize, expected_blocksize)
|
||||
|
||||
# Check op(inp).shape == inp.shape
|
||||
self.assertEqual(result.shape, t_inp.shape)
|
||||
|
||||
if expected_layout is torch.strided:
|
||||
self.assertEqual(result.sparse_dim(), 0)
|
||||
# Check op(inp, layout=torch.strided).dense_dim() == inp.dim()
|
||||
self.assertEqual(result.dense_dim(), t_inp.dim())
|
||||
elif expected_layout is torch.sparse_coo:
|
||||
# Check op(inp, layout=torch.sparse_coo).sparse_dim() == batch_dim + inp.sparse_dim()
|
||||
self.assertEqual(result.sparse_dim(), batch_dim + t_inp.sparse_dim())
|
||||
# Check op(inp, layout=torch.sparse_coo).dense_dim() == inp.dense_dim()
|
||||
self.assertEqual(result.dense_dim(), t_inp.dense_dim())
|
||||
|
||||
torch._validate_sparse_coo_tensor_args(result._indices(), result._values(), result.shape)
|
||||
else:
|
||||
# Check op(inp).sparse_dim() == inp.sparse_dim()
|
||||
self.assertEqual(result.sparse_dim(), t_inp.sparse_dim())
|
||||
# Check op(inp).dense_dim() == inp.dense_dim()
|
||||
self.assertEqual(result.dense_dim(), t_inp.dense_dim())
|
||||
|
||||
if result.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
||||
compressed_indices, plain_indices = result.crow_indices(), result.col_indices()
|
||||
else:
|
||||
compressed_indices, plain_indices = result.ccol_indices(), result.row_indices()
|
||||
|
||||
torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, result.values(),
|
||||
result.shape, result.layout)
|
||||
|
||||
# e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA
|
||||
instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta')
|
||||
|
||||
|
||||
@ -515,6 +515,14 @@ class TestSparseCompressed(TestCase):
|
||||
if len(samples) == 0:
|
||||
raise ValueError("Expected at least one 2 or higher D tensor in samples.")
|
||||
|
||||
# Re-define atol and rtol for operations that result values
|
||||
# are random (and hence, non-comparable) be we still want to
|
||||
# check the shape, dtype, etc attributes of the results:
|
||||
atol = rtol = None
|
||||
if op.name == 'randn_like':
|
||||
atol = 1e300
|
||||
rtol = 1
|
||||
|
||||
for sample, sparse_sample in samples:
|
||||
expected = op(sample.input, *sample.args, **sample.kwargs)
|
||||
assert torch.is_tensor(expected)
|
||||
@ -524,7 +532,7 @@ class TestSparseCompressed(TestCase):
|
||||
if require_mask and sample.kwargs.get('mask') is not None:
|
||||
output_mask = torch.masked._output_mask(op.op, sample.input, *sample.args, **sample.kwargs)
|
||||
expected.masked_fill_(~output_mask, 0)
|
||||
self.assertEqual(strided_output, expected)
|
||||
self.assertEqual(strided_output, expected, atol=atol, rtol=rtol)
|
||||
|
||||
@skipMeta
|
||||
@all_sparse_compressed_layouts()
|
||||
|
||||
@ -134,6 +134,8 @@ from torch.testing._internal.opinfo.definitions._masked import (
|
||||
sample_inputs_softmax_variant,
|
||||
)
|
||||
from torch.testing._internal.opinfo.definitions.sparse import (
|
||||
error_inputs_sparse_like_fns,
|
||||
sample_inputs_sparse_like_fns,
|
||||
error_inputs_sparse_mul,
|
||||
sample_inputs_sparse_mul,
|
||||
error_inputs_sparse_reduction_sum,
|
||||
@ -15690,6 +15692,12 @@ op_db: List[OpInfo] = [
|
||||
supports_out=False,
|
||||
sample_inputs_func=sample_inputs_like_fns,
|
||||
supports_autograd=False,
|
||||
error_inputs_sparse_func=error_inputs_sparse_like_fns,
|
||||
sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_coo),
|
||||
sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csr),
|
||||
sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csc),
|
||||
sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsr),
|
||||
sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsc),
|
||||
skips=(
|
||||
)),
|
||||
OpInfo('ones_like',
|
||||
@ -15732,7 +15740,12 @@ op_db: List[OpInfo] = [
|
||||
supports_out=False,
|
||||
sample_inputs_func=sample_inputs_like_fns,
|
||||
supports_autograd=False,
|
||||
supports_sparse_csr=True,
|
||||
error_inputs_sparse_func=error_inputs_sparse_like_fns,
|
||||
sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_coo),
|
||||
sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csr),
|
||||
sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csc),
|
||||
sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsr),
|
||||
sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsc),
|
||||
skips=(
|
||||
DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
|
||||
# AssertionError: JIT Test does not execute any logic
|
||||
|
||||
@ -973,18 +973,28 @@ class OpInfo:
|
||||
# corresponding layout support implies the layout support:
|
||||
if self.supports_sparse is None:
|
||||
self.supports_sparse = self.sample_inputs_sparse_coo_func is not None
|
||||
if self.sample_inputs_sparse_coo_func is None:
|
||||
self.sample_inputs_sparse_coo_func = self._sample_inputs_unspecified
|
||||
|
||||
if self.supports_sparse_csr is None:
|
||||
self.supports_sparse_csr = self.sample_inputs_sparse_csr_func is not None
|
||||
if self.sample_inputs_sparse_csr_func is None:
|
||||
self.sample_inputs_sparse_csr_func = self._sample_inputs_unspecified
|
||||
|
||||
if self.supports_sparse_csc is None:
|
||||
self.supports_sparse_csc = self.sample_inputs_sparse_csc_func is not None
|
||||
if self.sample_inputs_sparse_csc_func is None:
|
||||
self.sample_inputs_sparse_csc_func = self._sample_inputs_unspecified
|
||||
|
||||
if self.supports_sparse_bsr is None:
|
||||
self.supports_sparse_bsr = self.sample_inputs_sparse_bsr_func is not None
|
||||
if self.sample_inputs_sparse_bsr_func is None:
|
||||
self.sample_inputs_sparse_bsr_func = self._sample_inputs_unspecified
|
||||
|
||||
if self.supports_sparse_bsc is None:
|
||||
self.supports_sparse_bsc = self.sample_inputs_sparse_bsc_func is not None
|
||||
if self.sample_inputs_sparse_bsc_func is None:
|
||||
self.sample_inputs_sparse_bsc_func = self._sample_inputs_unspecified
|
||||
|
||||
# We run the sampling functions without tracking the gradiends of the creation of inputs
|
||||
self.sample_inputs_func = torch.no_grad()(self.sample_inputs_func)
|
||||
@ -1228,6 +1238,21 @@ class OpInfo:
|
||||
sample_inputs_mth(device, dtype, requires_grad=requires_grad, **kwargs),
|
||||
)
|
||||
|
||||
def _sample_inputs_unspecified(self, *args, **kwargs):
|
||||
"""Raises an NotImplemented exception in a OpInfo instance creation
|
||||
that specifies supports_sparse(|_csr|_csc|_bsr|_bsc)=True
|
||||
without specifying the corresponding sample function as
|
||||
sample_inputs_sparse_(coo|csr|csc|bsr|bsc)_func.
|
||||
|
||||
To avoid this, either define the corresponding sample function,
|
||||
or re-map unsupported samples to error inputs in an appropiate
|
||||
|
||||
opinfo/definitions/sparse.py:_validate_sample_input_sparse_<op>
|
||||
|
||||
function.
|
||||
"""
|
||||
raise NotImplementedError("no sample function specified")
|
||||
|
||||
def sample_inputs_sparse_coo(self, device, dtype, requires_grad=False, **kwargs):
|
||||
"""Returns an iterable of SampleInputs that contain inputs with sparse
|
||||
coo layout.
|
||||
|
||||
@ -598,6 +598,7 @@ def _validate_sample_input_elementwise_binary_sparse_mul(sample):
|
||||
layout is torch.sparse_csr
|
||||
and dtype is torch.complex32
|
||||
and t_inp.numel() > 0
|
||||
and t_inp._nnz() > 0
|
||||
and t_args[0].numel() > 0
|
||||
and t_args[0].ndim > 0
|
||||
):
|
||||
@ -619,6 +620,7 @@ def _validate_sample_input_elementwise_binary_sparse_mul(sample):
|
||||
elif (
|
||||
layout in {torch.sparse_coo, torch.sparse_csr}
|
||||
and dtype is torch.bool
|
||||
and t_inp._nnz() > 0
|
||||
and t_args[0].ndim > 0
|
||||
and t_inp.is_cpu
|
||||
and t_inp.numel() > 0
|
||||
@ -649,6 +651,7 @@ def _validate_sample_input_elementwise_binary_sparse_mul(sample):
|
||||
elif (
|
||||
layout is torch.sparse_csr
|
||||
and t_inp.dense_dim() > 0
|
||||
and t_inp._nnz() > 0
|
||||
and t_inp.is_cpu
|
||||
and dtype is torch.float16
|
||||
and t_args[0].ndim > 0
|
||||
@ -758,6 +761,150 @@ def error_inputs_sparse_mul(op_info, device, layout, **kwargs):
|
||||
)
|
||||
|
||||
|
||||
def _sample_inputs_sparse_like_fns(
|
||||
op_info, device, dtype, requires_grad, layout, **kwargs
|
||||
):
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
for tensor in TestCase().generate_simple_inputs(
|
||||
layout,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
enable_batch=True,
|
||||
enable_hybrid=True,
|
||||
enable_zero_sized=True,
|
||||
enable_non_contiguous_indices=False,
|
||||
enable_non_contiguous_values=False,
|
||||
):
|
||||
yield SampleInput(tensor, args=(), kwargs={})
|
||||
yield SampleInput(
|
||||
tensor, args=(), kwargs=dict(device=device, dtype=dtype, layout=layout)
|
||||
)
|
||||
|
||||
if dtype is not torch.float64:
|
||||
yield SampleInput(tensor, args=(), kwargs=dict(dtype=torch.float64))
|
||||
|
||||
if torch.cuda.is_available():
|
||||
other_device = "cuda" if tensor.device.type == "cpu" else "cpu"
|
||||
yield SampleInput(tensor, args=(), kwargs=dict(device=other_device))
|
||||
|
||||
if layout is torch.sparse_csr:
|
||||
other_layout = torch.sparse_csc
|
||||
elif layout is torch.sparse_csc:
|
||||
other_layout = torch.sparse_csr
|
||||
elif layout is torch.sparse_bsr:
|
||||
other_layout = torch.sparse_bsc
|
||||
elif layout is torch.sparse_bsc:
|
||||
other_layout = torch.sparse_bsr
|
||||
else:
|
||||
other_layout = torch.strided
|
||||
yield SampleInput(tensor, args=(), kwargs=dict(layout=other_layout))
|
||||
|
||||
if layout is not torch.sparse_coo:
|
||||
yield SampleInput(tensor, args=(), kwargs=dict(layout=torch.sparse_coo))
|
||||
|
||||
|
||||
def _validate_sample_input_sparse_like_fns(op_info, sample, check_validate=False):
|
||||
if sample.input.layout in {
|
||||
torch.sparse_csr,
|
||||
torch.sparse_csc,
|
||||
torch.sparse_bsr,
|
||||
torch.sparse_bsc,
|
||||
}:
|
||||
if sample.kwargs.get("device", sample.input.device) != sample.input.device:
|
||||
return ErrorInput(
|
||||
sample,
|
||||
error_regex=(
|
||||
"device of (ccol|crow)_indices \\(=(cpu|cuda.*)\\) must"
|
||||
" match device of values \\(=(cuda.*|cpu)\\)"
|
||||
),
|
||||
)
|
||||
if sample.kwargs.get("layout", sample.input.layout) != sample.input.layout:
|
||||
return ErrorInput(
|
||||
sample,
|
||||
error_regex=(
|
||||
"empty_like with different sparse layout is not supported"
|
||||
" \\(self is Sparse(Csc|Csr|Bsc|Bsr) but you requested Sparse(Csr|Csc|Bsr|Bsc)\\)"
|
||||
),
|
||||
)
|
||||
if sample.input.layout is torch.sparse_coo:
|
||||
return ErrorInput(
|
||||
sample,
|
||||
error_regex=(
|
||||
"Could not run 'aten::normal_' with arguments from the 'Sparse(CPU|CUDA)' backend."
|
||||
),
|
||||
)
|
||||
if check_validate:
|
||||
_check_validate(op_info, sample)
|
||||
return sample
|
||||
|
||||
|
||||
def _maybe_failing_sample_inputs_sparse_like_fns(
|
||||
op_info, device, dtype, requires_grad, layout, **kwargs
|
||||
):
|
||||
if torch.cuda.is_available() and layout is not torch.sparse_coo:
|
||||
other_device = "cuda" if torch.device(device).type == "cpu" else "cpu"
|
||||
if layout is torch.sparse_csr:
|
||||
other_layout = torch.sparse_csc
|
||||
elif layout is torch.sparse_csc:
|
||||
other_layout = torch.sparse_csr
|
||||
elif layout is torch.sparse_bsr:
|
||||
other_layout = torch.sparse_bsc
|
||||
elif layout is torch.sparse_bsc:
|
||||
other_layout = torch.sparse_bsr
|
||||
else:
|
||||
other_layout = torch.strided
|
||||
|
||||
blocksize = (1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None
|
||||
|
||||
yield SampleInput(
|
||||
torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse(
|
||||
layout=layout, blocksize=blocksize
|
||||
),
|
||||
kwargs=dict(device=other_device),
|
||||
)
|
||||
|
||||
yield SampleInput(
|
||||
torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse(
|
||||
layout=layout, blocksize=blocksize
|
||||
),
|
||||
kwargs=dict(layout=other_layout),
|
||||
)
|
||||
|
||||
|
||||
def sample_inputs_sparse_like_fns(
|
||||
op_info, device, dtype, requires_grad, layout, **kwargs
|
||||
):
|
||||
"""Sample inputs for like-functions on sparse tensors."""
|
||||
yield from _sample_inputs_sparse(
|
||||
_sample_inputs_sparse_like_fns,
|
||||
_maybe_failing_sample_inputs_sparse_like_fns,
|
||||
_validate_sample_input_sparse_like_fns,
|
||||
op_info,
|
||||
device,
|
||||
dtype,
|
||||
requires_grad,
|
||||
layout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def error_inputs_sparse_like_fns(op_info, device, layout, **kwargs):
|
||||
"""Error inputs for like-functions on sparse tensors."""
|
||||
dtype = torch.float64
|
||||
requires_grad = False
|
||||
yield from _error_inputs_sparse(
|
||||
_maybe_failing_sample_inputs_sparse_like_fns,
|
||||
_validate_sample_input_sparse_like_fns,
|
||||
op_info,
|
||||
device,
|
||||
dtype,
|
||||
requires_grad,
|
||||
layout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _validate_sample_input_sparse_default(op_info, sample, check_validate=False):
|
||||
if op_info.name == "to_sparse":
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user