pyfmt lint more torch/utils files (#155812)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155812
Approved by: https://github.com/Skylion007
ghstack dependencies: #155782, #155783
This commit is contained in:
Laith Sakka
2025-06-12 09:16:10 -07:00
committed by PyTorch MergeBot
parent 4d3ecefda5
commit d7e657da35
13 changed files with 277 additions and 176 deletions

View File

@ -1343,19 +1343,6 @@ exclude_patterns = [
'torch/testing/_internal/test_module/__init__.py',
'torch/testing/_internal/test_module/future_div.py',
'torch/testing/_internal/test_module/no_future_div.py',
'torch/utils/_contextlib.py',
'torch/utils/_cpp_extension_versioner.py',
'torch/utils/_crash_handler.py',
'torch/utils/_device.py',
'torch/utils/_foreach_utils.py',
'torch/utils/_freeze.py',
'torch/utils/_mode_utils.py',
'torch/utils/_python_dispatch.py',
'torch/utils/_stats.py',
'torch/utils/_traceback.py',
'torch/utils/_zip.py',
'torch/utils/backcompat/__init__.py',
'torch/utils/backend_registration.py',
'torch/utils/benchmark/__init__.py',
'torch/utils/benchmark/examples/__init__.py',
'torch/utils/benchmark/examples/compare.py',

View File

@ -4,15 +4,16 @@
import functools
import inspect
import warnings
import sys
from typing import Any, Callable, TypeVar, cast
import warnings
from typing import Any, Callable, cast, TypeVar
# Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
# 'no_grad' and 'enable_grad').
# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
FuncType = Callable[..., Any]
F = TypeVar('F', bound=FuncType)
F = TypeVar("F", bound=FuncType)
def _wrap_generator(ctx_factory, func):
@ -22,6 +23,7 @@ def _wrap_generator(ctx_factory, func):
The input should be a function that returns a context manager,
not a context manager itself, to handle one-shot context managers.
"""
@functools.wraps(func)
def generator_context(*args, **kwargs):
gen = func(*args, **kwargs)
@ -83,7 +85,7 @@ def context_decorator(ctx, func):
be a multi-shot context manager that can be directly invoked multiple times)
or a callable that produces a context manager.
"""
assert not (callable(ctx) and hasattr(ctx, '__enter__')), (
assert not (callable(ctx) and hasattr(ctx, "__enter__")), (
f"Passed in {ctx} is both callable and also a valid context manager "
"(has __enter__), making it ambiguous which interface to use. If you "
"intended to pass a context manager factory, rewrite your call as "
@ -92,8 +94,10 @@ def context_decorator(ctx, func):
)
if not callable(ctx):
def ctx_factory():
return ctx
else:
ctx_factory = ctx

View File

@ -2,18 +2,18 @@
import collections
Entry = collections.namedtuple('Entry', 'version, hash')
Entry = collections.namedtuple("Entry", "version, hash")
def update_hash(seed, value):
# Good old boost::hash_combine
# https://www.boost.org/doc/libs/1_35_0/doc/html/boost/hash_combine_id241013.html
return seed ^ (hash(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2))
return seed ^ (hash(value) + 0x9E3779B9 + (seed << 6) + (seed >> 2))
def hash_source_files(hash_value, source_files):
for filename in source_files:
with open(filename, 'rb') as file:
with open(filename, "rb") as file:
hash_value = update_hash(hash_value, file.read())
return hash_value
@ -34,15 +34,17 @@ class ExtensionVersioner:
entry = self.entries.get(name)
return None if entry is None else entry.version
def bump_version_if_changed(self,
name,
source_files,
build_arguments,
build_directory,
with_cuda,
with_sycl,
is_python_module,
is_standalone):
def bump_version_if_changed(
self,
name,
source_files,
build_arguments,
build_directory,
with_cuda,
with_sycl,
is_python_module,
is_standalone,
):
hash_value = 0
hash_value = hash_source_files(hash_value, source_files)
hash_value = hash_build_arguments(hash_value, build_arguments)

View File

@ -1,13 +1,16 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
from torch.overrides import TorchFunctionMode, _pop_mode, _push_mode
from torch.utils._contextlib import context_decorator
from torch._C import _len_torch_function_stack
import functools
from typing import Optional
import torch
from torch._C import _len_torch_function_stack
from torch.overrides import _pop_mode, _push_mode, TorchFunctionMode
from torch.utils._contextlib import context_decorator
CURRENT_DEVICE: Optional[torch.device] = None
@functools.lru_cache(1)
def _device_constructors():
return {
@ -50,9 +53,10 @@ def _device_constructors():
# weird ones
torch.tensor,
torch.as_tensor,
torch.scalar_tensor
torch.scalar_tensor,
}
# NB: This is directly called from C++ in torch/csrc/Device.cpp
class DeviceContext(TorchFunctionMode):
def __init__(self, device):
@ -73,7 +77,6 @@ class DeviceContext(TorchFunctionMode):
for mode in reversed(cur_stack):
_push_mode(mode)
def __exit__(self, exc_type, exc_val, exc_tb):
global CURRENT_DEVICE
CURRENT_DEVICE = self.old_device
@ -95,14 +98,16 @@ class DeviceContext(TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if func in _device_constructors() and kwargs.get('device') is None:
kwargs['device'] = self.device
if func in _device_constructors() and kwargs.get("device") is None:
kwargs["device"] = self.device
return func(*args, **kwargs)
# NB: This is directly called from C++ in torch/csrc/Device.cpp
def device_decorator(device, func):
return context_decorator(lambda: device, func)
def set_device(device):
"""
Set the default device inside of the wrapped function by decorating it with this function.

View File

@ -1,17 +1,27 @@
from typing import Optional
from typing_extensions import TypeAlias
import torch
from torch import Tensor
from torch.autograd.grad_mode import no_grad
from typing_extensions import TypeAlias
def _get_foreach_kernels_supported_devices() -> list[str]:
r"""Return the device type list that supports foreach kernels."""
return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()]
def _get_fused_kernels_supported_devices() -> list[str]:
r"""Return the device type list that supports fused kernels in optimizer."""
return ["mps", "cuda", "xpu", "hpu", "cpu", torch._C._get_privateuse1_backend_name()]
return [
"mps",
"cuda",
"xpu",
"hpu",
"cpu",
torch._C._get_privateuse1_backend_name(),
]
TensorListList: TypeAlias = list[list[Optional[Tensor]]]
Indices: TypeAlias = list[int]
@ -36,9 +46,15 @@ def _group_tensors_by_device_and_dtype(
) -> dict[tuple[torch.device, torch.dtype], tuple[TensorListList, Indices]]:
return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
def _device_has_foreach_support(device: torch.device) -> bool:
return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting()
return (
device.type in (_get_foreach_kernels_supported_devices() + ["cpu"])
and not torch.jit.is_scripting()
)
def _has_foreach_support(tensors: list[Tensor], device: torch.device) -> bool:
return _device_has_foreach_support(device) and all(t is None or type(t) in _foreach_supported_types for t in tensors)
return _device_has_foreach_support(device) and all(
t is None or type(t) in _foreach_supported_types for t in tensors
)

View File

@ -3,6 +3,9 @@
"""
Freeze Python packages.
Freezing makes it possible to ship arbitrary Python modules as part of a C++
library. The Python source of the module is compiled to bytecode and written
to `.c` files, to be imported by Python's built-in FrozenImporter.

View File

@ -1,11 +1,15 @@
# mypy: allow-untyped-defs
import torch
from typing import TypeVar
T = TypeVar('T')
import torch
T = TypeVar("T")
# returns if all are the same mode
def all_same_mode(modes):
return all(tuple(mode == modes[0] for mode in modes))
no_dispatch = torch._C._DisableTorchDispatch

View File

@ -1,12 +1,11 @@
# mypy: allow-untyped-defs
import contextlib
import warnings
from dataclasses import dataclass
from typing import Any, Optional, Union, Protocol, overload
from collections.abc import Sequence
from typing_extensions import TypeIs
from collections import deque
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Optional, overload, Protocol, Union
from typing_extensions import TypeIs
import torch
import torchgen
@ -29,8 +28,13 @@ from torch._C import (
_is_in_torch_dispatch_mode = False
_is_in_non_infra_torch_dispatch_mode = False
def is_in_torch_dispatch_mode(include_infra_modes=True) -> bool:
return _is_in_torch_dispatch_mode if include_infra_modes else _is_in_non_infra_torch_dispatch_mode
return (
_is_in_torch_dispatch_mode
if include_infra_modes
else _is_in_non_infra_torch_dispatch_mode
)
class TorchDispatchMode:
@ -79,7 +83,6 @@ class TorchDispatchMode:
if not hasattr(self, "old_non_infra_dispatch_mode_flags"):
self.old_non_infra_dispatch_mode_flags: deque[bool] = deque() # type: ignore[no-redef]
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise NotImplementedError
@ -93,8 +96,12 @@ class TorchDispatchMode:
self._lazy_init_old_dispatch_mode_flags()
self.old_dispatch_mode_flags.append(_is_in_torch_dispatch_mode)
_is_in_torch_dispatch_mode = True
self.old_non_infra_dispatch_mode_flags.append(_is_in_non_infra_torch_dispatch_mode)
_is_in_non_infra_torch_dispatch_mode = _is_in_non_infra_torch_dispatch_mode or not self.is_infra_mode()
self.old_non_infra_dispatch_mode_flags.append(
_is_in_non_infra_torch_dispatch_mode
)
_is_in_non_infra_torch_dispatch_mode = (
_is_in_non_infra_torch_dispatch_mode or not self.is_infra_mode()
)
_push_mode(self)
return self
@ -107,7 +114,9 @@ class TorchDispatchMode:
global _is_in_torch_dispatch_mode
_is_in_torch_dispatch_mode = self.old_dispatch_mode_flags.pop()
global _is_in_non_infra_torch_dispatch_mode
_is_in_non_infra_torch_dispatch_mode = self.old_non_infra_dispatch_mode_flags.pop()
_is_in_non_infra_torch_dispatch_mode = (
self.old_non_infra_dispatch_mode_flags.pop()
)
_pop_mode(mb_dk_or_mode_key)
@classmethod
@ -123,7 +132,6 @@ class TorchDispatchMode:
return False
def _get_current_dispatch_mode():
stack_len = _len_torch_dispatch_stack()
# Return a user mode on the stack if there are any
@ -133,19 +141,16 @@ def _get_current_dispatch_mode():
def _detect_infra_mode(key):
assert key in [torch._C._TorchDispatchModeKey.FUNCTIONAL, torch._C._TorchDispatchModeKey.PROXY]
assert key in [
torch._C._TorchDispatchModeKey.FUNCTIONAL,
torch._C._TorchDispatchModeKey.PROXY,
]
from torch._ops import _get_dispatch_mode_pre_dispatch
pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(
key
)
post_dispatch_mode = torch._C._get_dispatch_mode(
key
)
pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(key)
post_dispatch_mode = torch._C._get_dispatch_mode(key)
assert (pre_dispatch_mode is None) or (
post_dispatch_mode is None
)
assert (pre_dispatch_mode is None) or (post_dispatch_mode is None)
if pre_dispatch_mode is None:
return post_dispatch_mode
@ -232,8 +237,8 @@ def _disable_current_modes():
_pop_mode_from_pre_dispatch,
)
from torch._subclasses.functional_tensor import FunctionalTensorMode
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
from torch._subclasses.schema_check_mode import SchemaCheckMode
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
mode_len_pre_dispatch = _len_torch_dispatch_stack_pre_dispatch()
old_pre_dispatch_modes = [
@ -267,10 +272,7 @@ def _disable_current_modes():
raise AssertionError(
"Can't have ProxyTorchDispatchMode available both in PreDispatch and Python Key"
)
if (
isinstance(old, SchemaCheckMode)
and has_schema_check_mode_in_pre_dispatch
):
if isinstance(old, SchemaCheckMode) and has_schema_check_mode_in_pre_dispatch:
raise AssertionError(
"Can't have SchemaCheckMode available both in PreDispatch and Python Key"
)
@ -298,7 +300,9 @@ class TensorWithFlatten(Protocol):
...
@staticmethod
def __tensor_unflatten__(inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int) -> torch.Tensor:
def __tensor_unflatten__(
inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int
) -> torch.Tensor:
...
# It would be really nice to be able to say that the return of
@ -331,41 +335,39 @@ class TensorWithFlatten(Protocol):
@overload
def to(
self,
dtype: torch.types._dtype,
non_blocking: bool = False,
copy: bool = False,
*,
memory_format: Optional[torch.memory_format] = None
self,
dtype: torch.types._dtype,
non_blocking: bool = False,
copy: bool = False,
*,
memory_format: Optional[torch.memory_format] = None,
) -> torch.Tensor:
...
@overload
def to(
self,
device: Optional["torch._prims_common.DeviceLikeType"] = None,
dtype: Optional[torch.types._dtype] = None,
non_blocking: bool = False,
copy: bool = False,
*,
memory_format: Optional[torch.memory_format] = None
self,
device: Optional["torch._prims_common.DeviceLikeType"] = None,
dtype: Optional[torch.types._dtype] = None,
non_blocking: bool = False,
copy: bool = False,
*,
memory_format: Optional[torch.memory_format] = None,
) -> torch.Tensor:
...
@overload
def to(
self,
other: torch.Tensor,
non_blocking: bool = False,
copy: bool = False,
*,
memory_format: Optional[torch.memory_format] = None
self,
other: torch.Tensor,
non_blocking: bool = False,
copy: bool = False,
*,
memory_format: Optional[torch.memory_format] = None,
) -> torch.Tensor:
...
def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]:
"""
Returns whether or not a tensor subclass that implements __torch_dispatch__
@ -403,10 +405,15 @@ def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]:
and hasattr(t, "__tensor_unflatten__")
)
def is_traceable_wrapper_subclass_type(t: type) -> TypeIs[type[TensorWithFlatten]]:
"""Same as above, but takes a type argument instead of an instance."""
return (issubclass(t, torch.Tensor) and t != torch.Tensor
and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__"))
return (
issubclass(t, torch.Tensor)
and t != torch.Tensor
and hasattr(t, "__tensor_flatten__")
and hasattr(t, "__tensor_unflatten__")
)
def transform_subclass(t, callback, outer_size=None, outer_stride=None):
@ -551,7 +558,9 @@ def get_alias_info(func) -> SchemaInfo:
torchgen_schema_str = re.sub(r"=\[[0, ]+\]", "=0", torchgen_schema_str)
torchgen_schema_str = re.sub(r"=\[[1, ]+\]", "=1", torchgen_schema_str)
# for aten::rot90 / aten:fft_*
torchgen_schema_str = re.sub(r"=\[(-?[0-9]+), (-?[0-9]+)\]", r"=[\1,\2]", torchgen_schema_str)
torchgen_schema_str = re.sub(
r"=\[(-?[0-9]+), (-?[0-9]+)\]", r"=[\1,\2]", torchgen_schema_str
)
torchgen_schema = torchgen.model.FunctionSchema.parse(torchgen_schema_str)
arg_schemas = [
AliasInfo(

View File

@ -3,8 +3,8 @@
# AND SCRUB AWAY TORCH NOTIONS THERE.
import collections
import functools
from typing import Callable, TypeVar
from collections import OrderedDict
from typing import Callable, TypeVar
from typing_extensions import ParamSpec
@ -18,6 +18,7 @@ def count_label(label: str) -> None:
prev = simple_call_counter.setdefault(label, 0)
simple_call_counter[label] = prev + 1
def count(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@functools.wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
@ -25,4 +26,5 @@ def count(fn: Callable[_P, _R]) -> Callable[_P, _R]:
simple_call_counter[fn.__qualname__] = 0
simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
return fn(*args, **kwargs)
return wrapper

View File

@ -1,11 +1,12 @@
# mypy: allow-untyped-defs
from types import TracebackType
from typing import Optional
import tempfile
import traceback
import contextlib
import inspect
import os.path
import tempfile
import traceback
from types import TracebackType
from typing import Optional
# This file contains utilities for ensuring dynamically compile()'d
# code fragments display their line numbers in backtraces.
@ -44,6 +45,7 @@ import os.path
# - Before running the compiled code, enter the
# report_compile_source_on_error() context manager.
@contextlib.contextmanager
def report_compile_source_on_error():
try:
@ -83,15 +85,17 @@ def report_compile_source_on_error():
# Don't delete the temporary file so the user can inspect it
# TODO: This creates a temporary file for every frame, but we
# technically only need one per distinct __compile_source__
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".py") as f:
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".py"
) as f:
f.write(source)
# Create a frame. Python doesn't let you construct
# FrameType directly, so just make one with compile
frame = tb.tb_frame
code = compile('__inspect_currentframe()', f.name, 'eval')
code = compile("__inspect_currentframe()", f.name, "eval")
code = code.replace(co_name=frame.f_code.co_name)
# Python 3.11 only
if hasattr(frame.f_code, 'co_linetable'):
if hasattr(frame.f_code, "co_linetable"):
# We can't copy ALL of the metadata over, because you
# can cause Python to segfault this way. What exactly
# do we need? We need enough information for
@ -109,14 +113,9 @@ def report_compile_source_on_error():
fake_frame = eval(
code,
frame.f_globals,
{
**frame.f_locals,
'__inspect_currentframe': inspect.currentframe
}
)
fake_tb = TracebackType(
None, fake_frame, tb.tb_lasti, tb.tb_lineno
{**frame.f_locals, "__inspect_currentframe": inspect.currentframe},
)
fake_tb = TracebackType(None, fake_frame, tb.tb_lasti, tb.tb_lineno)
stack.append(fake_tb)
else:
stack.append(tb)
@ -131,6 +130,7 @@ def report_compile_source_on_error():
raise exc.with_traceback(tb_next) # noqa: B904
def shorten_filename(fn, *, base=None):
"""Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user."""
if base is None:
@ -141,7 +141,8 @@ def shorten_filename(fn, *, base=None):
except ValueError:
return fn
else:
return fn[len(prefix) + 1:]
return fn[len(prefix) + 1 :]
def format_frame(frame, *, base=None, line=False):
"""
@ -154,12 +155,14 @@ def format_frame(frame, *, base=None, line=False):
extra_line = f"{frame.line} # "
return f"{extra_line}{shorten_filename(frame.filename, base=base)}:{frame.lineno} in {frame.name}"
def format_traceback_short(tb):
"""Format a TracebackType in a short way, printing only the inner-most frame."""
return format_frame(traceback.extract_tb(tb)[-1])
class CapturedTraceback:
__slots__ = ['tb', 'skip']
__slots__ = ["tb", "skip"]
def __init__(self, tb, skip=0):
self.tb = tb
@ -176,15 +179,17 @@ class CapturedTraceback:
return traceback.StackSummary()
return _extract_symbolized_tb(
torch._C._profiler.symbolize_tracebacks([self.tb])[0],
self.skip
torch._C._profiler.symbolize_tracebacks([self.tb])[0], self.skip
)
def __getstate__(self):
return (None, {
'tb': None, # TB is not pickleable
'skip': self.skip,
})
return (
None,
{
"tb": None, # TB is not pickleable
"skip": self.skip,
},
)
@staticmethod
def extract(*, script=False, cpp=False, skip=0):
@ -207,7 +212,7 @@ class CapturedTraceback:
torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp),
# Elide extract() frame if we don't have script/cpp frames. If
# we do have those frames, it doesn't work so force zero.
0 if script or cpp else skip + 1
0 if script or cpp else skip + 1,
)
def format(self):
@ -251,5 +256,5 @@ def _extract_symbolized_tb(tb, skip):
"""
stack = traceback.StackSummary()
for f in reversed(tb[skip:]):
stack.append(traceback.FrameSummary(f['filename'], f['line'], f['name']))
stack.append(traceback.FrameSummary(f["filename"], f["line"], f["name"]))
return stack

View File

@ -5,6 +5,7 @@ import os
from pathlib import Path
from zipfile import ZipFile
# Exclude some standard library modules to:
# 1. Slim down the final zipped file size
# 2. Remove functionality we don't want to support.

View File

@ -1,8 +1,10 @@
# mypy: allow-untyped-defs
from torch._C import _set_backcompat_broadcast_warn
from torch._C import _get_backcompat_broadcast_warn
from torch._C import _set_backcompat_keepdim_warn
from torch._C import _get_backcompat_keepdim_warn
from torch._C import (
_get_backcompat_broadcast_warn,
_get_backcompat_keepdim_warn,
_set_backcompat_broadcast_warn,
_set_backcompat_keepdim_warn,
)
class Warning:
@ -18,5 +20,8 @@ class Warning:
enabled = property(get_enabled, set_enabled)
broadcast_warning = Warning(_set_backcompat_broadcast_warn, _get_backcompat_broadcast_warn)
broadcast_warning = Warning(
_set_backcompat_broadcast_warn, _get_backcompat_broadcast_warn
)
keepdim_warning = Warning(_set_backcompat_keepdim_warn, _get_backcompat_keepdim_warn)

View File

@ -1,12 +1,11 @@
# mypy: allow-untyped-defs
import torch
from torch.overrides import (
handle_torch_function,
has_torch_function_unary,
)
from torch._C import _rename_privateuse1_backend, _get_privateuse1_backend_name
from typing import Optional, Union
import torch
from torch._C import _get_privateuse1_backend_name, _rename_privateuse1_backend
from torch.overrides import handle_torch_function, has_torch_function_unary
__all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backend"]
# TODO: Should use `torch._C._get_privateuse1_backend_name()` to get
@ -15,6 +14,7 @@ __all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backe
# `_privateuse1_backend_name`.
_privateuse1_backend_name = "privateuseone"
def rename_privateuse1_backend(backend_name: str) -> None:
r"""
Rename the privateuse1 backend device to make it more convenient to use as a device name within PyTorch APIs.
@ -78,16 +78,22 @@ def rename_privateuse1_backend(backend_name: str) -> None:
global _privateuse1_backend_name
_privateuse1_backend_name = backend_name
def _check_register_once(module, attr):
if hasattr(module, attr):
raise RuntimeError(f"The custom device module of {module} has already been registered with {attr}")
raise RuntimeError(
f"The custom device module of {module} has already been registered with {attr}"
)
def _normalization_device(custom_backend_name: str, device: Optional[Union[int, str, torch.device]] = None) -> int:
def _normalization_device(
custom_backend_name: str, device: Optional[Union[int, str, torch.device]] = None
) -> int:
def _get_current_device_index():
_get_device_index = "current_device"
if hasattr(torch, custom_backend_name) and \
hasattr(getattr(torch, custom_backend_name), _get_device_index):
if hasattr(torch, custom_backend_name) and hasattr(
getattr(torch, custom_backend_name), _get_device_index
):
return getattr(getattr(torch, custom_backend_name), _get_device_index)()
else:
# The default device index is 0.
@ -122,12 +128,16 @@ def _generate_tensor_methods_for_privateuse1_backend(custom_backend_name: str) -
return handle_torch_function(wrap_tensor_backend.__get__, (self,), self) # type: ignore[attr-defined]
return self.device.type == custom_backend_name
_check_register_once(torch.Tensor, f'is_{custom_backend_name}')
wrap_tensor_backend.fget.__name__ = f'is_{custom_backend_name}' # type: ignore[attr-defined]
setattr(torch.Tensor, f'is_{custom_backend_name}', wrap_tensor_backend)
_check_register_once(torch.Tensor, f"is_{custom_backend_name}")
wrap_tensor_backend.fget.__name__ = f"is_{custom_backend_name}" # type: ignore[attr-defined]
setattr(torch.Tensor, f"is_{custom_backend_name}", wrap_tensor_backend)
def wrap_tensor_to(self: torch.Tensor, device: Optional[Union[int, torch.device]] = None, non_blocking=False,
**kwargs) -> torch.Tensor:
def wrap_tensor_to(
self: torch.Tensor,
device: Optional[Union[int, torch.device]] = None,
non_blocking=False,
**kwargs,
) -> torch.Tensor:
r"""Perform Tensor device conversion. Call the to operator implementation.
.. note::
@ -143,9 +153,20 @@ def _generate_tensor_methods_for_privateuse1_backend(custom_backend_name: str) -
**kwargs (dict): For compatibility, may contain the key ``memory_format`` argument.
"""
if has_torch_function_unary(self):
return handle_torch_function(wrap_tensor_to, (self,), self, device=device, non_blocking=False, **kwargs)
return handle_torch_function(
wrap_tensor_to,
(self,),
self,
device=device,
non_blocking=False,
**kwargs,
)
device_idx = _normalization_device(custom_backend_name, device)
return self.to(device=torch.device(f'{custom_backend_name}:{device_idx}'), non_blocking=non_blocking, **kwargs)
return self.to(
device=torch.device(f"{custom_backend_name}:{device_idx}"),
non_blocking=non_blocking,
**kwargs,
)
_check_register_once(torch.Tensor, custom_backend_name)
wrap_tensor_to.__name__ = custom_backend_name
@ -159,10 +180,13 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -
raise RuntimeError(
f"Can not automatically generate {custom_backend_name}() method for torch.nn.Module."
f"Because torch.Tensor doesn't has the method {custom_backend_name}()."
f"For this error, you can try setting for_tensor=True.")
f"For this error, you can try setting for_tensor=True."
)
def wrap_module_to(self: torch.nn.modules.module.T,
device: Optional[Union[int, torch.device]] = None) -> torch.nn.modules.module.T:
def wrap_module_to(
self: torch.nn.modules.module.T,
device: Optional[Union[int, torch.device]] = None,
) -> torch.nn.modules.module.T:
r"""Move all model parameters and buffers to the custom device.
This also makes associated parameters and buffers different objects. So
@ -180,27 +204,37 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -
_check_register_once(torch.nn.Module, custom_backend_name)
setattr(torch.nn.Module, custom_backend_name, wrap_module_to)
def _generate_packed_sequence_methods_for_privateuse1_backend(custom_backend_name: str) -> None:
def _generate_packed_sequence_methods_for_privateuse1_backend(
custom_backend_name: str,
) -> None:
# Generate PackedSequence Module attributes and methods depends on Tensor methods,
# so we need to check whether Tensor methods is already registered.
if not hasattr(torch.Tensor, f'is_{custom_backend_name}') or \
not hasattr(torch.Tensor, custom_backend_name):
if not hasattr(torch.Tensor, f"is_{custom_backend_name}") or not hasattr(
torch.Tensor, custom_backend_name
):
raise RuntimeError(
f"Can not automatically generate is_{custom_backend_name}() or "
f"{custom_backend_name}() method for torch.nn.utils.rnn.PackedSequence."
f"Because torch.Tensor doesn't has the method is_{custom_backend_name}()"
f"or {custom_backend_name}()."
f"For this error, you can try setting for_tensor=True.")
f"For this error, you can try setting for_tensor=True."
)
@property # type: ignore[misc]
def wrap_tensor_backend(self: torch.nn.utils.rnn.PackedSequence) -> bool:
return self.data.device.type == custom_backend_name
_check_register_once(torch.nn.utils.rnn.PackedSequence, f'is_{custom_backend_name}')
setattr(torch.nn.utils.rnn.PackedSequence, f'is_{custom_backend_name}', wrap_tensor_backend)
_check_register_once(torch.nn.utils.rnn.PackedSequence, f"is_{custom_backend_name}")
setattr(
torch.nn.utils.rnn.PackedSequence,
f"is_{custom_backend_name}",
wrap_tensor_backend,
)
def wrap_module_to(self: torch.nn.utils.rnn.PackedSequence,
*args, **kwargs) -> torch.nn.utils.rnn.PackedSequence:
def wrap_module_to(
self: torch.nn.utils.rnn.PackedSequence, *args, **kwargs
) -> torch.nn.utils.rnn.PackedSequence:
r"""Move all model parameters and buffers to the custom device.
This also makes associated parameters and buffers different objects. So
@ -213,17 +247,21 @@ def _generate_packed_sequence_methods_for_privateuse1_backend(custom_backend_nam
Args:
device (int, optional): if specified, all parameters will be copied to that device
"""
ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(*args, **kwargs)
ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(
*args, **kwargs
)
if ex.device.type == custom_backend_name:
return self.to(*args, **kwargs)
kwargs.update({'device': custom_backend_name})
kwargs.update({"device": custom_backend_name})
return self.to(*args, **kwargs)
_check_register_once(torch.nn.utils.rnn.PackedSequence, custom_backend_name)
setattr(torch.nn.utils.rnn.PackedSequence, custom_backend_name, wrap_module_to)
def _generate_storage_methods_for_privateuse1_backend(custom_backend_name: str,
unsupported_dtype: Optional[list[torch.dtype]] = None) -> None:
def _generate_storage_methods_for_privateuse1_backend(
custom_backend_name: str, unsupported_dtype: Optional[list[torch.dtype]] = None
) -> None:
# Attribute is registered in the _StorageBase class
# and UntypedStorage obtains through inheritance.
@property # type: ignore[misc]
@ -231,8 +269,10 @@ def _generate_storage_methods_for_privateuse1_backend(custom_backend_name: str,
r"""Return the internal :class:`torch.UntypedStorage`."""
return self.device.type == custom_backend_name
_check_register_once(torch.storage._StorageBase, f'is_{custom_backend_name}')
setattr(torch.storage._StorageBase, f'is_{custom_backend_name}', wrap_storage_backend)
_check_register_once(torch.storage._StorageBase, f"is_{custom_backend_name}")
setattr(
torch.storage._StorageBase, f"is_{custom_backend_name}", wrap_storage_backend
)
def wrap_storage_to(self, device=None, non_blocking=False):
r"""Return a copy of this object in custom device memory.
@ -250,16 +290,18 @@ def _generate_storage_methods_for_privateuse1_backend(custom_backend_name: str,
# but it depends on the extended function, so this part is temporarily omitted in the automatic generation.
device_idx = _normalization_device(custom_backend_name, device)
if getattr(self, f'is_{custom_backend_name}'):
if getattr(self, f"is_{custom_backend_name}"):
# storage has already on expected device.
if self.get_device() == device_idx:
return self
# For sparse storage, custom need to extend the implementation by themselves.
if self.is_sparse:
raise RuntimeError(f"Can not support a sparse storage move to {custom_backend_name} backend")
raise RuntimeError(
f"Can not support a sparse storage move to {custom_backend_name} backend"
)
# create untyped_storage and copy data
untyped_storage = torch.UntypedStorage(
self.size(), device=torch.device(f'{custom_backend_name}:{device_idx}')
self.size(), device=torch.device(f"{custom_backend_name}:{device_idx}")
)
untyped_storage.copy_(self, non_blocking)
return untyped_storage
@ -275,27 +317,38 @@ def _generate_storage_methods_for_privateuse1_backend(custom_backend_name: str,
torch.storage._warn_typed_storage_removal()
return self._untyped_storage.device.type == custom_backend_name
_check_register_once(torch.TypedStorage, f'is_{custom_backend_name}')
setattr(torch.storage.TypedStorage, f'is_{custom_backend_name}', wrap_typed_storage_backend)
_check_register_once(torch.TypedStorage, f"is_{custom_backend_name}")
setattr(
torch.storage.TypedStorage,
f"is_{custom_backend_name}",
wrap_typed_storage_backend,
)
def wrap_typed_storage_to(self: torch.storage.TypedStorage,
device=None, non_blocking=False, **kwargs) -> torch.storage.TypedStorage:
def wrap_typed_storage_to(
self: torch.storage.TypedStorage, device=None, non_blocking=False, **kwargs
) -> torch.storage.TypedStorage:
torch.storage._warn_typed_storage_removal()
if unsupported_dtype and self.dtype in unsupported_dtype:
raise RuntimeError(f"Cannot create {custom_backend_name} storage "
f"as {self.dtype} dtype is not supported by this backend")
raise RuntimeError(
f"Cannot create {custom_backend_name} storage "
f"as {self.dtype} dtype is not supported by this backend"
)
custom_backend_storage: torch.UntypedStorage = getattr(
self._untyped_storage, custom_backend_name)(device, non_blocking, **kwargs)
self._untyped_storage, custom_backend_name
)(device, non_blocking, **kwargs)
return self._new_wrapped_storage(custom_backend_storage)
_check_register_once(torch.TypedStorage, custom_backend_name)
setattr(torch.TypedStorage, custom_backend_name, wrap_typed_storage_to)
def generate_methods_for_privateuse1_backend(for_tensor: bool = True, for_module: bool = True,
for_packed_sequence: bool = True,
for_storage: bool = False,
unsupported_dtype: Optional[list[torch.dtype]] = None) -> None:
def generate_methods_for_privateuse1_backend(
for_tensor: bool = True,
for_module: bool = True,
for_packed_sequence: bool = True,
for_storage: bool = False,
unsupported_dtype: Optional[list[torch.dtype]] = None,
) -> None:
r"""
Automatically generate attributes and methods for the custom backend after rename privateuse1 backend.
@ -337,11 +390,14 @@ def generate_methods_for_privateuse1_backend(for_tensor: bool = True, for_module
_generate_module_methods_for_privateuse1_backend(custom_backend_name)
if for_storage:
_generate_storage_methods_for_privateuse1_backend(custom_backend_name, unsupported_dtype)
_generate_storage_methods_for_privateuse1_backend(
custom_backend_name, unsupported_dtype
)
if for_packed_sequence:
_generate_packed_sequence_methods_for_privateuse1_backend(custom_backend_name)
def _get_custom_mod_func(func_name: str):
r"""
Return the func named `func_name` defined in custom device module. If not defined,
@ -370,12 +426,14 @@ def _get_custom_mod_func(func_name: str):
it is marked as private. It is a convenience function for backend implementers to
more easily call the hooks into their backend extensions.
"""
assert isinstance(func_name, str), f"func_name must be `str`, but got `{type(func_name)}`."
assert isinstance(
func_name, str
), f"func_name must be `str`, but got `{type(func_name)}`."
backend_name = _get_privateuse1_backend_name()
custom_device_mod = getattr(torch, backend_name, None) # type: ignore[arg-type]
function = getattr(custom_device_mod, func_name, None) # type: ignore[arg-type]
if custom_device_mod is None or function is None:
message = f'Try to call torch.{backend_name}.{func_name}. The backend must register a custom backend '
message = f"Try to call torch.{backend_name}.{func_name}. The backend must register a custom backend "
message += f"module with `torch._register_device_module('{backend_name}', BackendModule)`. And "
message += f"BackendModule needs to have the following API's:\n `{func_name}(*args, **kwargs)`. \n"
raise RuntimeError(message)