Make torch.device usable as a context manager (#91525)

Fixes https://github.com/pytorch/pytorch/issues/82296
Fixes https://github.com/pytorch/pytorch/issues/27878
Fixes https://github.com/pytorch/pytorch/issues/260

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91525
Approved by: https://github.com/albanD
This commit is contained in:
Edward Z. Yang
2023-01-03 07:59:00 +08:00
committed by PyTorch MergeBot
parent aa0ca994ca
commit 619d52a5d2
12 changed files with 401 additions and 90 deletions

View File

@ -177,6 +177,20 @@ Via a string and device ordinal:
>>> torch.device('cpu', 0)
device(type='cpu', index=0)
The device object can also be used as a context manager to change the default
device tensors are allocated on:
::
>>> with torch.device('cuda:1'):
... r = torch.randn(2, 3)
>>> r.device
device(type='cuda', index=1)
This context manager has no effect if a factory function is passed an explicit,
non-None device argument. To globally change the default device, see also
:func:`torch.set_default_device`.
.. note::
The :class:`torch.device` argument in functions can generally be substituted with a string.
This allows for fast prototyping of code.

View File

@ -17,6 +17,7 @@ Tensors
is_nonzero
set_default_dtype
get_default_dtype
set_default_device
set_default_tensor_type
numel
set_printoptions

View File

@ -760,7 +760,7 @@ class Wrapper:
val = getattr(self._data, name)
# If it's a method
if callable(val):
if not isinstance(val, torch.device) and callable(val):
c = getattr(type(self._data), name)
# Don't append self to args if classmethod/staticmethod
if c is val:

View File

@ -14,8 +14,17 @@ import torch
import torch.nn as nn
import torch.utils.data
from torch.utils.data import DataLoader
from torch.testing._internal.common_device_type import (
ops,
onlyCPU,
instantiate_device_type_tests,
)
from torch.testing._internal.common_methods_invocations import op_db
import torch.cuda
from torch.utils._pytree import tree_any, tree_all_only
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from torch import set_default_device
from torch.utils._device import set_device
import torch.utils.cpp_extension
from torch.autograd._functions.utils import check_onnx_broadcast
from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings
@ -796,6 +805,74 @@ class TestExtensionUtils(TestCase):
torch._register_device_module('xpu', DummyXPUModule)
class TestDeviceUtils(TestCase):
def test_basic(self):
with torch.device('meta') as dev:
x = torch.empty(3, 3)
self.assertEqual(x.device.type, 'meta')
self.assertEqual(dev, torch.device('meta'))
def test_decorator(self):
@set_device('meta')
def f():
return torch.empty(3, 3)
self.assertEqual(f().device.type, 'meta')
def test_decorator_generator(self):
@set_device('meta')
def f():
yield torch.empty(3, 3)
yield torch.empty(3, 3)
r1, r2 = list(f())
self.assertEqual(r1.device.type, 'meta')
self.assertEqual(r2.device.type, 'meta')
def test_nn_module(self):
with torch.device('meta'):
m = nn.Linear(40, 50)
self.assertEqual(m.weight.device.type, 'meta')
def test_set_default_device(self):
try:
set_default_device('meta')
r = torch.empty(2, 2)
finally:
set_default_device(None)
self.assertEqual(r.device.type, 'meta')
@onlyCPU
@ops(op_db)
def test_device_mode_ops(self, device, dtype, op):
func = op.get_op()
samples = op.sample_inputs(device, dtype, requires_grad=False)
for sample in samples:
# Only test samples which don't have Tensor inputs. However,
# we don't test the factory property on OpInfo as it is very,
# very incomplete
if tree_any(
lambda x: isinstance(x, torch.Tensor),
(sample.input, sample.args, sample.kwargs)
):
continue
# Many OpInfos will explicitly pass in a device. DeviceContext
# will respect device if it is explicitly specified. To test
# DeviceContext, we have to remove the device kwarg in this case.
# NB: Can't pass None to sample_inputs, the function can't
# handle it.
kwargs = sample.kwargs.copy()
kwargs.pop('device', None)
with torch.device('meta'):
r = func(sample.input, *sample.args, **kwargs)
self.assertTrue(
tree_all_only(torch.Tensor, lambda x: x.device.type == 'meta', r)
)
instantiate_device_type_tests(TestDeviceUtils, globals())
class TestCppExtensionUtils(TestCase):
def test_cpp_compiler_is_ok(self):
self.assertTrue(torch.utils.cpp_extension.check_compiler_ok_for_platform('c++'))

View File

@ -45,6 +45,12 @@ class device:
@overload
def __init__(self, type: str, index: _int) -> None: ...
def __call__(self, func: T) -> T: ...
def __enter__(self) -> "device": ...
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
def __reduce__(self) -> Tuple[Any, ...]: ... # THPDevice_reduce
# Defined in torch/csrc/Stream.cpp

View File

@ -35,6 +35,7 @@ import builtins
__all__ = [
'typename', 'is_tensor', 'is_storage', 'set_default_tensor_type',
'set_default_device',
'set_rng_state', 'get_rng_state', 'manual_seed', 'initial_seed', 'seed',
'save', 'load', 'set_printoptions', 'chunk', 'split', 'stack', 'matmul',
'no_grad', 'enable_grad', 'rand', 'randn', 'inference_mode',
@ -444,6 +445,49 @@ def is_storage(obj):
return type(obj) in _storage_classes
_GLOBAL_DEVICE_CONTEXT = None
def set_default_device(device):
"""Sets the default ``torch.Tensor`` to be allocated on ``device``. This
does not affect factory function calls which are called with an explicit
``device`` argument. Factory calls will be performed as if they
were passed ``device`` as an argument.
To only temporarily change the default device instead of setting it
globally, use ``with torch.device(device):`` instead.
The default device is initially ``cpu``. If you set the default tensor
device to another device (e.g., ``cuda``) without a device index, tensors
will be allocated on whatever the current device for the device type,
even after :func:`torch.cuda.set_device` is called.
Args:
device (device or string): the device to set as default
Example::
>>> # xdoctest: +SKIP("requires cuda, changes global state")
>>> torch.tensor([1.2, 3]).device
device(type='cpu')
>>> torch.set_default_device('cuda') # current device is 0
>>> torch.tensor([1.2, 3]).device
device(type='cuda', index=0)
>>> torch.set_default_device('cuda:1')
>>> torch.tensor([1.2, 3]).device
device(type='cuda', index=1)
"""
global _GLOBAL_DEVICE_CONTEXT
if _GLOBAL_DEVICE_CONTEXT is not None:
_GLOBAL_DEVICE_CONTEXT.__exit__(None, None, None)
if device is None:
_GLOBAL_DEVICE_CONTEXT = None
return
from torch.utils._device import DeviceContext
_GLOBAL_DEVICE_CONTEXT = DeviceContext(device)
_GLOBAL_DEVICE_CONTEXT.__enter__()
def set_default_tensor_type(t):
r"""Sets the default ``torch.Tensor`` type to floating point tensor type
``t``. This type will also be used as default floating point type for

View File

@ -1,96 +1,11 @@
import sys
import torch
import functools
import inspect
import warnings
from typing import Any, Callable, TypeVar, cast
from typing import Any
from torch.utils._contextlib import _DecoratorContextManager
__all__ = ['no_grad', 'enable_grad', 'set_grad_enabled',
'inference_mode', 'set_multithreading_enabled']
# Used for annotating the decorator usage of 'no_grad' and 'enable_grad'.
# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
FuncType = Callable[..., Any]
F = TypeVar('F', bound=FuncType)
class _DecoratorContextManager:
"""Allow a context manager to be used as a decorator"""
def __call__(self, func: F) -> F:
if inspect.isclass(func):
warnings.warn("Decorating classes is deprecated and will be disabled in "
"future versions. You should only decorate functions or methods. "
"To preserve the current behavior of class decoration, you can "
"directly decorate the `__init__` method and nothing else.")
if inspect.isgeneratorfunction(func):
return self._wrap_generator(func)
@functools.wraps(func)
def decorate_context(*args, **kwargs):
with self.clone():
return func(*args, **kwargs)
return cast(F, decorate_context)
def _wrap_generator(self, func):
"""Wrap each generator invocation with the context manager"""
@functools.wraps(func)
def generator_context(*args, **kwargs):
gen = func(*args, **kwargs)
# Generators are suspended and unsuspended at `yield`, hence we
# make sure the grad mode is properly set every time the execution
# flow returns into the wrapped generator and restored when it
# returns through our `yield` to our caller (see PR #49017).
try:
# Issuing `None` to a generator fires it up
with self.clone():
response = gen.send(None)
while True:
try:
# Forward the response to our caller and get its next request
request = yield response
except GeneratorExit:
# Inform the still active generator about its imminent closure
with self.clone():
gen.close()
raise
except BaseException:
# Propagate the exception thrown at us by the caller
with self.clone():
response = gen.throw(*sys.exc_info())
else:
# Pass the last request to the generator and get its response
with self.clone():
response = gen.send(request)
# We let the exceptions raised above by the generator's `.throw` or
# `.send` methods bubble up to our caller, except for StopIteration
except StopIteration as e:
# The generator informed us that it is done: take whatever its
# returned value (if any) was and indicate that we're done too
# by returning it (see docs for python's return-statement).
return e.value
return generator_context
def __enter__(self) -> None:
raise NotImplementedError
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
raise NotImplementedError
def clone(self):
# override this method if your children class takes __init__ parameters
return self.__class__()
class no_grad(_DecoratorContextManager):
r"""Context-manager that disabled gradient calculation.

View File

@ -169,6 +169,36 @@ PyObject* THPDevice_reduce(PyObject* _self, PyObject* noargs) {
END_HANDLE_TH_ERRORS
}
PyObject* THPDevice_enter(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
py::object mode = py::module::import("torch.utils._device")
.attr("DeviceContext")(py::handle(self));
at::impl::PythonTorchFunctionTLS::push_onto_stack(
std::make_shared<c10::SafePyObject>(
mode.release().ptr(), getPyInterpreter()));
// So that with torch.device('cuda') as dev: works
Py_INCREF(self);
return self;
END_HANDLE_TH_ERRORS
}
PyObject* THPDevice_exit(PyObject* self, PyObject* unused) {
HANDLE_TH_ERRORS
at::impl::PythonTorchFunctionTLS::pop_stack();
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THPDevice_call(PyObject* self, PyObject* args, PyObject* kwargs) {
HANDLE_TH_ERRORS
py::object deco =
py::module::import("torch.utils._device").attr("device_decorator");
return deco(py::handle(self), *py::handle(args), **py::handle(kwargs))
.release()
.ptr();
END_HANDLE_TH_ERRORS
}
typedef PyObject* (*getter)(PyObject*, void*);
// NB: If you edit these properties/methods, update torch/_C/__init__.pyi.in
@ -182,6 +212,8 @@ static struct PyGetSetDef THPDevice_properties[] = {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
static PyMethodDef THPDevice_methods[] = {
{"__reduce__", THPDevice_reduce, METH_NOARGS, nullptr},
{"__enter__", THPDevice_enter, METH_NOARGS, nullptr},
{"__exit__", THPDevice_exit, METH_VARARGS, nullptr},
{nullptr} /* Sentinel */
};
@ -199,6 +231,11 @@ PyTypeObject THPDeviceType = {
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
(hashfunc)THPDevice_hash, /* tp_hash */
// TODO: We're not sure if this is a good idea or not, because making
// torch.device callable means that it will start returning true
// for callable() queries, and that is unexpected. We can always add
// this later, so for now, don't actually implement this
// THPDevice_call, /* tp_call */
nullptr, /* tp_call */
(reprfunc)THPDevice_str, /* tp_str */
nullptr, /* tp_getattro */

View File

@ -75,6 +75,7 @@ def get_ignored_functions() -> Set[Callable]:
torch.is_tensor,
torch.is_storage,
torch.set_default_tensor_type,
torch.set_default_device,
torch.set_rng_state,
torch.get_rng_state,
torch.manual_seed,

View File

@ -378,7 +378,7 @@ def kaiser(
device=device,
requires_grad=requires_grad)
return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(torch.tensor(beta))
return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(torch.tensor(beta, device=device))
@_add_docstr(

141
torch/utils/_contextlib.py Normal file
View File

@ -0,0 +1,141 @@
# Extra utilities for working with context managers that should have been
# in the standard library but are not
import functools
import inspect
import warnings
import sys
from typing import Any, Callable, TypeVar, cast
# Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
# 'no_grad' and 'enable_grad').
# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
FuncType = Callable[..., Any]
F = TypeVar('F', bound=FuncType)
def _wrap_generator(ctx_factory, func):
"""
Wrap each generator invocation with the context manager factory.
The input should be a function that returns a context manager,
not a context manager itself, to handle one-shot context managers.
"""
@functools.wraps(func)
def generator_context(*args, **kwargs):
gen = func(*args, **kwargs)
# Generators are suspended and unsuspended at `yield`, hence we
# make sure the grad mode is properly set every time the execution
# flow returns into the wrapped generator and restored when it
# returns through our `yield` to our caller (see PR #49017).
try:
# Issuing `None` to a generator fires it up
with ctx_factory():
response = gen.send(None)
while True:
try:
# Forward the response to our caller and get its next request
request = yield response
except GeneratorExit:
# Inform the still active generator about its imminent closure
with ctx_factory():
gen.close()
raise
except BaseException:
# Propagate the exception thrown at us by the caller
with ctx_factory():
response = gen.throw(*sys.exc_info())
else:
# Pass the last request to the generator and get its response
with ctx_factory():
response = gen.send(request)
# We let the exceptions raised above by the generator's `.throw` or
# `.send` methods bubble up to our caller, except for StopIteration
except StopIteration as e:
# The generator informed us that it is done: take whatever its
# returned value (if any) was and indicate that we're done too
# by returning it (see docs for python's return-statement).
return e.value
return generator_context
def context_decorator(ctx, func):
"""
Like contextlib.ContextDecorator, but:
1. Is done by wrapping, rather than inheritance, so it works with context
managers that are implemented from C and thus cannot easily inherit from
Python classes
2. Wraps generators in the intuitive way (c.f. https://bugs.python.org/issue37743)
3. Errors out if you try to wrap a class, because it is ambiguous whether
or not you intended to wrap only the constructor
The input argument can either be a context manager (in which case it must
be a multi-shot context manager that can be directly invoked multiple times)
or a callable that produces a context manager.
"""
assert not (callable(ctx) and hasattr(ctx, '__enter__')), (
f"Passed in {ctx} is both callable and also a valid context manager "
"(has __enter__), making it ambiguous which interface to use. If you "
"intended to pass a context manager factory, rewrite your call as "
"context_decorator(lambda: ctx()); if you intended to pass a context "
"manager directly, rewrite your call as context_decorator(lambda: ctx)"
)
if not callable(ctx):
def ctx_factory():
return ctx
else:
ctx_factory = ctx
if inspect.isclass(func):
raise RuntimeError(
"Cannot decorate classes; it is ambiguous whether or not only the "
"constructor or all methods should have the context manager applied; "
"additionally, decorating a class at definition-site will prevent "
"use of the identifier as a conventional type. "
"To specify which methods to decorate, decorate each of them "
"individually."
)
if inspect.isgeneratorfunction(func):
return _wrap_generator(ctx_factory, func)
@functools.wraps(func)
def decorate_context(*args, **kwargs):
with ctx_factory():
return func(*args, **kwargs)
return decorate_context
class _DecoratorContextManager:
"""Allow a context manager to be used as a decorator"""
def __call__(self, func: F) -> F:
if inspect.isclass(func):
warnings.warn("Decorating classes is deprecated and will be disabled in "
"future versions. You should only decorate functions or methods. "
"To preserve the current behavior of class decoration, you can "
"directly decorate the `__init__` method and nothing else.")
func = cast(F, lambda *args, **kwargs: func(*args, **kwargs))
return cast(F, context_decorator(self.clone, func))
def __enter__(self) -> None:
raise NotImplementedError
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
raise NotImplementedError
def clone(self):
# override this method if your children class takes __init__ parameters
return self.__class__()

75
torch/utils/_device.py Normal file
View File

@ -0,0 +1,75 @@
import torch
from torch.overrides import TorchFunctionMode
from torch.utils._contextlib import context_decorator
import functools
@functools.lru_cache(1)
def _device_constructors():
return {
# standard ones
torch.empty,
torch.empty_strided,
torch.empty_quantized,
torch.ones,
torch.arange,
torch.bartlett_window,
torch.blackman_window,
torch.eye,
torch.fft.fftfreq,
torch.fft.rfftfreq,
torch.full,
torch.fill,
torch.hamming_window,
torch.hann_window,
torch.kaiser_window,
torch.linspace,
torch.logspace,
torch.nested.nested_tensor,
# This function doesn't actually take a device argument
# torch.normal,
torch.ones,
torch.rand,
torch.randn,
torch.randint,
torch.randperm,
torch.range,
torch.sparse_coo_tensor,
torch.sparse_compressed_tensor,
torch.sparse_csr_tensor,
torch.sparse_csc_tensor,
torch.sparse_bsr_tensor,
torch.sparse_bsc_tensor,
torch.tril_indices,
torch.triu_indices,
torch.vander,
torch.zeros,
torch.asarray,
# weird ones
torch.tensor,
torch.as_tensor,
torch.scalar_tensor,
}
# NB: This is directly called from C++ in torch/csrc/Device.cpp
class DeviceContext(TorchFunctionMode):
def __init__(self, device):
self.device = torch.device(device)
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if func in _device_constructors() and kwargs.get('device') is None:
kwargs['device'] = self.device
return func(*args, **kwargs)
# NB: This is directly called from C++ in torch/csrc/Device.cpp
def device_decorator(device, func):
return context_decorator(lambda: device, func)
def set_device(device):
"""
Decorator which sets the default device inside of the wrapped
function. If you would like to use this as a context manager,
use device as a context manager directly, e.g.,
``with torch.device(device)``.
"""
return lambda func: device_decorator(torch.device(device), func)