Add masked_grad kw argument to to_dense (#96095)

As in the title.

The `masked_grad` kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked_grad=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked_grad==True` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`.

As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked_grad=False)`. For example,
```python
torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense())
torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense())
```
(recall, gradcheck has `masked=False` as default) must be updated to
```python
torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked_grad=False))
torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked_grad=True), masked=True)
```

Fixes https://github.com/pytorch/pytorch/issues/95550

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96095
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Pearu Peterson
2023-03-16 18:45:55 +02:00
committed by PyTorch MergeBot
parent 9d80969fa4
commit 2abcafcfd8
10 changed files with 79 additions and 65 deletions

View File

@ -498,23 +498,52 @@ std::vector<Tensor> _to_cpu(TensorList tensors) {
return cpu_tensors; return cpu_tensors;
} }
Tensor to_dense_backward(const Tensor& grad, const Tensor& input_) { Tensor to_dense_backward(const Tensor& grad, const Tensor& input_, c10::optional<bool> masked_grad_) {
/*
For historical reasons, to_dense backward implements masked
semantics for sparse tensors, that is, gradients with respect to
unspecified elements are ignored. The masked_grad kw argument of
to_dense is introduced to allow to_dense to be used in the
non-masked semantics context. However, for BC reasons, the default
value to masked_grad kw argument is set True as a first instance.
Eventually, we should eliminate the masked_grad kw argument and
let to_dense backward to behave according to non-masked
semantics. Masked semantics of tensors is implemented in the
framework of masked tensors.
*/
const auto input_layout = input_.layout(); const auto input_layout = input_.layout();
const bool masked_grad = masked_grad_.value_or(true);
switch (input_layout) { switch (input_layout) {
case kStrided: case kStrided:
return grad.to_dense(); // TODO: return grad as it is
return grad.to_dense(input_.scalar_type(), masked_grad_);
case kSparse: case kSparse:
// Autograd operates on the coalesced assumption, i.e. no duplicate values. // Autograd operates on the coalesced assumption, i.e. no duplicate values.
return grad.sparse_mask(input_.coalesce()); if (masked_grad) {
return grad.sparse_mask(input_.coalesce());
} else {
// TODO: return grad as it is
return grad.to_sparse(input_.sparse_dim());
}
case kSparseCsr: case kSparseCsr:
case kSparseCsc: case kSparseCsc:
// TODO: add efficient CSR/CSC support for sparse_mask // TODO: add efficient CSR/CSC support for sparse_mask
return grad.sparse_mask(input_.to_sparse()).to_sparse(input_layout); if (masked_grad) {
return grad.sparse_mask(input_.to_sparse(input_.sparse_dim())).to_sparse(input_layout);
} else {
// TODO: return grad as it is
return grad.to_sparse(input_layout, /*blocksize=*/c10::nullopt, /*dense_dim=*/input_.dense_dim());
}
case kSparseBsr: case kSparseBsr:
case kSparseBsc: { case kSparseBsc: {
// TODO: add efficient BSR/BSC support for sparse_mask // TODO: add efficient BSR/BSC support for sparse_mask
const auto blocksize = at::DimVector(input_.values().sizes().slice(1, 2)); const auto blocksize = at::sparse_csr::getBlockSize(input_);
return grad.sparse_mask(input_.to_sparse()).to_sparse(input_layout, blocksize); if (masked_grad) {
return grad.sparse_mask(input_.to_sparse(input_.sparse_dim())).to_sparse(input_layout, blocksize);
} else {
// TODO: return grad as it is
return grad.to_sparse(input_layout, blocksize, input_.dense_dim());
}
} }
case kMkldnn: case kMkldnn:
return grad.to_mkldnn(input_.scalar_type()); return grad.to_mkldnn(input_.scalar_type());
@ -529,18 +558,18 @@ Tensor to_mkldnn_backward(const Tensor& grad, const Tensor& input_) {
return grad.to_dense(input_.scalar_type()); return grad.to_dense(input_.scalar_type());
} }
Tensor to_dense(const Tensor& tensor, c10::optional<c10::ScalarType> dtype) { Tensor to_dense(const Tensor& tensor, c10::optional<c10::ScalarType> dtype, c10::optional<bool> masked_grad) {
if (tensor.layout() == c10::kSparse) { if (tensor.layout() == c10::kSparse) {
return tensor._to_dense(dtype); return tensor._to_dense(dtype, masked_grad);
} }
if (tensor.layout() == c10::kSparseCsr || if (tensor.layout() == c10::kSparseCsr ||
tensor.layout() == c10::kSparseCsc || tensor.layout() == c10::kSparseCsc ||
tensor.layout() == c10::kSparseBsr || tensor.layout() == c10::kSparseBsr ||
tensor.layout() == c10::kSparseBsc) { tensor.layout() == c10::kSparseBsc) {
return tensor._to_dense(dtype); return tensor._to_dense(dtype, masked_grad);
} }
if (tensor.layout() == c10::kMkldnn) { if (tensor.layout() == c10::kMkldnn) {
return tensor._to_dense(dtype); return tensor._to_dense(dtype, masked_grad);
} }
TORCH_CHECK( TORCH_CHECK(
tensor.layout() == c10::kStrided, tensor.layout() == c10::kStrided,
@ -552,7 +581,7 @@ Tensor to_dense(const Tensor& tensor, c10::optional<c10::ScalarType> dtype) {
return tensor; return tensor;
} }
Tensor sparse_to_dense(const Tensor& self, c10::optional<ScalarType> dtype) { Tensor sparse_to_dense(const Tensor& self, c10::optional<ScalarType> dtype, c10::optional<bool> masked) {
TORCH_CHECK( TORCH_CHECK(
!dtype.has_value(), "dtype argument is not supported by sparse_to_dense"); !dtype.has_value(), "dtype argument is not supported by sparse_to_dense");
Tensor dst = at::zeros(self.sizes(), self.options().layout(kStrided)); Tensor dst = at::zeros(self.sizes(), self.options().layout(kStrided));
@ -561,7 +590,8 @@ Tensor sparse_to_dense(const Tensor& self, c10::optional<ScalarType> dtype) {
Tensor sparse_compressed_to_dense( Tensor sparse_compressed_to_dense(
const Tensor& self, const Tensor& self,
c10::optional<ScalarType> dtype) { c10::optional<ScalarType> dtype,
c10::optional<bool> masked_grad) {
TORCH_CHECK( TORCH_CHECK(
!dtype.has_value(), !dtype.has_value(),
"dtype argument is not supported by sparse_csr_to_dense"); "dtype argument is not supported by sparse_csr_to_dense");
@ -1756,7 +1786,7 @@ Tensor sparse_compressed_to_sparse(const Tensor& self, c10::optional<c10::Layout
} }
switch (layout_) { switch (layout_) {
case kStrided: case kStrided:
return sparse_compressed_to_dense(self); return sparse_compressed_to_dense(self, /*dtype=*/c10::nullopt, /*masked_grad=*/c10::nullopt);
case kSparse: case kSparse:
return sparse_compressed_to_sparse(self, 2); return sparse_compressed_to_sparse(self, 2);
case kSparseCsr: case kSparseCsr:
@ -1801,7 +1831,7 @@ Tensor sparse_coo_to_sparse(const Tensor& self, c10::optional<c10::Layout> layou
} }
switch (layout_) { switch (layout_) {
case kStrided: case kStrided:
return self.to_dense(); return self.to_dense(c10::nullopt, c10::nullopt);
case kSparse: case kSparse:
return self; return self;
case kSparseCsr: case kSparseCsr:

View File

@ -23,7 +23,7 @@ namespace at { namespace native {
#if AT_MKLDNN_ENABLED() #if AT_MKLDNN_ENABLED()
Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional<ScalarType> dtype) { Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional<ScalarType> dtype, c10::optional<bool> masked_grad) {
TORCH_CHECK(mkldnn_tensor.scalar_type() == ScalarType::Float || TORCH_CHECK(mkldnn_tensor.scalar_type() == ScalarType::Float ||
mkldnn_tensor.scalar_type() == ScalarType::BFloat16, mkldnn_tensor.scalar_type() == ScalarType::BFloat16,
"mkldnn_to_dense expects float or bfloat16 tensor input"); "mkldnn_to_dense expects float or bfloat16 tensor input");
@ -269,7 +269,7 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
#else #else
Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional<ScalarType> dtype) { Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional<ScalarType> dtype, c10::optional<bool> masked_grad) {
TORCH_CHECK(false, "MKL-DNN build is disabled"); TORCH_CHECK(false, "MKL-DNN build is disabled");
} }
@ -342,4 +342,5 @@ TORCH_LIBRARY_IMPL(mkl, MkldnnCPU, m) {
} }
#endif // AT_MKL_ENABLED && AT_MKLDNN_ENABLED #endif // AT_MKL_ENABLED && AT_MKLDNN_ENABLED
}} }}

View File

@ -6702,11 +6702,11 @@
- func: _to_cpu(Tensor[] tensors) -> Tensor[] - func: _to_cpu(Tensor[] tensors) -> Tensor[]
variants: function variants: function
- func: to_dense(Tensor self, ScalarType? dtype=None) -> Tensor - func: to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor
variants: method variants: method
# Special case of to_dense with custom derivative # Special case of to_dense with custom derivative
- func: _to_dense(Tensor self, ScalarType? dtype=None) -> Tensor - func: _to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor
variants: method variants: method
dispatch: dispatch:
SparseCPU, SparseCUDA: sparse_to_dense SparseCPU, SparseCUDA: sparse_to_dense
@ -6714,7 +6714,7 @@
MkldnnCPU: mkldnn_to_dense MkldnnCPU: mkldnn_to_dense
autogen: _to_dense.out autogen: _to_dense.out
- func: to_dense_backward(Tensor grad, Tensor input) -> Tensor - func: to_dense_backward(Tensor grad, Tensor input, bool? masked_grad=None) -> Tensor
- func: sparse_dim(Tensor self) -> int - func: sparse_dim(Tensor self) -> int
variants: method variants: method

View File

@ -524,7 +524,8 @@ const SparseTensor& resize_as_sparse_(const SparseTensor& self, const SparseTens
SparseTensor dense_to_sparse(const Tensor& self, c10::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, c10::optional<int64_t> dense_dim_opt) { SparseTensor dense_to_sparse(const Tensor& self, c10::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, c10::optional<int64_t> dense_dim_opt) {
if (layout.has_value()) { if (layout.has_value()) {
if (blocksize.has_value() && !(*layout == kSparseBsr || *layout == kSparseBsc)) { if (blocksize.has_value() && !(*layout == kSparseBsr || *layout == kSparseBsc)) {
AT_ERROR("to_sparse for ", self.layout(), " to ", *layout, " conversion does not use specified blocksize"); AT_ERROR("to_sparse for ", self.layout(), " to ", *layout,
" conversion does not use the specified blocksize ", blocksize.value(), ".");
} }
if (self.layout() == *layout) { if (self.layout() == *layout) {
return self; return self;

View File

@ -10,7 +10,7 @@ from torch.testing import make_tensor
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \ from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \ load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
DeterministicGuard, first_sample, TEST_WITH_CROSSREF, TEST_WITH_ROCM, skipIfTorchDynamo, \ DeterministicGuard, first_sample, TEST_WITH_CROSSREF, TEST_WITH_ROCM, skipIfTorchDynamo, \
parametrize, subtest, is_coalesced_indices, suppress_warnings, is_slow_gradcheck_env parametrize, subtest, is_coalesced_indices, suppress_warnings
from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
from numbers import Number from numbers import Number
from typing import Dict, Any from typing import Dict, Any
@ -413,8 +413,6 @@ class TestSparse(TestSparseBase):
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error") @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
@gradcheck_semantics() @gradcheck_semantics()
def test_to_dense_with_gradcheck(self, device, dtype, gradcheck): def test_to_dense_with_gradcheck(self, device, dtype, gradcheck):
if not gradcheck.masked and is_slow_gradcheck_env():
self.skipTest('FIXME: to_dense_backward supports masked semantics only')
def test_tensor(x, res): def test_tensor(x, res):
x.to_dense() # Tests triple to_dense for memory corruption x.to_dense() # Tests triple to_dense for memory corruption
@ -432,7 +430,7 @@ class TestSparse(TestSparseBase):
return return
def fn(x): def fn(x):
return x.to_dense() return x.to_dense(masked_grad=gradcheck.masked)
x.requires_grad_(True) x.requires_grad_(True)
gradcheck(fn, (x,), check_sparse_nnz=True) gradcheck(fn, (x,), check_sparse_nnz=True)
@ -550,8 +548,6 @@ class TestSparse(TestSparseBase):
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error") @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
@gradcheck_semantics() @gradcheck_semantics()
def test_to_dense_hybrid(self, device, dtype, gradcheck): def test_to_dense_hybrid(self, device, dtype, gradcheck):
if not gradcheck.masked and is_slow_gradcheck_env():
self.skipTest('FIXME: to_dense_backward supports masked semantics only')
def test_tensor(x, res): def test_tensor(x, res):
x.to_dense() # Tests double to_dense for memory corruption x.to_dense() # Tests double to_dense for memory corruption
@ -561,7 +557,7 @@ class TestSparse(TestSparseBase):
self.assertEqual(res, self.safeToDense(x)) self.assertEqual(res, self.safeToDense(x))
def fn(x): def fn(x):
return x.to_dense() return x.to_dense(masked_grad=gradcheck.masked)
x.requires_grad_(True) x.requires_grad_(True)
gradcheck(fn, (x,), check_sparse_nnz=True) gradcheck(fn, (x,), check_sparse_nnz=True)
@ -908,8 +904,6 @@ class TestSparse(TestSparseBase):
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error") @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
@gradcheck_semantics() @gradcheck_semantics()
def test_permute(self, device, dtype, coalesced, gradcheck): def test_permute(self, device, dtype, coalesced, gradcheck):
if not gradcheck.masked and is_slow_gradcheck_env():
self.skipTest('FIXME: to_dense_backward supports masked semantics only')
# trivial checks # trivial checks
s = torch.rand(3, 3, 3, device=device, dtype=dtype).to_sparse() s = torch.rand(3, 3, 3, device=device, dtype=dtype).to_sparse()
with self.assertRaisesRegex(RuntimeError, "does not match the length"): with self.assertRaisesRegex(RuntimeError, "does not match the length"):
@ -941,7 +935,7 @@ class TestSparse(TestSparseBase):
else: else:
self.assertFalse(s_permuted.is_coalesced()) self.assertFalse(s_permuted.is_coalesced())
gradcheck(lambda t: t.permute(dims).to_dense(), s.requires_grad_(True), check_sparse_nnz=True) gradcheck(lambda t: t.permute(dims).to_dense(masked_grad=gradcheck.masked), s.requires_grad_())
else: else:
# otherwise check if exception is thrown # otherwise check if exception is thrown
fail_message = "transpositions between sparse and dense dimensions are not allowed" fail_message = "transpositions between sparse and dense dimensions are not allowed"
@ -1778,10 +1772,7 @@ class TestSparse(TestSparseBase):
self.assertEqual(S_sum.item(), D_sum.item()) self.assertEqual(S_sum.item(), D_sum.item())
def fn(S): def fn(S):
res = torch.sparse.sum(S) return torch.sparse.sum(S)
if res.is_sparse:
res = res.to_dense()
return res
gradcheck(fn, (S,), check_sparse_nnz=True, masked=True) gradcheck(fn, (S,), check_sparse_nnz=True, masked=True)
else: else:
S_sum = torch.sparse.sum(S, td) S_sum = torch.sparse.sum(S, td)
@ -1790,9 +1781,7 @@ class TestSparse(TestSparseBase):
def fn(S): def fn(S):
res = torch.sparse.sum(S, td) res = torch.sparse.sum(S, td)
if res.is_sparse: return res.to_dense(masked_grad=True)
res = res.to_dense()
return res
gradcheck(fn, (S,), check_sparse_nnz=True, masked=True) gradcheck(fn, (S,), check_sparse_nnz=True, masked=True)
nnz = 10 nnz = 10
@ -4012,7 +4001,7 @@ class TestSparseOneOff(TestCase):
def _sparse_to_dense(tensor): def _sparse_to_dense(tensor):
if tensor.dtype != torch.bool: if tensor.dtype != torch.bool:
return tensor.to_dense() return tensor.to_dense(masked_grad=True)
# to_dense uses coalesce which isn't implemented for bool # to_dense uses coalesce which isn't implemented for bool
return tensor.to(torch.int8).to_dense().to(torch.bool) return tensor.to(torch.int8).to_dense().to(torch.bool)
@ -4423,22 +4412,9 @@ class TestSparseAny(TestCase):
# TODO: implement batch support in _convert_indices_from_csr_to_coo # TODO: implement batch support in _convert_indices_from_csr_to_coo
continue continue
t = t.clone().detach().requires_grad_(True) t = t.clone().detach().requires_grad_(True)
if is_slow_gradcheck_env() and not gradcheck.masked: r = gradcheck(lambda x: torch.Tensor.to_dense(x, masked_grad=gradcheck.masked), t)
# TODO: remove this if-block when TODO items below are resolved
try:
gradcheck(torch.Tensor.to_dense, t)
except RuntimeError as msg:
# TODO: implement non-masked semantics support in to_dense_backward
with self.assertRaisesRegex(RuntimeError, "Jacobian mismatch"):
gradcheck(torch.Tensor.to_dense, t)
self.skipTest('non-masked semantics not supported')
r = gradcheck(torch.Tensor.to_dense, t)
self.assertTrue(r) self.assertTrue(r)
# when the following assert fails, it means that the if-block
# above and the assertFalse test below can be safely removed
self.assertFalse(is_slow_gradcheck_env() and not gradcheck.masked)
@all_sparse_layouts('from_layout', include_strided=True) @all_sparse_layouts('from_layout', include_strided=True)
@all_sparse_layouts('to_layout', include_strided=False) @all_sparse_layouts('to_layout', include_strided=False)
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))

View File

@ -1676,10 +1676,10 @@
# DO NOT define a backward for to_dense # DO NOT define a backward for to_dense
# See [Note: Sometimes view derivatives] # See [Note: Sometimes view derivatives]
# - name: to_dense(Tensor self, ScalarType? dtype=None) -> Tensor # - name: to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor
# #
- name: _to_dense(Tensor self, ScalarType? dtype=None) -> Tensor - name: _to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor
self: to_dense_backward(grad, self) self: to_dense_backward(grad, self, masked_grad)
- name: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor - name: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor
self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) self: to_sparse_backward(grad, self.layout(), self.sym_blocksize())

View File

@ -5360,10 +5360,16 @@ See :func:`torch.topk`
add_docstr_all( add_docstr_all(
"to_dense", "to_dense",
r""" r"""
to_dense() -> Tensor to_dense(dtype=None, *, masked_grad=True) -> Tensor
Creates a strided copy of :attr:`self` if :attr:`self` is not a strided tensor, otherwise returns :attr:`self`. Creates a strided copy of :attr:`self` if :attr:`self` is not a strided tensor, otherwise returns :attr:`self`.
Keyword args:
{dtype}
masked_grad (bool, optional): If set to ``True`` (default) and
:attr:`self` has a sparse layout then the backward of
:meth:`to_dense` returns ``grad.sparse_mask(self)``.
Example:: Example::
>>> s = torch.sparse_coo_tensor( >>> s = torch.sparse_coo_tensor(

View File

@ -770,10 +770,10 @@ def _check_outputs(outputs) -> None:
# it is easier to call to_dense() on the sparse output than # it is easier to call to_dense() on the sparse output than
# to modify analytical jacobian # to modify analytical jacobian
raise ValueError('Sparse output is not supported at gradcheck yet. ' raise ValueError('Sparse output is not supported at gradcheck yet. '
'Please call to_dense() on the output of fn for gradcheck.') 'Please call to_dense(masked_grad=...) on the output of fn for gradcheck.')
if any(t.layout == torch._mkldnn for t in outputs if isinstance(t, torch.Tensor)): # type: ignore[attr-defined] if any(t.layout == torch._mkldnn for t in outputs if isinstance(t, torch.Tensor)): # type: ignore[attr-defined]
raise ValueError('MKLDNN output is not supported at gradcheck yet. ' raise ValueError('MKLDNN output is not supported at gradcheck yet. '
'Please call to_dense() on the output of fn for gradcheck.') 'Please call to_dense(masked_grad=...) on the output of fn for gradcheck.')
def _check_no_differentiable_outputs(func, inputs, func_out, eps, *, is_forward_ad) -> bool: def _check_no_differentiable_outputs(func, inputs, func_out, eps, *, is_forward_ad) -> bool:

View File

@ -737,11 +737,11 @@ void ComputeSubgraphInMKLDNN(Node* subgraph_node) {
if (!v->type()->cast<TensorType>()) { if (!v->type()->cast<TensorType>()) {
continue; continue;
} }
auto from_mkldnn = auto from_mkldnn = graph
graph ->create(
->create( c10::Symbol::fromQualString("aten::to_dense"),
c10::Symbol::fromQualString("aten::to_dense"), {v, none_value}) {v, none_value, none_value})
->insertAfter(subgraph_node); ->insertAfter(subgraph_node);
v->replaceAllUsesAfterNodeWith(from_mkldnn, from_mkldnn->output()); v->replaceAllUsesAfterNodeWith(from_mkldnn, from_mkldnn->output());
} }

View File

@ -1335,8 +1335,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.sum_to_size: lambda self, size: -1, Tensor.sum_to_size: lambda self, size: -1,
Tensor.tile: lambda self, *reps: -1, Tensor.tile: lambda self, *reps: -1,
Tensor.to: lambda self, dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format: -1, Tensor.to: lambda self, dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format: -1,
Tensor.to_dense: lambda self, dtype=None: -1, Tensor.to_dense: lambda self, dtype=None, *, masked_grad=None: -1,
Tensor._to_dense: lambda self, dtype=None: -1, Tensor._to_dense: lambda self, dtype=None, masked_grad=None: -1,
Tensor.to_sparse: lambda self: -1, Tensor.to_sparse: lambda self: -1,
Tensor.tolist: lambda self: -1, Tensor.tolist: lambda self: -1,
Tensor.to_mkldnn: lambda self: -1, Tensor.to_mkldnn: lambda self: -1,