[export] Remove .contiguous() when saving weights to raw bytes (#163587)

Summary: `.contiguous()` will discard the original storage size of the tensor, and could lead to issues during loading.

Test Plan:
buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_1D_tensor_slicing
buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_2D_tensor_slicing

Differential Revision: D83016250

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163587
Approved by: https://github.com/angelayi
This commit is contained in:
Yiming Zhou
2025-09-23 15:44:52 +00:00
committed by PyTorch MergeBot
parent 49e7b2f69d
commit 720a7b2887
2 changed files with 37 additions and 1 deletions

View File

@ -888,6 +888,42 @@ def forward(self, x):
loaded_ep = load(buffer)
self.assertEqual(m(*sample_inputs), loaded_ep.module()(*sample_inputs))
def test_1D_tensor_slicing(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.arange(8)[::2]
def forward(self, x):
return x + self.const
m = M()
sample_inputs = (torch.randn(4),)
ep = torch.export.export(m, sample_inputs)
buffer = io.BytesIO()
save(ep, buffer)
buffer.seek(0)
loaded_ep = load(buffer)
self.assertEqual(m(*sample_inputs), loaded_ep.module()(*sample_inputs))
def test_2D_tensor_slicing(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.const = torch.randn(4, 4)[:2, :2]
def forward(self, x):
return x + self.const
m = M()
sample_inputs = (torch.randn(2, 2),)
ep = torch.export.export(m, sample_inputs)
buffer = io.BytesIO()
save(ep, buffer)
buffer.seek(0)
loaded_ep = load(buffer)
self.assertEqual(m(*sample_inputs), loaded_ep.module()(*sample_inputs))
def test_complex_constant(self) -> None:
class M(torch.nn.Module):
def forward(self, x):

View File

@ -346,7 +346,7 @@ def _get_raw_tensor_bytes(value: torch.Tensor) -> bytes:
if _is_fake_tensor(value):
value_bytes = b""
elif value.data_ptr():
cpu_tensor = value.cpu().contiguous()
cpu_tensor = value.cpu()
value_untyped_storage = cpu_tensor.untyped_storage()
# we store the raw bytes the untyped storage. Tensor metadata is stored separately
value_bytes = bytes(