mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	codemod tensor.type().is_cuda(), tensor.type().is_sparse() (#13590)
Summary: Followup to #12841 Changed these to not require type dispatch: tensor.type().is_cuda() -> tensor.is_cuda() tensor.type().is_sparse() -> tensor.is_sparse() isVariable(tensor.type()) -> tensor.is_variable() This probably does not affect performance very much in most cases but it is nice to have. Pull Request resolved: https://github.com/pytorch/pytorch/pull/13590 Reviewed By: ezyang Differential Revision: D12929301 Pulled By: zou3519 fbshipit-source-id: 8ac5c6200c579dd7a44fb4ee58fc9bb170feb1d7
This commit is contained in:
		
				
					committed by
					
						 Facebook Github Bot
						Facebook Github Bot
					
				
			
			
				
	
			
			
			
						parent
						
							e70321ed9e
						
					
				
				
					commit
					e60a7c2c88
				
			| @ -152,7 +152,7 @@ DLManagedTensor* toDLPack(const Tensor& src) { | |||||||
|   atDLMTensor->tensor.deleter = &deleter; |   atDLMTensor->tensor.deleter = &deleter; | ||||||
|   atDLMTensor->tensor.dl_tensor.data = src.data_ptr(); |   atDLMTensor->tensor.dl_tensor.data = src.data_ptr(); | ||||||
|   int64_t device_id = 0; |   int64_t device_id = 0; | ||||||
|   if (src.type().is_cuda()) { |   if (src.is_cuda()) { | ||||||
|     device_id = src.get_device(); |     device_id = src.get_device(); | ||||||
|   } |   } | ||||||
|   atDLMTensor->tensor.dl_tensor.ctx = getDLContext(src.type(), device_id); |   atDLMTensor->tensor.dl_tensor.ctx = getDLContext(src.type(), device_id); | ||||||
|  | |||||||
| @ -180,7 +180,7 @@ AT_ERROR("${api_name} only supports a 0-dimensional ${check_name} tensor, but go | |||||||
| """) | """) | ||||||
|  |  | ||||||
| SPARSE_CHECK = CodeTemplate("""\ | SPARSE_CHECK = CodeTemplate("""\ | ||||||
| if(${check_name}.type().is_sparse()) { | if(${check_name}.is_sparse()) { | ||||||
|     return static_cast<const TypeExtendedInterface*>(this)->${api_name}(${sparse_actuals}); |     return static_cast<const TypeExtendedInterface*>(this)->${api_name}(${sparse_actuals}); | ||||||
| }""") | }""") | ||||||
|  |  | ||||||
|  | |||||||
| @ -108,7 +108,7 @@ auto ConvParams::use_cudnn(const at::Tensor& input) const -> bool { | |||||||
|   if (!detail::getCUDAHooks().compiledWithCuDNN()) { |   if (!detail::getCUDAHooks().compiledWithCuDNN()) { | ||||||
|     return false; |     return false; | ||||||
|   } |   } | ||||||
|   if (!input.type().is_cuda() || !cudnn_enabled) { |   if (!input.is_cuda() || !cudnn_enabled) { | ||||||
|     return false; |     return false; | ||||||
|   } |   } | ||||||
|   if (deterministic && is_dilated()) { |   if (deterministic && is_dilated()) { | ||||||
| @ -125,7 +125,7 @@ auto ConvParams::use_miopen(const at::Tensor& input) const -> bool { | |||||||
|  |  | ||||||
|   return ((input.type().scalarType() == at::kFloat) || (input.type().scalarType() == at::kHalf)) |   return ((input.type().scalarType() == at::kFloat) || (input.type().scalarType() == at::kHalf)) | ||||||
|          && detail::getCUDAHooks().compiledWithMIOpen() |          && detail::getCUDAHooks().compiledWithMIOpen() | ||||||
|          && input.type().is_cuda() |          && input.is_cuda() | ||||||
|          && input.dim() <= MIOPEN_DIM_MAX |          && input.dim() <= MIOPEN_DIM_MAX | ||||||
|          && !(groups > 1 && is_dilated()) // MIOpen currently does not support dilation with groups of size > 1 |          && !(groups > 1 && is_dilated()) // MIOpen currently does not support dilation with groups of size > 1 | ||||||
|          && !transposed |          && !transposed | ||||||
| @ -150,7 +150,7 @@ auto ConvParams::use_mkldnn(const at::Tensor& input) const -> bool { | |||||||
| // a depthwise multiplier) | // a depthwise multiplier) | ||||||
| auto ConvParams::is_depthwise( | auto ConvParams::is_depthwise( | ||||||
|         const at::Tensor& input, const at::Tensor& weight) const -> bool { |         const at::Tensor& input, const at::Tensor& weight) const -> bool { | ||||||
|   return input.type().is_cuda() && |   return input.is_cuda() && | ||||||
|          !transposed && |          !transposed && | ||||||
|          input.ndimension() == 4 && |          input.ndimension() == 4 && | ||||||
|          input.size(1) == groups && |          input.size(1) == groups && | ||||||
| @ -450,7 +450,7 @@ at::Tensor _convolution_nogroup( | |||||||
|             input, weight, kernel_size, bias, |             input, weight, kernel_size, bias, | ||||||
|             stride, padding); |             stride, padding); | ||||||
|       } |       } | ||||||
|     } else if (dim == 5 && (input.type().is_cuda() || dilated)) { |     } else if (dim == 5 && (input.is_cuda() || dilated)) { | ||||||
|       return at::thnn_conv_dilated3d( |       return at::thnn_conv_dilated3d( | ||||||
|           input, weight, kernel_size, bias, |           input, weight, kernel_size, bias, | ||||||
|           stride, padding, dilation); |           stride, padding, dilation); | ||||||
| @ -498,14 +498,14 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( | |||||||
|   // Compute ggO = conv(ggI, w) + conv(i, ggW) + ggb |   // Compute ggO = conv(ggI, w) + conv(i, ggW) + ggb | ||||||
|   Tensor ggO; |   Tensor ggO; | ||||||
|   if (ggI.defined()) { |   if (ggI.defined()) { | ||||||
|     if (weight.type().is_cuda()) { |     if (weight.is_cuda()) { | ||||||
|       weight = weight.contiguous(); |       weight = weight.contiguous(); | ||||||
|     } |     } | ||||||
|     ggO = at::_convolution(ggI, weight, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups, params.benchmark, params.deterministic, params.cudnn_enabled); |     ggO = at::_convolution(ggI, weight, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups, params.benchmark, params.deterministic, params.cudnn_enabled); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   if (ggW.defined()) { |   if (ggW.defined()) { | ||||||
|     if (ggW.type().is_cuda()) { |     if (ggW.is_cuda()) { | ||||||
|       ggW = ggW.contiguous(); |       ggW = ggW.contiguous(); | ||||||
|     } |     } | ||||||
|     auto ggW_term = at::_convolution(input, ggW, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups, params.benchmark, params.deterministic, params.cudnn_enabled); |     auto ggW_term = at::_convolution(input, ggW, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups, params.benchmark, params.deterministic, params.cudnn_enabled); | ||||||
| @ -553,7 +553,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( | |||||||
|     Tensor gWt; |     Tensor gWt; | ||||||
|     // Compute conv |     // Compute conv | ||||||
|     if (groups == 1) { |     if (groups == 1) { | ||||||
|       if (gOt.type().is_cuda()) { |       if (gOt.is_cuda()) { | ||||||
|         gOt = gOt.contiguous(); |         gOt = gOt.contiguous(); | ||||||
|       } |       } | ||||||
|  |  | ||||||
| @ -569,7 +569,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( | |||||||
|       for (int g = 0; g < groups; ++g) { |       for (int g = 0; g < groups; ++g) { | ||||||
|         auto ggIt_g = subvariable(ggIt, 0, groups, g); |         auto ggIt_g = subvariable(ggIt, 0, groups, g); | ||||||
|         auto gOt_g = subvariable(gOt, 0, groups, g); |         auto gOt_g = subvariable(gOt, 0, groups, g); | ||||||
|         if (gOt_g.type().is_cuda()) { |         if (gOt_g.is_cuda()) { | ||||||
|           gOt_g = gOt_g.contiguous(); |           gOt_g = gOt_g.contiguous(); | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @ -609,7 +609,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( | |||||||
|     gi_conv_params.transposed = !params.transposed; |     gi_conv_params.transposed = !params.transposed; | ||||||
|  |  | ||||||
|     if (params.transposed) { |     if (params.transposed) { | ||||||
|       if (gO.type().is_cuda()) { |       if (gO.is_cuda()) { | ||||||
|         gO = gO.contiguous(); |         gO = gO.contiguous(); | ||||||
|       } |       } | ||||||
|       gI = at::_convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups, gi_conv_params.benchmark, gi_conv_params.deterministic, gi_conv_params.cudnn_enabled); |       gI = at::_convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups, gi_conv_params.benchmark, gi_conv_params.deterministic, gi_conv_params.cudnn_enabled); | ||||||
| @ -662,7 +662,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( | |||||||
|  |  | ||||||
|       Tensor gIt; |       Tensor gIt; | ||||||
|       if (params.groups == 1) { |       if (params.groups == 1) { | ||||||
|         if (gOt.type().is_cuda()) { |         if (gOt.is_cuda()) { | ||||||
|           gOt = gOt.contiguous(); |           gOt = gOt.contiguous(); | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @ -672,7 +672,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( | |||||||
|         for (int g = 0; g < groups; ++g) { |         for (int g = 0; g < groups; ++g) { | ||||||
|           auto ggWt_g = subvariable(ggWt, 1, groups, g); |           auto ggWt_g = subvariable(ggWt, 1, groups, g); | ||||||
|           auto gOt_g = subvariable(gOt, 0, groups, g); |           auto gOt_g = subvariable(gOt, 0, groups, g); | ||||||
|           if (gOt_g.type().is_cuda()) { |           if (gOt_g.is_cuda()) { | ||||||
|             gOt_g = gOt_g.contiguous(); |             gOt_g = gOt_g.contiguous(); | ||||||
|           } |           } | ||||||
|  |  | ||||||
|  | |||||||
| @ -243,7 +243,7 @@ Tensor batch_norm( | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   bool use_cudnn = false; |   bool use_cudnn = false; | ||||||
|   use_cudnn = (input.type().is_cuda() |   use_cudnn = (input.is_cuda() | ||||||
|                && (input.type().scalarType() != at::kHalf |                && (input.type().scalarType() != at::kHalf | ||||||
|                  || weight.type().scalarType() == at::kFloat) |                  || weight.type().scalarType() == at::kFloat) | ||||||
|                && weight.defined() && bias.defined() |                && weight.defined() && bias.defined() | ||||||
| @ -262,7 +262,7 @@ Tensor batch_norm( | |||||||
|                         training, momentum, eps)); |                         training, momentum, eps)); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   bool use_miopen = (input.type().is_cuda() |   bool use_miopen = (input.is_cuda() | ||||||
|                && input.dim() <= MIOPEN_DIM_MAX |                && input.dim() <= MIOPEN_DIM_MAX | ||||||
|                && input.type().scalarType() != at::kDouble |                && input.type().scalarType() != at::kDouble | ||||||
|                && (input.type().scalarType() == weight.type().scalarType()) |                && (input.type().scalarType() == weight.type().scalarType()) | ||||||
|  | |||||||
| @ -431,7 +431,7 @@ Tensor& norm_out(Tensor &result, const Tensor &self, Scalar p, int64_t dim, bool | |||||||
| } | } | ||||||
|  |  | ||||||
| Tensor _norm(const Tensor &self, Scalar p) { | Tensor _norm(const Tensor &self, Scalar p) { | ||||||
|   if (self.type().is_sparse()) { |   if (self.is_sparse()) { | ||||||
|     return at::native_norm(self, p); |     return at::native_norm(self, p); | ||||||
|   } else { |   } else { | ||||||
|     AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, |     AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, | ||||||
|  | |||||||
| @ -151,7 +151,7 @@ Tensor empty_like(const Tensor& self) { | |||||||
| } | } | ||||||
|  |  | ||||||
| Tensor empty_like(const Tensor& self, const TensorOptions& options) { | Tensor empty_like(const Tensor& self, const TensorOptions& options) { | ||||||
|   if (options.layout() == kSparse && self.type().is_sparse()) { |   if (options.layout() == kSparse && self.is_sparse()) { | ||||||
|     auto res = at::empty({0}, options); // to be resized |     auto res = at::empty({0}, options); // to be resized | ||||||
|     res.sparse_resize_and_clear_(self.sizes(), self.sparse_dim(), self.dense_dim()); |     res.sparse_resize_and_clear_(self.sizes(), self.sparse_dim(), self.dense_dim()); | ||||||
|     return res; |     return res; | ||||||
| @ -523,7 +523,7 @@ Tensor zeros_like(const Tensor& self) { | |||||||
| } | } | ||||||
|  |  | ||||||
| Tensor zeros_like(const Tensor& self, const TensorOptions& options) { | Tensor zeros_like(const Tensor& self, const TensorOptions& options) { | ||||||
|   if (options.layout() == kSparse && self.type().is_sparse()) { |   if (options.layout() == kSparse && self.is_sparse()) { | ||||||
|     auto res = at::empty({0}, options); // to be resized |     auto res = at::empty({0}, options); // to be resized | ||||||
|     res.sparse_resize_and_clear_(self.sizes(), self.sparse_dim(), self.dense_dim()); |     res.sparse_resize_and_clear_(self.sizes(), self.sparse_dim(), self.dense_dim()); | ||||||
|     return res; |     return res; | ||||||
|  | |||||||
| @ -289,7 +289,7 @@ Tensor repeat(const Tensor& self, IntList repeats) { | |||||||
| } | } | ||||||
|  |  | ||||||
| Tensor reshape(const Tensor& self, IntList proposed_shape) { | Tensor reshape(const Tensor& self, IntList proposed_shape) { | ||||||
|   if (self.type().is_sparse()) { |   if (self.is_sparse()) { | ||||||
|     AT_ERROR("reshape is not implemented for sparse tensors"); |     AT_ERROR("reshape is not implemented for sparse tensors"); | ||||||
|   } |   } | ||||||
|   auto shape = infer_size(proposed_shape, self.numel()); |   auto shape = infer_size(proposed_shape, self.numel()); | ||||||
|  | |||||||
| @ -6,7 +6,7 @@ | |||||||
| namespace at { namespace native { | namespace at { namespace native { | ||||||
|  |  | ||||||
| bool is_cuda(const Tensor& self) { | bool is_cuda(const Tensor& self) { | ||||||
|   return self.type().is_cuda(); |   return self.is_cuda(); | ||||||
| } | } | ||||||
|  |  | ||||||
| bool is_distributed(const Tensor& self) { | bool is_distributed(const Tensor& self) { | ||||||
| @ -31,7 +31,7 @@ bool is_signed(const Tensor &self) { | |||||||
| } | } | ||||||
|  |  | ||||||
| bool is_sparse(const Tensor& self) { | bool is_sparse(const Tensor& self) { | ||||||
|   return self.type().is_sparse(); |   return self.is_sparse(); | ||||||
| } | } | ||||||
|  |  | ||||||
| Tensor type_as(const Tensor& self, const Tensor& other) { | Tensor type_as(const Tensor& self, const Tensor& other) { | ||||||
|  | |||||||
| @ -50,7 +50,7 @@ Tensor _weight_norm | |||||||
|   auto v = v_in.contiguous(); |   auto v = v_in.contiguous(); | ||||||
|   auto g = g_in.contiguous(); |   auto g = g_in.contiguous(); | ||||||
|  |  | ||||||
|   bool can_use_fused = v.type().is_cuda() && (dim == 0 || dim == v.dim() - 1); |   bool can_use_fused = v.is_cuda() && (dim == 0 || dim == v.dim() - 1); | ||||||
|  |  | ||||||
|   if (can_use_fused) { |   if (can_use_fused) { | ||||||
|     // weight_norm does not have a derivative defined for it, so this will route back through |     // weight_norm does not have a derivative defined for it, so this will route back through | ||||||
|  | |||||||
| @ -182,7 +182,7 @@ const Variable & VariableType::checked_cast_variable(const Tensor & t, const cha | |||||||
|   if (!t.defined()) { |   if (!t.defined()) { | ||||||
|     AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor for argument #", pos, " '", name, "'"); |     AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor for argument #", pos, " '", name, "'"); | ||||||
|   } |   } | ||||||
|   if (!isVariableType(t.type())) { |   if (!t.is_variable()) { | ||||||
|     AT_ERROR("Expected object of type Variable but found type ", t.type().toString(), " for argument #", pos, " '", name, "'"); |     AT_ERROR("Expected object of type Variable but found type ", t.type().toString(), " for argument #", pos, " '", name, "'"); | ||||||
|   } |   } | ||||||
|   return as_variable_ref(t); |   return as_variable_ref(t); | ||||||
| @ -192,7 +192,7 @@ Variable & VariableType::checked_cast_variable(Tensor & t, const char * name, in | |||||||
|   if (!t.defined()) { |   if (!t.defined()) { | ||||||
|     AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor for argument #", pos, " '", name, "'"); |     AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor for argument #", pos, " '", name, "'"); | ||||||
|   } |   } | ||||||
|   if (!isVariableType(t.type())) { |   if (!t.is_variable()) { | ||||||
|     AT_ERROR("Expected object of type Variable but found type ", t.type().toString(), " for argument #", pos, " '", name, "'"); |     AT_ERROR("Expected object of type Variable but found type ", t.type().toString(), " for argument #", pos, " '", name, "'"); | ||||||
|   } |   } | ||||||
|   return as_variable_ref(t); |   return as_variable_ref(t); | ||||||
|  | |||||||
| @ -42,7 +42,7 @@ auto AccumulateGrad::apply(variable_list&& grads) -> variable_list { | |||||||
|   if (!grad.defined()) { |   if (!grad.defined()) { | ||||||
|     // under following condition, we can avoid clone() |     // under following condition, we can avoid clone() | ||||||
|     if (!GradMode::is_enabled() |     if (!GradMode::is_enabled() | ||||||
|         && !new_grad.type().is_sparse() |         && !new_grad.is_sparse() | ||||||
|         && new_grad.is_contiguous() |         && new_grad.is_contiguous() | ||||||
|         && new_grad.use_count() == 1) { |         && new_grad.use_count() == 1) { | ||||||
|       // first check it is in first-order grad only mode |       // first check it is in first-order grad only mode | ||||||
| @ -60,7 +60,7 @@ auto AccumulateGrad::apply(variable_list&& grads) -> variable_list { | |||||||
|     // the users. Thanks to this case we can avoid changing the grad tensor, |     // the users. Thanks to this case we can avoid changing the grad tensor, | ||||||
|     // a thing never promised and documented, but used in some hacks seen |     // a thing never promised and documented, but used in some hacks seen | ||||||
|     // on the internet. |     // on the internet. | ||||||
|     if (grad_variable.type().is_sparse() && !new_grad.type().is_sparse()) { |     if (grad_variable.is_sparse() && !new_grad.is_sparse()) { | ||||||
|       grad_variable.data() = new_grad.data() + grad_variable.data(); |       grad_variable.data() = new_grad.data() + grad_variable.data(); | ||||||
|     } else { |     } else { | ||||||
|       grad_variable.data() += new_grad.data(); |       grad_variable.data() += new_grad.data(); | ||||||
|  | |||||||
| @ -22,7 +22,7 @@ void InputBuffer::add(size_t pos, Variable var) { | |||||||
|   } else { |   } else { | ||||||
|     at::DeviceGuard device_guard(var); |     at::DeviceGuard device_guard(var); | ||||||
|     // ATen doesn't route sparse additions correctly... |     // ATen doesn't route sparse additions correctly... | ||||||
|     if (old_var.type().is_sparse()) { |     if (old_var.is_sparse()) { | ||||||
|       buffer[pos] = var + old_var; |       buffer[pos] = var + old_var; | ||||||
|     } else { |     } else { | ||||||
|       buffer[pos] = old_var + var; |       buffer[pos] = old_var + var; | ||||||
| @ -32,7 +32,7 @@ void InputBuffer::add(size_t pos, Variable var) { | |||||||
|  |  | ||||||
| auto InputBuffer::device() const -> int { | auto InputBuffer::device() const -> int { | ||||||
|   for (auto& var : buffer) { |   for (auto& var : buffer) { | ||||||
|     if (var.defined() && var.type().is_cuda()) { |     if (var.defined() && var.is_cuda()) { | ||||||
|       return var.get_device(); |       return var.get_device(); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  | |||||||
| @ -169,11 +169,11 @@ static void check_single_result(PyObject* _original, PyObject* _result, PyObject | |||||||
|     throw std::runtime_error(ss.str()); |     throw std::runtime_error(ss.str()); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   if (original.type().is_cuda() != result.type().is_cuda()) { |   if (original.is_cuda() != result.is_cuda()) { | ||||||
|     std::stringstream ss; |     std::stringstream ss; | ||||||
|     auto name = hook_name(hook); |     auto name = hook_name(hook); | ||||||
|     ss << "hook '" << name << "' has changed the type of value"; |     ss << "hook '" << name << "' has changed the type of value"; | ||||||
|     if (original.type().is_cuda()) { |     if (original.is_cuda()) { | ||||||
|       ss << " (was CUDA tensor got CPU tensor)"; |       ss << " (was CUDA tensor got CPU tensor)"; | ||||||
|     } else { |     } else { | ||||||
|       ss << " (was CPU tensor got CUDA tensor)"; |       ss << " (was CPU tensor got CUDA tensor)"; | ||||||
|  | |||||||
| @ -247,7 +247,7 @@ int THPVariable_set_grad(THPVariable *self, PyObject *py_grad) | |||||||
|  |  | ||||||
|   THPUtils_assertRet(-1, grad.type() == var.type() || gradIsSparse, |   THPUtils_assertRet(-1, grad.type() == var.type() || gradIsSparse, | ||||||
|       "assigned grad has data of a different type"); |       "assigned grad has data of a different type"); | ||||||
|   if (var.type().is_cuda()) { |   if (var.is_cuda()) { | ||||||
|     THPUtils_assertRet(-1, grad.get_device() == var.get_device(), |     THPUtils_assertRet(-1, grad.get_device() == var.get_device(), | ||||||
|         "assigned grad has data located on a different device"); |         "assigned grad has data located on a different device"); | ||||||
|   } |   } | ||||||
|  | |||||||
| @ -177,7 +177,7 @@ at::Tensor gather( | |||||||
|   std::vector<int64_t> expected_size(first_size.begin(), first_size.end()); |   std::vector<int64_t> expected_size(first_size.begin(), first_size.end()); | ||||||
|   for (const auto& tensor : tensors) { |   for (const auto& tensor : tensors) { | ||||||
|     AT_CHECK( |     AT_CHECK( | ||||||
|         tensor.type().is_cuda(), "Gather expects all inputs to have CUDA type"); |         tensor.is_cuda(), "Gather expects all inputs to have CUDA type"); | ||||||
|     AT_ASSERT(tensor.ndimension() == static_cast<int64_t>(expected_size.size())); |     AT_ASSERT(tensor.ndimension() == static_cast<int64_t>(expected_size.size())); | ||||||
|     expected_size[dim] = tensor.size(dim); |     expected_size[dim] = tensor.size(dim); | ||||||
|     for (size_t dimension = 0; dimension < expected_size.size(); ++dimension) { |     for (size_t dimension = 0; dimension < expected_size.size(); ++dimension) { | ||||||
|  | |||||||
| @ -133,8 +133,8 @@ void _check_inputs( | |||||||
|     auto input = inputs[i]; |     auto input = inputs[i]; | ||||||
|     auto output = outputs[i]; |     auto output = outputs[i]; | ||||||
|  |  | ||||||
|     if (!(input.type().is_cuda() && !input.type().is_sparse() && |     if (!(input.is_cuda() && !input.is_sparse() && | ||||||
|           output.type().is_cuda() && !output.type().is_sparse())) { |           output.is_cuda() && !output.is_sparse())) { | ||||||
|       throw std::runtime_error( |       throw std::runtime_error( | ||||||
|           "input and output elements have to be cuda dense Tensors"); |           "input and output elements have to be cuda dense Tensors"); | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -81,7 +81,7 @@ struct ArgumentSpec { | |||||||
|       if ((arg.defined_ = t.defined())) { |       if ((arg.defined_ = t.defined())) { | ||||||
|         arg.requires_grad_ = with_grad && autograd::Variable(t).requires_grad(); |         arg.requires_grad_ = with_grad && autograd::Variable(t).requires_grad(); | ||||||
|         arg.dim_ = t.dim(); |         arg.dim_ = t.dim(); | ||||||
|         arg.device_ = t.type().is_cuda() ? t.get_device() : -1; |         arg.device_ = t.is_cuda() ? t.get_device() : -1; | ||||||
|         arg.type_ = static_cast<unsigned>(t.type().scalarType()); |         arg.type_ = static_cast<unsigned>(t.type().scalarType()); | ||||||
|       } |       } | ||||||
|  |  | ||||||
| @ -203,7 +203,7 @@ struct CompleteArgumentSpec { | |||||||
|         pod.defined = t.defined(); |         pod.defined = t.defined(); | ||||||
|         if (pod.defined) { |         if (pod.defined) { | ||||||
|           pod.type = static_cast<int>(t.type().scalarType()); |           pod.type = static_cast<int>(t.type().scalarType()); | ||||||
|           pod.device = (!t.type().is_cuda()) ? -1 : t.get_device(); |           pod.device = (!t.is_cuda()) ? -1 : t.get_device(); | ||||||
|           pod.requires_grad = with_grad && autograd::as_variable_ref(t).requires_grad(); |           pod.requires_grad = with_grad && autograd::as_variable_ref(t).requires_grad(); | ||||||
|           total_dims += t.ndimension(); |           total_dims += t.ndimension(); | ||||||
|           auto sizes = t.sizes(); |           auto sizes = t.sizes(); | ||||||
|  | |||||||
| @ -16,7 +16,7 @@ struct IODescriptor { | |||||||
|     VariableMetadata(const autograd::Variable& var) |     VariableMetadata(const autograd::Variable& var) | ||||||
|       : sizes(var.sizes().vec()) |       : sizes(var.sizes().vec()) | ||||||
|       , type(var.type().scalarType()) |       , type(var.type().scalarType()) | ||||||
|       , device(var.type().is_cuda() ? var.get_device() : -1) |       , device(var.is_cuda() ? var.get_device() : -1) | ||||||
|       , requires_grad(var.requires_grad()) {} |       , requires_grad(var.requires_grad()) {} | ||||||
|  |  | ||||||
|     bool operator==(const VariableMetadata& o) const { |     bool operator==(const VariableMetadata& o) const { | ||||||
|  | |||||||
| @ -324,7 +324,7 @@ struct TORCH_API TensorType : public Type { | |||||||
| protected: | protected: | ||||||
|   TensorType(const at::Tensor& tensor, TypeKind kind=TypeKind::TensorType) |   TensorType(const at::Tensor& tensor, TypeKind kind=TypeKind::TensorType) | ||||||
|     : TensorType(tensor.type().scalarType(), |     : TensorType(tensor.type().scalarType(), | ||||||
|                  tensor.type().is_cuda() ? tensor.get_device() : -1, |                  tensor.is_cuda() ? tensor.get_device() : -1, | ||||||
|                  tensor.dim(), |                  tensor.dim(), | ||||||
|                  tensor.is_variable() && tensor.requires_grad(), |                  tensor.is_variable() && tensor.requires_grad(), | ||||||
|                  kind) {} |                  kind) {} | ||||||
|  | |||||||
| @ -28,7 +28,7 @@ static inline int get_device(PyObject* args) { | |||||||
|     PyObject* arg = PyTuple_GET_ITEM(args, i); |     PyObject* arg = PyTuple_GET_ITEM(args, i); | ||||||
|     if (THPVariable_Check(arg)) { |     if (THPVariable_Check(arg)) { | ||||||
|       auto& tensor = THPVariable_UnpackData(arg); |       auto& tensor = THPVariable_UnpackData(arg); | ||||||
|       if (tensor.type().is_cuda()) { |       if (tensor.is_cuda()) { | ||||||
|         return tensor.get_device(); |         return tensor.get_device(); | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -388,7 +388,7 @@ at::Type& get_default_tensor_type() { | |||||||
| } | } | ||||||
|  |  | ||||||
| Device getDevice(const at::Tensor& tensor) { | Device getDevice(const at::Tensor& tensor) { | ||||||
|   if (tensor.type().is_cuda()) { |   if (tensor.is_cuda()) { | ||||||
|     return at::Device(at::DeviceType::CUDA, tensor.get_device()); |     return at::Device(at::DeviceType::CUDA, tensor.get_device()); | ||||||
|   } |   } | ||||||
|   return at::Device(at::DeviceType::CPU); |   return at::Device(at::DeviceType::CPU); | ||||||
|  | |||||||
| @ -345,8 +345,8 @@ bool DataChannelNccl::_tensorCheckHelper( | |||||||
|  |  | ||||||
|   for (size_t i = 0; i < input.size(); ++i) { |   for (size_t i = 0; i < input.size(); ++i) { | ||||||
|     //  Check to make sure it's a GPU dense tensor |     //  Check to make sure it's a GPU dense tensor | ||||||
|     if (!(input[i].type().is_cuda() && !input[i].type().is_sparse() && |     if (!(input[i].is_cuda() && !input[i].is_sparse() && | ||||||
|           output[i].type().is_cuda() && !output[i].type().is_sparse())) { |           output[i].is_cuda() && !output[i].is_sparse())) { | ||||||
|       throw std::runtime_error( |       throw std::runtime_error( | ||||||
|           "Only CUDA dense tensor is supported for NCCL " |           "Only CUDA dense tensor is supported for NCCL " | ||||||
|           "collective operations"); |           "collective operations"); | ||||||
|  | |||||||
| @ -290,8 +290,8 @@ void ProcessGroupNCCL::tensorCheckHelper( | |||||||
|  |  | ||||||
|   for (size_t i = 0; i < input.size(); ++i) { |   for (size_t i = 0; i < input.size(); ++i) { | ||||||
|     //  Check to make sure it's a GPU dense tensor |     //  Check to make sure it's a GPU dense tensor | ||||||
|     if (!(input[i].type().is_cuda() && !input[i].type().is_sparse() && |     if (!(input[i].is_cuda() && !input[i].is_sparse() && | ||||||
|           output[i].type().is_cuda() && !output[i].type().is_sparse())) { |           output[i].is_cuda() && !output[i].is_sparse())) { | ||||||
|       throw std::runtime_error( |       throw std::runtime_error( | ||||||
|           "Only CUDA dense tensor is supported for NCCL " |           "Only CUDA dense tensor is supported for NCCL " | ||||||
|           "collective operations"); |           "collective operations"); | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user