mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow storage() to work on python tensor subclasses, but error on future data accesses (#107417)
This was discussed in feedback from the original version of my "reorder proxy/fake" PR. This PR allows calls to `tensor.untyped_storage()` to **always** return a python storage object to the user. Previously, we would error loudly if we detected that the storage had a null dataptr. Instead, I updated the python bindings for the python storage methods that I saw involve data access, to throw an error later, only if you try to access those methods (e.g. `storage.data_ptr()` will now raise an error if the data ptr is null). Pull Request resolved: https://github.com/pytorch/pytorch/pull/107417 Approved by: https://github.com/albanD, https://github.com/ezyang, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
df42f15e28
commit
2c8759df9d
@ -254,6 +254,27 @@ class TestSubclass(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, r"requires that detach\(\) returns an instance of the same type"):
|
||||
param = nn.Parameter(NonRewrappingTensor(torch.randn(3)))
|
||||
|
||||
def test_tensor_subclass_storage_data_accesses_throw(self):
|
||||
from torch.testing._internal.logging_tensor import LoggingTensor
|
||||
x = torch.ones(2)
|
||||
x_log = LoggingTensor(x)
|
||||
# Accessing storage on a tensor subclass is valid
|
||||
storage = x_log.untyped_storage()
|
||||
# This includes accessing metadata on the storage
|
||||
sz = storage.size()
|
||||
# But storage methods that access data will throw
|
||||
with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"):
|
||||
storage.data_ptr()
|
||||
with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"):
|
||||
storage.resize_(0)
|
||||
with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"):
|
||||
storage.copy_(storage)
|
||||
with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"):
|
||||
storage.fill_(0)
|
||||
with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"):
|
||||
storage._write_file("file")
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestSubclass)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user