codemod tensor.type().is_cuda(), tensor.type().is_sparse() (#13590)

Summary:
Followup to #12841

Changed these to not require type dispatch:
tensor.type().is_cuda() -> tensor.is_cuda()
tensor.type().is_sparse() -> tensor.is_sparse()
isVariable(tensor.type()) -> tensor.is_variable()

This probably does not affect performance
very much in most cases but it is nice to have.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13590

Reviewed By: ezyang

Differential Revision: D12929301

Pulled By: zou3519

fbshipit-source-id: 8ac5c6200c579dd7a44fb4ee58fc9bb170feb1d7
This commit is contained in:
Richard Zou
2018-11-07 07:24:38 -08:00
committed by Facebook Github Bot
parent e70321ed9e
commit e60a7c2c88
23 changed files with 44 additions and 44 deletions

View File

@ -152,7 +152,7 @@ DLManagedTensor* toDLPack(const Tensor& src) {
atDLMTensor->tensor.deleter = &deleter; atDLMTensor->tensor.deleter = &deleter;
atDLMTensor->tensor.dl_tensor.data = src.data_ptr(); atDLMTensor->tensor.dl_tensor.data = src.data_ptr();
int64_t device_id = 0; int64_t device_id = 0;
if (src.type().is_cuda()) { if (src.is_cuda()) {
device_id = src.get_device(); device_id = src.get_device();
} }
atDLMTensor->tensor.dl_tensor.ctx = getDLContext(src.type(), device_id); atDLMTensor->tensor.dl_tensor.ctx = getDLContext(src.type(), device_id);

View File

@ -180,7 +180,7 @@ AT_ERROR("${api_name} only supports a 0-dimensional ${check_name} tensor, but go
""") """)
SPARSE_CHECK = CodeTemplate("""\ SPARSE_CHECK = CodeTemplate("""\
if(${check_name}.type().is_sparse()) { if(${check_name}.is_sparse()) {
return static_cast<const TypeExtendedInterface*>(this)->${api_name}(${sparse_actuals}); return static_cast<const TypeExtendedInterface*>(this)->${api_name}(${sparse_actuals});
}""") }""")

View File

@ -108,7 +108,7 @@ auto ConvParams::use_cudnn(const at::Tensor& input) const -> bool {
if (!detail::getCUDAHooks().compiledWithCuDNN()) { if (!detail::getCUDAHooks().compiledWithCuDNN()) {
return false; return false;
} }
if (!input.type().is_cuda() || !cudnn_enabled) { if (!input.is_cuda() || !cudnn_enabled) {
return false; return false;
} }
if (deterministic && is_dilated()) { if (deterministic && is_dilated()) {
@ -125,7 +125,7 @@ auto ConvParams::use_miopen(const at::Tensor& input) const -> bool {
return ((input.type().scalarType() == at::kFloat) || (input.type().scalarType() == at::kHalf)) return ((input.type().scalarType() == at::kFloat) || (input.type().scalarType() == at::kHalf))
&& detail::getCUDAHooks().compiledWithMIOpen() && detail::getCUDAHooks().compiledWithMIOpen()
&& input.type().is_cuda() && input.is_cuda()
&& input.dim() <= MIOPEN_DIM_MAX && input.dim() <= MIOPEN_DIM_MAX
&& !(groups > 1 && is_dilated()) // MIOpen currently does not support dilation with groups of size > 1 && !(groups > 1 && is_dilated()) // MIOpen currently does not support dilation with groups of size > 1
&& !transposed && !transposed
@ -150,7 +150,7 @@ auto ConvParams::use_mkldnn(const at::Tensor& input) const -> bool {
// a depthwise multiplier) // a depthwise multiplier)
auto ConvParams::is_depthwise( auto ConvParams::is_depthwise(
const at::Tensor& input, const at::Tensor& weight) const -> bool { const at::Tensor& input, const at::Tensor& weight) const -> bool {
return input.type().is_cuda() && return input.is_cuda() &&
!transposed && !transposed &&
input.ndimension() == 4 && input.ndimension() == 4 &&
input.size(1) == groups && input.size(1) == groups &&
@ -450,7 +450,7 @@ at::Tensor _convolution_nogroup(
input, weight, kernel_size, bias, input, weight, kernel_size, bias,
stride, padding); stride, padding);
} }
} else if (dim == 5 && (input.type().is_cuda() || dilated)) { } else if (dim == 5 && (input.is_cuda() || dilated)) {
return at::thnn_conv_dilated3d( return at::thnn_conv_dilated3d(
input, weight, kernel_size, bias, input, weight, kernel_size, bias,
stride, padding, dilation); stride, padding, dilation);
@ -498,14 +498,14 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward(
// Compute ggO = conv(ggI, w) + conv(i, ggW) + ggb // Compute ggO = conv(ggI, w) + conv(i, ggW) + ggb
Tensor ggO; Tensor ggO;
if (ggI.defined()) { if (ggI.defined()) {
if (weight.type().is_cuda()) { if (weight.is_cuda()) {
weight = weight.contiguous(); weight = weight.contiguous();
} }
ggO = at::_convolution(ggI, weight, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups, params.benchmark, params.deterministic, params.cudnn_enabled); ggO = at::_convolution(ggI, weight, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups, params.benchmark, params.deterministic, params.cudnn_enabled);
} }
if (ggW.defined()) { if (ggW.defined()) {
if (ggW.type().is_cuda()) { if (ggW.is_cuda()) {
ggW = ggW.contiguous(); ggW = ggW.contiguous();
} }
auto ggW_term = at::_convolution(input, ggW, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups, params.benchmark, params.deterministic, params.cudnn_enabled); auto ggW_term = at::_convolution(input, ggW, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups, params.benchmark, params.deterministic, params.cudnn_enabled);
@ -553,7 +553,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward(
Tensor gWt; Tensor gWt;
// Compute conv // Compute conv
if (groups == 1) { if (groups == 1) {
if (gOt.type().is_cuda()) { if (gOt.is_cuda()) {
gOt = gOt.contiguous(); gOt = gOt.contiguous();
} }
@ -569,7 +569,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward(
for (int g = 0; g < groups; ++g) { for (int g = 0; g < groups; ++g) {
auto ggIt_g = subvariable(ggIt, 0, groups, g); auto ggIt_g = subvariable(ggIt, 0, groups, g);
auto gOt_g = subvariable(gOt, 0, groups, g); auto gOt_g = subvariable(gOt, 0, groups, g);
if (gOt_g.type().is_cuda()) { if (gOt_g.is_cuda()) {
gOt_g = gOt_g.contiguous(); gOt_g = gOt_g.contiguous();
} }
@ -609,7 +609,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward(
gi_conv_params.transposed = !params.transposed; gi_conv_params.transposed = !params.transposed;
if (params.transposed) { if (params.transposed) {
if (gO.type().is_cuda()) { if (gO.is_cuda()) {
gO = gO.contiguous(); gO = gO.contiguous();
} }
gI = at::_convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups, gi_conv_params.benchmark, gi_conv_params.deterministic, gi_conv_params.cudnn_enabled); gI = at::_convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups, gi_conv_params.benchmark, gi_conv_params.deterministic, gi_conv_params.cudnn_enabled);
@ -662,7 +662,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward(
Tensor gIt; Tensor gIt;
if (params.groups == 1) { if (params.groups == 1) {
if (gOt.type().is_cuda()) { if (gOt.is_cuda()) {
gOt = gOt.contiguous(); gOt = gOt.contiguous();
} }
@ -672,7 +672,7 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward(
for (int g = 0; g < groups; ++g) { for (int g = 0; g < groups; ++g) {
auto ggWt_g = subvariable(ggWt, 1, groups, g); auto ggWt_g = subvariable(ggWt, 1, groups, g);
auto gOt_g = subvariable(gOt, 0, groups, g); auto gOt_g = subvariable(gOt, 0, groups, g);
if (gOt_g.type().is_cuda()) { if (gOt_g.is_cuda()) {
gOt_g = gOt_g.contiguous(); gOt_g = gOt_g.contiguous();
} }

View File

@ -243,7 +243,7 @@ Tensor batch_norm(
} }
bool use_cudnn = false; bool use_cudnn = false;
use_cudnn = (input.type().is_cuda() use_cudnn = (input.is_cuda()
&& (input.type().scalarType() != at::kHalf && (input.type().scalarType() != at::kHalf
|| weight.type().scalarType() == at::kFloat) || weight.type().scalarType() == at::kFloat)
&& weight.defined() && bias.defined() && weight.defined() && bias.defined()
@ -262,7 +262,7 @@ Tensor batch_norm(
training, momentum, eps)); training, momentum, eps));
} }
bool use_miopen = (input.type().is_cuda() bool use_miopen = (input.is_cuda()
&& input.dim() <= MIOPEN_DIM_MAX && input.dim() <= MIOPEN_DIM_MAX
&& input.type().scalarType() != at::kDouble && input.type().scalarType() != at::kDouble
&& (input.type().scalarType() == weight.type().scalarType()) && (input.type().scalarType() == weight.type().scalarType())

View File

@ -431,7 +431,7 @@ Tensor& norm_out(Tensor &result, const Tensor &self, Scalar p, int64_t dim, bool
} }
Tensor _norm(const Tensor &self, Scalar p) { Tensor _norm(const Tensor &self, Scalar p) {
if (self.type().is_sparse()) { if (self.is_sparse()) {
return at::native_norm(self, p); return at::native_norm(self, p);
} else { } else {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,

View File

@ -151,7 +151,7 @@ Tensor empty_like(const Tensor& self) {
} }
Tensor empty_like(const Tensor& self, const TensorOptions& options) { Tensor empty_like(const Tensor& self, const TensorOptions& options) {
if (options.layout() == kSparse && self.type().is_sparse()) { if (options.layout() == kSparse && self.is_sparse()) {
auto res = at::empty({0}, options); // to be resized auto res = at::empty({0}, options); // to be resized
res.sparse_resize_and_clear_(self.sizes(), self.sparse_dim(), self.dense_dim()); res.sparse_resize_and_clear_(self.sizes(), self.sparse_dim(), self.dense_dim());
return res; return res;
@ -523,7 +523,7 @@ Tensor zeros_like(const Tensor& self) {
} }
Tensor zeros_like(const Tensor& self, const TensorOptions& options) { Tensor zeros_like(const Tensor& self, const TensorOptions& options) {
if (options.layout() == kSparse && self.type().is_sparse()) { if (options.layout() == kSparse && self.is_sparse()) {
auto res = at::empty({0}, options); // to be resized auto res = at::empty({0}, options); // to be resized
res.sparse_resize_and_clear_(self.sizes(), self.sparse_dim(), self.dense_dim()); res.sparse_resize_and_clear_(self.sizes(), self.sparse_dim(), self.dense_dim());
return res; return res;

View File

@ -289,7 +289,7 @@ Tensor repeat(const Tensor& self, IntList repeats) {
} }
Tensor reshape(const Tensor& self, IntList proposed_shape) { Tensor reshape(const Tensor& self, IntList proposed_shape) {
if (self.type().is_sparse()) { if (self.is_sparse()) {
AT_ERROR("reshape is not implemented for sparse tensors"); AT_ERROR("reshape is not implemented for sparse tensors");
} }
auto shape = infer_size(proposed_shape, self.numel()); auto shape = infer_size(proposed_shape, self.numel());

View File

@ -6,7 +6,7 @@
namespace at { namespace native { namespace at { namespace native {
bool is_cuda(const Tensor& self) { bool is_cuda(const Tensor& self) {
return self.type().is_cuda(); return self.is_cuda();
} }
bool is_distributed(const Tensor& self) { bool is_distributed(const Tensor& self) {
@ -31,7 +31,7 @@ bool is_signed(const Tensor &self) {
} }
bool is_sparse(const Tensor& self) { bool is_sparse(const Tensor& self) {
return self.type().is_sparse(); return self.is_sparse();
} }
Tensor type_as(const Tensor& self, const Tensor& other) { Tensor type_as(const Tensor& self, const Tensor& other) {

View File

@ -50,7 +50,7 @@ Tensor _weight_norm
auto v = v_in.contiguous(); auto v = v_in.contiguous();
auto g = g_in.contiguous(); auto g = g_in.contiguous();
bool can_use_fused = v.type().is_cuda() && (dim == 0 || dim == v.dim() - 1); bool can_use_fused = v.is_cuda() && (dim == 0 || dim == v.dim() - 1);
if (can_use_fused) { if (can_use_fused) {
// weight_norm does not have a derivative defined for it, so this will route back through // weight_norm does not have a derivative defined for it, so this will route back through

View File

@ -182,7 +182,7 @@ const Variable & VariableType::checked_cast_variable(const Tensor & t, const cha
if (!t.defined()) { if (!t.defined()) {
AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor for argument #", pos, " '", name, "'"); AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor for argument #", pos, " '", name, "'");
} }
if (!isVariableType(t.type())) { if (!t.is_variable()) {
AT_ERROR("Expected object of type Variable but found type ", t.type().toString(), " for argument #", pos, " '", name, "'"); AT_ERROR("Expected object of type Variable but found type ", t.type().toString(), " for argument #", pos, " '", name, "'");
} }
return as_variable_ref(t); return as_variable_ref(t);
@ -192,7 +192,7 @@ Variable & VariableType::checked_cast_variable(Tensor & t, const char * name, in
if (!t.defined()) { if (!t.defined()) {
AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor for argument #", pos, " '", name, "'"); AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor for argument #", pos, " '", name, "'");
} }
if (!isVariableType(t.type())) { if (!t.is_variable()) {
AT_ERROR("Expected object of type Variable but found type ", t.type().toString(), " for argument #", pos, " '", name, "'"); AT_ERROR("Expected object of type Variable but found type ", t.type().toString(), " for argument #", pos, " '", name, "'");
} }
return as_variable_ref(t); return as_variable_ref(t);

View File

@ -42,7 +42,7 @@ auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
if (!grad.defined()) { if (!grad.defined()) {
// under following condition, we can avoid clone() // under following condition, we can avoid clone()
if (!GradMode::is_enabled() if (!GradMode::is_enabled()
&& !new_grad.type().is_sparse() && !new_grad.is_sparse()
&& new_grad.is_contiguous() && new_grad.is_contiguous()
&& new_grad.use_count() == 1) { && new_grad.use_count() == 1) {
// first check it is in first-order grad only mode // first check it is in first-order grad only mode
@ -60,7 +60,7 @@ auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
// the users. Thanks to this case we can avoid changing the grad tensor, // the users. Thanks to this case we can avoid changing the grad tensor,
// a thing never promised and documented, but used in some hacks seen // a thing never promised and documented, but used in some hacks seen
// on the internet. // on the internet.
if (grad_variable.type().is_sparse() && !new_grad.type().is_sparse()) { if (grad_variable.is_sparse() && !new_grad.is_sparse()) {
grad_variable.data() = new_grad.data() + grad_variable.data(); grad_variable.data() = new_grad.data() + grad_variable.data();
} else { } else {
grad_variable.data() += new_grad.data(); grad_variable.data() += new_grad.data();

View File

@ -22,7 +22,7 @@ void InputBuffer::add(size_t pos, Variable var) {
} else { } else {
at::DeviceGuard device_guard(var); at::DeviceGuard device_guard(var);
// ATen doesn't route sparse additions correctly... // ATen doesn't route sparse additions correctly...
if (old_var.type().is_sparse()) { if (old_var.is_sparse()) {
buffer[pos] = var + old_var; buffer[pos] = var + old_var;
} else { } else {
buffer[pos] = old_var + var; buffer[pos] = old_var + var;
@ -32,7 +32,7 @@ void InputBuffer::add(size_t pos, Variable var) {
auto InputBuffer::device() const -> int { auto InputBuffer::device() const -> int {
for (auto& var : buffer) { for (auto& var : buffer) {
if (var.defined() && var.type().is_cuda()) { if (var.defined() && var.is_cuda()) {
return var.get_device(); return var.get_device();
} }
} }

View File

@ -169,11 +169,11 @@ static void check_single_result(PyObject* _original, PyObject* _result, PyObject
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
if (original.type().is_cuda() != result.type().is_cuda()) { if (original.is_cuda() != result.is_cuda()) {
std::stringstream ss; std::stringstream ss;
auto name = hook_name(hook); auto name = hook_name(hook);
ss << "hook '" << name << "' has changed the type of value"; ss << "hook '" << name << "' has changed the type of value";
if (original.type().is_cuda()) { if (original.is_cuda()) {
ss << " (was CUDA tensor got CPU tensor)"; ss << " (was CUDA tensor got CPU tensor)";
} else { } else {
ss << " (was CPU tensor got CUDA tensor)"; ss << " (was CPU tensor got CUDA tensor)";

View File

@ -247,7 +247,7 @@ int THPVariable_set_grad(THPVariable *self, PyObject *py_grad)
THPUtils_assertRet(-1, grad.type() == var.type() || gradIsSparse, THPUtils_assertRet(-1, grad.type() == var.type() || gradIsSparse,
"assigned grad has data of a different type"); "assigned grad has data of a different type");
if (var.type().is_cuda()) { if (var.is_cuda()) {
THPUtils_assertRet(-1, grad.get_device() == var.get_device(), THPUtils_assertRet(-1, grad.get_device() == var.get_device(),
"assigned grad has data located on a different device"); "assigned grad has data located on a different device");
} }

View File

@ -177,7 +177,7 @@ at::Tensor gather(
std::vector<int64_t> expected_size(first_size.begin(), first_size.end()); std::vector<int64_t> expected_size(first_size.begin(), first_size.end());
for (const auto& tensor : tensors) { for (const auto& tensor : tensors) {
AT_CHECK( AT_CHECK(
tensor.type().is_cuda(), "Gather expects all inputs to have CUDA type"); tensor.is_cuda(), "Gather expects all inputs to have CUDA type");
AT_ASSERT(tensor.ndimension() == static_cast<int64_t>(expected_size.size())); AT_ASSERT(tensor.ndimension() == static_cast<int64_t>(expected_size.size()));
expected_size[dim] = tensor.size(dim); expected_size[dim] = tensor.size(dim);
for (size_t dimension = 0; dimension < expected_size.size(); ++dimension) { for (size_t dimension = 0; dimension < expected_size.size(); ++dimension) {

View File

@ -133,8 +133,8 @@ void _check_inputs(
auto input = inputs[i]; auto input = inputs[i];
auto output = outputs[i]; auto output = outputs[i];
if (!(input.type().is_cuda() && !input.type().is_sparse() && if (!(input.is_cuda() && !input.is_sparse() &&
output.type().is_cuda() && !output.type().is_sparse())) { output.is_cuda() && !output.is_sparse())) {
throw std::runtime_error( throw std::runtime_error(
"input and output elements have to be cuda dense Tensors"); "input and output elements have to be cuda dense Tensors");
} }

View File

@ -81,7 +81,7 @@ struct ArgumentSpec {
if ((arg.defined_ = t.defined())) { if ((arg.defined_ = t.defined())) {
arg.requires_grad_ = with_grad && autograd::Variable(t).requires_grad(); arg.requires_grad_ = with_grad && autograd::Variable(t).requires_grad();
arg.dim_ = t.dim(); arg.dim_ = t.dim();
arg.device_ = t.type().is_cuda() ? t.get_device() : -1; arg.device_ = t.is_cuda() ? t.get_device() : -1;
arg.type_ = static_cast<unsigned>(t.type().scalarType()); arg.type_ = static_cast<unsigned>(t.type().scalarType());
} }
@ -203,7 +203,7 @@ struct CompleteArgumentSpec {
pod.defined = t.defined(); pod.defined = t.defined();
if (pod.defined) { if (pod.defined) {
pod.type = static_cast<int>(t.type().scalarType()); pod.type = static_cast<int>(t.type().scalarType());
pod.device = (!t.type().is_cuda()) ? -1 : t.get_device(); pod.device = (!t.is_cuda()) ? -1 : t.get_device();
pod.requires_grad = with_grad && autograd::as_variable_ref(t).requires_grad(); pod.requires_grad = with_grad && autograd::as_variable_ref(t).requires_grad();
total_dims += t.ndimension(); total_dims += t.ndimension();
auto sizes = t.sizes(); auto sizes = t.sizes();

View File

@ -16,7 +16,7 @@ struct IODescriptor {
VariableMetadata(const autograd::Variable& var) VariableMetadata(const autograd::Variable& var)
: sizes(var.sizes().vec()) : sizes(var.sizes().vec())
, type(var.type().scalarType()) , type(var.type().scalarType())
, device(var.type().is_cuda() ? var.get_device() : -1) , device(var.is_cuda() ? var.get_device() : -1)
, requires_grad(var.requires_grad()) {} , requires_grad(var.requires_grad()) {}
bool operator==(const VariableMetadata& o) const { bool operator==(const VariableMetadata& o) const {

View File

@ -324,7 +324,7 @@ struct TORCH_API TensorType : public Type {
protected: protected:
TensorType(const at::Tensor& tensor, TypeKind kind=TypeKind::TensorType) TensorType(const at::Tensor& tensor, TypeKind kind=TypeKind::TensorType)
: TensorType(tensor.type().scalarType(), : TensorType(tensor.type().scalarType(),
tensor.type().is_cuda() ? tensor.get_device() : -1, tensor.is_cuda() ? tensor.get_device() : -1,
tensor.dim(), tensor.dim(),
tensor.is_variable() && tensor.requires_grad(), tensor.is_variable() && tensor.requires_grad(),
kind) {} kind) {}

View File

@ -28,7 +28,7 @@ static inline int get_device(PyObject* args) {
PyObject* arg = PyTuple_GET_ITEM(args, i); PyObject* arg = PyTuple_GET_ITEM(args, i);
if (THPVariable_Check(arg)) { if (THPVariable_Check(arg)) {
auto& tensor = THPVariable_UnpackData(arg); auto& tensor = THPVariable_UnpackData(arg);
if (tensor.type().is_cuda()) { if (tensor.is_cuda()) {
return tensor.get_device(); return tensor.get_device();
} }
} }

View File

@ -388,7 +388,7 @@ at::Type& get_default_tensor_type() {
} }
Device getDevice(const at::Tensor& tensor) { Device getDevice(const at::Tensor& tensor) {
if (tensor.type().is_cuda()) { if (tensor.is_cuda()) {
return at::Device(at::DeviceType::CUDA, tensor.get_device()); return at::Device(at::DeviceType::CUDA, tensor.get_device());
} }
return at::Device(at::DeviceType::CPU); return at::Device(at::DeviceType::CPU);

View File

@ -345,8 +345,8 @@ bool DataChannelNccl::_tensorCheckHelper(
for (size_t i = 0; i < input.size(); ++i) { for (size_t i = 0; i < input.size(); ++i) {
// Check to make sure it's a GPU dense tensor // Check to make sure it's a GPU dense tensor
if (!(input[i].type().is_cuda() && !input[i].type().is_sparse() && if (!(input[i].is_cuda() && !input[i].is_sparse() &&
output[i].type().is_cuda() && !output[i].type().is_sparse())) { output[i].is_cuda() && !output[i].is_sparse())) {
throw std::runtime_error( throw std::runtime_error(
"Only CUDA dense tensor is supported for NCCL " "Only CUDA dense tensor is supported for NCCL "
"collective operations"); "collective operations");

View File

@ -290,8 +290,8 @@ void ProcessGroupNCCL::tensorCheckHelper(
for (size_t i = 0; i < input.size(); ++i) { for (size_t i = 0; i < input.size(); ++i) {
// Check to make sure it's a GPU dense tensor // Check to make sure it's a GPU dense tensor
if (!(input[i].type().is_cuda() && !input[i].type().is_sparse() && if (!(input[i].is_cuda() && !input[i].is_sparse() &&
output[i].type().is_cuda() && !output[i].type().is_sparse())) { output[i].is_cuda() && !output[i].is_sparse())) {
throw std::runtime_error( throw std::runtime_error(
"Only CUDA dense tensor is supported for NCCL " "Only CUDA dense tensor is supported for NCCL "
"collective operations"); "collective operations");