[remove untyped defs] batch 1 (#157011)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157011
Approved by: https://github.com/Skylion007
This commit is contained in:
Bob Ren
2025-06-30 13:13:55 -07:00
committed by PyTorch MergeBot
parent fee2377f9e
commit 7709ff5512
10 changed files with 118 additions and 54 deletions

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
import contextlib
from collections.abc import Generator, Sequence
from typing import Optional
import torch
@ -10,7 +10,7 @@ LOAD_TENSOR_READER: Optional[ContentStoreReader] = None
@contextlib.contextmanager
def load_tensor_reader(loc):
def load_tensor_reader(loc: str) -> Generator[None, None, None]:
global LOAD_TENSOR_READER
assert LOAD_TENSOR_READER is None
# load_tensor is an "op", and we will play merry hell on
@ -26,14 +26,20 @@ def load_tensor_reader(loc):
LOAD_TENSOR_READER = None
def register_debug_prims():
def register_debug_prims() -> None:
torch.library.define(
"debugprims::load_tensor",
"(str name, int[] size, int[] stride, *, ScalarType dtype, Device device) -> Tensor",
)
@torch.library.impl("debugprims::load_tensor", "BackendSelect")
def load_tensor_factory(name, size, stride, dtype, device):
def load_tensor_factory(
name: str,
size: Sequence[int],
stride: Sequence[int],
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
if LOAD_TENSOR_READER is None:
from torch._dynamo.testing import rand_strided
@ -50,5 +56,5 @@ def register_debug_prims():
# Unlike the other properties, we will do coercions for dtype
# mismatch
if r.dtype != dtype:
r = clone_input(r, dtype=dtype)
r = clone_input(r, dtype=dtype) # type: ignore[no-untyped-call]
return r

View File

@ -1,11 +1,18 @@
# mypy: allow-untyped-defs
from __future__ import annotations
from typing import Optional
import torch
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.qat as nnqat
import torch.nn.functional as F
from torch.ao.nn.intrinsic.modules.fused import _FusedModule
class LinearReLU(nnqat.Linear, nni._FusedModule):
__all__ = ["LinearReLU"]
class LinearReLU(nnqat.Linear, _FusedModule):
r"""
A LinearReLU module fused from Linear and ReLU modules, attached with
FakeQuantize modules for weight, used in
@ -29,19 +36,29 @@ class LinearReLU(nnqat.Linear, nni._FusedModule):
torch.Size([128, 30])
"""
_FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment]
_FLOAT_MODULE = nni.LinearReLU
def __init__(self, in_features, out_features, bias=True, qconfig=None):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
qconfig: Optional[object] = None,
) -> None:
super().__init__(in_features, out_features, bias, qconfig)
def forward(self, input):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias))
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant)
def from_float(
cls,
mod: torch.nn.Module,
use_precomputed_fake_quant: bool = False,
) -> LinearReLU:
return super().from_float(mod, use_precomputed_fake_quant) # type: ignore[no-untyped-call,no-any-return]
def to_float(self):
def to_float(self) -> nni.LinearReLU:
linear = torch.nn.Linear(
self.in_features, self.out_features, self.bias is not None
)
@ -49,4 +66,4 @@ class LinearReLU(nnqat.Linear, nni._FusedModule):
if self.bias is not None:
linear.bias = torch.nn.Parameter(self.bias.detach())
relu = torch.nn.ReLU()
return torch.ao.nn.intrinsic.LinearReLU(linear, relu)
return torch.ao.nn.intrinsic.LinearReLU(linear, relu) # type: ignore[no-untyped-call]

View File

@ -1,9 +1,10 @@
# mypy: allow-untyped-defs
from typing import cast
from typing import Any, cast
import torch
from torch import nn
from .base_structured_sparsifier import BaseStructuredSparsifier, FakeStructuredSparsity
from .base_structured_sparsifier import BaseStructuredSparsifier
from .parametrization import FakeStructuredSparsity
class LSTMSaliencyPruner(BaseStructuredSparsifier):
@ -25,7 +26,7 @@ class LSTMSaliencyPruner(BaseStructuredSparsifier):
This applies to both weight_ih_l{k} and weight_hh_l{k}.
"""
def update_mask(self, module, tensor_name, **kwargs):
def update_mask(self, module: nn.Module, tensor_name: str, **kwargs: Any) -> None:
weights = getattr(module, tensor_name)
for p in getattr(module.parametrizations, tensor_name):

View File

@ -1,5 +1,7 @@
# mypy: allow-untyped-defs
import warnings
from typing import Callable, Union
from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier
from .base_scheduler import BaseScheduler
@ -30,7 +32,13 @@ class LambdaSL(BaseScheduler):
>>> scheduler.step()
"""
def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False):
def __init__(
self,
sparsifier: BaseSparsifier,
sl_lambda: Union[Callable[[int], float], list[Callable[[int], float]]],
last_epoch: int = -1,
verbose: bool = False,
) -> None:
self.sparsifier = sparsifier
if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple):
@ -41,9 +49,9 @@ class LambdaSL(BaseScheduler):
f"Expected {len(sparsifier.groups)} lr_lambdas, but got {len(sl_lambda)}"
)
self.sl_lambdas = list(sl_lambda)
super().__init__(sparsifier, last_epoch, verbose)
super().__init__(sparsifier, last_epoch, verbose) # type: ignore[no-untyped-call]
def get_sl(self):
def get_sl(self) -> list[float]:
if not self._get_sl_called_within_step:
warnings.warn(
"To get the last sparsity level computed by the scheduler, "

View File

@ -1,9 +1,15 @@
# mypy: allow-untyped-defs
from __future__ import annotations
from typing import Any, TYPE_CHECKING
import torch
def is_available():
if TYPE_CHECKING:
from types import TracebackType
def is_available() -> bool:
return hasattr(torch._C, "_dist_autograd_init")
@ -25,6 +31,8 @@ if is_available():
get_gradients,
)
__all__ = ["context", "is_available"]
class context:
"""
@ -45,9 +53,14 @@ class context:
>>> dist_autograd.backward(context_id, [loss])
"""
def __enter__(self):
def __enter__(self) -> int:
self.autograd_context = _new_context()
return self.autograd_context._context_id()
def __exit__(self, type, value, traceback):
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
_release_context(self.autograd_context._context_id())

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-defs
r"""
PyTorch Profiler is a tool that allows the collection of performance metrics during training and inference.
Profiler's context manager API can be used to better understand what model operators are the most expensive,
@ -9,12 +8,14 @@ examine their input shapes and stack traces, study device kernel activity and vi
"""
import os
from typing import Any
from typing_extensions import TypeVarTuple, Unpack
from torch._C._autograd import _supported_activities, DeviceType, kineto_available
from torch._C._profiler import _ExperimentalConfig, ProfilerActivity, RecordScope
from torch._environment import is_fbcode
from torch.autograd.profiler import KinetoStepTracker, record_function
from torch.optim.optimizer import register_optimizer_step_post_hook
from torch.optim.optimizer import Optimizer, register_optimizer_step_post_hook
from .profiler import (
_KinetoProfile,
@ -43,7 +44,12 @@ __all__ = [
from . import itt
def _optimizer_post_hook(optimizer, args, kwargs):
_Ts = TypeVarTuple("_Ts")
def _optimizer_post_hook(
optimizer: Optimizer, args: tuple[Unpack[_Ts]], kwargs: dict[str, Any]
) -> None:
KinetoStepTracker.increment_step("Optimizer")

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-defs
r"""Utility classes & functions for data loading. Code in this folder is mostly used by ../dataloder.py.
A lot of multiprocessing is used in data loading, which only supports running
@ -43,7 +42,7 @@ except ModuleNotFoundError:
HAS_NUMPY = False
def _set_python_exit_flag():
def _set_python_exit_flag() -> None:
global python_exit_status
python_exit_status = True

View File

@ -1,14 +1,17 @@
# mypy: allow-untyped-defs
import copy
import warnings
from collections.abc import Iterable, Iterator, Sized
from typing import TypeVar
from torch.utils.data.datapipes.datapipe import IterDataPipe
_T = TypeVar("_T")
__all__ = ["IterableWrapperIterDataPipe"]
class IterableWrapperIterDataPipe(IterDataPipe):
class IterableWrapperIterDataPipe(IterDataPipe[_T]):
r"""
Wraps an iterable object to create an IterDataPipe.
@ -30,11 +33,11 @@ class IterableWrapperIterDataPipe(IterDataPipe):
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
def __init__(self, iterable, deepcopy=True):
def __init__(self, iterable: Iterable[_T], deepcopy: bool = True) -> None:
self.iterable = iterable
self.deepcopy = deepcopy
def __iter__(self):
def __iter__(self) -> Iterator[_T]:
source_data = self.iterable
if self.deepcopy:
try:
@ -50,5 +53,7 @@ class IterableWrapperIterDataPipe(IterDataPipe):
)
yield from source_data
def __len__(self):
return len(self.iterable)
def __len__(self) -> int:
if isinstance(self.iterable, Sized):
return len(self.iterable)
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")

View File

@ -1,14 +1,17 @@
# mypy: allow-untyped-defs
import copy
import warnings
from collections.abc import Mapping, Sequence
from typing import Any, TypeVar, Union
from torch.utils.data.datapipes.datapipe import MapDataPipe
_T = TypeVar("_T")
__all__ = ["SequenceWrapperMapDataPipe"]
class SequenceWrapperMapDataPipe(MapDataPipe):
class SequenceWrapperMapDataPipe(MapDataPipe[_T]):
r"""
Wraps a sequence object into a MapDataPipe.
@ -33,7 +36,11 @@ class SequenceWrapperMapDataPipe(MapDataPipe):
100
"""
def __init__(self, sequence, deepcopy=True):
sequence: Union[Sequence[_T], Mapping[Any, _T]]
def __init__(
self, sequence: Union[Sequence[_T], Mapping[Any, _T]], deepcopy: bool = True
) -> None:
if deepcopy:
try:
self.sequence = copy.deepcopy(sequence)
@ -46,8 +53,8 @@ class SequenceWrapperMapDataPipe(MapDataPipe):
else:
self.sequence = sequence
def __getitem__(self, index):
def __getitem__(self, index: int) -> _T:
return self.sequence[index]
def __len__(self):
def __len__(self) -> int:
return len(self.sequence)

View File

@ -1,11 +1,13 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
from typing import Optional, Union
from collections.abc import Sequence
from tensorboard.compat.proto.node_def_pb2 import NodeDef
from tensorboard.compat.proto.attr_value_pb2 import AttrValue
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
def attr_value_proto(dtype, shape, s):
def attr_value_proto(dtype: object, shape: Optional[Sequence[int]], s: Optional[str]) -> dict[str, AttrValue]:
"""Create a dict of objects matching a NodeDef's attr field.
Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto
@ -21,7 +23,7 @@ def attr_value_proto(dtype, shape, s):
return attr
def tensor_shape_proto(outputsize):
def tensor_shape_proto(outputsize: Sequence[int]) -> TensorShapeProto:
"""Create an object matching a tensor_shape field.
Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto .
@ -30,14 +32,14 @@ def tensor_shape_proto(outputsize):
def node_proto(
name,
op="UnSpecified",
input=None,
dtype=None,
shape: Optional[tuple] = None,
outputsize=None,
attributes="",
):
name: str,
op: str = "UnSpecified",
input: Optional[Union[list[str], str]] = None,
dtype: Optional[torch.dtype] = None,
shape: Optional[tuple[int, ...]] = None,
outputsize: Optional[Sequence[int]] = None,
attributes: str = "",
) -> NodeDef:
"""Create an object matching a NodeDef.
Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/node_def.proto .