Prevent cudaStreamSync when indexing GPU tensors with boolean CPU mask (#156384)

`index_put` with a boolean mask (`target[mask] = src`) causes a `cudaStreamSynchronize`. When both `mask` and `target` tensors are on GPU this is expected.

However, the sync can be prevented if the `mask` is a CPU tensor.
Internally a new index tensor is created with `mask.nonzero()` so we can use a non-blocking copy to transfer it to the GPU since it cannot be accidentally mutated by the user between its creation and the device copy. @ngimel Let me know if I'm missing something.

I think this is useful since users can't prevent a sync simply by making sure all tensors are on the same device as with other ops. Instead one would need to do something like this which is much less readable
```python
indices = mask.nonzero().squeeze(1).to("cuda", non_blocking=True)
target[indices] = src
```
Fixes #12461

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156384
Approved by: https://github.com/ngimel
This commit is contained in:
Lukas Geiger
2025-06-28 05:41:13 +00:00
committed by PyTorch MergeBot
parent 5692cbb818
commit a92b24cd83
3 changed files with 23 additions and 9 deletions

View File

@ -5,6 +5,13 @@
#include <ATen/core/IListRef.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/nonzero.h>
#endif
namespace at::native {
[[noreturn]]
@ -15,7 +22,8 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask,
[[maybe_unused]] static std::vector<Tensor> expandTensors(
const Tensor& self,
IOptTensorListRef indices) {
IOptTensorListRef indices,
bool ensure_same_device = false) {
// If indices come in as ByteTensor or BoolTensor (masks), expand them into
// the equivalent indexing by LongTensors
std::vector<Tensor> result;
@ -38,10 +46,19 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask,
}
}
// Replace with nonzeros
auto nonzero = index.nonzero();
at::Tensor nonzero;
if (ensure_same_device && index.device() != self.device()) {
bool non_blocking = index.is_cpu() && self.device().is_cuda();
auto out = at::empty({0}, index.options().dtype(kLong).pinned_memory(non_blocking));
nonzero = at::nonzero_out(out, index).to(self.device(), non_blocking);
} else {
nonzero = index.nonzero();
}
for (const auto j : c10::irange(index.dim())) {
result.emplace_back(nonzero.select(1, j));
}
} else if (ensure_same_device && index.device() != self.device()) {
result.emplace_back(index.to(self.device()));
} else {
result.emplace_back(index);
}

View File

@ -71,7 +71,7 @@ inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
checkIndexTensorTypes(orig, /*allow_int*/ true);
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more
// LongTensors
auto indices = expandTensors(self, orig);
auto indices = expandTensors(self, orig, /*ensure_same_device=*/true);
// next broadcast all index tensors together
try {
indices = expand_outplace(indices);
@ -91,12 +91,6 @@ inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
if (!hasContiguousSubspace(indices)) {
std::tie(self, indices) = transposeToFront(self, indices);
}
// Ensure indices are on the same device as self
for (auto& indice : indices) {
if (indice.defined() && indice.device() != self.device()) {
indice = indice.to(self.device());
}
}
for (auto& indice : indices) {
if (indice.defined() && indice.dtype() == at::kInt) {
indice = indice.to(at::kLong);

View File

@ -2150,8 +2150,11 @@ else:
ind_cpu = ind.cpu()
repeats = torch.full((1,), 2, device=device)
mask = torch.randint(2, (size,), device=device, dtype=bool)
mask_cpu = mask.cpu()
expect_no_sync = (lambda: _ind_put_fn(x, mask, 1.),
lambda: _ind_put_fn(x, mask_cpu, y),
lambda: _ind_put_fn(x, ind, y),
lambda: _ind_get_fn(x, mask_cpu),
lambda: _ind_get_fn(x, ind),
lambda: torch.nn.functional.one_hot(ind, num_classes=size),
lambda: torch.randperm(20000, device=device),