mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add masked_grad kw argument to to_dense (#96095)
As in the title. The `masked_grad` kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked_grad=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked_grad==True` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked_grad=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked_grad=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked_grad=True), masked=True) ``` Fixes https://github.com/pytorch/pytorch/issues/95550 Pull Request resolved: https://github.com/pytorch/pytorch/pull/96095 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
9d80969fa4
commit
2abcafcfd8
@ -498,23 +498,52 @@ std::vector<Tensor> _to_cpu(TensorList tensors) {
|
|||||||
return cpu_tensors;
|
return cpu_tensors;
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor to_dense_backward(const Tensor& grad, const Tensor& input_) {
|
Tensor to_dense_backward(const Tensor& grad, const Tensor& input_, c10::optional<bool> masked_grad_) {
|
||||||
|
/*
|
||||||
|
For historical reasons, to_dense backward implements masked
|
||||||
|
semantics for sparse tensors, that is, gradients with respect to
|
||||||
|
unspecified elements are ignored. The masked_grad kw argument of
|
||||||
|
to_dense is introduced to allow to_dense to be used in the
|
||||||
|
non-masked semantics context. However, for BC reasons, the default
|
||||||
|
value to masked_grad kw argument is set True as a first instance.
|
||||||
|
Eventually, we should eliminate the masked_grad kw argument and
|
||||||
|
let to_dense backward to behave according to non-masked
|
||||||
|
semantics. Masked semantics of tensors is implemented in the
|
||||||
|
framework of masked tensors.
|
||||||
|
*/
|
||||||
const auto input_layout = input_.layout();
|
const auto input_layout = input_.layout();
|
||||||
|
const bool masked_grad = masked_grad_.value_or(true);
|
||||||
switch (input_layout) {
|
switch (input_layout) {
|
||||||
case kStrided:
|
case kStrided:
|
||||||
return grad.to_dense();
|
// TODO: return grad as it is
|
||||||
|
return grad.to_dense(input_.scalar_type(), masked_grad_);
|
||||||
case kSparse:
|
case kSparse:
|
||||||
// Autograd operates on the coalesced assumption, i.e. no duplicate values.
|
// Autograd operates on the coalesced assumption, i.e. no duplicate values.
|
||||||
return grad.sparse_mask(input_.coalesce());
|
if (masked_grad) {
|
||||||
|
return grad.sparse_mask(input_.coalesce());
|
||||||
|
} else {
|
||||||
|
// TODO: return grad as it is
|
||||||
|
return grad.to_sparse(input_.sparse_dim());
|
||||||
|
}
|
||||||
case kSparseCsr:
|
case kSparseCsr:
|
||||||
case kSparseCsc:
|
case kSparseCsc:
|
||||||
// TODO: add efficient CSR/CSC support for sparse_mask
|
// TODO: add efficient CSR/CSC support for sparse_mask
|
||||||
return grad.sparse_mask(input_.to_sparse()).to_sparse(input_layout);
|
if (masked_grad) {
|
||||||
|
return grad.sparse_mask(input_.to_sparse(input_.sparse_dim())).to_sparse(input_layout);
|
||||||
|
} else {
|
||||||
|
// TODO: return grad as it is
|
||||||
|
return grad.to_sparse(input_layout, /*blocksize=*/c10::nullopt, /*dense_dim=*/input_.dense_dim());
|
||||||
|
}
|
||||||
case kSparseBsr:
|
case kSparseBsr:
|
||||||
case kSparseBsc: {
|
case kSparseBsc: {
|
||||||
// TODO: add efficient BSR/BSC support for sparse_mask
|
// TODO: add efficient BSR/BSC support for sparse_mask
|
||||||
const auto blocksize = at::DimVector(input_.values().sizes().slice(1, 2));
|
const auto blocksize = at::sparse_csr::getBlockSize(input_);
|
||||||
return grad.sparse_mask(input_.to_sparse()).to_sparse(input_layout, blocksize);
|
if (masked_grad) {
|
||||||
|
return grad.sparse_mask(input_.to_sparse(input_.sparse_dim())).to_sparse(input_layout, blocksize);
|
||||||
|
} else {
|
||||||
|
// TODO: return grad as it is
|
||||||
|
return grad.to_sparse(input_layout, blocksize, input_.dense_dim());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case kMkldnn:
|
case kMkldnn:
|
||||||
return grad.to_mkldnn(input_.scalar_type());
|
return grad.to_mkldnn(input_.scalar_type());
|
||||||
@ -529,18 +558,18 @@ Tensor to_mkldnn_backward(const Tensor& grad, const Tensor& input_) {
|
|||||||
return grad.to_dense(input_.scalar_type());
|
return grad.to_dense(input_.scalar_type());
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor to_dense(const Tensor& tensor, c10::optional<c10::ScalarType> dtype) {
|
Tensor to_dense(const Tensor& tensor, c10::optional<c10::ScalarType> dtype, c10::optional<bool> masked_grad) {
|
||||||
if (tensor.layout() == c10::kSparse) {
|
if (tensor.layout() == c10::kSparse) {
|
||||||
return tensor._to_dense(dtype);
|
return tensor._to_dense(dtype, masked_grad);
|
||||||
}
|
}
|
||||||
if (tensor.layout() == c10::kSparseCsr ||
|
if (tensor.layout() == c10::kSparseCsr ||
|
||||||
tensor.layout() == c10::kSparseCsc ||
|
tensor.layout() == c10::kSparseCsc ||
|
||||||
tensor.layout() == c10::kSparseBsr ||
|
tensor.layout() == c10::kSparseBsr ||
|
||||||
tensor.layout() == c10::kSparseBsc) {
|
tensor.layout() == c10::kSparseBsc) {
|
||||||
return tensor._to_dense(dtype);
|
return tensor._to_dense(dtype, masked_grad);
|
||||||
}
|
}
|
||||||
if (tensor.layout() == c10::kMkldnn) {
|
if (tensor.layout() == c10::kMkldnn) {
|
||||||
return tensor._to_dense(dtype);
|
return tensor._to_dense(dtype, masked_grad);
|
||||||
}
|
}
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
tensor.layout() == c10::kStrided,
|
tensor.layout() == c10::kStrided,
|
||||||
@ -552,7 +581,7 @@ Tensor to_dense(const Tensor& tensor, c10::optional<c10::ScalarType> dtype) {
|
|||||||
return tensor;
|
return tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor sparse_to_dense(const Tensor& self, c10::optional<ScalarType> dtype) {
|
Tensor sparse_to_dense(const Tensor& self, c10::optional<ScalarType> dtype, c10::optional<bool> masked) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
!dtype.has_value(), "dtype argument is not supported by sparse_to_dense");
|
!dtype.has_value(), "dtype argument is not supported by sparse_to_dense");
|
||||||
Tensor dst = at::zeros(self.sizes(), self.options().layout(kStrided));
|
Tensor dst = at::zeros(self.sizes(), self.options().layout(kStrided));
|
||||||
@ -561,7 +590,8 @@ Tensor sparse_to_dense(const Tensor& self, c10::optional<ScalarType> dtype) {
|
|||||||
|
|
||||||
Tensor sparse_compressed_to_dense(
|
Tensor sparse_compressed_to_dense(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
c10::optional<ScalarType> dtype) {
|
c10::optional<ScalarType> dtype,
|
||||||
|
c10::optional<bool> masked_grad) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
!dtype.has_value(),
|
!dtype.has_value(),
|
||||||
"dtype argument is not supported by sparse_csr_to_dense");
|
"dtype argument is not supported by sparse_csr_to_dense");
|
||||||
@ -1756,7 +1786,7 @@ Tensor sparse_compressed_to_sparse(const Tensor& self, c10::optional<c10::Layout
|
|||||||
}
|
}
|
||||||
switch (layout_) {
|
switch (layout_) {
|
||||||
case kStrided:
|
case kStrided:
|
||||||
return sparse_compressed_to_dense(self);
|
return sparse_compressed_to_dense(self, /*dtype=*/c10::nullopt, /*masked_grad=*/c10::nullopt);
|
||||||
case kSparse:
|
case kSparse:
|
||||||
return sparse_compressed_to_sparse(self, 2);
|
return sparse_compressed_to_sparse(self, 2);
|
||||||
case kSparseCsr:
|
case kSparseCsr:
|
||||||
@ -1801,7 +1831,7 @@ Tensor sparse_coo_to_sparse(const Tensor& self, c10::optional<c10::Layout> layou
|
|||||||
}
|
}
|
||||||
switch (layout_) {
|
switch (layout_) {
|
||||||
case kStrided:
|
case kStrided:
|
||||||
return self.to_dense();
|
return self.to_dense(c10::nullopt, c10::nullopt);
|
||||||
case kSparse:
|
case kSparse:
|
||||||
return self;
|
return self;
|
||||||
case kSparseCsr:
|
case kSparseCsr:
|
||||||
|
@ -23,7 +23,7 @@ namespace at { namespace native {
|
|||||||
|
|
||||||
#if AT_MKLDNN_ENABLED()
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
|
||||||
Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional<ScalarType> dtype) {
|
Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional<ScalarType> dtype, c10::optional<bool> masked_grad) {
|
||||||
TORCH_CHECK(mkldnn_tensor.scalar_type() == ScalarType::Float ||
|
TORCH_CHECK(mkldnn_tensor.scalar_type() == ScalarType::Float ||
|
||||||
mkldnn_tensor.scalar_type() == ScalarType::BFloat16,
|
mkldnn_tensor.scalar_type() == ScalarType::BFloat16,
|
||||||
"mkldnn_to_dense expects float or bfloat16 tensor input");
|
"mkldnn_to_dense expects float or bfloat16 tensor input");
|
||||||
@ -269,7 +269,7 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
|
|||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional<ScalarType> dtype) {
|
Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional<ScalarType> dtype, c10::optional<bool> masked_grad) {
|
||||||
TORCH_CHECK(false, "MKL-DNN build is disabled");
|
TORCH_CHECK(false, "MKL-DNN build is disabled");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -342,4 +342,5 @@ TORCH_LIBRARY_IMPL(mkl, MkldnnCPU, m) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#endif // AT_MKL_ENABLED && AT_MKLDNN_ENABLED
|
#endif // AT_MKL_ENABLED && AT_MKLDNN_ENABLED
|
||||||
|
|
||||||
}}
|
}}
|
||||||
|
@ -6702,11 +6702,11 @@
|
|||||||
- func: _to_cpu(Tensor[] tensors) -> Tensor[]
|
- func: _to_cpu(Tensor[] tensors) -> Tensor[]
|
||||||
variants: function
|
variants: function
|
||||||
|
|
||||||
- func: to_dense(Tensor self, ScalarType? dtype=None) -> Tensor
|
- func: to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor
|
||||||
variants: method
|
variants: method
|
||||||
|
|
||||||
# Special case of to_dense with custom derivative
|
# Special case of to_dense with custom derivative
|
||||||
- func: _to_dense(Tensor self, ScalarType? dtype=None) -> Tensor
|
- func: _to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor
|
||||||
variants: method
|
variants: method
|
||||||
dispatch:
|
dispatch:
|
||||||
SparseCPU, SparseCUDA: sparse_to_dense
|
SparseCPU, SparseCUDA: sparse_to_dense
|
||||||
@ -6714,7 +6714,7 @@
|
|||||||
MkldnnCPU: mkldnn_to_dense
|
MkldnnCPU: mkldnn_to_dense
|
||||||
autogen: _to_dense.out
|
autogen: _to_dense.out
|
||||||
|
|
||||||
- func: to_dense_backward(Tensor grad, Tensor input) -> Tensor
|
- func: to_dense_backward(Tensor grad, Tensor input, bool? masked_grad=None) -> Tensor
|
||||||
|
|
||||||
- func: sparse_dim(Tensor self) -> int
|
- func: sparse_dim(Tensor self) -> int
|
||||||
variants: method
|
variants: method
|
||||||
|
@ -524,7 +524,8 @@ const SparseTensor& resize_as_sparse_(const SparseTensor& self, const SparseTens
|
|||||||
SparseTensor dense_to_sparse(const Tensor& self, c10::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, c10::optional<int64_t> dense_dim_opt) {
|
SparseTensor dense_to_sparse(const Tensor& self, c10::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, c10::optional<int64_t> dense_dim_opt) {
|
||||||
if (layout.has_value()) {
|
if (layout.has_value()) {
|
||||||
if (blocksize.has_value() && !(*layout == kSparseBsr || *layout == kSparseBsc)) {
|
if (blocksize.has_value() && !(*layout == kSparseBsr || *layout == kSparseBsc)) {
|
||||||
AT_ERROR("to_sparse for ", self.layout(), " to ", *layout, " conversion does not use specified blocksize");
|
AT_ERROR("to_sparse for ", self.layout(), " to ", *layout,
|
||||||
|
" conversion does not use the specified blocksize ", blocksize.value(), ".");
|
||||||
}
|
}
|
||||||
if (self.layout() == *layout) {
|
if (self.layout() == *layout) {
|
||||||
return self;
|
return self;
|
||||||
|
@ -10,7 +10,7 @@ from torch.testing import make_tensor
|
|||||||
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
|
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
|
||||||
load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
|
load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
|
||||||
DeterministicGuard, first_sample, TEST_WITH_CROSSREF, TEST_WITH_ROCM, skipIfTorchDynamo, \
|
DeterministicGuard, first_sample, TEST_WITH_CROSSREF, TEST_WITH_ROCM, skipIfTorchDynamo, \
|
||||||
parametrize, subtest, is_coalesced_indices, suppress_warnings, is_slow_gradcheck_env
|
parametrize, subtest, is_coalesced_indices, suppress_warnings
|
||||||
from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
|
from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
@ -413,8 +413,6 @@ class TestSparse(TestSparseBase):
|
|||||||
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
|
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
|
||||||
@gradcheck_semantics()
|
@gradcheck_semantics()
|
||||||
def test_to_dense_with_gradcheck(self, device, dtype, gradcheck):
|
def test_to_dense_with_gradcheck(self, device, dtype, gradcheck):
|
||||||
if not gradcheck.masked and is_slow_gradcheck_env():
|
|
||||||
self.skipTest('FIXME: to_dense_backward supports masked semantics only')
|
|
||||||
|
|
||||||
def test_tensor(x, res):
|
def test_tensor(x, res):
|
||||||
x.to_dense() # Tests triple to_dense for memory corruption
|
x.to_dense() # Tests triple to_dense for memory corruption
|
||||||
@ -432,7 +430,7 @@ class TestSparse(TestSparseBase):
|
|||||||
return
|
return
|
||||||
|
|
||||||
def fn(x):
|
def fn(x):
|
||||||
return x.to_dense()
|
return x.to_dense(masked_grad=gradcheck.masked)
|
||||||
x.requires_grad_(True)
|
x.requires_grad_(True)
|
||||||
gradcheck(fn, (x,), check_sparse_nnz=True)
|
gradcheck(fn, (x,), check_sparse_nnz=True)
|
||||||
|
|
||||||
@ -550,8 +548,6 @@ class TestSparse(TestSparseBase):
|
|||||||
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
|
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
|
||||||
@gradcheck_semantics()
|
@gradcheck_semantics()
|
||||||
def test_to_dense_hybrid(self, device, dtype, gradcheck):
|
def test_to_dense_hybrid(self, device, dtype, gradcheck):
|
||||||
if not gradcheck.masked and is_slow_gradcheck_env():
|
|
||||||
self.skipTest('FIXME: to_dense_backward supports masked semantics only')
|
|
||||||
|
|
||||||
def test_tensor(x, res):
|
def test_tensor(x, res):
|
||||||
x.to_dense() # Tests double to_dense for memory corruption
|
x.to_dense() # Tests double to_dense for memory corruption
|
||||||
@ -561,7 +557,7 @@ class TestSparse(TestSparseBase):
|
|||||||
self.assertEqual(res, self.safeToDense(x))
|
self.assertEqual(res, self.safeToDense(x))
|
||||||
|
|
||||||
def fn(x):
|
def fn(x):
|
||||||
return x.to_dense()
|
return x.to_dense(masked_grad=gradcheck.masked)
|
||||||
x.requires_grad_(True)
|
x.requires_grad_(True)
|
||||||
gradcheck(fn, (x,), check_sparse_nnz=True)
|
gradcheck(fn, (x,), check_sparse_nnz=True)
|
||||||
|
|
||||||
@ -908,8 +904,6 @@ class TestSparse(TestSparseBase):
|
|||||||
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
|
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
|
||||||
@gradcheck_semantics()
|
@gradcheck_semantics()
|
||||||
def test_permute(self, device, dtype, coalesced, gradcheck):
|
def test_permute(self, device, dtype, coalesced, gradcheck):
|
||||||
if not gradcheck.masked and is_slow_gradcheck_env():
|
|
||||||
self.skipTest('FIXME: to_dense_backward supports masked semantics only')
|
|
||||||
# trivial checks
|
# trivial checks
|
||||||
s = torch.rand(3, 3, 3, device=device, dtype=dtype).to_sparse()
|
s = torch.rand(3, 3, 3, device=device, dtype=dtype).to_sparse()
|
||||||
with self.assertRaisesRegex(RuntimeError, "does not match the length"):
|
with self.assertRaisesRegex(RuntimeError, "does not match the length"):
|
||||||
@ -941,7 +935,7 @@ class TestSparse(TestSparseBase):
|
|||||||
else:
|
else:
|
||||||
self.assertFalse(s_permuted.is_coalesced())
|
self.assertFalse(s_permuted.is_coalesced())
|
||||||
|
|
||||||
gradcheck(lambda t: t.permute(dims).to_dense(), s.requires_grad_(True), check_sparse_nnz=True)
|
gradcheck(lambda t: t.permute(dims).to_dense(masked_grad=gradcheck.masked), s.requires_grad_())
|
||||||
else:
|
else:
|
||||||
# otherwise check if exception is thrown
|
# otherwise check if exception is thrown
|
||||||
fail_message = "transpositions between sparse and dense dimensions are not allowed"
|
fail_message = "transpositions between sparse and dense dimensions are not allowed"
|
||||||
@ -1778,10 +1772,7 @@ class TestSparse(TestSparseBase):
|
|||||||
self.assertEqual(S_sum.item(), D_sum.item())
|
self.assertEqual(S_sum.item(), D_sum.item())
|
||||||
|
|
||||||
def fn(S):
|
def fn(S):
|
||||||
res = torch.sparse.sum(S)
|
return torch.sparse.sum(S)
|
||||||
if res.is_sparse:
|
|
||||||
res = res.to_dense()
|
|
||||||
return res
|
|
||||||
gradcheck(fn, (S,), check_sparse_nnz=True, masked=True)
|
gradcheck(fn, (S,), check_sparse_nnz=True, masked=True)
|
||||||
else:
|
else:
|
||||||
S_sum = torch.sparse.sum(S, td)
|
S_sum = torch.sparse.sum(S, td)
|
||||||
@ -1790,9 +1781,7 @@ class TestSparse(TestSparseBase):
|
|||||||
|
|
||||||
def fn(S):
|
def fn(S):
|
||||||
res = torch.sparse.sum(S, td)
|
res = torch.sparse.sum(S, td)
|
||||||
if res.is_sparse:
|
return res.to_dense(masked_grad=True)
|
||||||
res = res.to_dense()
|
|
||||||
return res
|
|
||||||
gradcheck(fn, (S,), check_sparse_nnz=True, masked=True)
|
gradcheck(fn, (S,), check_sparse_nnz=True, masked=True)
|
||||||
|
|
||||||
nnz = 10
|
nnz = 10
|
||||||
@ -4012,7 +4001,7 @@ class TestSparseOneOff(TestCase):
|
|||||||
|
|
||||||
def _sparse_to_dense(tensor):
|
def _sparse_to_dense(tensor):
|
||||||
if tensor.dtype != torch.bool:
|
if tensor.dtype != torch.bool:
|
||||||
return tensor.to_dense()
|
return tensor.to_dense(masked_grad=True)
|
||||||
|
|
||||||
# to_dense uses coalesce which isn't implemented for bool
|
# to_dense uses coalesce which isn't implemented for bool
|
||||||
return tensor.to(torch.int8).to_dense().to(torch.bool)
|
return tensor.to(torch.int8).to_dense().to(torch.bool)
|
||||||
@ -4423,22 +4412,9 @@ class TestSparseAny(TestCase):
|
|||||||
# TODO: implement batch support in _convert_indices_from_csr_to_coo
|
# TODO: implement batch support in _convert_indices_from_csr_to_coo
|
||||||
continue
|
continue
|
||||||
t = t.clone().detach().requires_grad_(True)
|
t = t.clone().detach().requires_grad_(True)
|
||||||
if is_slow_gradcheck_env() and not gradcheck.masked:
|
r = gradcheck(lambda x: torch.Tensor.to_dense(x, masked_grad=gradcheck.masked), t)
|
||||||
# TODO: remove this if-block when TODO items below are resolved
|
|
||||||
try:
|
|
||||||
gradcheck(torch.Tensor.to_dense, t)
|
|
||||||
except RuntimeError as msg:
|
|
||||||
# TODO: implement non-masked semantics support in to_dense_backward
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "Jacobian mismatch"):
|
|
||||||
gradcheck(torch.Tensor.to_dense, t)
|
|
||||||
self.skipTest('non-masked semantics not supported')
|
|
||||||
r = gradcheck(torch.Tensor.to_dense, t)
|
|
||||||
self.assertTrue(r)
|
self.assertTrue(r)
|
||||||
|
|
||||||
# when the following assert fails, it means that the if-block
|
|
||||||
# above and the assertFalse test below can be safely removed
|
|
||||||
self.assertFalse(is_slow_gradcheck_env() and not gradcheck.masked)
|
|
||||||
|
|
||||||
@all_sparse_layouts('from_layout', include_strided=True)
|
@all_sparse_layouts('from_layout', include_strided=True)
|
||||||
@all_sparse_layouts('to_layout', include_strided=False)
|
@all_sparse_layouts('to_layout', include_strided=False)
|
||||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
||||||
|
@ -1676,10 +1676,10 @@
|
|||||||
|
|
||||||
# DO NOT define a backward for to_dense
|
# DO NOT define a backward for to_dense
|
||||||
# See [Note: Sometimes view derivatives]
|
# See [Note: Sometimes view derivatives]
|
||||||
# - name: to_dense(Tensor self, ScalarType? dtype=None) -> Tensor
|
# - name: to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor
|
||||||
#
|
#
|
||||||
- name: _to_dense(Tensor self, ScalarType? dtype=None) -> Tensor
|
- name: _to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor
|
||||||
self: to_dense_backward(grad, self)
|
self: to_dense_backward(grad, self, masked_grad)
|
||||||
|
|
||||||
- name: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor
|
- name: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor
|
||||||
self: to_sparse_backward(grad, self.layout(), self.sym_blocksize())
|
self: to_sparse_backward(grad, self.layout(), self.sym_blocksize())
|
||||||
|
@ -5360,10 +5360,16 @@ See :func:`torch.topk`
|
|||||||
add_docstr_all(
|
add_docstr_all(
|
||||||
"to_dense",
|
"to_dense",
|
||||||
r"""
|
r"""
|
||||||
to_dense() -> Tensor
|
to_dense(dtype=None, *, masked_grad=True) -> Tensor
|
||||||
|
|
||||||
Creates a strided copy of :attr:`self` if :attr:`self` is not a strided tensor, otherwise returns :attr:`self`.
|
Creates a strided copy of :attr:`self` if :attr:`self` is not a strided tensor, otherwise returns :attr:`self`.
|
||||||
|
|
||||||
|
Keyword args:
|
||||||
|
{dtype}
|
||||||
|
masked_grad (bool, optional): If set to ``True`` (default) and
|
||||||
|
:attr:`self` has a sparse layout then the backward of
|
||||||
|
:meth:`to_dense` returns ``grad.sparse_mask(self)``.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
>>> s = torch.sparse_coo_tensor(
|
>>> s = torch.sparse_coo_tensor(
|
||||||
|
@ -770,10 +770,10 @@ def _check_outputs(outputs) -> None:
|
|||||||
# it is easier to call to_dense() on the sparse output than
|
# it is easier to call to_dense() on the sparse output than
|
||||||
# to modify analytical jacobian
|
# to modify analytical jacobian
|
||||||
raise ValueError('Sparse output is not supported at gradcheck yet. '
|
raise ValueError('Sparse output is not supported at gradcheck yet. '
|
||||||
'Please call to_dense() on the output of fn for gradcheck.')
|
'Please call to_dense(masked_grad=...) on the output of fn for gradcheck.')
|
||||||
if any(t.layout == torch._mkldnn for t in outputs if isinstance(t, torch.Tensor)): # type: ignore[attr-defined]
|
if any(t.layout == torch._mkldnn for t in outputs if isinstance(t, torch.Tensor)): # type: ignore[attr-defined]
|
||||||
raise ValueError('MKLDNN output is not supported at gradcheck yet. '
|
raise ValueError('MKLDNN output is not supported at gradcheck yet. '
|
||||||
'Please call to_dense() on the output of fn for gradcheck.')
|
'Please call to_dense(masked_grad=...) on the output of fn for gradcheck.')
|
||||||
|
|
||||||
|
|
||||||
def _check_no_differentiable_outputs(func, inputs, func_out, eps, *, is_forward_ad) -> bool:
|
def _check_no_differentiable_outputs(func, inputs, func_out, eps, *, is_forward_ad) -> bool:
|
||||||
|
@ -737,11 +737,11 @@ void ComputeSubgraphInMKLDNN(Node* subgraph_node) {
|
|||||||
if (!v->type()->cast<TensorType>()) {
|
if (!v->type()->cast<TensorType>()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto from_mkldnn =
|
auto from_mkldnn = graph
|
||||||
graph
|
->create(
|
||||||
->create(
|
c10::Symbol::fromQualString("aten::to_dense"),
|
||||||
c10::Symbol::fromQualString("aten::to_dense"), {v, none_value})
|
{v, none_value, none_value})
|
||||||
->insertAfter(subgraph_node);
|
->insertAfter(subgraph_node);
|
||||||
v->replaceAllUsesAfterNodeWith(from_mkldnn, from_mkldnn->output());
|
v->replaceAllUsesAfterNodeWith(from_mkldnn, from_mkldnn->output());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1335,8 +1335,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||||||
Tensor.sum_to_size: lambda self, size: -1,
|
Tensor.sum_to_size: lambda self, size: -1,
|
||||||
Tensor.tile: lambda self, *reps: -1,
|
Tensor.tile: lambda self, *reps: -1,
|
||||||
Tensor.to: lambda self, dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format: -1,
|
Tensor.to: lambda self, dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format: -1,
|
||||||
Tensor.to_dense: lambda self, dtype=None: -1,
|
Tensor.to_dense: lambda self, dtype=None, *, masked_grad=None: -1,
|
||||||
Tensor._to_dense: lambda self, dtype=None: -1,
|
Tensor._to_dense: lambda self, dtype=None, masked_grad=None: -1,
|
||||||
Tensor.to_sparse: lambda self: -1,
|
Tensor.to_sparse: lambda self: -1,
|
||||||
Tensor.tolist: lambda self: -1,
|
Tensor.tolist: lambda self: -1,
|
||||||
Tensor.to_mkldnn: lambda self: -1,
|
Tensor.to_mkldnn: lambda self: -1,
|
||||||
|
Reference in New Issue
Block a user