mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use a better decomposition for split_with_sizes (#135728)
This decomposition has less checks and improves the performance of torch.compile. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135728 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
7647c398ff
commit
dab7d646d5
@ -1409,14 +1409,17 @@ def split_with_sizes(
|
|||||||
sum(split_sizes) == self.shape[dim],
|
sum(split_sizes) == self.shape[dim],
|
||||||
lambda: f"Split sizes add up to {sum(split_sizes)} but got the tensor's size of {self.shape[dim]}",
|
lambda: f"Split sizes add up to {sum(split_sizes)} but got the tensor's size of {self.shape[dim]}",
|
||||||
)
|
)
|
||||||
num_splits = len(split_sizes)
|
|
||||||
splits = []
|
|
||||||
start_idx = 0
|
|
||||||
|
|
||||||
for i in range(num_splits):
|
splits = []
|
||||||
length = split_sizes[i]
|
offset = self.storage_offset()
|
||||||
splits.append(self.narrow(dim, start_idx, length))
|
|
||||||
start_idx += length
|
for split_size in split_sizes:
|
||||||
|
new_shape = list(self.shape)
|
||||||
|
new_shape[dim] = split_size
|
||||||
|
# We reimplement narrow here to avoid a lot of checks in the
|
||||||
|
# decomposition of narrow which calls slice_in_dim and slice
|
||||||
|
splits.append(self.as_strided(new_shape, self.stride(), offset))
|
||||||
|
offset = offset + self.stride()[dim] * split_size
|
||||||
return splits
|
return splits
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user