[dynamic shapes] avoid unnecessary slices (#157528)

Fixes #157289, by extending optimization to slices where the end index exceeds the size.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157528
Approved by: https://github.com/angelayi
This commit is contained in:
Pian Pawakapan
2025-07-10 06:34:46 +00:00
committed by PyTorch MergeBot
parent 565fd07909
commit 4cc13c4af6
4 changed files with 59 additions and 23 deletions

View File

@ -214,7 +214,7 @@ inline Tensor applySlice(
"step must be greater than zero");
// See NOTE [nested tensor size for indexing]
if (self_sizes.has_value()) {
if (self_sizes.has_value() && self_sizes.value().size() > 0) {
// Skip this optimization if we are tracing, as the trace may be polymorphic
// over the shape of the `self` tensor, and we still want to record
// the slice.
@ -223,7 +223,7 @@ inline Tensor applySlice(
: self.sym_size(dim);
if (!disable_slice_optimization &&
TORCH_STATICALLY_KNOWN_TRUE(start.sym_eq(0)) &&
TORCH_STATICALLY_KNOWN_TRUE(length.sym_eq(stop)) && step == 1) {
TORCH_STATICALLY_KNOWN_TRUE(length.sym_le(stop)) && step == 1) {
return self;
}
}

View File

@ -2726,6 +2726,52 @@ graph():
):
export(Foo(), inputs, dynamic_shapes=shapes)
def test_issue_157289(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
def forward(self, causal_mask, fill_value):
causal_mask = causal_mask.clone()
mask_length = fill_value.shape[-1]
causal_mask[:, :, :, :mask_length] = fill_value
return causal_mask
causal_mask = torch.randn(2, 2, 3, 4)
fill_value = torch.randn(2, 2, 3, 3)
dynamic_shapes = {
"causal_mask": {3: Dim("M")},
"fill_value": {3: Dim("N")},
}
ep = export(
MyModule(), (causal_mask, fill_value), dynamic_shapes=dynamic_shapes
)
if not is_training_ir_test(self._testMethodName) and not is_retracebility_test(
self._testMethodName
):
self.assertExpectedInline(
str(ep.graph_module.code).strip(),
"""\
def forward(self, causal_mask, fill_value):
sym_size_int_4 = torch.ops.aten.sym_size.int(fill_value, 3)
clone = torch.ops.aten.clone.default(causal_mask); causal_mask = None
slice_1 = torch.ops.aten.slice.Tensor(clone, 3, 0, sym_size_int_4); sym_size_int_4 = None
copy_ = torch.ops.aten.copy_.default(slice_1, fill_value); slice_1 = fill_value = copy_ = None
return (clone,)""",
)
decomposed_ep = ep.run_decompositions()
self.assertExpectedInline(
str(decomposed_ep.graph_module.code).strip(),
"""\
def forward(self, causal_mask, fill_value):
sym_size_int_5 = torch.ops.aten.sym_size.int(fill_value, 3)
clone = torch.ops.aten.clone.default(causal_mask); causal_mask = None
slice_1 = torch.ops.aten.slice.Tensor(clone, 3, 0, sym_size_int_5)
copy = torch.ops.aten.copy.default(slice_1, fill_value); slice_1 = fill_value = None
slice_scatter = torch.ops.aten.slice_scatter.default(clone, copy, 3, 0, sym_size_int_5); clone = copy = sym_size_int_5 = None
return (slice_scatter,)""",
)
def test_dim_dynamic_specialization(self):
class Foo(torch.nn.Module):
def forward(self, x):

View File

@ -188,14 +188,11 @@ def forward(self, a__1):
def forward(self, a__1):
clone = torch.ops.aten.clone.default(a__1); a__1 = None
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
select = torch.ops.aten.select.int(clone, 1, 1)
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = add = None
slice_2 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
select_2 = torch.ops.aten.select.int(slice_2, 1, 1); slice_2 = select_2 = None
slice_3 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
select_3 = torch.ops.aten.select.int(slice_3, 1, 1); slice_3 = None
select_2 = torch.ops.aten.select.int(clone, 1, 1); select_2 = None
select_3 = torch.ops.aten.select.int(clone, 1, 1)
select_4 = torch.ops.aten.select.int(select_3, 0, 1); select_3 = select_4 = None
return clone
""")
@ -228,8 +225,7 @@ def forward(self, a__1):
def forward(self, a__1):
clone = torch.ops.aten.clone.default(a__1); a__1 = None
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
select = torch.ops.aten.select.int(clone, 1, 1)
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = add = None
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None
@ -264,8 +260,7 @@ def forward(self, a__1):
def forward(self, a__1):
clone = torch.ops.aten.clone.default(a__1); a__1 = None
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
select = torch.ops.aten.select.int(clone, 1, 1)
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None
@ -297,8 +292,7 @@ def forward(self, a__1):
def forward(self, a__1):
clone = torch.ops.aten.clone.default(a__1); a__1 = None
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
select = torch.ops.aten.select.int(clone, 1, 1)
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 0); clone = None
@ -353,12 +347,9 @@ def forward(self):
def forward(self):
zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
slice_1 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 2, 9223372036854775807); slice_1 = None
copy = torch.ops.aten.copy_.default(slice_2, ones); slice_2 = ones = copy = None
slice_3 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807); slice_3 = None
slice_4 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
slice_5 = torch.ops.aten.slice.Tensor(slice_4, 1, 2, 9223372036854775807); slice_4 = slice_5 = None
slice_1 = torch.ops.aten.slice.Tensor(zeros, 1, 2, 9223372036854775807)
copy = torch.ops.aten.copy_.default(slice_1, ones); slice_1 = ones = copy = None
slice_2 = torch.ops.aten.slice.Tensor(zeros, 1, 2, 9223372036854775807); slice_2 = None
return zeros
""")

View File

@ -741,9 +741,8 @@ $4: f32[1] = torch._ops.aten._foobar.default($0, False, arg3=False)""",
$0: f32[2, 2] = input('x')
$1: f64[2, 2] = torch._ops.aten._to_copy.default($0, dtype=torch.float64)
$2: f64[2, 2] = torch._ops.aten.cumprod.default($0, 0, dtype=torch.float64)
$3: f32[2, 2] = torch._ops.aten.slice.Tensor($0, 0, 0, 9223372036854775807)
$4: f32[2] = torch._ops.aten.select.int($3, 1, 1)
$5: f32[2] = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)""",
$3: f32[2] = torch._ops.aten.select.int($0, 1, 1)
$4: f32[2] = torch._ops.aten.clone.default($3, memory_format=torch.contiguous_format)""",
)
def test_optional_tensor_list(self) -> None: