mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make torch.device usable as a context manager (#91525)
Fixes https://github.com/pytorch/pytorch/issues/82296 Fixes https://github.com/pytorch/pytorch/issues/27878 Fixes https://github.com/pytorch/pytorch/issues/260 Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/91525 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
aa0ca994ca
commit
619d52a5d2
@ -177,6 +177,20 @@ Via a string and device ordinal:
|
||||
>>> torch.device('cpu', 0)
|
||||
device(type='cpu', index=0)
|
||||
|
||||
The device object can also be used as a context manager to change the default
|
||||
device tensors are allocated on:
|
||||
|
||||
::
|
||||
|
||||
>>> with torch.device('cuda:1'):
|
||||
... r = torch.randn(2, 3)
|
||||
>>> r.device
|
||||
device(type='cuda', index=1)
|
||||
|
||||
This context manager has no effect if a factory function is passed an explicit,
|
||||
non-None device argument. To globally change the default device, see also
|
||||
:func:`torch.set_default_device`.
|
||||
|
||||
.. note::
|
||||
The :class:`torch.device` argument in functions can generally be substituted with a string.
|
||||
This allows for fast prototyping of code.
|
||||
|
@ -17,6 +17,7 @@ Tensors
|
||||
is_nonzero
|
||||
set_default_dtype
|
||||
get_default_dtype
|
||||
set_default_device
|
||||
set_default_tensor_type
|
||||
numel
|
||||
set_printoptions
|
||||
|
@ -760,7 +760,7 @@ class Wrapper:
|
||||
val = getattr(self._data, name)
|
||||
|
||||
# If it's a method
|
||||
if callable(val):
|
||||
if not isinstance(val, torch.device) and callable(val):
|
||||
c = getattr(type(self._data), name)
|
||||
# Don't append self to args if classmethod/staticmethod
|
||||
if c is val:
|
||||
|
@ -14,8 +14,17 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.data
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.testing._internal.common_device_type import (
|
||||
ops,
|
||||
onlyCPU,
|
||||
instantiate_device_type_tests,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
import torch.cuda
|
||||
from torch.utils._pytree import tree_any, tree_all_only
|
||||
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
|
||||
from torch import set_default_device
|
||||
from torch.utils._device import set_device
|
||||
import torch.utils.cpp_extension
|
||||
from torch.autograd._functions.utils import check_onnx_broadcast
|
||||
from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings
|
||||
@ -796,6 +805,74 @@ class TestExtensionUtils(TestCase):
|
||||
torch._register_device_module('xpu', DummyXPUModule)
|
||||
|
||||
|
||||
class TestDeviceUtils(TestCase):
|
||||
def test_basic(self):
|
||||
with torch.device('meta') as dev:
|
||||
x = torch.empty(3, 3)
|
||||
self.assertEqual(x.device.type, 'meta')
|
||||
self.assertEqual(dev, torch.device('meta'))
|
||||
|
||||
def test_decorator(self):
|
||||
@set_device('meta')
|
||||
def f():
|
||||
return torch.empty(3, 3)
|
||||
self.assertEqual(f().device.type, 'meta')
|
||||
|
||||
def test_decorator_generator(self):
|
||||
@set_device('meta')
|
||||
def f():
|
||||
yield torch.empty(3, 3)
|
||||
yield torch.empty(3, 3)
|
||||
r1, r2 = list(f())
|
||||
self.assertEqual(r1.device.type, 'meta')
|
||||
self.assertEqual(r2.device.type, 'meta')
|
||||
|
||||
|
||||
def test_nn_module(self):
|
||||
with torch.device('meta'):
|
||||
m = nn.Linear(40, 50)
|
||||
self.assertEqual(m.weight.device.type, 'meta')
|
||||
|
||||
def test_set_default_device(self):
|
||||
try:
|
||||
set_default_device('meta')
|
||||
r = torch.empty(2, 2)
|
||||
finally:
|
||||
set_default_device(None)
|
||||
|
||||
self.assertEqual(r.device.type, 'meta')
|
||||
|
||||
@onlyCPU
|
||||
@ops(op_db)
|
||||
def test_device_mode_ops(self, device, dtype, op):
|
||||
func = op.get_op()
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=False)
|
||||
for sample in samples:
|
||||
# Only test samples which don't have Tensor inputs. However,
|
||||
# we don't test the factory property on OpInfo as it is very,
|
||||
# very incomplete
|
||||
if tree_any(
|
||||
lambda x: isinstance(x, torch.Tensor),
|
||||
(sample.input, sample.args, sample.kwargs)
|
||||
):
|
||||
continue
|
||||
# Many OpInfos will explicitly pass in a device. DeviceContext
|
||||
# will respect device if it is explicitly specified. To test
|
||||
# DeviceContext, we have to remove the device kwarg in this case.
|
||||
# NB: Can't pass None to sample_inputs, the function can't
|
||||
# handle it.
|
||||
kwargs = sample.kwargs.copy()
|
||||
kwargs.pop('device', None)
|
||||
with torch.device('meta'):
|
||||
r = func(sample.input, *sample.args, **kwargs)
|
||||
self.assertTrue(
|
||||
tree_all_only(torch.Tensor, lambda x: x.device.type == 'meta', r)
|
||||
)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestDeviceUtils, globals())
|
||||
|
||||
|
||||
class TestCppExtensionUtils(TestCase):
|
||||
def test_cpp_compiler_is_ok(self):
|
||||
self.assertTrue(torch.utils.cpp_extension.check_compiler_ok_for_platform('c++'))
|
||||
|
@ -45,6 +45,12 @@ class device:
|
||||
@overload
|
||||
def __init__(self, type: str, index: _int) -> None: ...
|
||||
|
||||
def __call__(self, func: T) -> T: ...
|
||||
|
||||
def __enter__(self) -> "device": ...
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
|
||||
|
||||
def __reduce__(self) -> Tuple[Any, ...]: ... # THPDevice_reduce
|
||||
|
||||
# Defined in torch/csrc/Stream.cpp
|
||||
|
@ -35,6 +35,7 @@ import builtins
|
||||
|
||||
__all__ = [
|
||||
'typename', 'is_tensor', 'is_storage', 'set_default_tensor_type',
|
||||
'set_default_device',
|
||||
'set_rng_state', 'get_rng_state', 'manual_seed', 'initial_seed', 'seed',
|
||||
'save', 'load', 'set_printoptions', 'chunk', 'split', 'stack', 'matmul',
|
||||
'no_grad', 'enable_grad', 'rand', 'randn', 'inference_mode',
|
||||
@ -444,6 +445,49 @@ def is_storage(obj):
|
||||
return type(obj) in _storage_classes
|
||||
|
||||
|
||||
_GLOBAL_DEVICE_CONTEXT = None
|
||||
|
||||
def set_default_device(device):
|
||||
"""Sets the default ``torch.Tensor`` to be allocated on ``device``. This
|
||||
does not affect factory function calls which are called with an explicit
|
||||
``device`` argument. Factory calls will be performed as if they
|
||||
were passed ``device`` as an argument.
|
||||
|
||||
To only temporarily change the default device instead of setting it
|
||||
globally, use ``with torch.device(device):`` instead.
|
||||
|
||||
The default device is initially ``cpu``. If you set the default tensor
|
||||
device to another device (e.g., ``cuda``) without a device index, tensors
|
||||
will be allocated on whatever the current device for the device type,
|
||||
even after :func:`torch.cuda.set_device` is called.
|
||||
|
||||
Args:
|
||||
device (device or string): the device to set as default
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +SKIP("requires cuda, changes global state")
|
||||
>>> torch.tensor([1.2, 3]).device
|
||||
device(type='cpu')
|
||||
>>> torch.set_default_device('cuda') # current device is 0
|
||||
>>> torch.tensor([1.2, 3]).device
|
||||
device(type='cuda', index=0)
|
||||
>>> torch.set_default_device('cuda:1')
|
||||
>>> torch.tensor([1.2, 3]).device
|
||||
device(type='cuda', index=1)
|
||||
|
||||
"""
|
||||
global _GLOBAL_DEVICE_CONTEXT
|
||||
if _GLOBAL_DEVICE_CONTEXT is not None:
|
||||
_GLOBAL_DEVICE_CONTEXT.__exit__(None, None, None)
|
||||
if device is None:
|
||||
_GLOBAL_DEVICE_CONTEXT = None
|
||||
return
|
||||
from torch.utils._device import DeviceContext
|
||||
_GLOBAL_DEVICE_CONTEXT = DeviceContext(device)
|
||||
_GLOBAL_DEVICE_CONTEXT.__enter__()
|
||||
|
||||
|
||||
def set_default_tensor_type(t):
|
||||
r"""Sets the default ``torch.Tensor`` type to floating point tensor type
|
||||
``t``. This type will also be used as default floating point type for
|
||||
|
@ -1,96 +1,11 @@
|
||||
import sys
|
||||
import torch
|
||||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, TypeVar, cast
|
||||
from typing import Any
|
||||
|
||||
from torch.utils._contextlib import _DecoratorContextManager
|
||||
|
||||
__all__ = ['no_grad', 'enable_grad', 'set_grad_enabled',
|
||||
'inference_mode', 'set_multithreading_enabled']
|
||||
|
||||
|
||||
# Used for annotating the decorator usage of 'no_grad' and 'enable_grad'.
|
||||
# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
|
||||
FuncType = Callable[..., Any]
|
||||
F = TypeVar('F', bound=FuncType)
|
||||
|
||||
|
||||
class _DecoratorContextManager:
|
||||
"""Allow a context manager to be used as a decorator"""
|
||||
|
||||
def __call__(self, func: F) -> F:
|
||||
if inspect.isclass(func):
|
||||
warnings.warn("Decorating classes is deprecated and will be disabled in "
|
||||
"future versions. You should only decorate functions or methods. "
|
||||
"To preserve the current behavior of class decoration, you can "
|
||||
"directly decorate the `__init__` method and nothing else.")
|
||||
|
||||
if inspect.isgeneratorfunction(func):
|
||||
return self._wrap_generator(func)
|
||||
|
||||
@functools.wraps(func)
|
||||
def decorate_context(*args, **kwargs):
|
||||
with self.clone():
|
||||
return func(*args, **kwargs)
|
||||
return cast(F, decorate_context)
|
||||
|
||||
def _wrap_generator(self, func):
|
||||
"""Wrap each generator invocation with the context manager"""
|
||||
@functools.wraps(func)
|
||||
def generator_context(*args, **kwargs):
|
||||
gen = func(*args, **kwargs)
|
||||
|
||||
# Generators are suspended and unsuspended at `yield`, hence we
|
||||
# make sure the grad mode is properly set every time the execution
|
||||
# flow returns into the wrapped generator and restored when it
|
||||
# returns through our `yield` to our caller (see PR #49017).
|
||||
try:
|
||||
# Issuing `None` to a generator fires it up
|
||||
with self.clone():
|
||||
response = gen.send(None)
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Forward the response to our caller and get its next request
|
||||
request = yield response
|
||||
|
||||
except GeneratorExit:
|
||||
# Inform the still active generator about its imminent closure
|
||||
with self.clone():
|
||||
gen.close()
|
||||
raise
|
||||
|
||||
except BaseException:
|
||||
# Propagate the exception thrown at us by the caller
|
||||
with self.clone():
|
||||
response = gen.throw(*sys.exc_info())
|
||||
|
||||
else:
|
||||
# Pass the last request to the generator and get its response
|
||||
with self.clone():
|
||||
response = gen.send(request)
|
||||
|
||||
# We let the exceptions raised above by the generator's `.throw` or
|
||||
# `.send` methods bubble up to our caller, except for StopIteration
|
||||
except StopIteration as e:
|
||||
# The generator informed us that it is done: take whatever its
|
||||
# returned value (if any) was and indicate that we're done too
|
||||
# by returning it (see docs for python's return-statement).
|
||||
return e.value
|
||||
|
||||
return generator_context
|
||||
|
||||
def __enter__(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def clone(self):
|
||||
# override this method if your children class takes __init__ parameters
|
||||
return self.__class__()
|
||||
|
||||
|
||||
class no_grad(_DecoratorContextManager):
|
||||
r"""Context-manager that disabled gradient calculation.
|
||||
|
||||
|
@ -169,6 +169,36 @@ PyObject* THPDevice_reduce(PyObject* _self, PyObject* noargs) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THPDevice_enter(PyObject* self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
py::object mode = py::module::import("torch.utils._device")
|
||||
.attr("DeviceContext")(py::handle(self));
|
||||
at::impl::PythonTorchFunctionTLS::push_onto_stack(
|
||||
std::make_shared<c10::SafePyObject>(
|
||||
mode.release().ptr(), getPyInterpreter()));
|
||||
// So that with torch.device('cuda') as dev: works
|
||||
Py_INCREF(self);
|
||||
return self;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THPDevice_exit(PyObject* self, PyObject* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
at::impl::PythonTorchFunctionTLS::pop_stack();
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THPDevice_call(PyObject* self, PyObject* args, PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
py::object deco =
|
||||
py::module::import("torch.utils._device").attr("device_decorator");
|
||||
return deco(py::handle(self), *py::handle(args), **py::handle(kwargs))
|
||||
.release()
|
||||
.ptr();
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
typedef PyObject* (*getter)(PyObject*, void*);
|
||||
|
||||
// NB: If you edit these properties/methods, update torch/_C/__init__.pyi.in
|
||||
@ -182,6 +212,8 @@ static struct PyGetSetDef THPDevice_properties[] = {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
|
||||
static PyMethodDef THPDevice_methods[] = {
|
||||
{"__reduce__", THPDevice_reduce, METH_NOARGS, nullptr},
|
||||
{"__enter__", THPDevice_enter, METH_NOARGS, nullptr},
|
||||
{"__exit__", THPDevice_exit, METH_VARARGS, nullptr},
|
||||
{nullptr} /* Sentinel */
|
||||
};
|
||||
|
||||
@ -199,6 +231,11 @@ PyTypeObject THPDeviceType = {
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
(hashfunc)THPDevice_hash, /* tp_hash */
|
||||
// TODO: We're not sure if this is a good idea or not, because making
|
||||
// torch.device callable means that it will start returning true
|
||||
// for callable() queries, and that is unexpected. We can always add
|
||||
// this later, so for now, don't actually implement this
|
||||
// THPDevice_call, /* tp_call */
|
||||
nullptr, /* tp_call */
|
||||
(reprfunc)THPDevice_str, /* tp_str */
|
||||
nullptr, /* tp_getattro */
|
||||
|
@ -75,6 +75,7 @@ def get_ignored_functions() -> Set[Callable]:
|
||||
torch.is_tensor,
|
||||
torch.is_storage,
|
||||
torch.set_default_tensor_type,
|
||||
torch.set_default_device,
|
||||
torch.set_rng_state,
|
||||
torch.get_rng_state,
|
||||
torch.manual_seed,
|
||||
|
@ -378,7 +378,7 @@ def kaiser(
|
||||
device=device,
|
||||
requires_grad=requires_grad)
|
||||
|
||||
return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(torch.tensor(beta))
|
||||
return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(torch.tensor(beta, device=device))
|
||||
|
||||
|
||||
@_add_docstr(
|
||||
|
141
torch/utils/_contextlib.py
Normal file
141
torch/utils/_contextlib.py
Normal file
@ -0,0 +1,141 @@
|
||||
# Extra utilities for working with context managers that should have been
|
||||
# in the standard library but are not
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
import sys
|
||||
from typing import Any, Callable, TypeVar, cast
|
||||
|
||||
# Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
|
||||
# 'no_grad' and 'enable_grad').
|
||||
# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
|
||||
FuncType = Callable[..., Any]
|
||||
F = TypeVar('F', bound=FuncType)
|
||||
|
||||
|
||||
def _wrap_generator(ctx_factory, func):
|
||||
"""
|
||||
Wrap each generator invocation with the context manager factory.
|
||||
|
||||
The input should be a function that returns a context manager,
|
||||
not a context manager itself, to handle one-shot context managers.
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def generator_context(*args, **kwargs):
|
||||
gen = func(*args, **kwargs)
|
||||
|
||||
# Generators are suspended and unsuspended at `yield`, hence we
|
||||
# make sure the grad mode is properly set every time the execution
|
||||
# flow returns into the wrapped generator and restored when it
|
||||
# returns through our `yield` to our caller (see PR #49017).
|
||||
try:
|
||||
# Issuing `None` to a generator fires it up
|
||||
with ctx_factory():
|
||||
response = gen.send(None)
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Forward the response to our caller and get its next request
|
||||
request = yield response
|
||||
|
||||
except GeneratorExit:
|
||||
# Inform the still active generator about its imminent closure
|
||||
with ctx_factory():
|
||||
gen.close()
|
||||
raise
|
||||
|
||||
except BaseException:
|
||||
# Propagate the exception thrown at us by the caller
|
||||
with ctx_factory():
|
||||
response = gen.throw(*sys.exc_info())
|
||||
|
||||
else:
|
||||
# Pass the last request to the generator and get its response
|
||||
with ctx_factory():
|
||||
response = gen.send(request)
|
||||
|
||||
# We let the exceptions raised above by the generator's `.throw` or
|
||||
# `.send` methods bubble up to our caller, except for StopIteration
|
||||
except StopIteration as e:
|
||||
# The generator informed us that it is done: take whatever its
|
||||
# returned value (if any) was and indicate that we're done too
|
||||
# by returning it (see docs for python's return-statement).
|
||||
return e.value
|
||||
|
||||
return generator_context
|
||||
|
||||
|
||||
def context_decorator(ctx, func):
|
||||
"""
|
||||
Like contextlib.ContextDecorator, but:
|
||||
|
||||
1. Is done by wrapping, rather than inheritance, so it works with context
|
||||
managers that are implemented from C and thus cannot easily inherit from
|
||||
Python classes
|
||||
2. Wraps generators in the intuitive way (c.f. https://bugs.python.org/issue37743)
|
||||
3. Errors out if you try to wrap a class, because it is ambiguous whether
|
||||
or not you intended to wrap only the constructor
|
||||
|
||||
The input argument can either be a context manager (in which case it must
|
||||
be a multi-shot context manager that can be directly invoked multiple times)
|
||||
or a callable that produces a context manager.
|
||||
"""
|
||||
|
||||
assert not (callable(ctx) and hasattr(ctx, '__enter__')), (
|
||||
f"Passed in {ctx} is both callable and also a valid context manager "
|
||||
"(has __enter__), making it ambiguous which interface to use. If you "
|
||||
"intended to pass a context manager factory, rewrite your call as "
|
||||
"context_decorator(lambda: ctx()); if you intended to pass a context "
|
||||
"manager directly, rewrite your call as context_decorator(lambda: ctx)"
|
||||
)
|
||||
|
||||
if not callable(ctx):
|
||||
def ctx_factory():
|
||||
return ctx
|
||||
else:
|
||||
ctx_factory = ctx
|
||||
|
||||
if inspect.isclass(func):
|
||||
raise RuntimeError(
|
||||
"Cannot decorate classes; it is ambiguous whether or not only the "
|
||||
"constructor or all methods should have the context manager applied; "
|
||||
"additionally, decorating a class at definition-site will prevent "
|
||||
"use of the identifier as a conventional type. "
|
||||
"To specify which methods to decorate, decorate each of them "
|
||||
"individually."
|
||||
)
|
||||
|
||||
if inspect.isgeneratorfunction(func):
|
||||
return _wrap_generator(ctx_factory, func)
|
||||
|
||||
@functools.wraps(func)
|
||||
def decorate_context(*args, **kwargs):
|
||||
with ctx_factory():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorate_context
|
||||
|
||||
|
||||
class _DecoratorContextManager:
|
||||
"""Allow a context manager to be used as a decorator"""
|
||||
|
||||
def __call__(self, func: F) -> F:
|
||||
if inspect.isclass(func):
|
||||
warnings.warn("Decorating classes is deprecated and will be disabled in "
|
||||
"future versions. You should only decorate functions or methods. "
|
||||
"To preserve the current behavior of class decoration, you can "
|
||||
"directly decorate the `__init__` method and nothing else.")
|
||||
func = cast(F, lambda *args, **kwargs: func(*args, **kwargs))
|
||||
|
||||
return cast(F, context_decorator(self.clone, func))
|
||||
|
||||
def __enter__(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def clone(self):
|
||||
# override this method if your children class takes __init__ parameters
|
||||
return self.__class__()
|
75
torch/utils/_device.py
Normal file
75
torch/utils/_device.py
Normal file
@ -0,0 +1,75 @@
|
||||
import torch
|
||||
from torch.overrides import TorchFunctionMode
|
||||
from torch.utils._contextlib import context_decorator
|
||||
import functools
|
||||
|
||||
@functools.lru_cache(1)
|
||||
def _device_constructors():
|
||||
return {
|
||||
# standard ones
|
||||
torch.empty,
|
||||
torch.empty_strided,
|
||||
torch.empty_quantized,
|
||||
torch.ones,
|
||||
torch.arange,
|
||||
torch.bartlett_window,
|
||||
torch.blackman_window,
|
||||
torch.eye,
|
||||
torch.fft.fftfreq,
|
||||
torch.fft.rfftfreq,
|
||||
torch.full,
|
||||
torch.fill,
|
||||
torch.hamming_window,
|
||||
torch.hann_window,
|
||||
torch.kaiser_window,
|
||||
torch.linspace,
|
||||
torch.logspace,
|
||||
torch.nested.nested_tensor,
|
||||
# This function doesn't actually take a device argument
|
||||
# torch.normal,
|
||||
torch.ones,
|
||||
torch.rand,
|
||||
torch.randn,
|
||||
torch.randint,
|
||||
torch.randperm,
|
||||
torch.range,
|
||||
torch.sparse_coo_tensor,
|
||||
torch.sparse_compressed_tensor,
|
||||
torch.sparse_csr_tensor,
|
||||
torch.sparse_csc_tensor,
|
||||
torch.sparse_bsr_tensor,
|
||||
torch.sparse_bsc_tensor,
|
||||
torch.tril_indices,
|
||||
torch.triu_indices,
|
||||
torch.vander,
|
||||
torch.zeros,
|
||||
torch.asarray,
|
||||
# weird ones
|
||||
torch.tensor,
|
||||
torch.as_tensor,
|
||||
torch.scalar_tensor,
|
||||
}
|
||||
|
||||
# NB: This is directly called from C++ in torch/csrc/Device.cpp
|
||||
class DeviceContext(TorchFunctionMode):
|
||||
def __init__(self, device):
|
||||
self.device = torch.device(device)
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
if func in _device_constructors() and kwargs.get('device') is None:
|
||||
kwargs['device'] = self.device
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# NB: This is directly called from C++ in torch/csrc/Device.cpp
|
||||
def device_decorator(device, func):
|
||||
return context_decorator(lambda: device, func)
|
||||
|
||||
def set_device(device):
|
||||
"""
|
||||
Decorator which sets the default device inside of the wrapped
|
||||
function. If you would like to use this as a context manager,
|
||||
use device as a context manager directly, e.g.,
|
||||
``with torch.device(device)``.
|
||||
"""
|
||||
return lambda func: device_decorator(torch.device(device), func)
|
Reference in New Issue
Block a user