mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pickle support for NT (#110219)
Fixes #104198 Pull Request resolved: https://github.com/pytorch/pytorch/pull/110219 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
c9511e8ac9
commit
3693777a86
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: nestedtensor"]
|
||||
|
||||
import io
|
||||
import itertools
|
||||
import unittest
|
||||
from functools import partial
|
||||
@ -103,14 +104,14 @@ def random_nt(device, dtype, num_tensors, max_dims, min_dims=None):
|
||||
# Alternate approach to generating a random NT.
|
||||
# dims should be something like [5, None, 10], with None indicating that a
|
||||
# random ragged structure should be used
|
||||
def random_nt_from_dims(dims, device=None, dtype=None):
|
||||
def random_nt_from_dims(dims, device=None, dtype=None, requires_grad=False):
|
||||
sizes = [
|
||||
[d if d is not None else torch.randint(2, 10, size=(1,)).item() for d in dims[1:]]
|
||||
for d in range(dims[0])
|
||||
]
|
||||
return torch.nested.nested_tensor([
|
||||
torch.randn(*size) for size in sizes
|
||||
], device=device, dtype=dtype)
|
||||
], device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
|
||||
# Creates an NT matching another NT's number of components and
|
||||
@ -2917,6 +2918,30 @@ class TestNestedTensorSubclass(TestCase):
|
||||
|
||||
gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
|
||||
|
||||
@dtypes(torch.float, torch.double, torch.half)
|
||||
@parametrize("requires_grad", [False, True])
|
||||
def test_serialization(self, device, dtype, requires_grad):
|
||||
|
||||
def compare_metadata(nt1, nt2):
|
||||
self.assertEqual(nt1._nested_tensor_size(), nt2._nested_tensor_size())
|
||||
self.assertEqual(nt1._nested_tensor_strides(), nt2._nested_tensor_strides())
|
||||
self.assertEqual(nt1._nested_tensor_storage_offsets(),
|
||||
nt2._nested_tensor_storage_offsets())
|
||||
|
||||
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
|
||||
for a in [nt_contiguous, nt_noncontiguous]:
|
||||
buffer = io.BytesIO()
|
||||
serialized = torch.save(a, buffer)
|
||||
buffer.seek(0)
|
||||
b = torch.load(buffer)
|
||||
# should be both conceptually equal and metadata equivalent
|
||||
self.assertEqual(a, b)
|
||||
compare_metadata(a, b)
|
||||
# should be conceptually equal but not necessarily metadata equivalent
|
||||
self.assertEqual(b, nt_contiguous)
|
||||
self.assertEqual(b, nt_noncontiguous)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestNestedTensor)
|
||||
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
|
||||
instantiate_device_type_tests(TestNestedTensorAutograd, globals())
|
||||
|
||||
@ -159,7 +159,6 @@ _SKIP_PYTHON_BINDINGS = [
|
||||
"fill.Scalar", # only used by the functionalization pass
|
||||
"lift.*",
|
||||
"normal_functional", # only used by the functionalization pas
|
||||
"_nested_view_from_buffer", # View only version of _nested_from_buffer. This will force users to only use the "safe" version.
|
||||
"_nested_view_from_buffer_copy",
|
||||
"_nested_view_from_buffer_copy_out",
|
||||
"nbytes",
|
||||
|
||||
@ -166,6 +166,7 @@ DONT_REQUIRE_DERIVATIVE = {
|
||||
# This function returns nested_tensor shape as a tensor that is non-differentiable
|
||||
"_nested_tensor_size",
|
||||
"_nested_tensor_strides",
|
||||
"_nested_tensor_storage_offsets",
|
||||
}
|
||||
|
||||
# The C -> R functions at the time of adding this are still being audited and tested
|
||||
|
||||
@ -366,6 +366,17 @@ class Tensor(torch._C.TensorBase):
|
||||
),
|
||||
)
|
||||
return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed)
|
||||
elif self.is_nested:
|
||||
args_nested = (
|
||||
# NB: values() currently returns the storage as a buffer in an unsafe way.
|
||||
# Ideally, we'd use a private API for this instead. TODO: Switch to this if
|
||||
# we ever get around to adding it.
|
||||
self.values(),
|
||||
self._nested_tensor_size(),
|
||||
self._nested_tensor_strides(),
|
||||
self._nested_tensor_storage_offsets(),
|
||||
)
|
||||
return (torch._nested_view_from_buffer, args_nested)
|
||||
elif (
|
||||
self.data_ptr() == 0
|
||||
and type(self) is not torch.Tensor
|
||||
|
||||
Reference in New Issue
Block a user