mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
pyfmt lint more torch/utils files (#155812)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155812 Approved by: https://github.com/Skylion007 ghstack dependencies: #155782, #155783
This commit is contained in:
committed by
PyTorch MergeBot
parent
4d3ecefda5
commit
d7e657da35
@ -1343,19 +1343,6 @@ exclude_patterns = [
|
||||
'torch/testing/_internal/test_module/__init__.py',
|
||||
'torch/testing/_internal/test_module/future_div.py',
|
||||
'torch/testing/_internal/test_module/no_future_div.py',
|
||||
'torch/utils/_contextlib.py',
|
||||
'torch/utils/_cpp_extension_versioner.py',
|
||||
'torch/utils/_crash_handler.py',
|
||||
'torch/utils/_device.py',
|
||||
'torch/utils/_foreach_utils.py',
|
||||
'torch/utils/_freeze.py',
|
||||
'torch/utils/_mode_utils.py',
|
||||
'torch/utils/_python_dispatch.py',
|
||||
'torch/utils/_stats.py',
|
||||
'torch/utils/_traceback.py',
|
||||
'torch/utils/_zip.py',
|
||||
'torch/utils/backcompat/__init__.py',
|
||||
'torch/utils/backend_registration.py',
|
||||
'torch/utils/benchmark/__init__.py',
|
||||
'torch/utils/benchmark/examples/__init__.py',
|
||||
'torch/utils/benchmark/examples/compare.py',
|
||||
|
@ -4,15 +4,16 @@
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
import sys
|
||||
from typing import Any, Callable, TypeVar, cast
|
||||
import warnings
|
||||
from typing import Any, Callable, cast, TypeVar
|
||||
|
||||
|
||||
# Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
|
||||
# 'no_grad' and 'enable_grad').
|
||||
# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
|
||||
FuncType = Callable[..., Any]
|
||||
F = TypeVar('F', bound=FuncType)
|
||||
F = TypeVar("F", bound=FuncType)
|
||||
|
||||
|
||||
def _wrap_generator(ctx_factory, func):
|
||||
@ -22,6 +23,7 @@ def _wrap_generator(ctx_factory, func):
|
||||
The input should be a function that returns a context manager,
|
||||
not a context manager itself, to handle one-shot context managers.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def generator_context(*args, **kwargs):
|
||||
gen = func(*args, **kwargs)
|
||||
@ -83,7 +85,7 @@ def context_decorator(ctx, func):
|
||||
be a multi-shot context manager that can be directly invoked multiple times)
|
||||
or a callable that produces a context manager.
|
||||
"""
|
||||
assert not (callable(ctx) and hasattr(ctx, '__enter__')), (
|
||||
assert not (callable(ctx) and hasattr(ctx, "__enter__")), (
|
||||
f"Passed in {ctx} is both callable and also a valid context manager "
|
||||
"(has __enter__), making it ambiguous which interface to use. If you "
|
||||
"intended to pass a context manager factory, rewrite your call as "
|
||||
@ -92,8 +94,10 @@ def context_decorator(ctx, func):
|
||||
)
|
||||
|
||||
if not callable(ctx):
|
||||
|
||||
def ctx_factory():
|
||||
return ctx
|
||||
|
||||
else:
|
||||
ctx_factory = ctx
|
||||
|
||||
|
@ -2,18 +2,18 @@
|
||||
import collections
|
||||
|
||||
|
||||
Entry = collections.namedtuple('Entry', 'version, hash')
|
||||
Entry = collections.namedtuple("Entry", "version, hash")
|
||||
|
||||
|
||||
def update_hash(seed, value):
|
||||
# Good old boost::hash_combine
|
||||
# https://www.boost.org/doc/libs/1_35_0/doc/html/boost/hash_combine_id241013.html
|
||||
return seed ^ (hash(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2))
|
||||
return seed ^ (hash(value) + 0x9E3779B9 + (seed << 6) + (seed >> 2))
|
||||
|
||||
|
||||
def hash_source_files(hash_value, source_files):
|
||||
for filename in source_files:
|
||||
with open(filename, 'rb') as file:
|
||||
with open(filename, "rb") as file:
|
||||
hash_value = update_hash(hash_value, file.read())
|
||||
return hash_value
|
||||
|
||||
@ -34,15 +34,17 @@ class ExtensionVersioner:
|
||||
entry = self.entries.get(name)
|
||||
return None if entry is None else entry.version
|
||||
|
||||
def bump_version_if_changed(self,
|
||||
name,
|
||||
source_files,
|
||||
build_arguments,
|
||||
build_directory,
|
||||
with_cuda,
|
||||
with_sycl,
|
||||
is_python_module,
|
||||
is_standalone):
|
||||
def bump_version_if_changed(
|
||||
self,
|
||||
name,
|
||||
source_files,
|
||||
build_arguments,
|
||||
build_directory,
|
||||
with_cuda,
|
||||
with_sycl,
|
||||
is_python_module,
|
||||
is_standalone,
|
||||
):
|
||||
hash_value = 0
|
||||
hash_value = hash_source_files(hash_value, source_files)
|
||||
hash_value = hash_build_arguments(hash_value, build_arguments)
|
||||
|
@ -1,13 +1,16 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Optional
|
||||
import torch
|
||||
from torch.overrides import TorchFunctionMode, _pop_mode, _push_mode
|
||||
from torch.utils._contextlib import context_decorator
|
||||
from torch._C import _len_torch_function_stack
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch._C import _len_torch_function_stack
|
||||
from torch.overrides import _pop_mode, _push_mode, TorchFunctionMode
|
||||
from torch.utils._contextlib import context_decorator
|
||||
|
||||
|
||||
CURRENT_DEVICE: Optional[torch.device] = None
|
||||
|
||||
|
||||
@functools.lru_cache(1)
|
||||
def _device_constructors():
|
||||
return {
|
||||
@ -50,9 +53,10 @@ def _device_constructors():
|
||||
# weird ones
|
||||
torch.tensor,
|
||||
torch.as_tensor,
|
||||
torch.scalar_tensor
|
||||
torch.scalar_tensor,
|
||||
}
|
||||
|
||||
|
||||
# NB: This is directly called from C++ in torch/csrc/Device.cpp
|
||||
class DeviceContext(TorchFunctionMode):
|
||||
def __init__(self, device):
|
||||
@ -73,7 +77,6 @@ class DeviceContext(TorchFunctionMode):
|
||||
for mode in reversed(cur_stack):
|
||||
_push_mode(mode)
|
||||
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
global CURRENT_DEVICE
|
||||
CURRENT_DEVICE = self.old_device
|
||||
@ -95,14 +98,16 @@ class DeviceContext(TorchFunctionMode):
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
if func in _device_constructors() and kwargs.get('device') is None:
|
||||
kwargs['device'] = self.device
|
||||
if func in _device_constructors() and kwargs.get("device") is None:
|
||||
kwargs["device"] = self.device
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
# NB: This is directly called from C++ in torch/csrc/Device.cpp
|
||||
def device_decorator(device, func):
|
||||
return context_decorator(lambda: device, func)
|
||||
|
||||
|
||||
def set_device(device):
|
||||
"""
|
||||
Set the default device inside of the wrapped function by decorating it with this function.
|
||||
|
@ -1,17 +1,27 @@
|
||||
from typing import Optional
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.autograd.grad_mode import no_grad
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
|
||||
def _get_foreach_kernels_supported_devices() -> list[str]:
|
||||
r"""Return the device type list that supports foreach kernels."""
|
||||
return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()]
|
||||
|
||||
|
||||
def _get_fused_kernels_supported_devices() -> list[str]:
|
||||
r"""Return the device type list that supports fused kernels in optimizer."""
|
||||
return ["mps", "cuda", "xpu", "hpu", "cpu", torch._C._get_privateuse1_backend_name()]
|
||||
return [
|
||||
"mps",
|
||||
"cuda",
|
||||
"xpu",
|
||||
"hpu",
|
||||
"cpu",
|
||||
torch._C._get_privateuse1_backend_name(),
|
||||
]
|
||||
|
||||
|
||||
TensorListList: TypeAlias = list[list[Optional[Tensor]]]
|
||||
Indices: TypeAlias = list[int]
|
||||
@ -36,9 +46,15 @@ def _group_tensors_by_device_and_dtype(
|
||||
) -> dict[tuple[torch.device, torch.dtype], tuple[TensorListList, Indices]]:
|
||||
return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
|
||||
|
||||
|
||||
def _device_has_foreach_support(device: torch.device) -> bool:
|
||||
return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting()
|
||||
return (
|
||||
device.type in (_get_foreach_kernels_supported_devices() + ["cpu"])
|
||||
and not torch.jit.is_scripting()
|
||||
)
|
||||
|
||||
|
||||
def _has_foreach_support(tensors: list[Tensor], device: torch.device) -> bool:
|
||||
return _device_has_foreach_support(device) and all(t is None or type(t) in _foreach_supported_types for t in tensors)
|
||||
return _device_has_foreach_support(device) and all(
|
||||
t is None or type(t) in _foreach_supported_types for t in tensors
|
||||
)
|
||||
|
@ -3,6 +3,9 @@
|
||||
"""
|
||||
Freeze Python packages.
|
||||
|
||||
|
||||
|
||||
|
||||
Freezing makes it possible to ship arbitrary Python modules as part of a C++
|
||||
library. The Python source of the module is compiled to bytecode and written
|
||||
to `.c` files, to be imported by Python's built-in FrozenImporter.
|
||||
|
@ -1,11 +1,15 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar('T')
|
||||
import torch
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# returns if all are the same mode
|
||||
def all_same_mode(modes):
|
||||
return all(tuple(mode == modes[0] for mode in modes))
|
||||
|
||||
|
||||
no_dispatch = torch._C._DisableTorchDispatch
|
||||
|
@ -1,12 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union, Protocol, overload
|
||||
from collections.abc import Sequence
|
||||
from typing_extensions import TypeIs
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, overload, Protocol, Union
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import torch
|
||||
import torchgen
|
||||
@ -29,8 +28,13 @@ from torch._C import (
|
||||
_is_in_torch_dispatch_mode = False
|
||||
_is_in_non_infra_torch_dispatch_mode = False
|
||||
|
||||
|
||||
def is_in_torch_dispatch_mode(include_infra_modes=True) -> bool:
|
||||
return _is_in_torch_dispatch_mode if include_infra_modes else _is_in_non_infra_torch_dispatch_mode
|
||||
return (
|
||||
_is_in_torch_dispatch_mode
|
||||
if include_infra_modes
|
||||
else _is_in_non_infra_torch_dispatch_mode
|
||||
)
|
||||
|
||||
|
||||
class TorchDispatchMode:
|
||||
@ -79,7 +83,6 @@ class TorchDispatchMode:
|
||||
if not hasattr(self, "old_non_infra_dispatch_mode_flags"):
|
||||
self.old_non_infra_dispatch_mode_flags: deque[bool] = deque() # type: ignore[no-redef]
|
||||
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -93,8 +96,12 @@ class TorchDispatchMode:
|
||||
self._lazy_init_old_dispatch_mode_flags()
|
||||
self.old_dispatch_mode_flags.append(_is_in_torch_dispatch_mode)
|
||||
_is_in_torch_dispatch_mode = True
|
||||
self.old_non_infra_dispatch_mode_flags.append(_is_in_non_infra_torch_dispatch_mode)
|
||||
_is_in_non_infra_torch_dispatch_mode = _is_in_non_infra_torch_dispatch_mode or not self.is_infra_mode()
|
||||
self.old_non_infra_dispatch_mode_flags.append(
|
||||
_is_in_non_infra_torch_dispatch_mode
|
||||
)
|
||||
_is_in_non_infra_torch_dispatch_mode = (
|
||||
_is_in_non_infra_torch_dispatch_mode or not self.is_infra_mode()
|
||||
)
|
||||
_push_mode(self)
|
||||
return self
|
||||
|
||||
@ -107,7 +114,9 @@ class TorchDispatchMode:
|
||||
global _is_in_torch_dispatch_mode
|
||||
_is_in_torch_dispatch_mode = self.old_dispatch_mode_flags.pop()
|
||||
global _is_in_non_infra_torch_dispatch_mode
|
||||
_is_in_non_infra_torch_dispatch_mode = self.old_non_infra_dispatch_mode_flags.pop()
|
||||
_is_in_non_infra_torch_dispatch_mode = (
|
||||
self.old_non_infra_dispatch_mode_flags.pop()
|
||||
)
|
||||
_pop_mode(mb_dk_or_mode_key)
|
||||
|
||||
@classmethod
|
||||
@ -123,7 +132,6 @@ class TorchDispatchMode:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def _get_current_dispatch_mode():
|
||||
stack_len = _len_torch_dispatch_stack()
|
||||
# Return a user mode on the stack if there are any
|
||||
@ -133,19 +141,16 @@ def _get_current_dispatch_mode():
|
||||
|
||||
|
||||
def _detect_infra_mode(key):
|
||||
assert key in [torch._C._TorchDispatchModeKey.FUNCTIONAL, torch._C._TorchDispatchModeKey.PROXY]
|
||||
assert key in [
|
||||
torch._C._TorchDispatchModeKey.FUNCTIONAL,
|
||||
torch._C._TorchDispatchModeKey.PROXY,
|
||||
]
|
||||
from torch._ops import _get_dispatch_mode_pre_dispatch
|
||||
|
||||
pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(
|
||||
key
|
||||
)
|
||||
post_dispatch_mode = torch._C._get_dispatch_mode(
|
||||
key
|
||||
)
|
||||
pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(key)
|
||||
post_dispatch_mode = torch._C._get_dispatch_mode(key)
|
||||
|
||||
assert (pre_dispatch_mode is None) or (
|
||||
post_dispatch_mode is None
|
||||
)
|
||||
assert (pre_dispatch_mode is None) or (post_dispatch_mode is None)
|
||||
|
||||
if pre_dispatch_mode is None:
|
||||
return post_dispatch_mode
|
||||
@ -232,8 +237,8 @@ def _disable_current_modes():
|
||||
_pop_mode_from_pre_dispatch,
|
||||
)
|
||||
from torch._subclasses.functional_tensor import FunctionalTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
|
||||
from torch._subclasses.schema_check_mode import SchemaCheckMode
|
||||
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
|
||||
|
||||
mode_len_pre_dispatch = _len_torch_dispatch_stack_pre_dispatch()
|
||||
old_pre_dispatch_modes = [
|
||||
@ -267,10 +272,7 @@ def _disable_current_modes():
|
||||
raise AssertionError(
|
||||
"Can't have ProxyTorchDispatchMode available both in PreDispatch and Python Key"
|
||||
)
|
||||
if (
|
||||
isinstance(old, SchemaCheckMode)
|
||||
and has_schema_check_mode_in_pre_dispatch
|
||||
):
|
||||
if isinstance(old, SchemaCheckMode) and has_schema_check_mode_in_pre_dispatch:
|
||||
raise AssertionError(
|
||||
"Can't have SchemaCheckMode available both in PreDispatch and Python Key"
|
||||
)
|
||||
@ -298,7 +300,9 @@ class TensorWithFlatten(Protocol):
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int) -> torch.Tensor:
|
||||
def __tensor_unflatten__(
|
||||
inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
|
||||
# It would be really nice to be able to say that the return of
|
||||
@ -331,41 +335,39 @@ class TensorWithFlatten(Protocol):
|
||||
|
||||
@overload
|
||||
def to(
|
||||
self,
|
||||
dtype: torch.types._dtype,
|
||||
non_blocking: bool = False,
|
||||
copy: bool = False,
|
||||
*,
|
||||
memory_format: Optional[torch.memory_format] = None
|
||||
self,
|
||||
dtype: torch.types._dtype,
|
||||
non_blocking: bool = False,
|
||||
copy: bool = False,
|
||||
*,
|
||||
memory_format: Optional[torch.memory_format] = None,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
|
||||
@overload
|
||||
def to(
|
||||
self,
|
||||
device: Optional["torch._prims_common.DeviceLikeType"] = None,
|
||||
dtype: Optional[torch.types._dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
copy: bool = False,
|
||||
*,
|
||||
memory_format: Optional[torch.memory_format] = None
|
||||
self,
|
||||
device: Optional["torch._prims_common.DeviceLikeType"] = None,
|
||||
dtype: Optional[torch.types._dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
copy: bool = False,
|
||||
*,
|
||||
memory_format: Optional[torch.memory_format] = None,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
|
||||
@overload
|
||||
def to(
|
||||
self,
|
||||
other: torch.Tensor,
|
||||
non_blocking: bool = False,
|
||||
copy: bool = False,
|
||||
*,
|
||||
memory_format: Optional[torch.memory_format] = None
|
||||
self,
|
||||
other: torch.Tensor,
|
||||
non_blocking: bool = False,
|
||||
copy: bool = False,
|
||||
*,
|
||||
memory_format: Optional[torch.memory_format] = None,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
|
||||
|
||||
|
||||
|
||||
def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]:
|
||||
"""
|
||||
Returns whether or not a tensor subclass that implements __torch_dispatch__
|
||||
@ -403,10 +405,15 @@ def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]:
|
||||
and hasattr(t, "__tensor_unflatten__")
|
||||
)
|
||||
|
||||
|
||||
def is_traceable_wrapper_subclass_type(t: type) -> TypeIs[type[TensorWithFlatten]]:
|
||||
"""Same as above, but takes a type argument instead of an instance."""
|
||||
return (issubclass(t, torch.Tensor) and t != torch.Tensor
|
||||
and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__"))
|
||||
return (
|
||||
issubclass(t, torch.Tensor)
|
||||
and t != torch.Tensor
|
||||
and hasattr(t, "__tensor_flatten__")
|
||||
and hasattr(t, "__tensor_unflatten__")
|
||||
)
|
||||
|
||||
|
||||
def transform_subclass(t, callback, outer_size=None, outer_stride=None):
|
||||
@ -551,7 +558,9 @@ def get_alias_info(func) -> SchemaInfo:
|
||||
torchgen_schema_str = re.sub(r"=\[[0, ]+\]", "=0", torchgen_schema_str)
|
||||
torchgen_schema_str = re.sub(r"=\[[1, ]+\]", "=1", torchgen_schema_str)
|
||||
# for aten::rot90 / aten:fft_*
|
||||
torchgen_schema_str = re.sub(r"=\[(-?[0-9]+), (-?[0-9]+)\]", r"=[\1,\2]", torchgen_schema_str)
|
||||
torchgen_schema_str = re.sub(
|
||||
r"=\[(-?[0-9]+), (-?[0-9]+)\]", r"=[\1,\2]", torchgen_schema_str
|
||||
)
|
||||
torchgen_schema = torchgen.model.FunctionSchema.parse(torchgen_schema_str)
|
||||
arg_schemas = [
|
||||
AliasInfo(
|
||||
|
@ -3,8 +3,8 @@
|
||||
# AND SCRUB AWAY TORCH NOTIONS THERE.
|
||||
import collections
|
||||
import functools
|
||||
from typing import Callable, TypeVar
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ def count_label(label: str) -> None:
|
||||
prev = simple_call_counter.setdefault(label, 0)
|
||||
simple_call_counter[label] = prev + 1
|
||||
|
||||
|
||||
def count(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
@ -25,4 +26,5 @@ def count(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
simple_call_counter[fn.__qualname__] = 0
|
||||
simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
@ -1,11 +1,12 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from types import TracebackType
|
||||
from typing import Optional
|
||||
import tempfile
|
||||
import traceback
|
||||
import contextlib
|
||||
import inspect
|
||||
import os.path
|
||||
import tempfile
|
||||
import traceback
|
||||
from types import TracebackType
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# This file contains utilities for ensuring dynamically compile()'d
|
||||
# code fragments display their line numbers in backtraces.
|
||||
@ -44,6 +45,7 @@ import os.path
|
||||
# - Before running the compiled code, enter the
|
||||
# report_compile_source_on_error() context manager.
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def report_compile_source_on_error():
|
||||
try:
|
||||
@ -83,15 +85,17 @@ def report_compile_source_on_error():
|
||||
# Don't delete the temporary file so the user can inspect it
|
||||
# TODO: This creates a temporary file for every frame, but we
|
||||
# technically only need one per distinct __compile_source__
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".py") as f:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=".py"
|
||||
) as f:
|
||||
f.write(source)
|
||||
# Create a frame. Python doesn't let you construct
|
||||
# FrameType directly, so just make one with compile
|
||||
frame = tb.tb_frame
|
||||
code = compile('__inspect_currentframe()', f.name, 'eval')
|
||||
code = compile("__inspect_currentframe()", f.name, "eval")
|
||||
code = code.replace(co_name=frame.f_code.co_name)
|
||||
# Python 3.11 only
|
||||
if hasattr(frame.f_code, 'co_linetable'):
|
||||
if hasattr(frame.f_code, "co_linetable"):
|
||||
# We can't copy ALL of the metadata over, because you
|
||||
# can cause Python to segfault this way. What exactly
|
||||
# do we need? We need enough information for
|
||||
@ -109,14 +113,9 @@ def report_compile_source_on_error():
|
||||
fake_frame = eval(
|
||||
code,
|
||||
frame.f_globals,
|
||||
{
|
||||
**frame.f_locals,
|
||||
'__inspect_currentframe': inspect.currentframe
|
||||
}
|
||||
)
|
||||
fake_tb = TracebackType(
|
||||
None, fake_frame, tb.tb_lasti, tb.tb_lineno
|
||||
{**frame.f_locals, "__inspect_currentframe": inspect.currentframe},
|
||||
)
|
||||
fake_tb = TracebackType(None, fake_frame, tb.tb_lasti, tb.tb_lineno)
|
||||
stack.append(fake_tb)
|
||||
else:
|
||||
stack.append(tb)
|
||||
@ -131,6 +130,7 @@ def report_compile_source_on_error():
|
||||
|
||||
raise exc.with_traceback(tb_next) # noqa: B904
|
||||
|
||||
|
||||
def shorten_filename(fn, *, base=None):
|
||||
"""Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user."""
|
||||
if base is None:
|
||||
@ -141,7 +141,8 @@ def shorten_filename(fn, *, base=None):
|
||||
except ValueError:
|
||||
return fn
|
||||
else:
|
||||
return fn[len(prefix) + 1:]
|
||||
return fn[len(prefix) + 1 :]
|
||||
|
||||
|
||||
def format_frame(frame, *, base=None, line=False):
|
||||
"""
|
||||
@ -154,12 +155,14 @@ def format_frame(frame, *, base=None, line=False):
|
||||
extra_line = f"{frame.line} # "
|
||||
return f"{extra_line}{shorten_filename(frame.filename, base=base)}:{frame.lineno} in {frame.name}"
|
||||
|
||||
|
||||
def format_traceback_short(tb):
|
||||
"""Format a TracebackType in a short way, printing only the inner-most frame."""
|
||||
return format_frame(traceback.extract_tb(tb)[-1])
|
||||
|
||||
|
||||
class CapturedTraceback:
|
||||
__slots__ = ['tb', 'skip']
|
||||
__slots__ = ["tb", "skip"]
|
||||
|
||||
def __init__(self, tb, skip=0):
|
||||
self.tb = tb
|
||||
@ -176,15 +179,17 @@ class CapturedTraceback:
|
||||
return traceback.StackSummary()
|
||||
|
||||
return _extract_symbolized_tb(
|
||||
torch._C._profiler.symbolize_tracebacks([self.tb])[0],
|
||||
self.skip
|
||||
torch._C._profiler.symbolize_tracebacks([self.tb])[0], self.skip
|
||||
)
|
||||
|
||||
def __getstate__(self):
|
||||
return (None, {
|
||||
'tb': None, # TB is not pickleable
|
||||
'skip': self.skip,
|
||||
})
|
||||
return (
|
||||
None,
|
||||
{
|
||||
"tb": None, # TB is not pickleable
|
||||
"skip": self.skip,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def extract(*, script=False, cpp=False, skip=0):
|
||||
@ -207,7 +212,7 @@ class CapturedTraceback:
|
||||
torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp),
|
||||
# Elide extract() frame if we don't have script/cpp frames. If
|
||||
# we do have those frames, it doesn't work so force zero.
|
||||
0 if script or cpp else skip + 1
|
||||
0 if script or cpp else skip + 1,
|
||||
)
|
||||
|
||||
def format(self):
|
||||
@ -251,5 +256,5 @@ def _extract_symbolized_tb(tb, skip):
|
||||
"""
|
||||
stack = traceback.StackSummary()
|
||||
for f in reversed(tb[skip:]):
|
||||
stack.append(traceback.FrameSummary(f['filename'], f['line'], f['name']))
|
||||
stack.append(traceback.FrameSummary(f["filename"], f["line"], f["name"]))
|
||||
return stack
|
||||
|
@ -5,6 +5,7 @@ import os
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
|
||||
|
||||
# Exclude some standard library modules to:
|
||||
# 1. Slim down the final zipped file size
|
||||
# 2. Remove functionality we don't want to support.
|
||||
|
@ -1,8 +1,10 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from torch._C import _set_backcompat_broadcast_warn
|
||||
from torch._C import _get_backcompat_broadcast_warn
|
||||
from torch._C import _set_backcompat_keepdim_warn
|
||||
from torch._C import _get_backcompat_keepdim_warn
|
||||
from torch._C import (
|
||||
_get_backcompat_broadcast_warn,
|
||||
_get_backcompat_keepdim_warn,
|
||||
_set_backcompat_broadcast_warn,
|
||||
_set_backcompat_keepdim_warn,
|
||||
)
|
||||
|
||||
|
||||
class Warning:
|
||||
@ -18,5 +20,8 @@ class Warning:
|
||||
|
||||
enabled = property(get_enabled, set_enabled)
|
||||
|
||||
broadcast_warning = Warning(_set_backcompat_broadcast_warn, _get_backcompat_broadcast_warn)
|
||||
|
||||
broadcast_warning = Warning(
|
||||
_set_backcompat_broadcast_warn, _get_backcompat_broadcast_warn
|
||||
)
|
||||
keepdim_warning = Warning(_set_backcompat_keepdim_warn, _get_backcompat_keepdim_warn)
|
||||
|
@ -1,12 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
from torch.overrides import (
|
||||
handle_torch_function,
|
||||
has_torch_function_unary,
|
||||
)
|
||||
from torch._C import _rename_privateuse1_backend, _get_privateuse1_backend_name
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._C import _get_privateuse1_backend_name, _rename_privateuse1_backend
|
||||
from torch.overrides import handle_torch_function, has_torch_function_unary
|
||||
|
||||
|
||||
__all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backend"]
|
||||
|
||||
# TODO: Should use `torch._C._get_privateuse1_backend_name()` to get
|
||||
@ -15,6 +14,7 @@ __all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backe
|
||||
# `_privateuse1_backend_name`.
|
||||
_privateuse1_backend_name = "privateuseone"
|
||||
|
||||
|
||||
def rename_privateuse1_backend(backend_name: str) -> None:
|
||||
r"""
|
||||
Rename the privateuse1 backend device to make it more convenient to use as a device name within PyTorch APIs.
|
||||
@ -78,16 +78,22 @@ def rename_privateuse1_backend(backend_name: str) -> None:
|
||||
global _privateuse1_backend_name
|
||||
_privateuse1_backend_name = backend_name
|
||||
|
||||
|
||||
def _check_register_once(module, attr):
|
||||
if hasattr(module, attr):
|
||||
raise RuntimeError(f"The custom device module of {module} has already been registered with {attr}")
|
||||
raise RuntimeError(
|
||||
f"The custom device module of {module} has already been registered with {attr}"
|
||||
)
|
||||
|
||||
|
||||
def _normalization_device(custom_backend_name: str, device: Optional[Union[int, str, torch.device]] = None) -> int:
|
||||
def _normalization_device(
|
||||
custom_backend_name: str, device: Optional[Union[int, str, torch.device]] = None
|
||||
) -> int:
|
||||
def _get_current_device_index():
|
||||
_get_device_index = "current_device"
|
||||
if hasattr(torch, custom_backend_name) and \
|
||||
hasattr(getattr(torch, custom_backend_name), _get_device_index):
|
||||
if hasattr(torch, custom_backend_name) and hasattr(
|
||||
getattr(torch, custom_backend_name), _get_device_index
|
||||
):
|
||||
return getattr(getattr(torch, custom_backend_name), _get_device_index)()
|
||||
else:
|
||||
# The default device index is 0.
|
||||
@ -122,12 +128,16 @@ def _generate_tensor_methods_for_privateuse1_backend(custom_backend_name: str) -
|
||||
return handle_torch_function(wrap_tensor_backend.__get__, (self,), self) # type: ignore[attr-defined]
|
||||
return self.device.type == custom_backend_name
|
||||
|
||||
_check_register_once(torch.Tensor, f'is_{custom_backend_name}')
|
||||
wrap_tensor_backend.fget.__name__ = f'is_{custom_backend_name}' # type: ignore[attr-defined]
|
||||
setattr(torch.Tensor, f'is_{custom_backend_name}', wrap_tensor_backend)
|
||||
_check_register_once(torch.Tensor, f"is_{custom_backend_name}")
|
||||
wrap_tensor_backend.fget.__name__ = f"is_{custom_backend_name}" # type: ignore[attr-defined]
|
||||
setattr(torch.Tensor, f"is_{custom_backend_name}", wrap_tensor_backend)
|
||||
|
||||
def wrap_tensor_to(self: torch.Tensor, device: Optional[Union[int, torch.device]] = None, non_blocking=False,
|
||||
**kwargs) -> torch.Tensor:
|
||||
def wrap_tensor_to(
|
||||
self: torch.Tensor,
|
||||
device: Optional[Union[int, torch.device]] = None,
|
||||
non_blocking=False,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""Perform Tensor device conversion. Call the to operator implementation.
|
||||
|
||||
.. note::
|
||||
@ -143,9 +153,20 @@ def _generate_tensor_methods_for_privateuse1_backend(custom_backend_name: str) -
|
||||
**kwargs (dict): For compatibility, may contain the key ``memory_format`` argument.
|
||||
"""
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(wrap_tensor_to, (self,), self, device=device, non_blocking=False, **kwargs)
|
||||
return handle_torch_function(
|
||||
wrap_tensor_to,
|
||||
(self,),
|
||||
self,
|
||||
device=device,
|
||||
non_blocking=False,
|
||||
**kwargs,
|
||||
)
|
||||
device_idx = _normalization_device(custom_backend_name, device)
|
||||
return self.to(device=torch.device(f'{custom_backend_name}:{device_idx}'), non_blocking=non_blocking, **kwargs)
|
||||
return self.to(
|
||||
device=torch.device(f"{custom_backend_name}:{device_idx}"),
|
||||
non_blocking=non_blocking,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
_check_register_once(torch.Tensor, custom_backend_name)
|
||||
wrap_tensor_to.__name__ = custom_backend_name
|
||||
@ -159,10 +180,13 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -
|
||||
raise RuntimeError(
|
||||
f"Can not automatically generate {custom_backend_name}() method for torch.nn.Module."
|
||||
f"Because torch.Tensor doesn't has the method {custom_backend_name}()."
|
||||
f"For this error, you can try setting for_tensor=True.")
|
||||
f"For this error, you can try setting for_tensor=True."
|
||||
)
|
||||
|
||||
def wrap_module_to(self: torch.nn.modules.module.T,
|
||||
device: Optional[Union[int, torch.device]] = None) -> torch.nn.modules.module.T:
|
||||
def wrap_module_to(
|
||||
self: torch.nn.modules.module.T,
|
||||
device: Optional[Union[int, torch.device]] = None,
|
||||
) -> torch.nn.modules.module.T:
|
||||
r"""Move all model parameters and buffers to the custom device.
|
||||
|
||||
This also makes associated parameters and buffers different objects. So
|
||||
@ -180,27 +204,37 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -
|
||||
_check_register_once(torch.nn.Module, custom_backend_name)
|
||||
setattr(torch.nn.Module, custom_backend_name, wrap_module_to)
|
||||
|
||||
def _generate_packed_sequence_methods_for_privateuse1_backend(custom_backend_name: str) -> None:
|
||||
|
||||
def _generate_packed_sequence_methods_for_privateuse1_backend(
|
||||
custom_backend_name: str,
|
||||
) -> None:
|
||||
# Generate PackedSequence Module attributes and methods depends on Tensor methods,
|
||||
# so we need to check whether Tensor methods is already registered.
|
||||
if not hasattr(torch.Tensor, f'is_{custom_backend_name}') or \
|
||||
not hasattr(torch.Tensor, custom_backend_name):
|
||||
if not hasattr(torch.Tensor, f"is_{custom_backend_name}") or not hasattr(
|
||||
torch.Tensor, custom_backend_name
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Can not automatically generate is_{custom_backend_name}() or "
|
||||
f"{custom_backend_name}() method for torch.nn.utils.rnn.PackedSequence."
|
||||
f"Because torch.Tensor doesn't has the method is_{custom_backend_name}()"
|
||||
f"or {custom_backend_name}()."
|
||||
f"For this error, you can try setting for_tensor=True.")
|
||||
f"For this error, you can try setting for_tensor=True."
|
||||
)
|
||||
|
||||
@property # type: ignore[misc]
|
||||
def wrap_tensor_backend(self: torch.nn.utils.rnn.PackedSequence) -> bool:
|
||||
return self.data.device.type == custom_backend_name
|
||||
|
||||
_check_register_once(torch.nn.utils.rnn.PackedSequence, f'is_{custom_backend_name}')
|
||||
setattr(torch.nn.utils.rnn.PackedSequence, f'is_{custom_backend_name}', wrap_tensor_backend)
|
||||
_check_register_once(torch.nn.utils.rnn.PackedSequence, f"is_{custom_backend_name}")
|
||||
setattr(
|
||||
torch.nn.utils.rnn.PackedSequence,
|
||||
f"is_{custom_backend_name}",
|
||||
wrap_tensor_backend,
|
||||
)
|
||||
|
||||
def wrap_module_to(self: torch.nn.utils.rnn.PackedSequence,
|
||||
*args, **kwargs) -> torch.nn.utils.rnn.PackedSequence:
|
||||
def wrap_module_to(
|
||||
self: torch.nn.utils.rnn.PackedSequence, *args, **kwargs
|
||||
) -> torch.nn.utils.rnn.PackedSequence:
|
||||
r"""Move all model parameters and buffers to the custom device.
|
||||
|
||||
This also makes associated parameters and buffers different objects. So
|
||||
@ -213,17 +247,21 @@ def _generate_packed_sequence_methods_for_privateuse1_backend(custom_backend_nam
|
||||
Args:
|
||||
device (int, optional): if specified, all parameters will be copied to that device
|
||||
"""
|
||||
ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(*args, **kwargs)
|
||||
ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(
|
||||
*args, **kwargs
|
||||
)
|
||||
if ex.device.type == custom_backend_name:
|
||||
return self.to(*args, **kwargs)
|
||||
kwargs.update({'device': custom_backend_name})
|
||||
kwargs.update({"device": custom_backend_name})
|
||||
return self.to(*args, **kwargs)
|
||||
|
||||
_check_register_once(torch.nn.utils.rnn.PackedSequence, custom_backend_name)
|
||||
setattr(torch.nn.utils.rnn.PackedSequence, custom_backend_name, wrap_module_to)
|
||||
|
||||
def _generate_storage_methods_for_privateuse1_backend(custom_backend_name: str,
|
||||
unsupported_dtype: Optional[list[torch.dtype]] = None) -> None:
|
||||
|
||||
def _generate_storage_methods_for_privateuse1_backend(
|
||||
custom_backend_name: str, unsupported_dtype: Optional[list[torch.dtype]] = None
|
||||
) -> None:
|
||||
# Attribute is registered in the _StorageBase class
|
||||
# and UntypedStorage obtains through inheritance.
|
||||
@property # type: ignore[misc]
|
||||
@ -231,8 +269,10 @@ def _generate_storage_methods_for_privateuse1_backend(custom_backend_name: str,
|
||||
r"""Return the internal :class:`torch.UntypedStorage`."""
|
||||
return self.device.type == custom_backend_name
|
||||
|
||||
_check_register_once(torch.storage._StorageBase, f'is_{custom_backend_name}')
|
||||
setattr(torch.storage._StorageBase, f'is_{custom_backend_name}', wrap_storage_backend)
|
||||
_check_register_once(torch.storage._StorageBase, f"is_{custom_backend_name}")
|
||||
setattr(
|
||||
torch.storage._StorageBase, f"is_{custom_backend_name}", wrap_storage_backend
|
||||
)
|
||||
|
||||
def wrap_storage_to(self, device=None, non_blocking=False):
|
||||
r"""Return a copy of this object in custom device memory.
|
||||
@ -250,16 +290,18 @@ def _generate_storage_methods_for_privateuse1_backend(custom_backend_name: str,
|
||||
# but it depends on the extended function, so this part is temporarily omitted in the automatic generation.
|
||||
device_idx = _normalization_device(custom_backend_name, device)
|
||||
|
||||
if getattr(self, f'is_{custom_backend_name}'):
|
||||
if getattr(self, f"is_{custom_backend_name}"):
|
||||
# storage has already on expected device.
|
||||
if self.get_device() == device_idx:
|
||||
return self
|
||||
# For sparse storage, custom need to extend the implementation by themselves.
|
||||
if self.is_sparse:
|
||||
raise RuntimeError(f"Can not support a sparse storage move to {custom_backend_name} backend")
|
||||
raise RuntimeError(
|
||||
f"Can not support a sparse storage move to {custom_backend_name} backend"
|
||||
)
|
||||
# create untyped_storage and copy data
|
||||
untyped_storage = torch.UntypedStorage(
|
||||
self.size(), device=torch.device(f'{custom_backend_name}:{device_idx}')
|
||||
self.size(), device=torch.device(f"{custom_backend_name}:{device_idx}")
|
||||
)
|
||||
untyped_storage.copy_(self, non_blocking)
|
||||
return untyped_storage
|
||||
@ -275,27 +317,38 @@ def _generate_storage_methods_for_privateuse1_backend(custom_backend_name: str,
|
||||
torch.storage._warn_typed_storage_removal()
|
||||
return self._untyped_storage.device.type == custom_backend_name
|
||||
|
||||
_check_register_once(torch.TypedStorage, f'is_{custom_backend_name}')
|
||||
setattr(torch.storage.TypedStorage, f'is_{custom_backend_name}', wrap_typed_storage_backend)
|
||||
_check_register_once(torch.TypedStorage, f"is_{custom_backend_name}")
|
||||
setattr(
|
||||
torch.storage.TypedStorage,
|
||||
f"is_{custom_backend_name}",
|
||||
wrap_typed_storage_backend,
|
||||
)
|
||||
|
||||
def wrap_typed_storage_to(self: torch.storage.TypedStorage,
|
||||
device=None, non_blocking=False, **kwargs) -> torch.storage.TypedStorage:
|
||||
def wrap_typed_storage_to(
|
||||
self: torch.storage.TypedStorage, device=None, non_blocking=False, **kwargs
|
||||
) -> torch.storage.TypedStorage:
|
||||
torch.storage._warn_typed_storage_removal()
|
||||
if unsupported_dtype and self.dtype in unsupported_dtype:
|
||||
raise RuntimeError(f"Cannot create {custom_backend_name} storage "
|
||||
f"as {self.dtype} dtype is not supported by this backend")
|
||||
raise RuntimeError(
|
||||
f"Cannot create {custom_backend_name} storage "
|
||||
f"as {self.dtype} dtype is not supported by this backend"
|
||||
)
|
||||
custom_backend_storage: torch.UntypedStorage = getattr(
|
||||
self._untyped_storage, custom_backend_name)(device, non_blocking, **kwargs)
|
||||
self._untyped_storage, custom_backend_name
|
||||
)(device, non_blocking, **kwargs)
|
||||
return self._new_wrapped_storage(custom_backend_storage)
|
||||
|
||||
_check_register_once(torch.TypedStorage, custom_backend_name)
|
||||
setattr(torch.TypedStorage, custom_backend_name, wrap_typed_storage_to)
|
||||
|
||||
|
||||
def generate_methods_for_privateuse1_backend(for_tensor: bool = True, for_module: bool = True,
|
||||
for_packed_sequence: bool = True,
|
||||
for_storage: bool = False,
|
||||
unsupported_dtype: Optional[list[torch.dtype]] = None) -> None:
|
||||
def generate_methods_for_privateuse1_backend(
|
||||
for_tensor: bool = True,
|
||||
for_module: bool = True,
|
||||
for_packed_sequence: bool = True,
|
||||
for_storage: bool = False,
|
||||
unsupported_dtype: Optional[list[torch.dtype]] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Automatically generate attributes and methods for the custom backend after rename privateuse1 backend.
|
||||
|
||||
@ -337,11 +390,14 @@ def generate_methods_for_privateuse1_backend(for_tensor: bool = True, for_module
|
||||
_generate_module_methods_for_privateuse1_backend(custom_backend_name)
|
||||
|
||||
if for_storage:
|
||||
_generate_storage_methods_for_privateuse1_backend(custom_backend_name, unsupported_dtype)
|
||||
_generate_storage_methods_for_privateuse1_backend(
|
||||
custom_backend_name, unsupported_dtype
|
||||
)
|
||||
|
||||
if for_packed_sequence:
|
||||
_generate_packed_sequence_methods_for_privateuse1_backend(custom_backend_name)
|
||||
|
||||
|
||||
def _get_custom_mod_func(func_name: str):
|
||||
r"""
|
||||
Return the func named `func_name` defined in custom device module. If not defined,
|
||||
@ -370,12 +426,14 @@ def _get_custom_mod_func(func_name: str):
|
||||
it is marked as private. It is a convenience function for backend implementers to
|
||||
more easily call the hooks into their backend extensions.
|
||||
"""
|
||||
assert isinstance(func_name, str), f"func_name must be `str`, but got `{type(func_name)}`."
|
||||
assert isinstance(
|
||||
func_name, str
|
||||
), f"func_name must be `str`, but got `{type(func_name)}`."
|
||||
backend_name = _get_privateuse1_backend_name()
|
||||
custom_device_mod = getattr(torch, backend_name, None) # type: ignore[arg-type]
|
||||
function = getattr(custom_device_mod, func_name, None) # type: ignore[arg-type]
|
||||
if custom_device_mod is None or function is None:
|
||||
message = f'Try to call torch.{backend_name}.{func_name}. The backend must register a custom backend '
|
||||
message = f"Try to call torch.{backend_name}.{func_name}. The backend must register a custom backend "
|
||||
message += f"module with `torch._register_device_module('{backend_name}', BackendModule)`. And "
|
||||
message += f"BackendModule needs to have the following API's:\n `{func_name}(*args, **kwargs)`. \n"
|
||||
raise RuntimeError(message)
|
||||
|
Reference in New Issue
Block a user