Files
pytorch/torch/cuda/_utils.py
Peter Bell eece6da162 [inductor] Reduce device context manager overhead (#91045)
This adds `torch.cuda._DeviceGuard` which is a stripped down version of
`torch.cuda.device` with lower overhead. To do this, it only accepts `int` as
the device so we don't need to call `_get_device_index` and is implemented
with a new C++ helper `torch._C._cuda_exchangeDevice` that allows
`_DeviceGuard.__enter__` to be just a single function call. On my machine,
I see a drop from 3.8us of overhead to 0.94 us with this simple benchmark:

```python
def set_device():
    with torch.cuda.device(0):
        pass

%timeit set_device()
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91045
Approved by: https://github.com/ngimel, https://github.com/anijain2305
2023-01-12 16:51:59 +00:00

50 lines
2.1 KiB
Python

import torch
from typing import Any
# The _get_device_index has been moved to torch.utils._get_device_index
from torch._utils import _get_device_index as _torch_get_device_index
def _get_device_index(device: Any, optional: bool = False,
allow_cpu: bool = False) -> int:
r"""Gets the device index from :attr:`device`, which can be a torch.device
object, a Python integer, or ``None``.
If :attr:`device` is a torch.device object, returns the device index if it
is a CUDA device. Note that for a CUDA device without a specified index,
i.e., ``torch.device('cuda')``, this will return the current default CUDA
device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
CPU devices will be accepted and ``-1`` will be returned in this case.
If :attr:`device` is a Python integer, it is returned as is.
If :attr:`device` is ``None``, this will return the current default CUDA
device if :attr:`optional` is ``True``.
"""
if isinstance(device, int):
return device
if isinstance(device, str):
device = torch.device(device)
if isinstance(device, torch.device):
if allow_cpu:
if device.type not in ['cuda', 'cpu']:
raise ValueError('Expected a cuda or cpu device, but got: {}'.format(device))
elif device.type != 'cuda':
raise ValueError('Expected a cuda device, but got: {}'.format(device))
if not torch.jit.is_scripting():
if isinstance(device, torch.cuda.device):
return device.idx
return _torch_get_device_index(device, optional, allow_cpu)
def _dummy_type(name: str) -> type:
def get_err_fn(is_init: bool):
def err_fn(obj, *args, **kwargs):
if is_init:
class_name = obj.__class__.__name__
else:
class_name = obj.__name__
raise RuntimeError(
"Tried to instantiate dummy base class {}".format(class_name))
return err_fn
return type(name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)})