diff --git a/aten/src/ATen/TensorIndexing.h b/aten/src/ATen/TensorIndexing.h index b385ea80b809..d9d8554abc79 100644 --- a/aten/src/ATen/TensorIndexing.h +++ b/aten/src/ATen/TensorIndexing.h @@ -214,7 +214,7 @@ inline Tensor applySlice( "step must be greater than zero"); // See NOTE [nested tensor size for indexing] - if (self_sizes.has_value()) { + if (self_sizes.has_value() && self_sizes.value().size() > 0) { // Skip this optimization if we are tracing, as the trace may be polymorphic // over the shape of the `self` tensor, and we still want to record // the slice. @@ -223,7 +223,7 @@ inline Tensor applySlice( : self.sym_size(dim); if (!disable_slice_optimization && TORCH_STATICALLY_KNOWN_TRUE(start.sym_eq(0)) && - TORCH_STATICALLY_KNOWN_TRUE(length.sym_eq(stop)) && step == 1) { + TORCH_STATICALLY_KNOWN_TRUE(length.sym_le(stop)) && step == 1) { return self; } } diff --git a/test/export/test_export.py b/test/export/test_export.py index 32f1fafa4d99..c2083d6c02f3 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -2726,6 +2726,52 @@ graph(): ): export(Foo(), inputs, dynamic_shapes=shapes) + def test_issue_157289(self): + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + + def forward(self, causal_mask, fill_value): + causal_mask = causal_mask.clone() + mask_length = fill_value.shape[-1] + causal_mask[:, :, :, :mask_length] = fill_value + return causal_mask + + causal_mask = torch.randn(2, 2, 3, 4) + fill_value = torch.randn(2, 2, 3, 3) + dynamic_shapes = { + "causal_mask": {3: Dim("M")}, + "fill_value": {3: Dim("N")}, + } + ep = export( + MyModule(), (causal_mask, fill_value), dynamic_shapes=dynamic_shapes + ) + if not is_training_ir_test(self._testMethodName) and not is_retracebility_test( + self._testMethodName + ): + self.assertExpectedInline( + str(ep.graph_module.code).strip(), + """\ +def forward(self, causal_mask, fill_value): + sym_size_int_4 = torch.ops.aten.sym_size.int(fill_value, 3) + clone = torch.ops.aten.clone.default(causal_mask); causal_mask = None + slice_1 = torch.ops.aten.slice.Tensor(clone, 3, 0, sym_size_int_4); sym_size_int_4 = None + copy_ = torch.ops.aten.copy_.default(slice_1, fill_value); slice_1 = fill_value = copy_ = None + return (clone,)""", + ) + decomposed_ep = ep.run_decompositions() + self.assertExpectedInline( + str(decomposed_ep.graph_module.code).strip(), + """\ +def forward(self, causal_mask, fill_value): + sym_size_int_5 = torch.ops.aten.sym_size.int(fill_value, 3) + clone = torch.ops.aten.clone.default(causal_mask); causal_mask = None + slice_1 = torch.ops.aten.slice.Tensor(clone, 3, 0, sym_size_int_5) + copy = torch.ops.aten.copy.default(slice_1, fill_value); slice_1 = fill_value = None + slice_scatter = torch.ops.aten.slice_scatter.default(clone, copy, 3, 0, sym_size_int_5); clone = copy = sym_size_int_5 = None + return (slice_scatter,)""", + ) + def test_dim_dynamic_specialization(self): class Foo(torch.nn.Module): def forward(self, x): diff --git a/test/test_fx_reinplace_pass.py b/test/test_fx_reinplace_pass.py index 5db11af8e47a..4acda3bece74 100644 --- a/test/test_fx_reinplace_pass.py +++ b/test/test_fx_reinplace_pass.py @@ -188,14 +188,11 @@ def forward(self, a__1): def forward(self, a__1): clone = torch.ops.aten.clone.default(a__1); a__1 = None - slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) - select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None + select = torch.ops.aten.select.int(clone, 1, 1) select_1 = torch.ops.aten.select.int(select, 0, 1); select = None add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = add = None - slice_2 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) - select_2 = torch.ops.aten.select.int(slice_2, 1, 1); slice_2 = select_2 = None - slice_3 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) - select_3 = torch.ops.aten.select.int(slice_3, 1, 1); slice_3 = None + select_2 = torch.ops.aten.select.int(clone, 1, 1); select_2 = None + select_3 = torch.ops.aten.select.int(clone, 1, 1) select_4 = torch.ops.aten.select.int(select_3, 0, 1); select_3 = select_4 = None return clone """) @@ -228,8 +225,7 @@ def forward(self, a__1): def forward(self, a__1): clone = torch.ops.aten.clone.default(a__1); a__1 = None - slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) - select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None + select = torch.ops.aten.select.int(clone, 1, 1) select_1 = torch.ops.aten.select.int(select, 0, 1); select = None add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = add = None as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None @@ -264,8 +260,7 @@ def forward(self, a__1): def forward(self, a__1): clone = torch.ops.aten.clone.default(a__1); a__1 = None - slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) - select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None + select = torch.ops.aten.select.int(clone, 1, 1) select_1 = torch.ops.aten.select.int(select, 0, 1); select = None add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None @@ -297,8 +292,7 @@ def forward(self, a__1): def forward(self, a__1): clone = torch.ops.aten.clone.default(a__1); a__1 = None - slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) - select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None + select = torch.ops.aten.select.int(clone, 1, 1) select_1 = torch.ops.aten.select.int(select, 0, 1); select = None add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 0); clone = None @@ -353,12 +347,9 @@ def forward(self): def forward(self): zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False) ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False) - slice_1 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807) - slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 2, 9223372036854775807); slice_1 = None - copy = torch.ops.aten.copy_.default(slice_2, ones); slice_2 = ones = copy = None - slice_3 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807); slice_3 = None - slice_4 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807) - slice_5 = torch.ops.aten.slice.Tensor(slice_4, 1, 2, 9223372036854775807); slice_4 = slice_5 = None + slice_1 = torch.ops.aten.slice.Tensor(zeros, 1, 2, 9223372036854775807) + copy = torch.ops.aten.copy_.default(slice_1, ones); slice_1 = ones = copy = None + slice_2 = torch.ops.aten.slice.Tensor(zeros, 1, 2, 9223372036854775807); slice_2 = None return zeros """) diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index f1a13e3db1a2..aef4cb0e6917 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -741,9 +741,8 @@ $4: f32[1] = torch._ops.aten._foobar.default($0, False, arg3=False)""", $0: f32[2, 2] = input('x') $1: f64[2, 2] = torch._ops.aten._to_copy.default($0, dtype=torch.float64) $2: f64[2, 2] = torch._ops.aten.cumprod.default($0, 0, dtype=torch.float64) -$3: f32[2, 2] = torch._ops.aten.slice.Tensor($0, 0, 0, 9223372036854775807) -$4: f32[2] = torch._ops.aten.select.int($3, 1, 1) -$5: f32[2] = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)""", +$3: f32[2] = torch._ops.aten.select.int($0, 1, 1) +$4: f32[2] = torch._ops.aten.clone.default($3, memory_format=torch.contiguous_format)""", ) def test_optional_tensor_list(self) -> None: