mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Symintify _gather_sparse_backward (#96591)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96591 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
cb7c796b4b
commit
12735952a0
@ -2019,25 +2019,25 @@ Tensor& take_along_dim_out(const Tensor& self, const Tensor& indices, c10::optio
|
||||
|
||||
Tensor _gather_sparse_backward(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& grad){
|
||||
// special case scalar input and/or index
|
||||
if (self.ndimension() == 0) return at::_sparse_coo_tensor_unsafe(at::empty({0,grad.numel()}, index.options()), grad, self.sizes());
|
||||
if (grad.ndimension() == 0) return at::_sparse_coo_tensor_unsafe(index.view({1,1}), grad, self.sizes());
|
||||
Tensor sparse_ind = at::empty({self.ndimension(), grad.numel()}, self.options().dtype(at::kLong));
|
||||
int64_t grad_numel = grad.numel();
|
||||
if (self.ndimension() == 0) return at::_sparse_coo_tensor_unsafe_symint(at::empty_symint({0,grad.sym_numel()}, index.options()), grad, self.sym_sizes());
|
||||
if (grad.ndimension() == 0) return at::_sparse_coo_tensor_unsafe_symint(index.view({1,1}), grad, self.sym_sizes());
|
||||
Tensor sparse_ind = at::empty_symint({self.ndimension(), grad.sym_numel()}, self.options().dtype(at::kLong));
|
||||
SymInt grad_numel = grad.sym_numel();
|
||||
if (grad_numel > 0) {
|
||||
int64_t n_above = grad_numel;
|
||||
int64_t n_below = 1;
|
||||
SymInt n_above = grad_numel;
|
||||
SymInt n_below = 1;
|
||||
if (dim < 0) dim += self.ndimension();
|
||||
for (const auto i : c10::irange(self.ndimension())) {
|
||||
n_above /= grad.size(i);
|
||||
n_above /= grad.sym_size(i);
|
||||
if (i == dim) {
|
||||
sparse_ind[i] = index.reshape(-1);
|
||||
} else {
|
||||
sparse_ind[i] = at::arange(grad.size(i),self.options().dtype(at::kLong)).unsqueeze(1).expand({grad.size(i), n_above}).reshape(-1).repeat(n_below);
|
||||
sparse_ind[i] = at::arange(grad.sym_size(i),self.options().dtype(at::kLong)).unsqueeze(1).expand_symint({grad.sym_size(i), n_above}).reshape(-1).repeat_symint(n_below);
|
||||
}
|
||||
n_below *= grad.size(i);
|
||||
n_below *= grad.sym_size(i);
|
||||
}
|
||||
}
|
||||
return at::_sparse_coo_tensor_unsafe(sparse_ind, grad.reshape(-1), self.sizes());
|
||||
return at::_sparse_coo_tensor_unsafe_symint(sparse_ind, grad.reshape(-1), self.sym_sizes());
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
|
||||
@ -42,7 +42,6 @@ importlib.import_module("filelock")
|
||||
test_skips = {
|
||||
"test_cpp_wrapper_dynamic_shapes": ("cpu",),
|
||||
"test_cudnn_rnn_dynamic_shapes": ("cuda",),
|
||||
"test_gather3_dynamic_shapes": ("cpu", "cuda"),
|
||||
"test_kwargs_dynamic_shapes": ("cpu",),
|
||||
# test_roi_align uses torchvision, which doesn't work with dynamic shapes
|
||||
"test_roi_align_dynamic_shapes": ("cpu", "cuda"),
|
||||
|
||||
Reference in New Issue
Block a user