mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Throw error when saving storages that view same data with different type (#66949)
Summary: Fixes https://github.com/pytorch/pytorch/issues/58970 cc mruberry Pull Request resolved: https://github.com/pytorch/pytorch/pull/66949 Reviewed By: albanD Differential Revision: D31926323 Pulled By: anjali411 fbshipit-source-id: f6e7acc0c1968b70a94f9b0b69a32780e8e21a62
This commit is contained in:
committed by
Facebook GitHub Bot
parent
bf60c6e71b
commit
bc3d380ed1
@ -568,6 +568,39 @@ class SerializationMixin(object):
|
||||
with self.assertRaisesRegex(AttributeError, expected_err_msg):
|
||||
torch.load(resource)
|
||||
|
||||
def test_save_different_dtype_error(self):
|
||||
error_msg = r"Cannot save multiple tensors or storages that view the same data as different types"
|
||||
|
||||
devices = ['cpu']
|
||||
if torch.cuda.is_available():
|
||||
devices.append('cuda')
|
||||
|
||||
for device in devices:
|
||||
a = torch.randn(10, dtype=torch.complex128, device=device)
|
||||
f = io.BytesIO()
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
torch.save([a, a.imag], f)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
torch.save([a.storage(), a.imag], f)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
torch.save([a, a.imag.storage()], f)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
torch.save([a.storage(), a.imag.storage()], f)
|
||||
|
||||
a = torch.randn(10, device=device)
|
||||
s_bytes = torch.TypedStorage(
|
||||
wrap_storage=a.storage()._untyped(),
|
||||
dtype=torch.uint8)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
torch.save([a, s_bytes], f)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
torch.save([a.storage(), s_bytes], f)
|
||||
|
||||
class serialization_method(object):
|
||||
def __init__(self, use_zip):
|
||||
|
Reference in New Issue
Block a user