Throw error when saving storages that view same data with different type (#66949)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/58970

cc mruberry

Pull Request resolved: https://github.com/pytorch/pytorch/pull/66949

Reviewed By: albanD

Differential Revision: D31926323

Pulled By: anjali411

fbshipit-source-id: f6e7acc0c1968b70a94f9b0b69a32780e8e21a62
This commit is contained in:
Kurt Mohler
2021-11-16 08:41:14 -08:00
committed by Facebook GitHub Bot
parent bf60c6e71b
commit bc3d380ed1
2 changed files with 66 additions and 0 deletions

View File

@ -568,6 +568,39 @@ class SerializationMixin(object):
with self.assertRaisesRegex(AttributeError, expected_err_msg):
torch.load(resource)
def test_save_different_dtype_error(self):
error_msg = r"Cannot save multiple tensors or storages that view the same data as different types"
devices = ['cpu']
if torch.cuda.is_available():
devices.append('cuda')
for device in devices:
a = torch.randn(10, dtype=torch.complex128, device=device)
f = io.BytesIO()
with self.assertRaisesRegex(RuntimeError, error_msg):
torch.save([a, a.imag], f)
with self.assertRaisesRegex(RuntimeError, error_msg):
torch.save([a.storage(), a.imag], f)
with self.assertRaisesRegex(RuntimeError, error_msg):
torch.save([a, a.imag.storage()], f)
with self.assertRaisesRegex(RuntimeError, error_msg):
torch.save([a.storage(), a.imag.storage()], f)
a = torch.randn(10, device=device)
s_bytes = torch.TypedStorage(
wrap_storage=a.storage()._untyped(),
dtype=torch.uint8)
with self.assertRaisesRegex(RuntimeError, error_msg):
torch.save([a, s_bytes], f)
with self.assertRaisesRegex(RuntimeError, error_msg):
torch.save([a.storage(), s_bytes], f)
class serialization_method(object):
def __init__(self, use_zip):