Symintify _gather_sparse_backward (#96591)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96591
Approved by: https://github.com/Skylion007
This commit is contained in:
Nikita Karetnikov
2023-03-11 04:49:22 +01:00
committed by PyTorch MergeBot
parent cb7c796b4b
commit 12735952a0
2 changed files with 10 additions and 11 deletions

View File

@ -2019,25 +2019,25 @@ Tensor& take_along_dim_out(const Tensor& self, const Tensor& indices, c10::optio
Tensor _gather_sparse_backward(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& grad){
// special case scalar input and/or index
if (self.ndimension() == 0) return at::_sparse_coo_tensor_unsafe(at::empty({0,grad.numel()}, index.options()), grad, self.sizes());
if (grad.ndimension() == 0) return at::_sparse_coo_tensor_unsafe(index.view({1,1}), grad, self.sizes());
Tensor sparse_ind = at::empty({self.ndimension(), grad.numel()}, self.options().dtype(at::kLong));
int64_t grad_numel = grad.numel();
if (self.ndimension() == 0) return at::_sparse_coo_tensor_unsafe_symint(at::empty_symint({0,grad.sym_numel()}, index.options()), grad, self.sym_sizes());
if (grad.ndimension() == 0) return at::_sparse_coo_tensor_unsafe_symint(index.view({1,1}), grad, self.sym_sizes());
Tensor sparse_ind = at::empty_symint({self.ndimension(), grad.sym_numel()}, self.options().dtype(at::kLong));
SymInt grad_numel = grad.sym_numel();
if (grad_numel > 0) {
int64_t n_above = grad_numel;
int64_t n_below = 1;
SymInt n_above = grad_numel;
SymInt n_below = 1;
if (dim < 0) dim += self.ndimension();
for (const auto i : c10::irange(self.ndimension())) {
n_above /= grad.size(i);
n_above /= grad.sym_size(i);
if (i == dim) {
sparse_ind[i] = index.reshape(-1);
} else {
sparse_ind[i] = at::arange(grad.size(i),self.options().dtype(at::kLong)).unsqueeze(1).expand({grad.size(i), n_above}).reshape(-1).repeat(n_below);
sparse_ind[i] = at::arange(grad.sym_size(i),self.options().dtype(at::kLong)).unsqueeze(1).expand_symint({grad.sym_size(i), n_above}).reshape(-1).repeat_symint(n_below);
}
n_below *= grad.size(i);
n_below *= grad.sym_size(i);
}
}
return at::_sparse_coo_tensor_unsafe(sparse_ind, grad.reshape(-1), self.sizes());
return at::_sparse_coo_tensor_unsafe_symint(sparse_ind, grad.reshape(-1), self.sym_sizes());
}
template <typename scalar_t>

View File

@ -42,7 +42,6 @@ importlib.import_module("filelock")
test_skips = {
"test_cpp_wrapper_dynamic_shapes": ("cpu",),
"test_cudnn_rnn_dynamic_shapes": ("cuda",),
"test_gather3_dynamic_shapes": ("cpu", "cuda"),
"test_kwargs_dynamic_shapes": ("cpu",),
# test_roi_align uses torchvision, which doesn't work with dynamic shapes
"test_roi_align_dynamic_shapes": ("cpu", "cuda"),