mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Prevent cudaStreamSync when indexing GPU tensors with boolean CPU mask (#156384)
`index_put` with a boolean mask (`target[mask] = src`) causes a `cudaStreamSynchronize`. When both `mask` and `target` tensors are on GPU this is expected. However, the sync can be prevented if the `mask` is a CPU tensor. Internally a new index tensor is created with `mask.nonzero()` so we can use a non-blocking copy to transfer it to the GPU since it cannot be accidentally mutated by the user between its creation and the device copy. @ngimel Let me know if I'm missing something. I think this is useful since users can't prevent a sync simply by making sure all tensors are on the same device as with other ops. Instead one would need to do something like this which is much less readable ```python indices = mask.nonzero().squeeze(1).to("cuda", non_blocking=True) target[indices] = src ``` Fixes #12461 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156384 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
5692cbb818
commit
a92b24cd83
@ -5,6 +5,13 @@
|
||||
#include <ATen/core/IListRef.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#else
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/nonzero.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
[[noreturn]]
|
||||
@ -15,7 +22,8 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask,
|
||||
|
||||
[[maybe_unused]] static std::vector<Tensor> expandTensors(
|
||||
const Tensor& self,
|
||||
IOptTensorListRef indices) {
|
||||
IOptTensorListRef indices,
|
||||
bool ensure_same_device = false) {
|
||||
// If indices come in as ByteTensor or BoolTensor (masks), expand them into
|
||||
// the equivalent indexing by LongTensors
|
||||
std::vector<Tensor> result;
|
||||
@ -38,10 +46,19 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask,
|
||||
}
|
||||
}
|
||||
// Replace with nonzeros
|
||||
auto nonzero = index.nonzero();
|
||||
at::Tensor nonzero;
|
||||
if (ensure_same_device && index.device() != self.device()) {
|
||||
bool non_blocking = index.is_cpu() && self.device().is_cuda();
|
||||
auto out = at::empty({0}, index.options().dtype(kLong).pinned_memory(non_blocking));
|
||||
nonzero = at::nonzero_out(out, index).to(self.device(), non_blocking);
|
||||
} else {
|
||||
nonzero = index.nonzero();
|
||||
}
|
||||
for (const auto j : c10::irange(index.dim())) {
|
||||
result.emplace_back(nonzero.select(1, j));
|
||||
}
|
||||
} else if (ensure_same_device && index.device() != self.device()) {
|
||||
result.emplace_back(index.to(self.device()));
|
||||
} else {
|
||||
result.emplace_back(index);
|
||||
}
|
||||
|
@ -71,7 +71,7 @@ inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
|
||||
checkIndexTensorTypes(orig, /*allow_int*/ true);
|
||||
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more
|
||||
// LongTensors
|
||||
auto indices = expandTensors(self, orig);
|
||||
auto indices = expandTensors(self, orig, /*ensure_same_device=*/true);
|
||||
// next broadcast all index tensors together
|
||||
try {
|
||||
indices = expand_outplace(indices);
|
||||
@ -91,12 +91,6 @@ inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
|
||||
if (!hasContiguousSubspace(indices)) {
|
||||
std::tie(self, indices) = transposeToFront(self, indices);
|
||||
}
|
||||
// Ensure indices are on the same device as self
|
||||
for (auto& indice : indices) {
|
||||
if (indice.defined() && indice.device() != self.device()) {
|
||||
indice = indice.to(self.device());
|
||||
}
|
||||
}
|
||||
for (auto& indice : indices) {
|
||||
if (indice.defined() && indice.dtype() == at::kInt) {
|
||||
indice = indice.to(at::kLong);
|
||||
|
@ -2150,8 +2150,11 @@ else:
|
||||
ind_cpu = ind.cpu()
|
||||
repeats = torch.full((1,), 2, device=device)
|
||||
mask = torch.randint(2, (size,), device=device, dtype=bool)
|
||||
mask_cpu = mask.cpu()
|
||||
expect_no_sync = (lambda: _ind_put_fn(x, mask, 1.),
|
||||
lambda: _ind_put_fn(x, mask_cpu, y),
|
||||
lambda: _ind_put_fn(x, ind, y),
|
||||
lambda: _ind_get_fn(x, mask_cpu),
|
||||
lambda: _ind_get_fn(x, ind),
|
||||
lambda: torch.nn.functional.one_hot(ind, num_classes=size),
|
||||
lambda: torch.randperm(20000, device=device),
|
||||
|
Reference in New Issue
Block a user