Pickle support for NT (#110219)

Fixes #104198
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110219
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Joel Schlosser
2023-09-28 16:13:21 -04:00
committed by PyTorch MergeBot
parent c9511e8ac9
commit 3693777a86
4 changed files with 39 additions and 3 deletions

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: nestedtensor"]
import io
import itertools
import unittest
from functools import partial
@ -103,14 +104,14 @@ def random_nt(device, dtype, num_tensors, max_dims, min_dims=None):
# Alternate approach to generating a random NT.
# dims should be something like [5, None, 10], with None indicating that a
# random ragged structure should be used
def random_nt_from_dims(dims, device=None, dtype=None):
def random_nt_from_dims(dims, device=None, dtype=None, requires_grad=False):
sizes = [
[d if d is not None else torch.randint(2, 10, size=(1,)).item() for d in dims[1:]]
for d in range(dims[0])
]
return torch.nested.nested_tensor([
torch.randn(*size) for size in sizes
], device=device, dtype=dtype)
], device=device, dtype=dtype, requires_grad=requires_grad)
# Creates an NT matching another NT's number of components and
@ -2917,6 +2918,30 @@ class TestNestedTensorSubclass(TestCase):
gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
@dtypes(torch.float, torch.double, torch.half)
@parametrize("requires_grad", [False, True])
def test_serialization(self, device, dtype, requires_grad):
def compare_metadata(nt1, nt2):
self.assertEqual(nt1._nested_tensor_size(), nt2._nested_tensor_size())
self.assertEqual(nt1._nested_tensor_strides(), nt2._nested_tensor_strides())
self.assertEqual(nt1._nested_tensor_storage_offsets(),
nt2._nested_tensor_storage_offsets())
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
for a in [nt_contiguous, nt_noncontiguous]:
buffer = io.BytesIO()
serialized = torch.save(a, buffer)
buffer.seek(0)
b = torch.load(buffer)
# should be both conceptually equal and metadata equivalent
self.assertEqual(a, b)
compare_metadata(a, b)
# should be conceptually equal but not necessarily metadata equivalent
self.assertEqual(b, nt_contiguous)
self.assertEqual(b, nt_noncontiguous)
instantiate_parametrized_tests(TestNestedTensor)
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
instantiate_device_type_tests(TestNestedTensorAutograd, globals())

View File

@ -159,7 +159,6 @@ _SKIP_PYTHON_BINDINGS = [
"fill.Scalar", # only used by the functionalization pass
"lift.*",
"normal_functional", # only used by the functionalization pas
"_nested_view_from_buffer", # View only version of _nested_from_buffer. This will force users to only use the "safe" version.
"_nested_view_from_buffer_copy",
"_nested_view_from_buffer_copy_out",
"nbytes",

View File

@ -166,6 +166,7 @@ DONT_REQUIRE_DERIVATIVE = {
# This function returns nested_tensor shape as a tensor that is non-differentiable
"_nested_tensor_size",
"_nested_tensor_strides",
"_nested_tensor_storage_offsets",
}
# The C -> R functions at the time of adding this are still being audited and tested

View File

@ -366,6 +366,17 @@ class Tensor(torch._C.TensorBase):
),
)
return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed)
elif self.is_nested:
args_nested = (
# NB: values() currently returns the storage as a buffer in an unsafe way.
# Ideally, we'd use a private API for this instead. TODO: Switch to this if
# we ever get around to adding it.
self.values(),
self._nested_tensor_size(),
self._nested_tensor_strides(),
self._nested_tensor_storage_offsets(),
)
return (torch._nested_view_from_buffer, args_nested)
elif (
self.data_ptr() == 0
and type(self) is not torch.Tensor