mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix torch.load(..., weights_only=True) for NT (#112516)
Found when looking into #112509 Pull Request resolved: https://github.com/pytorch/pytorch/pull/112516 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
85e93632e7
commit
51a38380d1
@ -2995,7 +2995,8 @@ class TestNestedTensorSubclass(TestCase):
|
||||
|
||||
@dtypes(torch.float, torch.double, torch.half)
|
||||
@parametrize("requires_grad", [False, True])
|
||||
def test_serialization(self, device, dtype, requires_grad):
|
||||
@parametrize("weights_only", [False, True])
|
||||
def test_serialization(self, device, dtype, requires_grad, weights_only):
|
||||
|
||||
def compare_metadata(nt1, nt2):
|
||||
self.assertEqual(nt1._nested_tensor_size(), nt2._nested_tensor_size())
|
||||
@ -3008,7 +3009,7 @@ class TestNestedTensorSubclass(TestCase):
|
||||
buffer = io.BytesIO()
|
||||
serialized = torch.save(a, buffer)
|
||||
buffer.seek(0)
|
||||
b = torch.load(buffer)
|
||||
b = torch.load(buffer, weights_only=weights_only)
|
||||
# should be both conceptually equal and metadata equivalent
|
||||
self.assertEqual(a, b)
|
||||
compare_metadata(a, b)
|
||||
|
@ -376,7 +376,7 @@ class Tensor(torch._C.TensorBase):
|
||||
self._nested_tensor_strides(),
|
||||
self._nested_tensor_storage_offsets(),
|
||||
)
|
||||
return (torch._nested_view_from_buffer, args_nested)
|
||||
return (torch._utils._rebuild_nested_tensor, args_nested)
|
||||
elif (
|
||||
self.data_ptr() == 0
|
||||
and type(self) is not torch.Tensor
|
||||
|
@ -304,6 +304,10 @@ def _rebuild_sparse_tensor(layout, data):
|
||||
raise NotImplementedError(f"rebuilding sparse tensor for layout {layout}")
|
||||
|
||||
|
||||
def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets):
|
||||
return torch._nested_view_from_buffer(buffer, sizes, strides, storage_offsets)
|
||||
|
||||
|
||||
def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
|
||||
tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
|
||||
tensor.requires_grad = requires_grad
|
||||
|
@ -101,6 +101,7 @@ def _get_allowed_globals():
|
||||
torch._utils._rebuild_tensor_v2,
|
||||
torch._utils._rebuild_sparse_tensor,
|
||||
torch._utils._rebuild_meta_tensor_no_storage,
|
||||
torch._utils._rebuild_nested_tensor,
|
||||
]:
|
||||
rc[f"torch._utils.{f.__name__}"] = f
|
||||
|
||||
|
Reference in New Issue
Block a user