mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[export] Remove .contiguous() when saving weights to raw bytes (#163587)
Summary: `.contiguous()` will discard the original storage size of the tensor, and could lead to issues during loading. Test Plan: buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_1D_tensor_slicing buck2 run mode/dev-nosan caffe2/test:test_export -- -r test_2D_tensor_slicing Differential Revision: D83016250 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163587 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
49e7b2f69d
commit
720a7b2887
@ -888,6 +888,42 @@ def forward(self, x):
|
||||
loaded_ep = load(buffer)
|
||||
self.assertEqual(m(*sample_inputs), loaded_ep.module()(*sample_inputs))
|
||||
|
||||
def test_1D_tensor_slicing(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.const = torch.arange(8)[::2]
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.const
|
||||
|
||||
m = M()
|
||||
sample_inputs = (torch.randn(4),)
|
||||
ep = torch.export.export(m, sample_inputs)
|
||||
buffer = io.BytesIO()
|
||||
save(ep, buffer)
|
||||
buffer.seek(0)
|
||||
loaded_ep = load(buffer)
|
||||
self.assertEqual(m(*sample_inputs), loaded_ep.module()(*sample_inputs))
|
||||
|
||||
def test_2D_tensor_slicing(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.const = torch.randn(4, 4)[:2, :2]
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.const
|
||||
|
||||
m = M()
|
||||
sample_inputs = (torch.randn(2, 2),)
|
||||
ep = torch.export.export(m, sample_inputs)
|
||||
buffer = io.BytesIO()
|
||||
save(ep, buffer)
|
||||
buffer.seek(0)
|
||||
loaded_ep = load(buffer)
|
||||
self.assertEqual(m(*sample_inputs), loaded_ep.module()(*sample_inputs))
|
||||
|
||||
def test_complex_constant(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
@ -346,7 +346,7 @@ def _get_raw_tensor_bytes(value: torch.Tensor) -> bytes:
|
||||
if _is_fake_tensor(value):
|
||||
value_bytes = b""
|
||||
elif value.data_ptr():
|
||||
cpu_tensor = value.cpu().contiguous()
|
||||
cpu_tensor = value.cpu()
|
||||
value_untyped_storage = cpu_tensor.untyped_storage()
|
||||
# we store the raw bytes the untyped storage. Tensor metadata is stored separately
|
||||
value_bytes = bytes(
|
||||
|
Reference in New Issue
Block a user