mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix map_location for wrapper subclass and device tensors that go through numpy (#126728)
Fixes https://github.com/pytorch/pytorch/issues/124418 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126728 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
4ff9113e3d
commit
87f79af24d
@ -2,6 +2,7 @@ import copyreg
|
||||
import functools
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
@ -108,6 +109,31 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
|
||||
return kwargs["async"]
|
||||
|
||||
|
||||
_thread_local_state = threading.local()
|
||||
|
||||
|
||||
def _get_restore_location(device):
|
||||
"""Return the map_location location.
|
||||
|
||||
Used for rebuild functions where the tensor device is distinct from the storage
|
||||
"""
|
||||
|
||||
map_location = getattr(_thread_local_state, "map_location", None)
|
||||
if map_location is None:
|
||||
return device
|
||||
else:
|
||||
if isinstance(map_location, dict):
|
||||
return map_location.get(device, device)
|
||||
elif isinstance(map_location, (str, torch.device)):
|
||||
return map_location
|
||||
else:
|
||||
assert callable(map_location)
|
||||
raise RuntimeError(
|
||||
"Callable map_location not supported with _rebuild_wrapper_subclass "
|
||||
"or _rebuild_device_tensor_from_numpy"
|
||||
)
|
||||
|
||||
|
||||
# Note [Don't serialize hooks]
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# Since time immemorial, we have serialized the backward hooks associated with
|
||||
@ -303,6 +329,7 @@ def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets):
|
||||
|
||||
|
||||
def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
|
||||
device = _get_restore_location(device)
|
||||
tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
|
||||
tensor.requires_grad = requires_grad
|
||||
return tensor
|
||||
@ -321,6 +348,7 @@ def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
|
||||
def _rebuild_wrapper_subclass(
|
||||
cls, dtype, size, stride, storage_offset, layout, device, requires_grad
|
||||
):
|
||||
device = _get_restore_location(device)
|
||||
return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||
cls,
|
||||
size,
|
||||
|
Reference in New Issue
Block a user