Use a better decomposition for split_with_sizes (#135728)

This decomposition has less checks and improves the performance
of torch.compile.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135728
Approved by: https://github.com/ezyang
This commit is contained in:
Isuru Fernando
2024-09-11 19:44:41 +00:00
committed by PyTorch MergeBot
parent 7647c398ff
commit dab7d646d5

View File

@ -1409,14 +1409,17 @@ def split_with_sizes(
sum(split_sizes) == self.shape[dim],
lambda: f"Split sizes add up to {sum(split_sizes)} but got the tensor's size of {self.shape[dim]}",
)
num_splits = len(split_sizes)
splits = []
start_idx = 0
for i in range(num_splits):
length = split_sizes[i]
splits.append(self.narrow(dim, start_idx, length))
start_idx += length
splits = []
offset = self.storage_offset()
for split_size in split_sizes:
new_shape = list(self.shape)
new_shape[dim] = split_size
# We reimplement narrow here to avoid a lot of checks in the
# decomposition of narrow which calls slice_in_dim and slice
splits.append(self.as_strided(new_shape, self.stride(), offset))
offset = offset + self.stride()[dim] * split_size
return splits