mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix .to(cpu) for Storage (#138011)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138011 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
555bddbef7
commit
37149d032c
@ -5109,10 +5109,20 @@ else:
|
||||
|
||||
@deviceCountAtLeast(1)
|
||||
@onlyCUDA
|
||||
def test_storage_all_devices(self, devices):
|
||||
@parametrize("non_blocking", (True, False))
|
||||
def test_storage_all_devices(self, devices, non_blocking):
|
||||
for device in devices:
|
||||
t = torch.tensor((), device=device)
|
||||
t = torch.randn(6, device=device)
|
||||
self.assertEqual(t.dtype, t.storage().dtype)
|
||||
s = t.untyped_storage()
|
||||
s_cpu = s.to(device='cpu', non_blocking=non_blocking)
|
||||
if non_blocking:
|
||||
torch.cuda.synchronize()
|
||||
self.assertTrue(s_cpu.is_pinned())
|
||||
else:
|
||||
self.assertFalse(s_cpu.is_pinned())
|
||||
t_cpu = torch.empty(()).set_(s_cpu)
|
||||
self.assertEqual(t.cpu(), t_cpu)
|
||||
|
||||
# Note [lazy_clone_ tests with inductor enabled]
|
||||
# These `lazy_clone_` tests are written in a way that makes them pass in
|
||||
|
@ -67,6 +67,17 @@ def _to(self, device, non_blocking=False):
|
||||
if self.device == device:
|
||||
return self
|
||||
|
||||
if device.type == "cpu":
|
||||
pin_memory = non_blocking and self.device.type in (
|
||||
"cuda",
|
||||
torch._C._get_privateuse1_backend_name(),
|
||||
)
|
||||
untyped_storage = torch.empty(
|
||||
self.nbytes(), dtype=torch.uint8, device=device, pin_memory=pin_memory
|
||||
).untyped_storage()
|
||||
untyped_storage.copy_(self, non_blocking)
|
||||
return untyped_storage
|
||||
|
||||
device_module = getattr(torch, device.type, None)
|
||||
assert (
|
||||
device_module is not None
|
||||
|
@ -8,7 +8,16 @@ import functools
|
||||
import io
|
||||
import threading
|
||||
import warnings
|
||||
from typing import Any, cast, Dict as _Dict, Optional as _Optional, Type, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
cast,
|
||||
Dict as _Dict,
|
||||
Optional as _Optional,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
@ -16,6 +25,10 @@ from torch._utils import _to, _type
|
||||
from torch.types import _bool, _int, Storage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._prims_common import DeviceLikeType
|
||||
|
||||
|
||||
__all__ = ["TypedStorage", "UntypedStorage"]
|
||||
|
||||
|
||||
@ -273,9 +286,9 @@ class _StorageBase:
|
||||
storage = storage.clone()
|
||||
return storage
|
||||
|
||||
def to(
|
||||
self, *, device: torch.device, non_blocking: _bool = False
|
||||
) -> Union[_StorageBase, TypedStorage]:
|
||||
def to(self, *, device: DeviceLikeType, non_blocking: _bool = False):
|
||||
if not isinstance(device, torch.device):
|
||||
device = torch.device(device)
|
||||
return _to(self, device, non_blocking)
|
||||
|
||||
def double(self):
|
||||
@ -1061,8 +1074,10 @@ class TypedStorage:
|
||||
hpu_storage = self._untyped_storage.hpu(device, non_blocking)
|
||||
return self._new_wrapped_storage(hpu_storage)
|
||||
|
||||
def to(self, *, device: torch.device, non_blocking: bool = False) -> Self:
|
||||
def to(self, *, device: DeviceLikeType, non_blocking: bool = False) -> Self:
|
||||
_warn_typed_storage_removal()
|
||||
if not isinstance(device, torch.device):
|
||||
device = torch.device(device)
|
||||
if self.dtype in [
|
||||
torch.quint8,
|
||||
torch.quint4x2,
|
||||
|
Reference in New Issue
Block a user