Use a better decomposition for split_with_sizes (#135728)

This decomposition has less checks and improves the performance
of torch.compile.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135728
Approved by: https://github.com/ezyang
This commit is contained in:
Isuru Fernando
2024-09-11 19:44:41 +00:00
committed by PyTorch MergeBot
parent 7647c398ff
commit dab7d646d5

View File

@ -1409,14 +1409,17 @@ def split_with_sizes(
sum(split_sizes) == self.shape[dim], sum(split_sizes) == self.shape[dim],
lambda: f"Split sizes add up to {sum(split_sizes)} but got the tensor's size of {self.shape[dim]}", lambda: f"Split sizes add up to {sum(split_sizes)} but got the tensor's size of {self.shape[dim]}",
) )
num_splits = len(split_sizes)
splits = []
start_idx = 0
for i in range(num_splits): splits = []
length = split_sizes[i] offset = self.storage_offset()
splits.append(self.narrow(dim, start_idx, length))
start_idx += length for split_size in split_sizes:
new_shape = list(self.shape)
new_shape[dim] = split_size
# We reimplement narrow here to avoid a lot of checks in the
# decomposition of narrow which calls slice_in_dim and slice
splits.append(self.as_strided(new_shape, self.stride(), offset))
offset = offset + self.stride()[dim] * split_size
return splits return splits