Improve repeat op to a single copy (#163842)

In #163455 , the `reshape` was not a pure view op.

The `permute` before it created an non-contiguous tensor, which would trigger a data copy during the reshape.

This PR improved the implementation by remove the `urtensor` intermediate tensor completely.
By simply expanding the `xtensor` would achieve the `repeat` effect.

Before this PR, there were two data copies (in `urtensor.copy_` and `urtensor.reshape`).
Now, there is only one data copy in the `.copy_()`.
Reshape would not copy data because it is on a contiguous tensor.

One more note is that we do want at one copy because we want to duplicate the elements for the repeats.
User can inplace modify single elements without afffecting others.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163842
Approved by: https://github.com/Skylion007

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
This commit is contained in:
Haifeng Jin
2025-10-01 06:27:53 +00:00
committed by PyTorch MergeBot
parent cc8b14d09a
commit 590224f83c

View File

@ -1880,43 +1880,34 @@ Tensor repeat(const Tensor& self, IntArrayRef repeats) {
Tensor xtensor = self.expand(padded_size); Tensor xtensor = self.expand(padded_size);
Tensor urtensor;
if (self.is_quantized()) {
urtensor = at::empty_quantized(target_size, self);
} else {
urtensor = at::empty(target_size, self.options());
}
// return an empty tensor if one of the repeat dimensions is zero
if (zero_tensor) { if (zero_tensor) {
return urtensor; return self.is_quantized() ? at::empty_quantized(target_size, self)
: at::empty(target_size, self.options());
} }
// Create view of shape [r0, s0, r1, s1, ...]
// where ri is repeat[i], si is self.size(i).
Tensor view = xtensor;
auto expand_shape = std::vector<int64_t>();
expand_shape.reserve(xtensor.dim() * 2);
for (const auto i : c10::irange(xtensor.dim())) { for (const auto i : c10::irange(xtensor.dim())) {
// can't unfold with step 0, so make sure step is at least 1 view = view.unsqueeze(2 * i);
// (it doesn't matter what it is in that case, because the size is 0). expand_shape.push_back(repeats[i]);
auto size_i = xtensor.sizes()[i]; expand_shape.push_back(xtensor.size(i));
urtensor = urtensor.unfold(i, size_i, std::max<int64_t>(size_i, 1));
} }
// expanded_view is non-contiguous because .expand set stride to 0.
auto expanded_view = view.expand(expand_shape);
urtensor.copy_(xtensor.expand_as(urtensor)); // copy to contiguous tensor.
auto contiguous_copy = at::empty(
expanded_view.sizes(),
expanded_view.options(),
at::MemoryFormat::Contiguous);
contiguous_copy.copy_(expanded_view);
// Combine the dimensions to produce the target_size. // Reshape to [s0 * r0, s1 * r1, ...].
// xtensor dims: [a0, ..., ad-1] // No extra copy of data during reshape for a contiguous tensor.
// urtensor dims: [a0, ..., ad-1, b0, ..., bd-1] return contiguous_copy.view(target_size);
// b dims are produced by unfold.
// Transform urtensor to [a0 * b0, ..., ad-1 * bd-1]
const int64_t n_dims = xtensor.dim();
auto range_a = at::arange(xtensor.dim(), at::TensorOptions(at::kLong));
auto range_b = range_a + n_dims;
auto stacked = stack({std::move(range_a), std::move(range_b)}, 1).flatten();
auto permutation = IntArrayRef(stacked.data_ptr<int64_t>(), n_dims * 2);
// Permute from [a0, ..., ad-1, b0, ..., bd-1] to [a0, b0, ..., ad-1, bd-1]
urtensor = urtensor.permute(permutation);
// Reshape from [a0, b0, ..., ad-1, bd-1] to [a0 * b0, ..., ad-1 * bd-1]
urtensor = urtensor.reshape(target_size);
return urtensor;
} }
Tensor tile_symint(const Tensor& self, SymIntArrayRef reps) { Tensor tile_symint(const Tensor& self, SymIntArrayRef reps) {