Fix map_location for wrapper subclass and device tensors that go through numpy (#126728)

Fixes https://github.com/pytorch/pytorch/issues/124418

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126728
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2024-05-23 22:32:29 +00:00
committed by PyTorch MergeBot
parent 4ff9113e3d
commit 87f79af24d
4 changed files with 73 additions and 1 deletions

View File

@ -2,6 +2,7 @@ import copyreg
import functools
import logging
import sys
import threading
import traceback
import warnings
from collections import defaultdict
@ -108,6 +109,31 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
return kwargs["async"]
_thread_local_state = threading.local()
def _get_restore_location(device):
"""Return the map_location location.
Used for rebuild functions where the tensor device is distinct from the storage
"""
map_location = getattr(_thread_local_state, "map_location", None)
if map_location is None:
return device
else:
if isinstance(map_location, dict):
return map_location.get(device, device)
elif isinstance(map_location, (str, torch.device)):
return map_location
else:
assert callable(map_location)
raise RuntimeError(
"Callable map_location not supported with _rebuild_wrapper_subclass "
"or _rebuild_device_tensor_from_numpy"
)
# Note [Don't serialize hooks]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Since time immemorial, we have serialized the backward hooks associated with
@ -303,6 +329,7 @@ def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets):
def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
device = _get_restore_location(device)
tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
tensor.requires_grad = requires_grad
return tensor
@ -321,6 +348,7 @@ def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
def _rebuild_wrapper_subclass(
cls, dtype, size, stride, storage_offset, layout, device, requires_grad
):
device = _get_restore_location(device)
return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
size,