diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index b4333a0d130f..8c8222d1b917 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -888,6 +888,42 @@ def forward(self, x): loaded_ep = load(buffer) self.assertEqual(m(*sample_inputs), loaded_ep.module()(*sample_inputs)) + def test_1D_tensor_slicing(self) -> None: + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.arange(8)[::2] + + def forward(self, x): + return x + self.const + + m = M() + sample_inputs = (torch.randn(4),) + ep = torch.export.export(m, sample_inputs) + buffer = io.BytesIO() + save(ep, buffer) + buffer.seek(0) + loaded_ep = load(buffer) + self.assertEqual(m(*sample_inputs), loaded_ep.module()(*sample_inputs)) + + def test_2D_tensor_slicing(self) -> None: + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.randn(4, 4)[:2, :2] + + def forward(self, x): + return x + self.const + + m = M() + sample_inputs = (torch.randn(2, 2),) + ep = torch.export.export(m, sample_inputs) + buffer = io.BytesIO() + save(ep, buffer) + buffer.seek(0) + loaded_ep = load(buffer) + self.assertEqual(m(*sample_inputs), loaded_ep.module()(*sample_inputs)) + def test_complex_constant(self) -> None: class M(torch.nn.Module): def forward(self, x): diff --git a/torch/export/pt2_archive/_package.py b/torch/export/pt2_archive/_package.py index eab67a092e1c..f8849282dd9a 100644 --- a/torch/export/pt2_archive/_package.py +++ b/torch/export/pt2_archive/_package.py @@ -346,7 +346,7 @@ def _get_raw_tensor_bytes(value: torch.Tensor) -> bytes: if _is_fake_tensor(value): value_bytes = b"" elif value.data_ptr(): - cpu_tensor = value.cpu().contiguous() + cpu_tensor = value.cpu() value_untyped_storage = cpu_tensor.untyped_storage() # we store the raw bytes the untyped storage. Tensor metadata is stored separately value_bytes = bytes(