Compare commits

...

7 Commits

Author SHA1 Message Date
da54f617bc Update
[ghstack-poisoned]
2025-11-10 11:51:45 +00:00
6676677c78 Update
[ghstack-poisoned]
2025-11-06 15:44:19 +00:00
7dd93abde7 Update
[ghstack-poisoned]
2025-11-06 15:35:44 +00:00
b657c266cf Update
[ghstack-poisoned]
2025-11-06 15:20:54 +00:00
44d157de4e Update (base update)
[ghstack-poisoned]
2025-11-06 15:20:54 +00:00
767fcf270d Update
[ghstack-poisoned]
2025-11-06 13:31:19 +00:00
c6a7c32b40 Update (base update)
[ghstack-poisoned]
2025-11-06 13:31:19 +00:00
5 changed files with 54 additions and 42 deletions

View File

@ -91,7 +91,7 @@ class Logger:
broadcast_buffers: bool,
has_sync_bn: bool,
static_graph: bool,
): ...
) -> None: ...
def set_runtime_stats_and_log(self) -> None: ...
def set_error_and_log(self, error: str) -> None: ...
def _get_ddp_logging_data(self) -> DDPLoggingData: ...
@ -103,15 +103,15 @@ class _WorkerServer:
def __init__(self, socket_path: str) -> None: ...
def shutdown(self) -> None: ...
def get_debug_level(): ...
def set_debug_level(): ...
def set_debug_level_from_env(): ...
class DebugLevel(Enum):
OFF = ...
INFO = ...
DETAIL = ...
def get_debug_level() -> DebugLevel: ...
def set_debug_level(level: DebugLevel) -> None: ...
def set_debug_level_from_env() -> None: ...
class ReduceOp:
# pyrefly: ignore # unknown-name
def __init__(self, op: RedOpType) -> None: ...
@ -195,7 +195,7 @@ class AllToAllOptions:
asyncOp: bool
class Store:
def set(self, key: str, value: str): ...
def set(self, key: str, value: str) -> None: ...
def get(self, key: str) -> bytes: ...
def add(self, key: str, value: int) -> int: ...
def check(self, keys: list[str]) -> bool: ...
@ -207,11 +207,11 @@ class Store:
) -> bytes: ...
def delete_key(self, key: str) -> bool: ...
def num_keys(self) -> int: ...
def set_timeout(self, timeout: timedelta): ...
def set_timeout(self, timeout: timedelta) -> None: ...
@overload
def wait(self, keys: list[str]): ...
@overload
def wait(self, keys: list[str], timeout: timedelta): ...
def wait(self, keys: list[str], timeout: timedelta) -> None: ...
def queue_pop(self, key: str, block: bool = True) -> bytes: ...
def queue_push(self, key: str, value: Union[bytes, str]) -> None: ...
def queue_len(self, key: str) -> int: ...

View File

@ -23,6 +23,7 @@ import weakref
from typing import ( # noqa: UP035, F401 # (Dict, List, Tuple) imported by torch.jit.annotations
Any,
Callable,
Class,
Dict,
Final,
ForwardRef,
@ -30,6 +31,7 @@ from typing import ( # noqa: UP035, F401 # (Dict, List, Tuple) imported by tor
get_origin,
List,
Optional,
Protocol,
Tuple,
TypeVar,
Union,
@ -49,9 +51,14 @@ from torch._sources import fake_range, get_source_lines_and_file, parse_def
from torch.futures import Future
class HasGetattr(Protocol):
def __getattr__(self, key: str) -> Any: ...
_P = ParamSpec("_P")
_R = TypeVar("_R")
BuiltinUnionType: Union[type, tuple[type, ...]] = types.UnionType
LockType: type
@ -202,7 +209,7 @@ class SourceLoader:
loader = SourceLoader()
def createResolutionCallbackFromEnv(lookup_base):
def createResolutionCallbackFromEnv(lookup_base: HasGetattr) -> Callable[[str], Any]:
"""
Creates a resolution callback that will look up qualified names in an
environment, starting with `lookup_base` for the base of any qualified
@ -212,7 +219,7 @@ def createResolutionCallbackFromEnv(lookup_base):
createResolutionCallbackFrom* functions.
"""
def lookupInModule(qualified_name, module):
def lookupInModule(qualified_name: str, module: Any) -> Any:
if "." in qualified_name:
base, remaining_pieces = qualified_name.split(".", maxsplit=1)
module_value = getattr(module, base)
@ -220,7 +227,7 @@ def createResolutionCallbackFromEnv(lookup_base):
else:
return getattr(module, qualified_name)
def parseNestedExpr(expr, module) -> tuple[Any, int]:
def parseNestedExpr(expr: str, module: Any) -> tuple[Any, int]:
i = 0
while i < len(expr) and expr[i] not in (",", "[", "]"):
i += 1
@ -248,7 +255,7 @@ def createResolutionCallbackFromEnv(lookup_base):
else:
return base[parts[0]], i + 1
def parseExpr(expr, module):
def parseExpr(expr: str, module: Any) -> Any:
try:
value, len_parsed = parseNestedExpr(expr, module)
assert len_parsed == len(expr), (
@ -267,7 +274,7 @@ def createResolutionCallbackFromEnv(lookup_base):
return lambda expr: parseExpr(expr, lookup_base)
def createResolutionCallbackFromFrame(frames_up: int = 0):
def createResolutionCallbackFromFrame(frames_up: int = 0) -> Callable[[str], Any]:
"""
Creates a function which, given a string variable name,
returns the value of the variable in the scope of the caller of
@ -308,7 +315,7 @@ def createResolutionCallbackFromFrame(frames_up: int = 0):
f_globals = frame.f_globals
class env:
def __getattr__(self, key):
def __getattr__(self, key: str) -> Any:
if key in f_locals:
return f_locals[key]
elif key in f_globals:
@ -377,7 +384,7 @@ def get_closure(fn):
# This could be worked around by manually adding it to `global()` dictionary.
def createResolutionCallbackFromClosure(fn):
def createResolutionCallbackFromClosure(fn) -> Callable[[str], Any]:
"""
Create a resolutionCallback by introspecting the function instead of
looking up the stack for the enclosing scope
@ -387,7 +394,7 @@ def createResolutionCallbackFromClosure(fn):
class closure_lookup:
# This is a class since `closure` is a dict and it's easier in
# `env_helper` if everything just works with `getattr` calls
def __getattr__(self, key):
def __getattr__(self, key: str) -> Any:
if key in closure:
return closure[key]
elif hasattr(typing, key):
@ -560,7 +567,7 @@ def get_type_hint_captures(fn):
return annotation_to_type
def createResolutionCallbackForClassMethods(cls):
def createResolutionCallbackForClassMethods(cls: Class) -> Callable[[str], Any]:
"""
This looks at all the methods defined in a class and pulls their closed-over
variables into a dictionary and uses that to resolve variables.
@ -582,7 +589,7 @@ def createResolutionCallbackForClassMethods(cls):
captures.update(get_closure(fn))
captures.update(get_type_hint_captures(fn))
def lookup_in_class(key):
def lookup_in_class(key: str) -> Any:
if key in captures:
return captures[key]
else:

View File

@ -2007,7 +2007,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
optim.load_state_dict(result)
return result
def register_comm_hook(self, state: object, hook: callable):
def register_comm_hook(self, state: object, hook: callable) -> None:
"""Register a communication hook.
This is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregates

View File

@ -1140,7 +1140,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
return loss
def join_hook(self, **kwargs):
def join_hook(self, **_kwargs: Any):
r"""
Return the ZeRO join hook.

View File

@ -15,8 +15,8 @@ import inspect
import pickle
import warnings
from collections.abc import Callable
from typing import Any, Union
from typing_extensions import deprecated
from typing import Any, Mapping, Sequence, TypeVar, Union
from typing_extensions import deprecated, Self
import torch
import torch._jit_internal as _jit_internal
@ -56,6 +56,9 @@ from torch.utils import set_module
from ._serialization import validate_map_location
_T = TypeVar("_T")
type_trace_db = JitTypeTraceStore() # DB to hold all call traces from MonkeyType
torch._C.ScriptMethod.graph_for = _script_method_graph_for # type: ignore[attr-defined]
@ -371,10 +374,10 @@ def script_method(fn):
class ConstMap:
def __init__(self, const_mapping):
def __init__(self, const_mapping: Mapping[str, Any]) -> None:
self.const_mapping = const_mapping
def __getattr__(self, attr):
def __getattr__(self, attr: str) -> Any:
return self.const_mapping[attr]
@ -461,7 +464,7 @@ if _enabled:
self.__dict__["_initializing"] = False
def __getattr__(self, attr):
def __getattr__(self, attr: str) -> Any:
if self.__dict__.get("_initializing"):
return super().__getattr__(attr) # type: ignore[misc]
@ -470,7 +473,7 @@ if _enabled:
return getattr(self._c, attr)
def __setattr__(self, attr, value):
def __setattr__(self, attr: str, value: Any) -> None:
if self.__dict__.get("_initializing"):
return super().__setattr__(attr, value)
@ -481,7 +484,9 @@ if _enabled:
# Delegate calls to magic methods like __len__ to the C++ module backing the
# RecursiveScriptClass.
def forward_magic_method(self, method_name, *args, **kwargs):
def forward_magic_method(
self, method_name: str, *args: Any, **kwargs: Any
) -> Any:
if not self._c._has_method(method_name):
raise TypeError
@ -491,7 +496,7 @@ if _enabled:
def __getstate__(self):
raise pickle.PickleError("ScriptClasses cannot be pickled")
def __iadd__(self, other):
def __iadd__(self, other: Self) -> Self:
if self._c._has_method("__iadd__"):
return self.forward_magic_method("__iadd__", other)
else:
@ -533,12 +538,12 @@ if _enabled:
forward: Callable[..., Any] = _CachedForward() # type: ignore[assignment]
def __getattr__(self, attr):
def __getattr__(self, attr: str) -> Any:
if "_actual_script_module" not in self.__dict__:
return super().__getattr__(attr)
return getattr(self._actual_script_module, attr)
def __setattr__(self, attr, value):
def __setattr__(self, attr: str, value: Any) -> None:
if "_actual_script_module" not in self.__dict__:
# Unwrap torch.jit.Attribute into a regular setattr + record
# the provided type in __annotations__.
@ -798,7 +803,7 @@ if _enabled:
def get_debug_state(self, *args, **kwargs):
return self._c.get_debug_state()
def extra_repr(self):
def extra_repr(self) -> str:
return f"original_name={self.original_name}"
def graph_for(self, *args, **kwargs):
@ -822,7 +827,7 @@ if _enabled:
rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=1)
self._c._define(self._concrete_type, src, rcb)
def __getattr__(self, attr):
def __getattr__(self, attr: str) -> Any:
if "_initializing" not in self.__dict__:
raise RuntimeError(
"ScriptModule has not been initialized, did you forget to call super's init?"
@ -846,7 +851,7 @@ if _enabled:
return super().__getattr__(attr)
def __setattr__(self, attr, value):
def __setattr__(self, attr: str, value: Any) -> None:
if self._initializing:
return super().__setattr__(attr, value)
@ -873,10 +878,10 @@ if _enabled:
# It's fairly trivial to save enough info to warn in this case.
return super().__setattr__(attr, value)
def __copy__(self):
def __copy__(self) -> Self:
return torch.jit._recursive.wrap_cpp_module(copy.copy(self._c))
def __deepcopy__(self, memo):
def __deepcopy__(self, memo: dict[int, Any] | None) -> Self:
return torch.jit._recursive.wrap_cpp_module(copy.deepcopy(self._c, memo))
# Python magic methods do method lookups on an object's class type, instead of looking up
@ -905,7 +910,7 @@ if _enabled:
# dir is defined by the base nn.Module, so instead of throwing if
# it is not overridden, we call into the nn.Module __dir__ method
def __dir__(self):
def __dir__(self) -> Sequence[str]:
self_method = self.__dir__
if (
self_method.__func__ # type: ignore[attr-defined]
@ -1251,12 +1256,12 @@ def _script_impl(
def script(
obj,
optimize=None,
_frames_up=0,
_rcb=None,
obj: Any,
optimize: None = None,
_frames_up: int = 0,
_rcb: Callable[[str], Any] | None = None,
example_inputs: Union[list[tuple], dict[Callable, list[tuple]], None] = None,
):
) -> Any:
r"""Script the function.
Scripting a function or ``nn.Module`` will inspect the source code, compile
@ -1555,7 +1560,7 @@ def _check_directly_compile_overloaded(obj):
)
def interface(obj):
def interface(obj: _T) -> _T:
r"""Decorate to annotate classes or modules of different types.
This decorator can be used to define an interface that can be used to annotate