mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[remove untyped defs] batch 1 (#157011)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157011 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
fee2377f9e
commit
7709ff5512
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -10,7 +10,7 @@ LOAD_TENSOR_READER: Optional[ContentStoreReader] = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def load_tensor_reader(loc):
|
||||
def load_tensor_reader(loc: str) -> Generator[None, None, None]:
|
||||
global LOAD_TENSOR_READER
|
||||
assert LOAD_TENSOR_READER is None
|
||||
# load_tensor is an "op", and we will play merry hell on
|
||||
@ -26,14 +26,20 @@ def load_tensor_reader(loc):
|
||||
LOAD_TENSOR_READER = None
|
||||
|
||||
|
||||
def register_debug_prims():
|
||||
def register_debug_prims() -> None:
|
||||
torch.library.define(
|
||||
"debugprims::load_tensor",
|
||||
"(str name, int[] size, int[] stride, *, ScalarType dtype, Device device) -> Tensor",
|
||||
)
|
||||
|
||||
@torch.library.impl("debugprims::load_tensor", "BackendSelect")
|
||||
def load_tensor_factory(name, size, stride, dtype, device):
|
||||
def load_tensor_factory(
|
||||
name: str,
|
||||
size: Sequence[int],
|
||||
stride: Sequence[int],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
if LOAD_TENSOR_READER is None:
|
||||
from torch._dynamo.testing import rand_strided
|
||||
|
||||
@ -50,5 +56,5 @@ def register_debug_prims():
|
||||
# Unlike the other properties, we will do coercions for dtype
|
||||
# mismatch
|
||||
if r.dtype != dtype:
|
||||
r = clone_input(r, dtype=dtype)
|
||||
r = clone_input(r, dtype=dtype) # type: ignore[no-untyped-call]
|
||||
return r
|
||||
|
@ -1,11 +1,18 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.ao.nn.intrinsic as nni
|
||||
import torch.ao.nn.qat as nnqat
|
||||
import torch.nn.functional as F
|
||||
from torch.ao.nn.intrinsic.modules.fused import _FusedModule
|
||||
|
||||
|
||||
class LinearReLU(nnqat.Linear, nni._FusedModule):
|
||||
__all__ = ["LinearReLU"]
|
||||
|
||||
|
||||
class LinearReLU(nnqat.Linear, _FusedModule):
|
||||
r"""
|
||||
A LinearReLU module fused from Linear and ReLU modules, attached with
|
||||
FakeQuantize modules for weight, used in
|
||||
@ -29,19 +36,29 @@ class LinearReLU(nnqat.Linear, nni._FusedModule):
|
||||
torch.Size([128, 30])
|
||||
"""
|
||||
|
||||
_FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment]
|
||||
_FLOAT_MODULE = nni.LinearReLU
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True, qconfig=None):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
qconfig: Optional[object] = None,
|
||||
) -> None:
|
||||
super().__init__(in_features, out_features, bias, qconfig)
|
||||
|
||||
def forward(self, input):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias))
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
return super().from_float(mod, use_precomputed_fake_quant)
|
||||
def from_float(
|
||||
cls,
|
||||
mod: torch.nn.Module,
|
||||
use_precomputed_fake_quant: bool = False,
|
||||
) -> LinearReLU:
|
||||
return super().from_float(mod, use_precomputed_fake_quant) # type: ignore[no-untyped-call,no-any-return]
|
||||
|
||||
def to_float(self):
|
||||
def to_float(self) -> nni.LinearReLU:
|
||||
linear = torch.nn.Linear(
|
||||
self.in_features, self.out_features, self.bias is not None
|
||||
)
|
||||
@ -49,4 +66,4 @@ class LinearReLU(nnqat.Linear, nni._FusedModule):
|
||||
if self.bias is not None:
|
||||
linear.bias = torch.nn.Parameter(self.bias.detach())
|
||||
relu = torch.nn.ReLU()
|
||||
return torch.ao.nn.intrinsic.LinearReLU(linear, relu)
|
||||
return torch.ao.nn.intrinsic.LinearReLU(linear, relu) # type: ignore[no-untyped-call]
|
||||
|
@ -1,9 +1,10 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .base_structured_sparsifier import BaseStructuredSparsifier, FakeStructuredSparsity
|
||||
from .base_structured_sparsifier import BaseStructuredSparsifier
|
||||
from .parametrization import FakeStructuredSparsity
|
||||
|
||||
|
||||
class LSTMSaliencyPruner(BaseStructuredSparsifier):
|
||||
@ -25,7 +26,7 @@ class LSTMSaliencyPruner(BaseStructuredSparsifier):
|
||||
This applies to both weight_ih_l{k} and weight_hh_l{k}.
|
||||
"""
|
||||
|
||||
def update_mask(self, module, tensor_name, **kwargs):
|
||||
def update_mask(self, module: nn.Module, tensor_name: str, **kwargs: Any) -> None:
|
||||
weights = getattr(module, tensor_name)
|
||||
|
||||
for p in getattr(module.parametrizations, tensor_name):
|
||||
|
@ -1,5 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from typing import Callable, Union
|
||||
|
||||
from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier
|
||||
|
||||
from .base_scheduler import BaseScheduler
|
||||
|
||||
@ -30,7 +32,13 @@ class LambdaSL(BaseScheduler):
|
||||
>>> scheduler.step()
|
||||
"""
|
||||
|
||||
def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False):
|
||||
def __init__(
|
||||
self,
|
||||
sparsifier: BaseSparsifier,
|
||||
sl_lambda: Union[Callable[[int], float], list[Callable[[int], float]]],
|
||||
last_epoch: int = -1,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
self.sparsifier = sparsifier
|
||||
|
||||
if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple):
|
||||
@ -41,9 +49,9 @@ class LambdaSL(BaseScheduler):
|
||||
f"Expected {len(sparsifier.groups)} lr_lambdas, but got {len(sl_lambda)}"
|
||||
)
|
||||
self.sl_lambdas = list(sl_lambda)
|
||||
super().__init__(sparsifier, last_epoch, verbose)
|
||||
super().__init__(sparsifier, last_epoch, verbose) # type: ignore[no-untyped-call]
|
||||
|
||||
def get_sl(self):
|
||||
def get_sl(self) -> list[float]:
|
||||
if not self._get_sl_called_within_step:
|
||||
warnings.warn(
|
||||
"To get the last sparsity level computed by the scheduler, "
|
||||
|
@ -1,9 +1,15 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def is_available():
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
|
||||
def is_available() -> bool:
|
||||
return hasattr(torch._C, "_dist_autograd_init")
|
||||
|
||||
|
||||
@ -25,6 +31,8 @@ if is_available():
|
||||
get_gradients,
|
||||
)
|
||||
|
||||
__all__ = ["context", "is_available"]
|
||||
|
||||
|
||||
class context:
|
||||
"""
|
||||
@ -45,9 +53,14 @@ class context:
|
||||
>>> dist_autograd.backward(context_id, [loss])
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> int:
|
||||
self.autograd_context = _new_context()
|
||||
return self.autograd_context._context_id()
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
_release_context(self.autograd_context._context_id())
|
||||
|
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""
|
||||
PyTorch Profiler is a tool that allows the collection of performance metrics during training and inference.
|
||||
Profiler's context manager API can be used to better understand what model operators are the most expensive,
|
||||
@ -9,12 +8,14 @@ examine their input shapes and stack traces, study device kernel activity and vi
|
||||
|
||||
"""
|
||||
import os
|
||||
from typing import Any
|
||||
from typing_extensions import TypeVarTuple, Unpack
|
||||
|
||||
from torch._C._autograd import _supported_activities, DeviceType, kineto_available
|
||||
from torch._C._profiler import _ExperimentalConfig, ProfilerActivity, RecordScope
|
||||
from torch._environment import is_fbcode
|
||||
from torch.autograd.profiler import KinetoStepTracker, record_function
|
||||
from torch.optim.optimizer import register_optimizer_step_post_hook
|
||||
from torch.optim.optimizer import Optimizer, register_optimizer_step_post_hook
|
||||
|
||||
from .profiler import (
|
||||
_KinetoProfile,
|
||||
@ -43,7 +44,12 @@ __all__ = [
|
||||
from . import itt
|
||||
|
||||
|
||||
def _optimizer_post_hook(optimizer, args, kwargs):
|
||||
_Ts = TypeVarTuple("_Ts")
|
||||
|
||||
|
||||
def _optimizer_post_hook(
|
||||
optimizer: Optimizer, args: tuple[Unpack[_Ts]], kwargs: dict[str, Any]
|
||||
) -> None:
|
||||
KinetoStepTracker.increment_step("Optimizer")
|
||||
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Utility classes & functions for data loading. Code in this folder is mostly used by ../dataloder.py.
|
||||
|
||||
A lot of multiprocessing is used in data loading, which only supports running
|
||||
@ -43,7 +42,7 @@ except ModuleNotFoundError:
|
||||
HAS_NUMPY = False
|
||||
|
||||
|
||||
def _set_python_exit_flag():
|
||||
def _set_python_exit_flag() -> None:
|
||||
global python_exit_status
|
||||
python_exit_status = True
|
||||
|
||||
|
@ -1,14 +1,17 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
import warnings
|
||||
from collections.abc import Iterable, Iterator, Sized
|
||||
from typing import TypeVar
|
||||
|
||||
from torch.utils.data.datapipes.datapipe import IterDataPipe
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
__all__ = ["IterableWrapperIterDataPipe"]
|
||||
|
||||
|
||||
class IterableWrapperIterDataPipe(IterDataPipe):
|
||||
class IterableWrapperIterDataPipe(IterDataPipe[_T]):
|
||||
r"""
|
||||
Wraps an iterable object to create an IterDataPipe.
|
||||
|
||||
@ -30,11 +33,11 @@ class IterableWrapperIterDataPipe(IterDataPipe):
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
"""
|
||||
|
||||
def __init__(self, iterable, deepcopy=True):
|
||||
def __init__(self, iterable: Iterable[_T], deepcopy: bool = True) -> None:
|
||||
self.iterable = iterable
|
||||
self.deepcopy = deepcopy
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[_T]:
|
||||
source_data = self.iterable
|
||||
if self.deepcopy:
|
||||
try:
|
||||
@ -50,5 +53,7 @@ class IterableWrapperIterDataPipe(IterDataPipe):
|
||||
)
|
||||
yield from source_data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.iterable)
|
||||
def __len__(self) -> int:
|
||||
if isinstance(self.iterable, Sized):
|
||||
return len(self.iterable)
|
||||
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
|
||||
|
@ -1,14 +1,17 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
import warnings
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, TypeVar, Union
|
||||
|
||||
from torch.utils.data.datapipes.datapipe import MapDataPipe
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
__all__ = ["SequenceWrapperMapDataPipe"]
|
||||
|
||||
|
||||
class SequenceWrapperMapDataPipe(MapDataPipe):
|
||||
class SequenceWrapperMapDataPipe(MapDataPipe[_T]):
|
||||
r"""
|
||||
Wraps a sequence object into a MapDataPipe.
|
||||
|
||||
@ -33,7 +36,11 @@ class SequenceWrapperMapDataPipe(MapDataPipe):
|
||||
100
|
||||
"""
|
||||
|
||||
def __init__(self, sequence, deepcopy=True):
|
||||
sequence: Union[Sequence[_T], Mapping[Any, _T]]
|
||||
|
||||
def __init__(
|
||||
self, sequence: Union[Sequence[_T], Mapping[Any, _T]], deepcopy: bool = True
|
||||
) -> None:
|
||||
if deepcopy:
|
||||
try:
|
||||
self.sequence = copy.deepcopy(sequence)
|
||||
@ -46,8 +53,8 @@ class SequenceWrapperMapDataPipe(MapDataPipe):
|
||||
else:
|
||||
self.sequence = sequence
|
||||
|
||||
def __getitem__(self, index):
|
||||
def __getitem__(self, index: int) -> _T:
|
||||
return self.sequence[index]
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.sequence)
|
||||
|
@ -1,11 +1,13 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Optional
|
||||
import torch
|
||||
|
||||
from typing import Optional, Union
|
||||
from collections.abc import Sequence
|
||||
from tensorboard.compat.proto.node_def_pb2 import NodeDef
|
||||
from tensorboard.compat.proto.attr_value_pb2 import AttrValue
|
||||
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
|
||||
|
||||
|
||||
def attr_value_proto(dtype, shape, s):
|
||||
def attr_value_proto(dtype: object, shape: Optional[Sequence[int]], s: Optional[str]) -> dict[str, AttrValue]:
|
||||
"""Create a dict of objects matching a NodeDef's attr field.
|
||||
|
||||
Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto
|
||||
@ -21,7 +23,7 @@ def attr_value_proto(dtype, shape, s):
|
||||
return attr
|
||||
|
||||
|
||||
def tensor_shape_proto(outputsize):
|
||||
def tensor_shape_proto(outputsize: Sequence[int]) -> TensorShapeProto:
|
||||
"""Create an object matching a tensor_shape field.
|
||||
|
||||
Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto .
|
||||
@ -30,14 +32,14 @@ def tensor_shape_proto(outputsize):
|
||||
|
||||
|
||||
def node_proto(
|
||||
name,
|
||||
op="UnSpecified",
|
||||
input=None,
|
||||
dtype=None,
|
||||
shape: Optional[tuple] = None,
|
||||
outputsize=None,
|
||||
attributes="",
|
||||
):
|
||||
name: str,
|
||||
op: str = "UnSpecified",
|
||||
input: Optional[Union[list[str], str]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
shape: Optional[tuple[int, ...]] = None,
|
||||
outputsize: Optional[Sequence[int]] = None,
|
||||
attributes: str = "",
|
||||
) -> NodeDef:
|
||||
"""Create an object matching a NodeDef.
|
||||
|
||||
Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/node_def.proto .
|
||||
|
Reference in New Issue
Block a user