mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
use more efficient implementation for broadcasted indexing in determi… (#156744)
…nistic scatter_add per title Pull Request resolved: https://github.com/pytorch/pytorch/pull/156744 Approved by: https://github.com/suo
This commit is contained in:
committed by
PyTorch MergeBot
parent
9b498d3bb2
commit
beb52f5c0a
@ -2153,81 +2153,53 @@ static void _scatter_via_index_put(
|
||||
const Tensor& src,
|
||||
const Tensor& mut_out,
|
||||
bool accumulate) {
|
||||
if (self.dim() == 1) {
|
||||
torch::List<std::optional<Tensor>> indices;
|
||||
indices.reserve(1);
|
||||
indices.push_back(index);
|
||||
mut_out.index_put_(indices, src, accumulate);
|
||||
} else {
|
||||
Tensor mut_out_contig = mut_out.contiguous();
|
||||
|
||||
auto index_coords_sizes = index.sizes().vec();
|
||||
index_coords_sizes.push_back(self.dim());
|
||||
auto index_coords = at::empty(
|
||||
index_coords_sizes,
|
||||
at::TensorOptions().dtype(at::ScalarType::Long).device(self.device()));
|
||||
|
||||
for (int64_t dim_other = 0; dim_other < self.dim(); dim_other++) {
|
||||
if (dim_other == dim) {
|
||||
continue;
|
||||
}
|
||||
auto dim_coord_vals = at::arange(
|
||||
index.size(dim_other), at::TensorOptions().device(self.device()));
|
||||
|
||||
for (int64_t dim_unsqueeze = 0; dim_unsqueeze < self.dim() - 1;
|
||||
dim_unsqueeze++) {
|
||||
dim_coord_vals =
|
||||
dim_coord_vals.unsqueeze((dim_unsqueeze >= dim_other) ? -1 : 0);
|
||||
}
|
||||
|
||||
auto view_sizes = index.sizes().vec();
|
||||
view_sizes.push_back(1);
|
||||
auto view_strides = index_coords.strides().vec();
|
||||
view_strides[self.dim()] = self.dim();
|
||||
|
||||
at::as_strided(index_coords, view_sizes, view_strides, dim_other)
|
||||
.copy_(dim_coord_vals.unsqueeze(-1));
|
||||
// If index is expanded with zero strides across non-scatter dimensions,
|
||||
// advanced indexing with the index tensor alone achieves the desired
|
||||
// semantics and avoids creating large intermediate tensors.
|
||||
bool broadcast_index = true;
|
||||
for (const auto i : c10::irange(index.dim())) {
|
||||
if (i == dim) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto view_sizes = index.sizes().vec();
|
||||
view_sizes.push_back(1);
|
||||
auto view_strides = index_coords.strides().vec();
|
||||
view_strides[self.dim()] = self.dim();
|
||||
|
||||
at::as_strided(index_coords, view_sizes, view_strides, dim)
|
||||
.copy_(index.unsqueeze(-1));
|
||||
|
||||
Tensor index_coords_flat = index_coords.flatten(0, -2);
|
||||
|
||||
// Copy mut_out_contig's strides into a tensor
|
||||
// TODO: Is there a utility function that already does this?
|
||||
IntArrayRef mut_out_contig_strides = mut_out_contig.strides();
|
||||
Tensor coord_strides = at::empty(
|
||||
{mut_out_contig.dim()},
|
||||
TensorOptions().dtype(at::ScalarType::Long).device(at::kCPU));
|
||||
std::memcpy(
|
||||
coord_strides.mutable_data_ptr(),
|
||||
mut_out_contig_strides.data(),
|
||||
coord_strides.nbytes());
|
||||
coord_strides = coord_strides.to(mut_out_contig.device());
|
||||
|
||||
// `index_flat` contains the 1-D indices corresponding with the
|
||||
// flattened `mut_out`
|
||||
Tensor index_flat = (index_coords_flat * coord_strides).sum({-1});
|
||||
Tensor mut_out_flat = mut_out_contig.flatten();
|
||||
Tensor src_flat =
|
||||
at::as_strided(src, index.sizes(), src.strides()).flatten();
|
||||
|
||||
torch::List<std::optional<Tensor>> indices;
|
||||
indices.reserve(1);
|
||||
indices.push_back(index_flat);
|
||||
|
||||
mut_out_flat.index_put_(indices, src_flat, accumulate);
|
||||
|
||||
if (!mut_out.is_contiguous()) {
|
||||
mut_out.copy_(mut_out_flat.reshape(mut_out.sizes()));
|
||||
if (index.stride(i) != 0) {
|
||||
broadcast_index = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto src_view = at::as_strided(src, index.sizes(), src.strides());
|
||||
torch::List<std::optional<Tensor>> indices;
|
||||
indices.reserve(self.dim());
|
||||
|
||||
if (self.dim() == 1 || broadcast_index) {
|
||||
Tensor squeezed = index;
|
||||
if (broadcast_index && index.dim() > 1) {
|
||||
for (const auto d : c10::irange(index.dim())) {
|
||||
if (d == dim) {
|
||||
continue;
|
||||
}
|
||||
squeezed = squeezed.select(d, 0);
|
||||
}
|
||||
}
|
||||
for ([[maybe_unused]] const auto d : c10::irange(dim)) {
|
||||
indices.push_back(Tensor());
|
||||
}
|
||||
indices.push_back(squeezed);
|
||||
mut_out.index_put_(indices, src_view, accumulate);
|
||||
return;
|
||||
}
|
||||
|
||||
for (const auto d : c10::irange(self.dim())) {
|
||||
if (d == dim) {
|
||||
indices.push_back(index);
|
||||
} else {
|
||||
auto arange = at::arange(index.size(d), index.options());
|
||||
std::vector<int64_t> shape(index.dim(), 1);
|
||||
shape[d] = index.size(d);
|
||||
indices.push_back(arange.view(shape).expand(index.sizes()));
|
||||
}
|
||||
}
|
||||
mut_out.index_put_(indices, src_view, accumulate);
|
||||
}
|
||||
|
||||
template <
|
||||
|
@ -380,6 +380,22 @@ class TestScatterGather(TestCase):
|
||||
helper([50, 8, 7], 100)
|
||||
helper([50, 3, 4, 5], 100)
|
||||
|
||||
@dtypes(torch.float32)
|
||||
def test_scatter_add_broadcasted_index_deterministic(self, device, dtype):
|
||||
for d in (0, 1):
|
||||
inp = torch.randn(3, 4, device=device, dtype=dtype)
|
||||
idx_1d = torch.randint(3, (10,), device=device)
|
||||
src_shape = list(inp.shape)
|
||||
src_shape[d] = 10
|
||||
src = torch.randn(src_shape, device=device, dtype=dtype)
|
||||
idx = idx_1d.unsqueeze(1 - d).expand(src_shape)
|
||||
print(idx.stride())
|
||||
ref = inp.clone().scatter_add_(d, idx, src)
|
||||
with DeterministicGuard(True):
|
||||
res = inp.clone().scatter_add_(d, idx, src)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
||||
def test_gather_expanded_index(self, device, dtype):
|
||||
|
Reference in New Issue
Block a user