Fix map_location for wrapper subclass and device tensors that go through numpy (#126728)

Fixes https://github.com/pytorch/pytorch/issues/124418

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126728
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2024-05-23 22:32:29 +00:00
committed by PyTorch MergeBot
parent 4ff9113e3d
commit 87f79af24d
4 changed files with 73 additions and 1 deletions

View File

@ -7,12 +7,18 @@ import tempfile
import types
import unittest
from typing import Union
from unittest.mock import patch
import torch
import torch.testing._internal.common_utils as common
import torch.utils.cpp_extension
from torch.testing._internal.common_utils import IS_ARM64, skipIfTorchDynamo, TEST_CUDA
from torch.testing._internal.common_utils import (
IS_ARM64,
skipIfTorchDynamo,
TemporaryFileName,
TEST_CUDA,
)
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
@ -572,6 +578,24 @@ class TestCppExtensionOpenRgistration(common.TestCase):
self.assertEqual(z_cpu, z[0])
self.assertEqual(z_cpu, z[1])
def test_open_device_numpy_serialization_map_location(self):
torch.utils.rename_privateuse1_backend("foo")
device = self.module.custom_device()
default_protocol = torch.serialization.DEFAULT_PROTOCOL
# This is a hack to test serialization through numpy
with patch.object(torch._C, "_has_storage", return_value=False):
x = torch.randn(2, 3)
x_foo = x.to(device)
sd = {"x": x_foo}
rebuild_func = x_foo._reduce_ex_internal(default_protocol)[0]
self.assertTrue(
rebuild_func is torch._utils._rebuild_device_tensor_from_numpy
)
with TemporaryFileName() as f:
torch.save(sd, f)
sd_loaded = torch.load(f, map_location="cpu")
self.assertTrue(sd_loaded["x"].is_cpu)
if __name__ == "__main__":
common.run_tests()

View File

@ -4403,6 +4403,22 @@ class TestSubclassSerialization(TestCase):
with self.assertRaisesRegex(TypeError, "'str' object is not callable"):
loaded = torch.load(f, weights_only=True)
@unittest.skipIf(not torch.cuda.is_available(), "map_location loads to cuda")
def test_tensor_subclass_map_location(self):
t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))
sd = {'t': t}
with TemporaryFileName() as f:
torch.save(sd, f)
sd_loaded = torch.load(f, map_location=torch.device('cuda:0'))
self.assertTrue(sd_loaded['t'].device == torch.device('cuda:0'))
self.assertTrue(sd_loaded['t'].a.device == torch.device('cuda:0'))
self.assertTrue(sd_loaded['t'].b.device == torch.device('cuda:0'))
# make sure map_location is not propagated over multiple torch.load calls
sd_loaded = torch.load(f)
self.assertTrue(sd_loaded['t'].device == torch.device('cpu'))
self.assertTrue(sd_loaded['t'].a.device == torch.device('cpu'))
self.assertTrue(sd_loaded['t'].b.device == torch.device('cpu'))
instantiate_device_type_tests(TestBothSerialization, globals())

View File

@ -2,6 +2,7 @@ import copyreg
import functools
import logging
import sys
import threading
import traceback
import warnings
from collections import defaultdict
@ -108,6 +109,31 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
return kwargs["async"]
_thread_local_state = threading.local()
def _get_restore_location(device):
"""Return the map_location location.
Used for rebuild functions where the tensor device is distinct from the storage
"""
map_location = getattr(_thread_local_state, "map_location", None)
if map_location is None:
return device
else:
if isinstance(map_location, dict):
return map_location.get(device, device)
elif isinstance(map_location, (str, torch.device)):
return map_location
else:
assert callable(map_location)
raise RuntimeError(
"Callable map_location not supported with _rebuild_wrapper_subclass "
"or _rebuild_device_tensor_from_numpy"
)
# Note [Don't serialize hooks]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Since time immemorial, we have serialized the backward hooks associated with
@ -303,6 +329,7 @@ def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets):
def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
device = _get_restore_location(device)
tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
tensor.requires_grad = requires_grad
return tensor
@ -321,6 +348,7 @@ def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
def _rebuild_wrapper_subclass(
cls, dtype, size, stride, storage_offset, layout, device, requires_grad
):
device = _get_restore_location(device)
return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
size,

View File

@ -1453,7 +1453,11 @@ def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall
unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
unpickler.persistent_load = persistent_load
# Needed for tensors where storage device and rebuild tensor device are
# not connected (wrapper subclasses and tensors rebuilt using numpy)
torch._utils._thread_local_state.map_location = map_location
result = unpickler.load()
del torch._utils._thread_local_state.map_location
torch._utils._validate_loaded_sparse_tensors()
torch._C._log_api_usage_metadata(