mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[dynamic shapes] avoid unnecessary slices (#157528)
Fixes #157289, by extending optimization to slices where the end index exceeds the size. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157528 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
565fd07909
commit
4cc13c4af6
@ -214,7 +214,7 @@ inline Tensor applySlice(
|
||||
"step must be greater than zero");
|
||||
|
||||
// See NOTE [nested tensor size for indexing]
|
||||
if (self_sizes.has_value()) {
|
||||
if (self_sizes.has_value() && self_sizes.value().size() > 0) {
|
||||
// Skip this optimization if we are tracing, as the trace may be polymorphic
|
||||
// over the shape of the `self` tensor, and we still want to record
|
||||
// the slice.
|
||||
@ -223,7 +223,7 @@ inline Tensor applySlice(
|
||||
: self.sym_size(dim);
|
||||
if (!disable_slice_optimization &&
|
||||
TORCH_STATICALLY_KNOWN_TRUE(start.sym_eq(0)) &&
|
||||
TORCH_STATICALLY_KNOWN_TRUE(length.sym_eq(stop)) && step == 1) {
|
||||
TORCH_STATICALLY_KNOWN_TRUE(length.sym_le(stop)) && step == 1) {
|
||||
return self;
|
||||
}
|
||||
}
|
||||
|
@ -2726,6 +2726,52 @@ graph():
|
||||
):
|
||||
export(Foo(), inputs, dynamic_shapes=shapes)
|
||||
|
||||
def test_issue_157289(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModule, self).__init__()
|
||||
|
||||
def forward(self, causal_mask, fill_value):
|
||||
causal_mask = causal_mask.clone()
|
||||
mask_length = fill_value.shape[-1]
|
||||
causal_mask[:, :, :, :mask_length] = fill_value
|
||||
return causal_mask
|
||||
|
||||
causal_mask = torch.randn(2, 2, 3, 4)
|
||||
fill_value = torch.randn(2, 2, 3, 3)
|
||||
dynamic_shapes = {
|
||||
"causal_mask": {3: Dim("M")},
|
||||
"fill_value": {3: Dim("N")},
|
||||
}
|
||||
ep = export(
|
||||
MyModule(), (causal_mask, fill_value), dynamic_shapes=dynamic_shapes
|
||||
)
|
||||
if not is_training_ir_test(self._testMethodName) and not is_retracebility_test(
|
||||
self._testMethodName
|
||||
):
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph_module.code).strip(),
|
||||
"""\
|
||||
def forward(self, causal_mask, fill_value):
|
||||
sym_size_int_4 = torch.ops.aten.sym_size.int(fill_value, 3)
|
||||
clone = torch.ops.aten.clone.default(causal_mask); causal_mask = None
|
||||
slice_1 = torch.ops.aten.slice.Tensor(clone, 3, 0, sym_size_int_4); sym_size_int_4 = None
|
||||
copy_ = torch.ops.aten.copy_.default(slice_1, fill_value); slice_1 = fill_value = copy_ = None
|
||||
return (clone,)""",
|
||||
)
|
||||
decomposed_ep = ep.run_decompositions()
|
||||
self.assertExpectedInline(
|
||||
str(decomposed_ep.graph_module.code).strip(),
|
||||
"""\
|
||||
def forward(self, causal_mask, fill_value):
|
||||
sym_size_int_5 = torch.ops.aten.sym_size.int(fill_value, 3)
|
||||
clone = torch.ops.aten.clone.default(causal_mask); causal_mask = None
|
||||
slice_1 = torch.ops.aten.slice.Tensor(clone, 3, 0, sym_size_int_5)
|
||||
copy = torch.ops.aten.copy.default(slice_1, fill_value); slice_1 = fill_value = None
|
||||
slice_scatter = torch.ops.aten.slice_scatter.default(clone, copy, 3, 0, sym_size_int_5); clone = copy = sym_size_int_5 = None
|
||||
return (slice_scatter,)""",
|
||||
)
|
||||
|
||||
def test_dim_dynamic_specialization(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
@ -188,14 +188,11 @@ def forward(self, a__1):
|
||||
|
||||
def forward(self, a__1):
|
||||
clone = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
|
||||
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
|
||||
select = torch.ops.aten.select.int(clone, 1, 1)
|
||||
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
|
||||
add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = add = None
|
||||
slice_2 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
|
||||
select_2 = torch.ops.aten.select.int(slice_2, 1, 1); slice_2 = select_2 = None
|
||||
slice_3 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
|
||||
select_3 = torch.ops.aten.select.int(slice_3, 1, 1); slice_3 = None
|
||||
select_2 = torch.ops.aten.select.int(clone, 1, 1); select_2 = None
|
||||
select_3 = torch.ops.aten.select.int(clone, 1, 1)
|
||||
select_4 = torch.ops.aten.select.int(select_3, 0, 1); select_3 = select_4 = None
|
||||
return clone
|
||||
""")
|
||||
@ -228,8 +225,7 @@ def forward(self, a__1):
|
||||
|
||||
def forward(self, a__1):
|
||||
clone = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
|
||||
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
|
||||
select = torch.ops.aten.select.int(clone, 1, 1)
|
||||
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
|
||||
add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = add = None
|
||||
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None
|
||||
@ -264,8 +260,7 @@ def forward(self, a__1):
|
||||
|
||||
def forward(self, a__1):
|
||||
clone = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
|
||||
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
|
||||
select = torch.ops.aten.select.int(clone, 1, 1)
|
||||
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
|
||||
add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None
|
||||
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None
|
||||
@ -297,8 +292,7 @@ def forward(self, a__1):
|
||||
|
||||
def forward(self, a__1):
|
||||
clone = torch.ops.aten.clone.default(a__1); a__1 = None
|
||||
slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
|
||||
select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None
|
||||
select = torch.ops.aten.select.int(clone, 1, 1)
|
||||
select_1 = torch.ops.aten.select.int(select, 0, 1); select = None
|
||||
add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None
|
||||
as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 0); clone = None
|
||||
@ -353,12 +347,9 @@ def forward(self):
|
||||
def forward(self):
|
||||
zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
|
||||
ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
|
||||
slice_1 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
|
||||
slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 2, 9223372036854775807); slice_1 = None
|
||||
copy = torch.ops.aten.copy_.default(slice_2, ones); slice_2 = ones = copy = None
|
||||
slice_3 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807); slice_3 = None
|
||||
slice_4 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
|
||||
slice_5 = torch.ops.aten.slice.Tensor(slice_4, 1, 2, 9223372036854775807); slice_4 = slice_5 = None
|
||||
slice_1 = torch.ops.aten.slice.Tensor(zeros, 1, 2, 9223372036854775807)
|
||||
copy = torch.ops.aten.copy_.default(slice_1, ones); slice_1 = ones = copy = None
|
||||
slice_2 = torch.ops.aten.slice.Tensor(zeros, 1, 2, 9223372036854775807); slice_2 = None
|
||||
return zeros
|
||||
""")
|
||||
|
||||
|
@ -741,9 +741,8 @@ $4: f32[1] = torch._ops.aten._foobar.default($0, False, arg3=False)""",
|
||||
$0: f32[2, 2] = input('x')
|
||||
$1: f64[2, 2] = torch._ops.aten._to_copy.default($0, dtype=torch.float64)
|
||||
$2: f64[2, 2] = torch._ops.aten.cumprod.default($0, 0, dtype=torch.float64)
|
||||
$3: f32[2, 2] = torch._ops.aten.slice.Tensor($0, 0, 0, 9223372036854775807)
|
||||
$4: f32[2] = torch._ops.aten.select.int($3, 1, 1)
|
||||
$5: f32[2] = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)""",
|
||||
$3: f32[2] = torch._ops.aten.select.int($0, 1, 1)
|
||||
$4: f32[2] = torch._ops.aten.clone.default($3, memory_format=torch.contiguous_format)""",
|
||||
)
|
||||
|
||||
def test_optional_tensor_list(self) -> None:
|
||||
|
Reference in New Issue
Block a user