mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add information about checkpoint offset to untyped storages when torch.load under FakeTensorMode (#147787)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147787 Approved by: https://github.com/albanD ghstack dependencies: #147786
This commit is contained in:
committed by
PyTorch MergeBot
parent
bdcc1b579b
commit
209977e6e5
@ -1616,35 +1616,59 @@ class FakeTensorPropTest(TestCase):
|
||||
sd = model.state_dict()
|
||||
sd['tt'] = TwoTensor(torch.randn(2), torch.randn(2))
|
||||
|
||||
def _read_tensor_and_check(key, sd_loaded, all_bytes, device):
|
||||
dtype = torch.float32
|
||||
t = sd_loaded[key]
|
||||
self.assertEqual(t.device.type, device)
|
||||
if isinstance(t, TwoTensor):
|
||||
untyped_storage_a, untyped_storage_b = t.a.untyped_storage(), t.b.untyped_storage()
|
||||
offset_a, offset_b = untyped_storage_a._checkpoint_offset, untyped_storage_b._checkpoint_offset
|
||||
nbytes_a, nbytes_b = untyped_storage_a.nbytes() // 4, untyped_storage_b.nbytes() // 4
|
||||
result_a = torch.frombuffer(all_bytes, dtype=dtype, count=nbytes_a, offset=offset_a).resize_(t.a.size())
|
||||
result_b = torch.frombuffer(all_bytes, dtype=dtype, count=nbytes_b, offset=offset_b).resize_(t.b.size())
|
||||
self.assertEqual(TwoTensor(result_a, result_b), sd[key])
|
||||
else:
|
||||
untyped_storage = t.untyped_storage()
|
||||
offset = untyped_storage._checkpoint_offset
|
||||
nbytes = untyped_storage.nbytes() // 4
|
||||
result = torch.frombuffer(all_bytes, dtype=dtype, count=nbytes, offset=offset).resize_(t.size())
|
||||
self.assertEqual(result, sd[key])
|
||||
|
||||
with TemporaryFileName() as state_dict_file, torch.serialization.safe_globals([TwoTensor]):
|
||||
|
||||
with TemporaryFileName() as f, torch.serialization.safe_globals([TwoTensor]):
|
||||
# Create state_dict to be loaded later
|
||||
torch.save(sd, state_dict_file)
|
||||
torch.save(sd, f)
|
||||
with open(f, 'rb') as g:
|
||||
all_bytes = g.read()
|
||||
|
||||
fake_mode = FakeTensorMode()
|
||||
with fake_mode:
|
||||
sd_loaded = torch.load(state_dict_file)
|
||||
self.assertEqual(sd_loaded["weight"].device.type, "cpu")
|
||||
self.assertEqual(sd_loaded["tt"].device.type, "cpu")
|
||||
sd_loaded = torch.load(state_dict_file, map_location="cuda")
|
||||
self.assertEqual(sd_loaded["weight"].device.type, "cuda")
|
||||
self.assertEqual(sd_loaded["tt"].device.type, "cuda")
|
||||
sd_loaded = torch.load(f)
|
||||
for k in sd:
|
||||
_read_tensor_and_check(k, sd_loaded, all_bytes, 'cpu')
|
||||
with fake_mode:
|
||||
sd_loaded = torch.load(f, map_location="cuda")
|
||||
for k in sd:
|
||||
_read_tensor_and_check(k, sd_loaded, all_bytes, 'cuda')
|
||||
|
||||
|
||||
for k in sd.keys():
|
||||
sd[k] = sd[k].to('cuda')
|
||||
|
||||
with TemporaryFileName() as state_dict_file, torch.serialization.safe_globals([TwoTensor]):
|
||||
torch.save(sd, state_dict_file)
|
||||
with TemporaryFileName() as f, torch.serialization.safe_globals([TwoTensor]):
|
||||
torch.save(sd, f)
|
||||
with open(f, 'rb') as g:
|
||||
all_bytes = g.read()
|
||||
|
||||
fake_mode = FakeTensorMode()
|
||||
with fake_mode:
|
||||
sd_loaded = torch.load(state_dict_file)
|
||||
self.assertEqual(sd_loaded["weight"].device.type, "cuda")
|
||||
self.assertEqual(sd_loaded["tt"].device.type, "cuda")
|
||||
sd_loaded = torch.load(state_dict_file, map_location="cpu")
|
||||
self.assertEqual(sd_loaded["weight"].device.type, "cpu")
|
||||
self.assertEqual(sd_loaded["tt"].device.type, "cpu")
|
||||
sd_loaded = torch.load(f)
|
||||
for k in sd:
|
||||
_read_tensor_and_check(k, sd_loaded, all_bytes, 'cuda')
|
||||
with fake_mode:
|
||||
sd_loaded = torch.load(f, map_location="cpu")
|
||||
for k in sd:
|
||||
_read_tensor_and_check(k, sd_loaded, all_bytes, 'cpu')
|
||||
|
||||
|
||||
make_propagate_real_tensors_cls(FakeTensorPropTest)
|
||||
|
Reference in New Issue
Block a user