Add information about checkpoint offset to untyped storages when torch.load under FakeTensorMode (#147787)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147787
Approved by: https://github.com/albanD
ghstack dependencies: #147786
This commit is contained in:
Mikayla Gawarecki
2025-03-06 08:50:55 +00:00
committed by PyTorch MergeBot
parent bdcc1b579b
commit 209977e6e5
3 changed files with 43 additions and 16 deletions

View File

@ -1616,35 +1616,59 @@ class FakeTensorPropTest(TestCase):
sd = model.state_dict()
sd['tt'] = TwoTensor(torch.randn(2), torch.randn(2))
def _read_tensor_and_check(key, sd_loaded, all_bytes, device):
dtype = torch.float32
t = sd_loaded[key]
self.assertEqual(t.device.type, device)
if isinstance(t, TwoTensor):
untyped_storage_a, untyped_storage_b = t.a.untyped_storage(), t.b.untyped_storage()
offset_a, offset_b = untyped_storage_a._checkpoint_offset, untyped_storage_b._checkpoint_offset
nbytes_a, nbytes_b = untyped_storage_a.nbytes() // 4, untyped_storage_b.nbytes() // 4
result_a = torch.frombuffer(all_bytes, dtype=dtype, count=nbytes_a, offset=offset_a).resize_(t.a.size())
result_b = torch.frombuffer(all_bytes, dtype=dtype, count=nbytes_b, offset=offset_b).resize_(t.b.size())
self.assertEqual(TwoTensor(result_a, result_b), sd[key])
else:
untyped_storage = t.untyped_storage()
offset = untyped_storage._checkpoint_offset
nbytes = untyped_storage.nbytes() // 4
result = torch.frombuffer(all_bytes, dtype=dtype, count=nbytes, offset=offset).resize_(t.size())
self.assertEqual(result, sd[key])
with TemporaryFileName() as state_dict_file, torch.serialization.safe_globals([TwoTensor]):
with TemporaryFileName() as f, torch.serialization.safe_globals([TwoTensor]):
# Create state_dict to be loaded later
torch.save(sd, state_dict_file)
torch.save(sd, f)
with open(f, 'rb') as g:
all_bytes = g.read()
fake_mode = FakeTensorMode()
with fake_mode:
sd_loaded = torch.load(state_dict_file)
self.assertEqual(sd_loaded["weight"].device.type, "cpu")
self.assertEqual(sd_loaded["tt"].device.type, "cpu")
sd_loaded = torch.load(state_dict_file, map_location="cuda")
self.assertEqual(sd_loaded["weight"].device.type, "cuda")
self.assertEqual(sd_loaded["tt"].device.type, "cuda")
sd_loaded = torch.load(f)
for k in sd:
_read_tensor_and_check(k, sd_loaded, all_bytes, 'cpu')
with fake_mode:
sd_loaded = torch.load(f, map_location="cuda")
for k in sd:
_read_tensor_and_check(k, sd_loaded, all_bytes, 'cuda')
for k in sd.keys():
sd[k] = sd[k].to('cuda')
with TemporaryFileName() as state_dict_file, torch.serialization.safe_globals([TwoTensor]):
torch.save(sd, state_dict_file)
with TemporaryFileName() as f, torch.serialization.safe_globals([TwoTensor]):
torch.save(sd, f)
with open(f, 'rb') as g:
all_bytes = g.read()
fake_mode = FakeTensorMode()
with fake_mode:
sd_loaded = torch.load(state_dict_file)
self.assertEqual(sd_loaded["weight"].device.type, "cuda")
self.assertEqual(sd_loaded["tt"].device.type, "cuda")
sd_loaded = torch.load(state_dict_file, map_location="cpu")
self.assertEqual(sd_loaded["weight"].device.type, "cpu")
self.assertEqual(sd_loaded["tt"].device.type, "cpu")
sd_loaded = torch.load(f)
for k in sd:
_read_tensor_and_check(k, sd_loaded, all_bytes, 'cuda')
with fake_mode:
sd_loaded = torch.load(f, map_location="cpu")
for k in sd:
_read_tensor_and_check(k, sd_loaded, all_bytes, 'cpu')
make_propagate_real_tensors_cls(FakeTensorPropTest)