mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix map_location for wrapper subclass and device tensors that go through numpy (#126728)
Fixes https://github.com/pytorch/pytorch/issues/124418 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126728 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
4ff9113e3d
commit
87f79af24d
@ -7,12 +7,18 @@ import tempfile
|
||||
import types
|
||||
import unittest
|
||||
from typing import Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
import torch.testing._internal.common_utils as common
|
||||
import torch.utils.cpp_extension
|
||||
from torch.testing._internal.common_utils import IS_ARM64, skipIfTorchDynamo, TEST_CUDA
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_ARM64,
|
||||
skipIfTorchDynamo,
|
||||
TemporaryFileName,
|
||||
TEST_CUDA,
|
||||
)
|
||||
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
|
||||
|
||||
|
||||
@ -572,6 +578,24 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
||||
self.assertEqual(z_cpu, z[0])
|
||||
self.assertEqual(z_cpu, z[1])
|
||||
|
||||
def test_open_device_numpy_serialization_map_location(self):
|
||||
torch.utils.rename_privateuse1_backend("foo")
|
||||
device = self.module.custom_device()
|
||||
default_protocol = torch.serialization.DEFAULT_PROTOCOL
|
||||
# This is a hack to test serialization through numpy
|
||||
with patch.object(torch._C, "_has_storage", return_value=False):
|
||||
x = torch.randn(2, 3)
|
||||
x_foo = x.to(device)
|
||||
sd = {"x": x_foo}
|
||||
rebuild_func = x_foo._reduce_ex_internal(default_protocol)[0]
|
||||
self.assertTrue(
|
||||
rebuild_func is torch._utils._rebuild_device_tensor_from_numpy
|
||||
)
|
||||
with TemporaryFileName() as f:
|
||||
torch.save(sd, f)
|
||||
sd_loaded = torch.load(f, map_location="cpu")
|
||||
self.assertTrue(sd_loaded["x"].is_cpu)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common.run_tests()
|
||||
|
@ -4403,6 +4403,22 @@ class TestSubclassSerialization(TestCase):
|
||||
with self.assertRaisesRegex(TypeError, "'str' object is not callable"):
|
||||
loaded = torch.load(f, weights_only=True)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "map_location loads to cuda")
|
||||
def test_tensor_subclass_map_location(self):
|
||||
t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))
|
||||
sd = {'t': t}
|
||||
|
||||
with TemporaryFileName() as f:
|
||||
torch.save(sd, f)
|
||||
sd_loaded = torch.load(f, map_location=torch.device('cuda:0'))
|
||||
self.assertTrue(sd_loaded['t'].device == torch.device('cuda:0'))
|
||||
self.assertTrue(sd_loaded['t'].a.device == torch.device('cuda:0'))
|
||||
self.assertTrue(sd_loaded['t'].b.device == torch.device('cuda:0'))
|
||||
# make sure map_location is not propagated over multiple torch.load calls
|
||||
sd_loaded = torch.load(f)
|
||||
self.assertTrue(sd_loaded['t'].device == torch.device('cpu'))
|
||||
self.assertTrue(sd_loaded['t'].a.device == torch.device('cpu'))
|
||||
self.assertTrue(sd_loaded['t'].b.device == torch.device('cpu'))
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestBothSerialization, globals())
|
||||
|
@ -2,6 +2,7 @@ import copyreg
|
||||
import functools
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
@ -108,6 +109,31 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
|
||||
return kwargs["async"]
|
||||
|
||||
|
||||
_thread_local_state = threading.local()
|
||||
|
||||
|
||||
def _get_restore_location(device):
|
||||
"""Return the map_location location.
|
||||
|
||||
Used for rebuild functions where the tensor device is distinct from the storage
|
||||
"""
|
||||
|
||||
map_location = getattr(_thread_local_state, "map_location", None)
|
||||
if map_location is None:
|
||||
return device
|
||||
else:
|
||||
if isinstance(map_location, dict):
|
||||
return map_location.get(device, device)
|
||||
elif isinstance(map_location, (str, torch.device)):
|
||||
return map_location
|
||||
else:
|
||||
assert callable(map_location)
|
||||
raise RuntimeError(
|
||||
"Callable map_location not supported with _rebuild_wrapper_subclass "
|
||||
"or _rebuild_device_tensor_from_numpy"
|
||||
)
|
||||
|
||||
|
||||
# Note [Don't serialize hooks]
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# Since time immemorial, we have serialized the backward hooks associated with
|
||||
@ -303,6 +329,7 @@ def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets):
|
||||
|
||||
|
||||
def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
|
||||
device = _get_restore_location(device)
|
||||
tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
|
||||
tensor.requires_grad = requires_grad
|
||||
return tensor
|
||||
@ -321,6 +348,7 @@ def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
|
||||
def _rebuild_wrapper_subclass(
|
||||
cls, dtype, size, stride, storage_offset, layout, device, requires_grad
|
||||
):
|
||||
device = _get_restore_location(device)
|
||||
return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||
cls,
|
||||
size,
|
||||
|
@ -1453,7 +1453,11 @@ def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall
|
||||
|
||||
unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
|
||||
unpickler.persistent_load = persistent_load
|
||||
# Needed for tensors where storage device and rebuild tensor device are
|
||||
# not connected (wrapper subclasses and tensors rebuilt using numpy)
|
||||
torch._utils._thread_local_state.map_location = map_location
|
||||
result = unpickler.load()
|
||||
del torch._utils._thread_local_state.map_location
|
||||
|
||||
torch._utils._validate_loaded_sparse_tensors()
|
||||
torch._C._log_api_usage_metadata(
|
||||
|
Reference in New Issue
Block a user