Fix .to(cpu) for Storage (#138011)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138011
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2024-10-22 21:12:10 +00:00
committed by PyTorch MergeBot
parent 555bddbef7
commit 37149d032c
3 changed files with 43 additions and 7 deletions

View File

@ -5109,10 +5109,20 @@ else:
@deviceCountAtLeast(1)
@onlyCUDA
def test_storage_all_devices(self, devices):
@parametrize("non_blocking", (True, False))
def test_storage_all_devices(self, devices, non_blocking):
for device in devices:
t = torch.tensor((), device=device)
t = torch.randn(6, device=device)
self.assertEqual(t.dtype, t.storage().dtype)
s = t.untyped_storage()
s_cpu = s.to(device='cpu', non_blocking=non_blocking)
if non_blocking:
torch.cuda.synchronize()
self.assertTrue(s_cpu.is_pinned())
else:
self.assertFalse(s_cpu.is_pinned())
t_cpu = torch.empty(()).set_(s_cpu)
self.assertEqual(t.cpu(), t_cpu)
# Note [lazy_clone_ tests with inductor enabled]
# These `lazy_clone_` tests are written in a way that makes them pass in

View File

@ -67,6 +67,17 @@ def _to(self, device, non_blocking=False):
if self.device == device:
return self
if device.type == "cpu":
pin_memory = non_blocking and self.device.type in (
"cuda",
torch._C._get_privateuse1_backend_name(),
)
untyped_storage = torch.empty(
self.nbytes(), dtype=torch.uint8, device=device, pin_memory=pin_memory
).untyped_storage()
untyped_storage.copy_(self, non_blocking)
return untyped_storage
device_module = getattr(torch, device.type, None)
assert (
device_module is not None

View File

@ -8,7 +8,16 @@ import functools
import io
import threading
import warnings
from typing import Any, cast, Dict as _Dict, Optional as _Optional, Type, TypeVar, Union
from typing import (
Any,
cast,
Dict as _Dict,
Optional as _Optional,
Type,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing_extensions import Self
import torch
@ -16,6 +25,10 @@ from torch._utils import _to, _type
from torch.types import _bool, _int, Storage
if TYPE_CHECKING:
from torch._prims_common import DeviceLikeType
__all__ = ["TypedStorage", "UntypedStorage"]
@ -273,9 +286,9 @@ class _StorageBase:
storage = storage.clone()
return storage
def to(
self, *, device: torch.device, non_blocking: _bool = False
) -> Union[_StorageBase, TypedStorage]:
def to(self, *, device: DeviceLikeType, non_blocking: _bool = False):
if not isinstance(device, torch.device):
device = torch.device(device)
return _to(self, device, non_blocking)
def double(self):
@ -1061,8 +1074,10 @@ class TypedStorage:
hpu_storage = self._untyped_storage.hpu(device, non_blocking)
return self._new_wrapped_storage(hpu_storage)
def to(self, *, device: torch.device, non_blocking: bool = False) -> Self:
def to(self, *, device: DeviceLikeType, non_blocking: bool = False) -> Self:
_warn_typed_storage_removal()
if not isinstance(device, torch.device):
device = torch.device(device)
if self.dtype in [
torch.quint8,
torch.quint4x2,