diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index aea88dd1cc94..c2d0856c3cd4 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -3063,7 +3063,7 @@ Tensor slice( } auto storage_offset = self.storage_offset() + start_val * strides[dim]; auto len = end_val - start_val; - sizes[dim] = (len / step) + (len % step != 0); // safely round-up + sizes[dim] = (len + step - 1) / step; // round-up strides[dim] *= step; Tensor result; diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index ac5b538189b3..6a6e3c674179 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4309,31 +4309,6 @@ class CommonTemplate: self.assertEqual(torch.compile(fn)(x1, y), fn(x1, y)) self.assertEqual(torch.compile(fn)(x2, y), fn(x2, y)) - def test_slice_copy(self): - class Model(nn.Module): - def __init__(self, start=449, step=(2**63 - 1)): - super().__init__() - self.start = start - self.step = step - - def forward(self, x: torch.Tensor): - sliced = torch.slice_copy( - x, dim=0, start=self.start, end=None, step=self.step - ) - return torch.reciprocal(sliced) - - with config.patch({"implicit_fallbacks": True}): - # bad case - self.common( - Model(), - (torch.randn(875),), - ) - # normal case - self.common( - Model(step=10), - (torch.randn(875),), - ) - def test_slice1(self): def fn(a): return ( diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 1918373b342e..ba09c6173c5f 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -759,8 +759,7 @@ def slice_forward( storage_offset = self.storage_offset() + start_val * strides[dim] len = end_val - start_val - # safely round-up for corresponding c++ impl - sizes[dim] = (len // step) + (1 if len % step != 0 else 0) + sizes[dim] = (len + step - 1) // step strides[dim] *= step if self.is_quantized: