Allow storage() to work on python tensor subclasses, but error on future data accesses (#107417)

This was discussed in feedback from the original version of my "reorder proxy/fake" PR. This PR allows calls to `tensor.untyped_storage()` to **always** return a python storage object to the user. Previously, we would error loudly if we detected that the storage had a null dataptr.

Instead, I updated the python bindings for the python storage methods that I saw involve data access, to throw an error later, only if you try to access those methods (e.g. `storage.data_ptr()` will now raise an error if the data ptr is null).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107417
Approved by: https://github.com/albanD, https://github.com/ezyang, https://github.com/zou3519
This commit is contained in:
Brian Hirsh
2023-08-21 18:09:46 -07:00
committed by PyTorch MergeBot
parent df42f15e28
commit 2c8759df9d
5 changed files with 65 additions and 15 deletions

View File

@ -254,6 +254,27 @@ class TestSubclass(TestCase):
with self.assertRaisesRegex(RuntimeError, r"requires that detach\(\) returns an instance of the same type"):
param = nn.Parameter(NonRewrappingTensor(torch.randn(3)))
def test_tensor_subclass_storage_data_accesses_throw(self):
from torch.testing._internal.logging_tensor import LoggingTensor
x = torch.ones(2)
x_log = LoggingTensor(x)
# Accessing storage on a tensor subclass is valid
storage = x_log.untyped_storage()
# This includes accessing metadata on the storage
sz = storage.size()
# But storage methods that access data will throw
with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"):
storage.data_ptr()
with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"):
storage.resize_(0)
with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"):
storage.copy_(storage)
with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"):
storage.fill_(0)
with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"):
storage._write_file("file")
instantiate_parametrized_tests(TestSubclass)
if __name__ == '__main__':