use more efficient implementation for broadcasted indexing in determi… (#156744)

…nistic scatter_add

per title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156744
Approved by: https://github.com/suo
This commit is contained in:
Natalia Gimelshein
2025-06-25 02:59:45 +00:00
committed by PyTorch MergeBot
parent 9b498d3bb2
commit beb52f5c0a
2 changed files with 60 additions and 72 deletions

View File

@ -2153,81 +2153,53 @@ static void _scatter_via_index_put(
const Tensor& src,
const Tensor& mut_out,
bool accumulate) {
if (self.dim() == 1) {
torch::List<std::optional<Tensor>> indices;
indices.reserve(1);
indices.push_back(index);
mut_out.index_put_(indices, src, accumulate);
} else {
Tensor mut_out_contig = mut_out.contiguous();
auto index_coords_sizes = index.sizes().vec();
index_coords_sizes.push_back(self.dim());
auto index_coords = at::empty(
index_coords_sizes,
at::TensorOptions().dtype(at::ScalarType::Long).device(self.device()));
for (int64_t dim_other = 0; dim_other < self.dim(); dim_other++) {
if (dim_other == dim) {
continue;
}
auto dim_coord_vals = at::arange(
index.size(dim_other), at::TensorOptions().device(self.device()));
for (int64_t dim_unsqueeze = 0; dim_unsqueeze < self.dim() - 1;
dim_unsqueeze++) {
dim_coord_vals =
dim_coord_vals.unsqueeze((dim_unsqueeze >= dim_other) ? -1 : 0);
}
auto view_sizes = index.sizes().vec();
view_sizes.push_back(1);
auto view_strides = index_coords.strides().vec();
view_strides[self.dim()] = self.dim();
at::as_strided(index_coords, view_sizes, view_strides, dim_other)
.copy_(dim_coord_vals.unsqueeze(-1));
// If index is expanded with zero strides across non-scatter dimensions,
// advanced indexing with the index tensor alone achieves the desired
// semantics and avoids creating large intermediate tensors.
bool broadcast_index = true;
for (const auto i : c10::irange(index.dim())) {
if (i == dim) {
continue;
}
auto view_sizes = index.sizes().vec();
view_sizes.push_back(1);
auto view_strides = index_coords.strides().vec();
view_strides[self.dim()] = self.dim();
at::as_strided(index_coords, view_sizes, view_strides, dim)
.copy_(index.unsqueeze(-1));
Tensor index_coords_flat = index_coords.flatten(0, -2);
// Copy mut_out_contig's strides into a tensor
// TODO: Is there a utility function that already does this?
IntArrayRef mut_out_contig_strides = mut_out_contig.strides();
Tensor coord_strides = at::empty(
{mut_out_contig.dim()},
TensorOptions().dtype(at::ScalarType::Long).device(at::kCPU));
std::memcpy(
coord_strides.mutable_data_ptr(),
mut_out_contig_strides.data(),
coord_strides.nbytes());
coord_strides = coord_strides.to(mut_out_contig.device());
// `index_flat` contains the 1-D indices corresponding with the
// flattened `mut_out`
Tensor index_flat = (index_coords_flat * coord_strides).sum({-1});
Tensor mut_out_flat = mut_out_contig.flatten();
Tensor src_flat =
at::as_strided(src, index.sizes(), src.strides()).flatten();
torch::List<std::optional<Tensor>> indices;
indices.reserve(1);
indices.push_back(index_flat);
mut_out_flat.index_put_(indices, src_flat, accumulate);
if (!mut_out.is_contiguous()) {
mut_out.copy_(mut_out_flat.reshape(mut_out.sizes()));
if (index.stride(i) != 0) {
broadcast_index = false;
break;
}
}
auto src_view = at::as_strided(src, index.sizes(), src.strides());
torch::List<std::optional<Tensor>> indices;
indices.reserve(self.dim());
if (self.dim() == 1 || broadcast_index) {
Tensor squeezed = index;
if (broadcast_index && index.dim() > 1) {
for (const auto d : c10::irange(index.dim())) {
if (d == dim) {
continue;
}
squeezed = squeezed.select(d, 0);
}
}
for ([[maybe_unused]] const auto d : c10::irange(dim)) {
indices.push_back(Tensor());
}
indices.push_back(squeezed);
mut_out.index_put_(indices, src_view, accumulate);
return;
}
for (const auto d : c10::irange(self.dim())) {
if (d == dim) {
indices.push_back(index);
} else {
auto arange = at::arange(index.size(d), index.options());
std::vector<int64_t> shape(index.dim(), 1);
shape[d] = index.size(d);
indices.push_back(arange.view(shape).expand(index.sizes()));
}
}
mut_out.index_put_(indices, src_view, accumulate);
}
template <

View File

@ -380,6 +380,22 @@ class TestScatterGather(TestCase):
helper([50, 8, 7], 100)
helper([50, 3, 4, 5], 100)
@dtypes(torch.float32)
def test_scatter_add_broadcasted_index_deterministic(self, device, dtype):
for d in (0, 1):
inp = torch.randn(3, 4, device=device, dtype=dtype)
idx_1d = torch.randint(3, (10,), device=device)
src_shape = list(inp.shape)
src_shape[d] = 10
src = torch.randn(src_shape, device=device, dtype=dtype)
idx = idx_1d.unsqueeze(1 - d).expand(src_shape)
print(idx.stride())
ref = inp.clone().scatter_add_(d, idx, src)
with DeterministicGuard(True):
res = inp.clone().scatter_add_(d, idx, src)
self.assertEqual(res, ref)
@onlyCPU
@dtypes(torch.float32, torch.float64, torch.bfloat16)
def test_gather_expanded_index(self, device, dtype):