mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Symintify _gather_sparse_backward (#96591)
				
					
				
			Pull Request resolved: https://github.com/pytorch/pytorch/pull/96591 Approved by: https://github.com/Skylion007
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							cb7c796b4b
						
					
				
				
					commit
					12735952a0
				
			| @ -2019,25 +2019,25 @@ Tensor& take_along_dim_out(const Tensor& self, const Tensor& indices, c10::optio | ||||
|  | ||||
| Tensor _gather_sparse_backward(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& grad){ | ||||
| // special case scalar input and/or index | ||||
|     if (self.ndimension() == 0) return at::_sparse_coo_tensor_unsafe(at::empty({0,grad.numel()}, index.options()), grad, self.sizes()); | ||||
|     if (grad.ndimension() == 0) return at::_sparse_coo_tensor_unsafe(index.view({1,1}), grad, self.sizes()); | ||||
|     Tensor sparse_ind = at::empty({self.ndimension(), grad.numel()}, self.options().dtype(at::kLong)); | ||||
|     int64_t grad_numel = grad.numel(); | ||||
|     if (self.ndimension() == 0) return at::_sparse_coo_tensor_unsafe_symint(at::empty_symint({0,grad.sym_numel()}, index.options()), grad, self.sym_sizes()); | ||||
|     if (grad.ndimension() == 0) return at::_sparse_coo_tensor_unsafe_symint(index.view({1,1}), grad, self.sym_sizes()); | ||||
|     Tensor sparse_ind = at::empty_symint({self.ndimension(), grad.sym_numel()}, self.options().dtype(at::kLong)); | ||||
|     SymInt grad_numel = grad.sym_numel(); | ||||
|     if (grad_numel > 0) { | ||||
|       int64_t n_above = grad_numel; | ||||
|       int64_t n_below = 1; | ||||
|       SymInt n_above = grad_numel; | ||||
|       SymInt n_below = 1; | ||||
|       if (dim < 0) dim += self.ndimension(); | ||||
|       for (const auto i : c10::irange(self.ndimension())) { | ||||
|           n_above /= grad.size(i); | ||||
|           n_above /= grad.sym_size(i); | ||||
|           if (i == dim) { | ||||
|               sparse_ind[i] = index.reshape(-1); | ||||
|           } else { | ||||
|               sparse_ind[i] = at::arange(grad.size(i),self.options().dtype(at::kLong)).unsqueeze(1).expand({grad.size(i), n_above}).reshape(-1).repeat(n_below); | ||||
|               sparse_ind[i] = at::arange(grad.sym_size(i),self.options().dtype(at::kLong)).unsqueeze(1).expand_symint({grad.sym_size(i), n_above}).reshape(-1).repeat_symint(n_below); | ||||
|           } | ||||
|           n_below *= grad.size(i); | ||||
|           n_below *= grad.sym_size(i); | ||||
|       } | ||||
|     } | ||||
|     return at::_sparse_coo_tensor_unsafe(sparse_ind, grad.reshape(-1), self.sizes()); | ||||
|     return at::_sparse_coo_tensor_unsafe_symint(sparse_ind, grad.reshape(-1), self.sym_sizes()); | ||||
| } | ||||
|  | ||||
| template <typename scalar_t> | ||||
|  | ||||
| @ -42,7 +42,6 @@ importlib.import_module("filelock") | ||||
| test_skips = { | ||||
|     "test_cpp_wrapper_dynamic_shapes": ("cpu",), | ||||
|     "test_cudnn_rnn_dynamic_shapes": ("cuda",), | ||||
|     "test_gather3_dynamic_shapes": ("cpu", "cuda"), | ||||
|     "test_kwargs_dynamic_shapes": ("cpu",), | ||||
|     # test_roi_align uses torchvision, which doesn't work with dynamic shapes | ||||
|     "test_roi_align_dynamic_shapes": ("cpu", "cuda"), | ||||
|  | ||||
		Reference in New Issue
	
	Block a user