diff --git a/aten/src/ATen/native/IndexingUtils.h b/aten/src/ATen/native/IndexingUtils.h index c442b2232a96..948a6b8320a4 100644 --- a/aten/src/ATen/native/IndexingUtils.h +++ b/aten/src/ATen/native/IndexingUtils.h @@ -5,6 +5,13 @@ #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif + namespace at::native { [[noreturn]] @@ -15,7 +22,8 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, [[maybe_unused]] static std::vector expandTensors( const Tensor& self, - IOptTensorListRef indices) { + IOptTensorListRef indices, + bool ensure_same_device = false) { // If indices come in as ByteTensor or BoolTensor (masks), expand them into // the equivalent indexing by LongTensors std::vector result; @@ -38,10 +46,19 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, } } // Replace with nonzeros - auto nonzero = index.nonzero(); + at::Tensor nonzero; + if (ensure_same_device && index.device() != self.device()) { + bool non_blocking = index.is_cpu() && self.device().is_cuda(); + auto out = at::empty({0}, index.options().dtype(kLong).pinned_memory(non_blocking)); + nonzero = at::nonzero_out(out, index).to(self.device(), non_blocking); + } else { + nonzero = index.nonzero(); + } for (const auto j : c10::irange(index.dim())) { result.emplace_back(nonzero.select(1, j)); } + } else if (ensure_same_device && index.device() != self.device()) { + result.emplace_back(index.to(self.device())); } else { result.emplace_back(index); } diff --git a/aten/src/ATen/native/TensorAdvancedIndexingUtils.h b/aten/src/ATen/native/TensorAdvancedIndexingUtils.h index 05009e96a7c4..0a200f157d51 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexingUtils.h +++ b/aten/src/ATen/native/TensorAdvancedIndexingUtils.h @@ -71,7 +71,7 @@ inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) { checkIndexTensorTypes(orig, /*allow_int*/ true); // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more // LongTensors - auto indices = expandTensors(self, orig); + auto indices = expandTensors(self, orig, /*ensure_same_device=*/true); // next broadcast all index tensors together try { indices = expand_outplace(indices); @@ -91,12 +91,6 @@ inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) { if (!hasContiguousSubspace(indices)) { std::tie(self, indices) = transposeToFront(self, indices); } - // Ensure indices are on the same device as self - for (auto& indice : indices) { - if (indice.defined() && indice.device() != self.device()) { - indice = indice.to(self.device()); - } - } for (auto& indice : indices) { if (indice.defined() && indice.dtype() == at::kInt) { indice = indice.to(at::kLong); diff --git a/test/test_torch.py b/test/test_torch.py index 4b300efd6c83..6383549eb9cb 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -2150,8 +2150,11 @@ else: ind_cpu = ind.cpu() repeats = torch.full((1,), 2, device=device) mask = torch.randint(2, (size,), device=device, dtype=bool) + mask_cpu = mask.cpu() expect_no_sync = (lambda: _ind_put_fn(x, mask, 1.), + lambda: _ind_put_fn(x, mask_cpu, y), lambda: _ind_put_fn(x, ind, y), + lambda: _ind_get_fn(x, mask_cpu), lambda: _ind_get_fn(x, ind), lambda: torch.nn.functional.one_hot(ind, num_classes=size), lambda: torch.randperm(20000, device=device),