Add None return type to init (#132335)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132335
Approved by: https://github.com/albanD
This commit is contained in:
Oguz Ulgen
2024-08-01 00:22:47 -07:00
committed by PyTorch MergeBot
parent 30d7f0b15a
commit 72d2dba992
130 changed files with 295 additions and 295 deletions

View File

@ -119,7 +119,7 @@ inner(torch.randn(20, 20, requires_grad=True) + 1)
backend_name = "relu_compile_error_TESTING_ONLY" backend_name = "relu_compile_error_TESTING_ONLY"
run_code = f"""\ run_code = f"""\
class CpuCudaModule(torch.nn.Module): class CpuCudaModule(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.m_x = torch.nn.Linear(20, 20).cuda() self.m_x = torch.nn.Linear(20, 20).cuda()
self.m_y = torch.nn.Linear(20, 20) self.m_y = torch.nn.Linear(20, 20)
@ -149,7 +149,7 @@ inner(torch.randn(20, 20).cuda(), torch.randn(20, 20))
res.minifier_module(), res.minifier_module(),
"""\ """\
class Repro(torch.nn.Module): class Repro(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.G__mod___m_x = Linear(in_features=20, out_features=20, bias=True).cuda() self.G__mod___m_x = Linear(in_features=20, out_features=20, bias=True).cuda()
self.G__mod___m_y = Linear(in_features=20, out_features=20, bias=True) self.G__mod___m_y = Linear(in_features=20, out_features=20, bias=True)
@ -204,7 +204,7 @@ inner(torch.randn(20, 20))
res.repro_module(), res.repro_module(),
"""\ """\
class Repro(torch.nn.Module): class Repro(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
def forward(self, x_19): def forward(self, x_19):

View File

@ -122,7 +122,7 @@ inner(torch.randn(20))
res.repro_module(), res.repro_module(),
"""\ """\
class Repro(torch.nn.Module): class Repro(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
def forward(self, arg0_1): def forward(self, arg0_1):
@ -138,7 +138,7 @@ class Repro(torch.nn.Module):
res.repro_module(), res.repro_module(),
"""\ """\
class Repro(torch.nn.Module): class Repro(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
def forward(self, arg0_1): def forward(self, arg0_1):

View File

@ -19,7 +19,7 @@ class _ClassNamespace(types.ModuleType):
class _Classes(types.ModuleType): class _Classes(types.ModuleType):
__file__ = "_classes.py" __file__ = "_classes.py"
def __init__(self): def __init__(self) -> None:
super().__init__("torch.classes") super().__init__("torch.classes")
def __getattr__(self, name): def __getattr__(self, name):

View File

@ -71,7 +71,7 @@ class PhiloxState:
trace time. trace time.
""" """
def __init__(self): def __init__(self) -> None:
self.reset() self.reset()
def reset(self): def reset(self):

View File

@ -247,7 +247,7 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
# This gives us the appropriately strided outputs here which will reflect runtime strides. # This gives us the appropriately strided outputs here which will reflect runtime strides.
class FakeifyFirstAOTInvocationGuard: class FakeifyFirstAOTInvocationGuard:
def __init__(self): def __init__(self) -> None:
self.tc = torch._guards.TracingContext.try_get() self.tc = torch._guards.TracingContext.try_get()
assert self.tc assert self.tc
torch._guards.TracingContext.try_get().fakify_first_call = True torch._guards.TracingContext.try_get().fakify_first_call = True

View File

@ -5,7 +5,7 @@ from .utils import ExactWeakKeyDictionary
class CodeContextDict: class CodeContextDict:
def __init__(self): def __init__(self) -> None:
self.code_context = ExactWeakKeyDictionary() self.code_context = ExactWeakKeyDictionary()
def has_context(self, code: types.CodeType): def has_context(self, code: types.CodeType):

View File

@ -170,7 +170,7 @@ class NNModuleToString:
""" """
from torch.nn import * from torch.nn import *
class Repro(torch.nn.Module): class Repro(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
""" """
) )
@ -491,7 +491,7 @@ _is_leaf_or_default = _mk_defaulter(False)
class NopInputReader: class NopInputReader:
def __init__(self): def __init__(self) -> None:
self.total = 0 self.total = 0
def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None):

View File

@ -497,7 +497,7 @@ class _TorchDynamoContext:
wrapper function. wrapper function.
>> class CallableClass: >> class CallableClass:
>> def __init__(self): >> def __init__(self) -> None:
>> super().__init__() >> super().__init__()
>> self.relu = torch.nn.ReLU() >> self.relu = torch.nn.ReLU()
>> >>
@ -578,7 +578,7 @@ class OptimizeContext(_TorchDynamoContext):
class RunOnlyContext(_TorchDynamoContext): class RunOnlyContext(_TorchDynamoContext):
def __init__(self): def __init__(self) -> None:
# cudagraph trees relies on generation increment # cudagraph trees relies on generation increment
def on_enter(): def on_enter():
torch._dynamo.mutation_guard.GenerationTracker.generation += 1 torch._dynamo.mutation_guard.GenerationTracker.generation += 1
@ -590,7 +590,7 @@ class RunOnlyContext(_TorchDynamoContext):
class DisableContext(_TorchDynamoContext): class DisableContext(_TorchDynamoContext):
def __init__(self): def __init__(self) -> None:
super().__init__(callback=None) super().__init__(callback=None)
def __call__(self, fn): def __call__(self, fn):

View File

@ -74,7 +74,7 @@ class InvalidBackend(TorchDynamoException):
class ResetRequired(TorchDynamoException): class ResetRequired(TorchDynamoException):
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
textwrap.dedent( textwrap.dedent(
""" """

View File

@ -92,7 +92,7 @@ def print_missing(stack):
class Profiler: class Profiler:
unique_graphs = 0 unique_graphs = 0
def __init__(self): def __init__(self) -> None:
self.prof = torch.profiler.profile( self.prof = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU], activities=[torch.profiler.ProfilerActivity.CPU],
with_stack=should_print_missing(), with_stack=should_print_missing(),

View File

@ -70,7 +70,7 @@ class MutableLocal(MutableLocalBase):
state. state.
""" """
def __init__(self): def __init__(self) -> None:
super().__init__(MutableLocalSource.Local) super().__init__(MutableLocalSource.Local)
def __hash__(self): def __hash__(self):

View File

@ -274,7 +274,7 @@ class GraphArg:
class BackwardStateGraphArg(GraphArg): class BackwardStateGraphArg(GraphArg):
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
source=None, source=None,
_example=BackwardState(), _example=BackwardState(),
@ -2646,7 +2646,7 @@ class SourcelessBuilder:
if/else type->VariableTracker trees that were cropping up all over dynamo. if/else type->VariableTracker trees that were cropping up all over dynamo.
""" """
def __init__(self): def __init__(self) -> None:
raise AssertionError("Use SourcelessBuilder.create()") raise AssertionError("Use SourcelessBuilder.create()")
@staticmethod @staticmethod

View File

@ -10,7 +10,7 @@ class ClassMethod(torch.nn.Module):
def method(cls, x): def method(cls, x):
return x + 1 return x + 1
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.linear = torch.nn.Linear(4, 2) self.linear = torch.nn.Linear(4, 2)

View File

@ -26,7 +26,7 @@ class CondBranchClassMethod(torch.nn.Module):
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
""" """
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.subm = MySubModule() self.subm = MySubModule()

View File

@ -8,7 +8,7 @@ class ModelAttrMutation(torch.nn.Module):
Attribute mutation is not supported. Attribute mutation is not supported.
""" """
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.attr_list = [torch.randn(3, 2), torch.randn(3, 2)] self.attr_list = [torch.randn(3, 2), torch.randn(3, 2)]

View File

@ -11,7 +11,7 @@ class ScalarOutput(torch.nn.Module):
Returning scalar values from the graph is supported, in addition to Tensor Returning scalar values from the graph is supported, in addition to Tensor
outputs. Symbolic shapes are captured and rank is specialized. outputs. Symbolic shapes are captured and rank is specialized.
""" """
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
def forward(self, x): def forward(self, x):

View File

@ -11,7 +11,7 @@ class SpecializedAttribute(torch.nn.Module):
Model attributes are specialized. Model attributes are specialized.
""" """
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.a = "moo" self.a = "moo"
self.b = 4 self.b = 4

View File

@ -24,7 +24,7 @@ class ConstantAttrMap(collections.abc.MutableMapping):
if that's the case). if that's the case).
""" """
def __init__(self): def __init__(self) -> None:
# Underlying dict that we use to implement this mapping. # Underlying dict that we use to implement this mapping.
self._constant_attrs: Dict[ self._constant_attrs: Dict[
Union[int, torch.Tensor, FakeScriptObject], List[Any] Union[int, torch.Tensor, FakeScriptObject], List[Any]

View File

@ -1413,7 +1413,7 @@ class GraphModuleDeserializer(metaclass=Final):
constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]] constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]]
example_inputs: Optional[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]] example_inputs: Optional[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]]
def __init__(self): def __init__(self) -> None:
self.serialized_name_to_node: Dict[str, torch.fx.Node] = {} self.serialized_name_to_node: Dict[str, torch.fx.Node] = {}
self.serialized_name_to_meta: Dict[str, MetaType] = {} self.serialized_name_to_meta: Dict[str, MetaType] = {}
self.graph = torch.fx.Graph() self.graph = torch.fx.Graph()

View File

@ -602,7 +602,7 @@ class SubclassMeta:
# Optional field because we don't compute for inference graphs # Optional field because we don't compute for inference graphs
grad_input_metas: Optional[List[Union[int, SubclassCreationMeta]]] = None grad_input_metas: Optional[List[Union[int, SubclassCreationMeta]]] = None
def __init__(self): def __init__(self) -> None:
# The fields in this class get set after its construction. # The fields in this class get set after its construction.
pass pass

View File

@ -878,7 +878,7 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
) )
class AOTModule(nn.Module): class AOTModule(nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.orig_module = mod self.orig_module = mod

View File

@ -30,7 +30,7 @@ from torch.autograd.forward_ad import _set_fwd_grad_enabled
# We do this by using creating a custom HigherOrderOperator that only functorch # We do this by using creating a custom HigherOrderOperator that only functorch
# dispatches specially. # dispatches specially.
class CustomFunctionHigherOrderOperator(HigherOrderOperator): class CustomFunctionHigherOrderOperator(HigherOrderOperator):
def __init__(self): def __init__(self) -> None:
super().__init__("custom_function_call") super().__init__("custom_function_call")
def __call__(self, autograd_function, *args, **kwargs): def __call__(self, autograd_function, *args, **kwargs):
@ -713,7 +713,7 @@ def autograd_function_forward_rewritten(original_forward, original_setup_context
class AutogradFunctionApply(HigherOrderOperator): class AutogradFunctionApply(HigherOrderOperator):
def __init__(self): def __init__(self) -> None:
super().__init__("autograd_function_apply") super().__init__("autograd_function_apply")
def __call__(self, fwd, bwd, *fwd_args, **fwd_kwargs): def __call__(self, fwd, bwd, *fwd_args, **fwd_kwargs):

View File

@ -427,7 +427,7 @@ class ModuleContextCheckpointState:
class ModuleContext(Checkpointable[ModuleContextCheckpointState]): class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
def __init__(self): def __init__(self) -> None:
self.nn_modules: Dict[str, Any] = {} self.nn_modules: Dict[str, Any] = {}
def copy_graphstate(self): def copy_graphstate(self):
@ -476,7 +476,7 @@ class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
"autocast_cache_enabled", "autocast_cache_enabled",
} }
def __init__(self): def __init__(self) -> None:
self.global_state: Dict[str, Tuple[Callable, ...]] = {} self.global_state: Dict[str, Tuple[Callable, ...]] = {}
def copy_graphstate(self): def copy_graphstate(self):
@ -544,7 +544,7 @@ class GuardsSet:
class GuardsContext(Checkpointable[GuardsCheckpointState]): class GuardsContext(Checkpointable[GuardsCheckpointState]):
def __init__(self): def __init__(self) -> None:
self.dynamo_guards: GuardsSet = GuardsSet() self.dynamo_guards: GuardsSet = GuardsSet()
self.aotautograd_guards: List[GuardEnvExpr] = [] self.aotautograd_guards: List[GuardEnvExpr] = []

View File

@ -54,7 +54,7 @@ class AutoFunctionalized(HigherOrderOperator):
underscore is to prevent collisions with kwarg names in **kwargs. underscore is to prevent collisions with kwarg names in **kwargs.
""" """
def __init__(self): def __init__(self) -> None:
super().__init__("auto_functionalized") super().__init__("auto_functionalized")
def __call__( def __call__(

View File

@ -55,7 +55,7 @@ class WithEffects(HigherOrderOperator):
per "effect type", which are enumerated in the _EffectType enum. per "effect type", which are enumerated in the _EffectType enum.
""" """
def __init__(self): def __init__(self) -> None:
super().__init__("with_effects") super().__init__("with_effects")
def __call__( def __call__(

View File

@ -38,7 +38,7 @@ class TransformGetItemToIndex(TorchFunctionMode):
class FlexAttentionHOP(HigherOrderOperator): class FlexAttentionHOP(HigherOrderOperator):
def __init__(self): def __init__(self) -> None:
super().__init__("flex_attention") super().__init__("flex_attention")
def __call__( def __call__(
@ -74,7 +74,7 @@ flex_attention.__module__ = "torch.ops.higher_order"
class FlexAttentionBackwardHOP(HigherOrderOperator): class FlexAttentionBackwardHOP(HigherOrderOperator):
def __init__(self): def __init__(self) -> None:
super().__init__("flex_attention_backward") super().__init__("flex_attention_backward")
def __call__( def __call__(

View File

@ -45,7 +45,7 @@ class OutDtypeOperator(HigherOrderOperator):
3. Cast the output to `out_dtype` 3. Cast the output to `out_dtype`
""" """
def __init__(self): def __init__(self) -> None:
super().__init__("out_dtype") super().__init__("out_dtype")
# TODO(ydwu4): Subclassing HigherOrderOperator causes __module__ to # TODO(ydwu4): Subclassing HigherOrderOperator causes __module__ to
# become different (torch._higher_order_ops.out_dtype) which will result # become different (torch._higher_order_ops.out_dtype) which will result

View File

@ -519,7 +519,7 @@ def identify_mutated_tensors(kernel, kwargs):
# Used for wrapping a Triton Kernel # Used for wrapping a Triton Kernel
class TritonKernelWrapperMutation(HigherOrderOperator): class TritonKernelWrapperMutation(HigherOrderOperator):
def __init__(self): def __init__(self) -> None:
super().__init__("triton_kernel_wrapper_mutation") super().__init__("triton_kernel_wrapper_mutation")
@ -528,7 +528,7 @@ triton_kernel_wrapper_mutation = TritonKernelWrapperMutation()
# Used for wrapping a Triton Kernel in a functional manner # Used for wrapping a Triton Kernel in a functional manner
class TritonKernelWrapperFunctional(HigherOrderOperator): class TritonKernelWrapperFunctional(HigherOrderOperator):
def __init__(self): def __init__(self) -> None:
super().__init__("triton_kernel_wrapper_functional") super().__init__("triton_kernel_wrapper_functional")

View File

@ -18,7 +18,7 @@ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_ten
class WhileLoopOp(HigherOrderOperator): class WhileLoopOp(HigherOrderOperator):
def __init__(self): def __init__(self) -> None:
super().__init__("while_loop") super().__init__("while_loop")
def __call__( def __call__(

View File

@ -15,7 +15,7 @@ uid = itertools.count(1)
# Used for testing the HigherOrderOperator mechanism # Used for testing the HigherOrderOperator mechanism
class Wrap(HigherOrderOperator): class Wrap(HigherOrderOperator):
def __init__(self): def __init__(self) -> None:
super().__init__("wrap") super().__init__("wrap")
def __call__(self, func, *args, **kwargs): def __call__(self, func, *args, **kwargs):
@ -36,7 +36,7 @@ wrap = Wrap()
class WrapWithSetGradEnabled(HigherOrderOperator): class WrapWithSetGradEnabled(HigherOrderOperator):
def __init__(self): def __init__(self) -> None:
super().__init__("wrap_with_set_grad_enabled") super().__init__("wrap_with_set_grad_enabled")
def __call__(self, enable_grad, wrapped_func, *args, **kwargs): def __call__(self, enable_grad, wrapped_func, *args, **kwargs):
@ -74,7 +74,7 @@ class WrapActivationCheckpoint(HigherOrderOperator):
partitioners. See TagActivationCheckpoint for more information. partitioners. See TagActivationCheckpoint for more information.
""" """
def __init__(self): def __init__(self) -> None:
super().__init__("wrap_activation_checkpoint") super().__init__("wrap_activation_checkpoint")
def __call__(self, function, *args, **kwargs): def __call__(self, function, *args, **kwargs):
@ -113,7 +113,7 @@ class TagActivationCheckpoint(HigherOrderOperator):
the forward and recomputed forward in backward. the forward and recomputed forward in backward.
""" """
def __init__(self): def __init__(self) -> None:
super().__init__("tag_activation_checkpoint") super().__init__("tag_activation_checkpoint")
@staticmethod @staticmethod

View File

@ -1560,7 +1560,7 @@ class CSE:
class CodeGen: class CodeGen:
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.exit_stack = contextlib.ExitStack() self.exit_stack = contextlib.ExitStack()

View File

@ -29,7 +29,7 @@ class CppWrapperCuda(CppWrapperCpu):
Generates cpp wrapper for running on GPU and calls CUDA kernels Generates cpp wrapper for running on GPU and calls CUDA kernels
""" """
def __init__(self): def __init__(self) -> None:
self.device = "cuda" self.device = "cuda"
super().__init__() super().__init__()
self.grid_id = count() self.grid_id = count()

View File

@ -1113,7 +1113,7 @@ class HelperFunctions:
_templates_seen: Dict[str, str] # Template code to function name _templates_seen: Dict[str, str] # Template code to function name
finalized_helpers: List[str] finalized_helpers: List[str]
def __init__(self): def __init__(self) -> None:
self._templates_seen = {} self._templates_seen = {}
self.finalized_helpers = [] self.finalized_helpers = []

View File

@ -589,7 +589,7 @@ def canonicalization_prefix():
class FreeUnbackedSymbolsOpsHandler: class FreeUnbackedSymbolsOpsHandler:
symbols: OrderedSet[sympy.Symbol] symbols: OrderedSet[sympy.Symbol]
def __init__(self): def __init__(self) -> None:
self.symbols = OrderedSet() self.symbols = OrderedSet()
def __getattr__(self, name: str) -> Callable[..., Any]: def __getattr__(self, name: str) -> Callable[..., Any]:

View File

@ -65,7 +65,7 @@ class SubgraphLoweringException(RuntimeError):
class InvalidCxxCompiler(RuntimeError): class InvalidCxxCompiler(RuntimeError):
def __init__(self): def __init__(self) -> None:
from . import config from . import config
super().__init__( super().__init__(

View File

@ -79,7 +79,7 @@ class NumpyCompatNormalization:
inverse_mapping: Dict[str, str] inverse_mapping: Dict[str, str]
cache: Dict["torch.fx.graph.Target", Set[str]] cache: Dict["torch.fx.graph.Target", Set[str]]
def __init__(self): def __init__(self) -> None:
self.cache = {} # callable -> tuple of replaceable args e.g. ["axis"] self.cache = {} # callable -> tuple of replaceable args e.g. ["axis"]
self.inverse_mapping = {} self.inverse_mapping = {}
for actual_kwarg, numpy_kwargs in self.numpy_compat.items(): for actual_kwarg, numpy_kwargs in self.numpy_compat.items():

View File

@ -1207,7 +1207,7 @@ if torch._C._has_mkldnn:
Combine packed weight nodes with the same inputs to reduce memory usage. Combine packed weight nodes with the same inputs to reduce memory usage.
for example: for example:
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.linear = nn.Linear(32, 32, bias=True) self.linear = nn.Linear(32, 32, bias=True)

View File

@ -99,7 +99,7 @@ class CachedMetricsHelper:
apply on a cache hit. apply on a cache hit.
""" """
def __init__(self): def __init__(self) -> None:
self.cached_metrics = {} self.cached_metrics = {}
for metric in get_metric_fields(): for metric in get_metric_fields():
self.cached_metrics[metric] = globals()[metric] self.cached_metrics[metric] = globals()[metric]

View File

@ -940,7 +940,7 @@ class IndentedBuffer:
class FakeIndentedBuffer(IndentedBuffer): class FakeIndentedBuffer(IndentedBuffer):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
def __getattribute__(self, name): def __getattribute__(self, name):
@ -1219,7 +1219,7 @@ class DebugDirManager:
counter = itertools.count(0) counter = itertools.count(0)
prev_debug_name: str prev_debug_name: str
def __init__(self): def __init__(self) -> None:
self.id = next(DebugDirManager.counter) self.id = next(DebugDirManager.counter)
def __enter__(self): def __enter__(self):
@ -1268,7 +1268,7 @@ def get_code(fn, *args, **kwargs):
class DummyModule: class DummyModule:
"""This is empty to replace the generated triton module""" """This is empty to replace the generated triton module"""
def __init__(self): def __init__(self) -> None:
pass pass
def call(self, *args, **kwargs): def call(self, *args, **kwargs):

View File

@ -7,7 +7,7 @@ from torch._lazy.device_context import get_device_context
class ClosureHandler: class ClosureHandler:
def __init__(self): def __init__(self) -> None:
pass pass
def run(self, closure): def run(self, closure):

View File

@ -42,7 +42,7 @@ class HasStaticMethodFromReal(Protocol):
class FakeClassRegistry: class FakeClassRegistry:
def __init__(self): def __init__(self) -> None:
self._registered_class: Dict[str, Any] = {} self._registered_class: Dict[str, Any] = {}
def has_impl(self, full_qualname: str) -> bool: def has_impl(self, full_qualname: str) -> bool:

View File

@ -70,7 +70,7 @@ class PythonDispatcher:
] ]
supported_keys = runtime_keys + alias_keys supported_keys = runtime_keys + alias_keys
def __init__(self): def __init__(self) -> None:
C._dispatch_check_invariants(self.name) # type: ignore[attr-defined] C._dispatch_check_invariants(self.name) # type: ignore[attr-defined]
self.ref = C._dispatch_library("FRAGMENT", self.namespace, "") self.ref = C._dispatch_library("FRAGMENT", self.namespace, "")
self.ref.def_("foo(Tensor x) -> Tensor") self.ref.def_("foo(Tensor x) -> Tensor")

View File

@ -60,7 +60,7 @@ def clone_inputs(args):
class SchemaCheckMode(TorchDispatchMode): class SchemaCheckMode(TorchDispatchMode):
def __init__(self): def __init__(self) -> None:
# Information recorded for testing purposes. For example: # Information recorded for testing purposes. For example:
# - incorrect schemas # - incorrect schemas
# - overly conservative schemas # - overly conservative schemas

View File

@ -36,7 +36,7 @@ class FloatFunctional(torch.nn.Module):
- mul_scalar - mul_scalar
""" """
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.activation_post_process = torch.nn.Identity() self.activation_post_process = torch.nn.Identity()
@ -190,7 +190,7 @@ class QFunctional(torch.nn.Module):
- mul_scalar - mul_scalar
""" """
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.scale = 1.0 self.scale = 1.0
self.zero_point = 0 self.zero_point = 0

View File

@ -72,7 +72,7 @@ class QConfigMultiMapping:
""" """
def __init__(self): def __init__(self) -> None:
# initialize this with 1 QConfigMapping to avoid corner cases # initialize this with 1 QConfigMapping to avoid corner cases
self.qconfig_mappings_list: List[QConfigMapping] = [QConfigMapping()] self.qconfig_mappings_list: List[QConfigMapping] = [QConfigMapping()]

View File

@ -99,7 +99,7 @@ from torch.ao.pruning._experimental.pruner import SaliencyPruner
# Define model # Define model
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.seq = nn.Sequential( self.seq = nn.Sequential(
nn.Linear(700, 500, bias=True), nn.Linear(700, 500, bias=True),

View File

@ -85,7 +85,7 @@ class FakeQuantizeBase(ABC, Module):
fake_quant_enabled: torch.Tensor fake_quant_enabled: torch.Tensor
observer_enabled: torch.Tensor observer_enabled: torch.Tensor
def __init__(self): def __init__(self) -> None:
"""Set fake_quant_enabled and observer_enabled.""" """Set fake_quant_enabled and observer_enabled."""
super().__init__() super().__init__()
# fake_quant_enabled and observer_enabled are buffers to support their # fake_quant_enabled and observer_enabled are buffers to support their

View File

@ -70,7 +70,7 @@ In the following, Ill first have a detailed description for each step, and th
``` ```
class LinearReLUModule(torch.nn.Module): class LinearReLUModule(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.linear = torch.nn.Linear(5, 10).float() self.linear = torch.nn.Linear(5, 10).float()
self.relu = torch.nn.ReLU() self.relu = torch.nn.ReLU()

View File

@ -137,7 +137,7 @@ class DetectorBase(ABC):
- Should return a str-based report and dict info in Tuple[str,Dict] format - Should return a str-based report and dict info in Tuple[str,Dict] format
""" """
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.detector_config_info = None self.detector_config_info = None

View File

@ -63,7 +63,7 @@ class PrepareCustomConfig:
.set_preserved_attributes(["attr1", "attr2"]) .set_preserved_attributes(["attr1", "attr2"])
""" """
def __init__(self): def __init__(self) -> None:
self.standalone_module_names: Dict[str, StandaloneModuleConfigEntry] = {} self.standalone_module_names: Dict[str, StandaloneModuleConfigEntry] = {}
self.standalone_module_classes: Dict[Type, StandaloneModuleConfigEntry] = {} self.standalone_module_classes: Dict[Type, StandaloneModuleConfigEntry] = {}
self.float_to_observed_mapping: Dict[QuantType, Dict[Type, Type]] = {} self.float_to_observed_mapping: Dict[QuantType, Dict[Type, Type]] = {}
@ -382,7 +382,7 @@ class ConvertCustomConfig:
.set_preserved_attributes(["attr1", "attr2"]) .set_preserved_attributes(["attr1", "attr2"])
""" """
def __init__(self): def __init__(self) -> None:
self.observed_to_quantized_mapping: Dict[QuantType, Dict[Type, Type]] = {} self.observed_to_quantized_mapping: Dict[QuantType, Dict[Type, Type]] = {}
self.preserved_attributes: List[str] = [] self.preserved_attributes: List[str] = []
@ -477,7 +477,7 @@ class FuseCustomConfig:
fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"]) fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"])
""" """
def __init__(self): def __init__(self) -> None:
self.preserved_attributes: List[str] = [] self.preserved_attributes: List[str] = []
def __repr__(self): def __repr__(self):

View File

@ -1568,7 +1568,7 @@ class ReuseInputObserver(ObserverBase):
Note: this is only enabled in FX Graph Mode Quantization Note: this is only enabled in FX Graph Mode Quantization
""" """
def __init__(self): def __init__(self) -> None:
super().__init__(torch.quint8, is_dynamic=False) super().__init__(torch.quint8, is_dynamic=False)
def forward(self, x): def forward(self, x):

View File

@ -229,7 +229,7 @@ class QConfigMapping:
""" """
def __init__(self): def __init__(self) -> None:
# In increasing match priority: # In increasing match priority:
self.global_qconfig: QConfigAny = None self.global_qconfig: QConfigAny = None
self.object_type_qconfigs: OrderedDict[ self.object_type_qconfigs: OrderedDict[

View File

@ -289,7 +289,7 @@ def prepare_fx(
from torch.ao.quantization.quantize_fx import prepare_fx from torch.ao.quantization.quantize_fx import prepare_fx
class Submodule(torch.nn.Module): class Submodule(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.linear = torch.nn.Linear(5, 5) self.linear = torch.nn.Linear(5, 5)
def forward(self, x): def forward(self, x):
@ -297,7 +297,7 @@ def prepare_fx(
return x return x
class M(torch.nn.Module): class M(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.linear = torch.nn.Linear(5, 5) self.linear = torch.nn.Linear(5, 5)
self.sub = Submodule() self.sub = Submodule()
@ -427,7 +427,7 @@ def prepare_qat_fx(
from torch.ao.quantization.quantize_fx import prepare_qat_fx from torch.ao.quantization.quantize_fx import prepare_qat_fx
class Submodule(torch.nn.Module): class Submodule(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.linear = torch.nn.Linear(5, 5) self.linear = torch.nn.Linear(5, 5)
def forward(self, x): def forward(self, x):
@ -435,7 +435,7 @@ def prepare_qat_fx(
return x return x
class M(torch.nn.Module): class M(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.linear = torch.nn.Linear(5, 5) self.linear = torch.nn.Linear(5, 5)
self.sub = Submodule() self.sub = Submodule()

View File

@ -56,7 +56,7 @@ def prepare_pt2e(
) )
class M(torch.nn.Module): class M(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.linear = torch.nn.Linear(5, 10) self.linear = torch.nn.Linear(5, 10)
@ -129,7 +129,7 @@ def prepare_qat_pt2e(
) )
class M(torch.nn.Module): class M(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.linear = torch.nn.Linear(5, 10) self.linear = torch.nn.Linear(5, 10)

View File

@ -42,7 +42,7 @@ def get_embedding_operators_config() -> OperatorConfig:
class EmbeddingQuantizer(Quantizer): class EmbeddingQuantizer(Quantizer):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
@classmethod @classmethod

View File

@ -436,7 +436,7 @@ class X86InductorQuantizer(Quantizer):
supported_config_and_operators = _get_supported_config_and_operators() supported_config_and_operators = _get_supported_config_and_operators()
module_function_to_aten_operator_type = _map_module_function_to_aten_operator_type() module_function_to_aten_operator_type = _map_module_function_to_aten_operator_type()
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.global_config: Optional[QuantizationConfig] = None self.global_config: Optional[QuantizationConfig] = None
self.operator_type_qconfig: Dict[ self.operator_type_qconfig: Dict[

View File

@ -268,7 +268,7 @@ class XNNPACKQuantizer(Quantizer):
"linear", "linear",
] ]
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.global_config: Optional[QuantizationConfig] = None self.global_config: Optional[QuantizationConfig] = None
self.operator_type_config: Dict[ self.operator_type_config: Dict[

View File

@ -513,7 +513,7 @@ def _get_path_of_module(
Example:: Example::
>> class M(torch.nn.Module): >> class M(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
self.linear = torch.nn.Linear(5, 5) self.linear = torch.nn.Linear(5, 5)
def forward(self, x): def forward(self, x):
return self.linear(x) return self.linear(x)

View File

@ -645,7 +645,7 @@ class FunctionEvent(FormattedTimesMixin):
class FunctionEventAvg(FormattedTimesMixin): class FunctionEventAvg(FormattedTimesMixin):
"""Used to average stats over multiple FunctionEvent objects.""" """Used to average stats over multiple FunctionEvent objects."""
def __init__(self): def __init__(self) -> None:
self.key: Optional[str] = None self.key: Optional[str] = None
self.count: int = 0 self.count: int = 0
self.node_id: int = 0 self.node_id: int = 0

View File

@ -266,7 +266,7 @@ class _Launcher:
or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or \ or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or \
{expanduser('~')}/.local/lib/ so the LD_PRELOAD environment variable will not be set." {expanduser('~')}/.local/lib/ so the LD_PRELOAD environment variable will not be set."
def __init__(self): def __init__(self) -> None:
self.cpuinfo = _CPUinfo() self.cpuinfo = _CPUinfo()
def add_lib_preload(self, lib_type): def add_lib_preload(self, lib_type):

View File

@ -77,17 +77,17 @@ namespace jit {
* *
* So why does debug handle map to DebugInfoTuple = {source range and inlined * So why does debug handle map to DebugInfoTuple = {source range and inlined
* cs}? {debug_handle, source_range_tag, serialized_callstack} Take this * cs}? {debug_handle, source_range_tag, serialized_callstack} Take this
* example: class L(nn.Module): def __init__(self): * example: class L(nn.Module): def __init__(self) -> None:
* ... * ...
* def forward(self, x): * def forward(self, x):
* return x * 5 * return x * 5
* class M(nn.Module): * class M(nn.Module):
* def __init__(self): * def __init__(self) -> None:
* ... * ...
* def forward(self, x): * def forward(self, x):
* return x - 2 * return x - 2
* class N(nn.Module): * class N(nn.Module):
* def __init__(self): * def __init__(self) -> None:
* self.m = M() * self.m = M()
* def forward(self, x): * def forward(self, x):
* return self.m(x) + 3 * return self.m(x) + 3

View File

@ -328,7 +328,7 @@ For example:
``` ```
class M(torch.nn.Module): class M(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
self.a = torch.rand(2, 3) self.a = torch.rand(2, 3)
self.b = torch.nn.Linear(10, 10) self.b = torch.nn.Linear(10, 10)

View File

@ -37,7 +37,7 @@ When making changes to the operators, the first thing to identify is if it's BC/
1. Add a test module in `test/jit/fixtures_srcs/fixtures_src.py`. In `test/jit/fixtures_srcs/generate_models.py`, 1. Add a test module in `test/jit/fixtures_srcs/fixtures_src.py`. In `test/jit/fixtures_srcs/generate_models.py`,
``` ```
class TestVersionedLinspaceV7(torch.nn.Module): class TestVersionedLinspaceV7(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex]): def forward(self, a: Union[int, float, complex], b: Union[int, float, complex]):
@ -163,7 +163,7 @@ When making changes to the operators, the first thing to identify is if it's BC/
# Step 2. Write down how current module should look like # Step 2. Write down how current module should look like
class MyModuleFloat(torch.nn.Module): class MyModuleFloat(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
def forward(self, a, b: float): def forward(self, a, b: float):

View File

@ -25,7 +25,7 @@ namespace onnx {
// //
// clang-format off // clang-format off
// class M(torch.nn.Module): // class M(torch.nn.Module):
// def __init__(self): // def __init__(self) -> None:
// super().__init__() // super().__init__()
// self.lns = torch.nn.ModuleList([torch.nn.LayerNorm(3, eps = i) for i in range(2)]) // self.lns = torch.nn.ModuleList([torch.nn.LayerNorm(3, eps = i) for i in range(2)])
// self.celu1 = torch.nn.CELU(1.0) // self.celu1 = torch.nn.CELU(1.0)

View File

@ -17,7 +17,7 @@ torch._lazy.ts_backend.init()
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1)

View File

@ -135,7 +135,7 @@ Here's our model definition:
```python ```python
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1)

View File

@ -163,7 +163,7 @@ class TensorInfo:
class _TensorsAccessed: class _TensorsAccessed:
def __init__(self): def __init__(self) -> None:
self.accesses: Dict[DataPtr, TensorInfo] = {} self.accesses: Dict[DataPtr, TensorInfo] = {}
def ensure_tensor_exists(self, data_ptr: DataPtr) -> None: def ensure_tensor_exists(self, data_ptr: DataPtr) -> None:
@ -218,7 +218,7 @@ class _TensorsAccessed:
class StreamSynchronizations: class StreamSynchronizations:
def __init__(self): def __init__(self) -> None:
self.current_sync_states: Dict[StreamId, Dict[StreamId, SeqNum]] = {} self.current_sync_states: Dict[StreamId, Dict[StreamId, SeqNum]] = {}
self.recorded_sync_states: Dict[EventId, Dict[StreamId, SeqNum]] = {} self.recorded_sync_states: Dict[EventId, Dict[StreamId, SeqNum]] = {}
self.host_sync_state: Dict[StreamId, SeqNum] = {} self.host_sync_state: Dict[StreamId, SeqNum] = {}
@ -338,7 +338,7 @@ class EventHandler:
data race. data race.
""" """
def __init__(self): def __init__(self) -> None:
self.tensors_accessed = _TensorsAccessed() self.tensors_accessed = _TensorsAccessed()
self.syncs = StreamSynchronizations() self.syncs = StreamSynchronizations()
self.seq_num: SeqNum = 0 self.seq_num: SeqNum = 0
@ -478,7 +478,7 @@ def zip_arguments(
class ArgumentHandler: class ArgumentHandler:
def __init__(self): def __init__(self) -> None:
self.dataptrs_read: Set[DataPtr] = set() self.dataptrs_read: Set[DataPtr] = set()
self.dataptrs_written: Set[DataPtr] = set() self.dataptrs_written: Set[DataPtr] = set()
self.tensor_aliases: Dict[DataPtr, List[str]] = {} self.tensor_aliases: Dict[DataPtr, List[str]] = {}
@ -527,7 +527,7 @@ class ArgumentHandler:
class CUDASanitizerDispatchMode(TorchDispatchMode): class CUDASanitizerDispatchMode(TorchDispatchMode):
def __init__(self): def __init__(self) -> None:
self.event_handler = EventHandler() self.event_handler = EventHandler()
torch._C._activate_gpu_trace() torch._C._activate_gpu_trace()
gpu_trace.register_callback_for_event_creation( gpu_trace.register_callback_for_event_creation(
@ -596,7 +596,7 @@ class CUDASanitizer:
This approach was deemed more elegant than using the atexit module. This approach was deemed more elegant than using the atexit module.
""" """
def __init__(self): def __init__(self) -> None:
self.dispatch = CUDASanitizerDispatchMode() self.dispatch = CUDASanitizerDispatchMode()
self.enabled = False self.enabled = False

View File

@ -49,7 +49,7 @@ def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
>>> import torch.nn as nn >>> import torch.nn as nn
>>> >>>
>>> class MyModel(nn.Module): >>> class MyModel(nn.Module):
>>> def __init__(self): >>> def __init__(self) -> None:
>>> super().__init__() >>> super().__init__()
>>> self.l1 = nn.Linear(10, 10) >>> self.l1 = nn.Linear(10, 10)
>>> self.l2 = nn.Linear(10, 10) >>> self.l2 = nn.Linear(10, 10)

View File

@ -47,7 +47,7 @@ def contract(state_cls: Type[_State] = _State):
>>> import torch.nn as nn >>> import torch.nn as nn
>>> >>>
>>> class MyModel(nn.Module): >>> class MyModel(nn.Module):
>>> def __init__(self): >>> def __init__(self) -> None:
>>> super().__init__() >>> super().__init__()
>>> self.l1 = nn.Linear(10, 10) >>> self.l1 = nn.Linear(10, 10)
>>> self.l2 = nn.Linear(10, 10) >>> self.l2 = nn.Linear(10, 10)

View File

@ -43,7 +43,7 @@ logger = logging.getLogger("torch.distributed._composable.fsdp")
class FSDPStateContext: class FSDPStateContext:
"""This has state shared across FSDP states.""" """This has state shared across FSDP states."""
def __init__(self): def __init__(self) -> None:
# All FSDP states in the root state's module tree # All FSDP states in the root state's module tree
self.all_states: List[FSDPState] = [] self.all_states: List[FSDPState] = []
# Iteration's forward root runs the once-per-forward logic; this root # Iteration's forward root runs the once-per-forward logic; this root
@ -71,7 +71,7 @@ def disable_if_config_true(func):
class FSDPState(_State): class FSDPState(_State):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self._fsdp_param_group: Optional[FSDPParamGroup] = None self._fsdp_param_group: Optional[FSDPParamGroup] = None
self._is_root: Optional[bool] = None # root set during lazy init self._is_root: Optional[bool] = None # root set during lazy init

View File

@ -38,7 +38,7 @@ class ShardingPlan:
>>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d) >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
>>> class MyModule(nn.Module): >>> class MyModule(nn.Module):
>>> def __init__(self): >>> def __init__(self) -> None:
>>> super().__init__() >>> super().__init__()
>>> self.fc1 = nn.Linear() >>> self.fc1 = nn.Linear()
>>> self.gelu = nn.GELU() >>> self.gelu = nn.GELU()

View File

@ -117,7 +117,7 @@ import torch.nn as nn
from torch.distributed._tensor import Shard, distribute_tensor, distribute_module, init_device_mesh from torch.distributed._tensor import Shard, distribute_tensor, distribute_module, init_device_mesh
class MyModule(nn.Module): class MyModule(nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.fc1 = nn.Linear(8, 8) self.fc1 = nn.Linear(8, 8)
self.fc2 = nn.Linear(8, 8) self.fc2 = nn.Linear(8, 8)

View File

@ -25,7 +25,7 @@ from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_modul
class SimpleMLP(torch.nn.Module): class SimpleMLP(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.net1 = torch.nn.Linear(5, 128) self.net1 = torch.nn.Linear(5, 128)
self.relu = torch.nn.ReLU() self.relu = torch.nn.ReLU()

View File

@ -55,7 +55,7 @@ class Joinable(ABC):
""" """
@abstractmethod @abstractmethod
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self._join_config = _JoinConfig.construct_disabled_join_config() self._join_config = _JoinConfig.construct_disabled_join_config()

View File

@ -31,7 +31,7 @@ class InjectedException(Exception):
class Model(torch.nn.Module): class Model(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.net1 = nn.Linear(8, 32) self.net1 = nn.Linear(8, 32)
self.net2 = nn.Linear(32, 128) self.net2 = nn.Linear(32, 128)

View File

@ -22,7 +22,7 @@ CHECKPOINT_DIR = f"~/{os.environ['LOGNAME']}/checkpoint"
class Model(torch.nn.Module): class Model(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
torch.manual_seed(0) torch.manual_seed(0)
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU()) self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())

View File

@ -434,7 +434,7 @@ class _reduce_op:
:class:`~torch.distributed.ReduceOp` is recommended to use instead. :class:`~torch.distributed.ReduceOp` is recommended to use instead.
""" """
def __init__(self): def __init__(self) -> None:
# __members__ is a dict storing key-value pairs for enum classes # __members__ is a dict storing key-value pairs for enum classes
for k, v in ReduceOp.RedOpType.__members__.items(): for k, v in ReduceOp.RedOpType.__members__.items():
setattr(self, k, v) setattr(self, k, v)
@ -568,7 +568,7 @@ class _World:
of c10d and is subject to change.. of c10d and is subject to change..
""" """
def __init__(self): def __init__(self) -> None:
self._default_pg = None self._default_pg = None
self._pg_coalesce_state: Dict[ProcessGroup, List[_CollOp]] = {} self._pg_coalesce_state: Dict[ProcessGroup, List[_CollOp]] = {}
self._pg_default_device: Dict[ProcessGroup, torch.device] = {} self._pg_default_device: Dict[ProcessGroup, torch.device] = {}
@ -2194,7 +2194,7 @@ class _IllegalWork(Work):
class _CoalescingManager: class _CoalescingManager:
def __init__(self): def __init__(self) -> None:
self.works: List[Work] = [] self.works: List[Work] = []
def append(self, work: Work): def append(self, work: Work):

View File

@ -106,7 +106,7 @@ class _FSDPDeviceHandle:
class _UninitializedDeviceHandle(_FSDPDeviceHandle): class _UninitializedDeviceHandle(_FSDPDeviceHandle):
def __init__(self): def __init__(self) -> None:
pass pass
def __getattribute__(self, __name: str) -> Any: def __getattribute__(self, __name: str) -> Any:

View File

@ -156,7 +156,7 @@ class _RemoteModule(nn.Module):
created outside of remote modules, rather than as submodules of any remote module (by calling ``add_module``). created outside of remote modules, rather than as submodules of any remote module (by calling ``add_module``).
Hybrid Example: Hybrid Example:
>>> class HybridModel(nn.Module): >>> class HybridModel(nn.Module):
>>> def __init__(self): >>> def __init__(self) -> None:
>>> nn.Module.__init__(self) >>> nn.Module.__init__(self)
>>> self.remote_embedding = RemoteModule(...) >>> self.remote_embedding = RemoteModule(...)
>>> self.local_linear = nn.Linear(...) >>> self.local_linear = nn.Linear(...)

View File

@ -248,7 +248,7 @@ class ExportGraphSignature:
e.g. If following module is exported:: e.g. If following module is exported::
class CustomModule(nn.Module): class CustomModule(nn.Module):
def __init__(self): def __init__(self) -> None:
super(CustomModule, self).__init__() super(CustomModule, self).__init__()
# Define a parameter # Define a parameter

View File

@ -45,7 +45,7 @@ FXs front-end makes use of the dynamic nature of Python to intercept call-sit
import torch import torch
class MyModule(torch.nn.Module): class MyModule(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.param = torch.nn.Parameter( self.param = torch.nn.Parameter(
torch.rand(3, 4)) torch.rand(3, 4))

View File

@ -9,7 +9,7 @@ demonstration of these components in action:
import torch import torch
# Simple module for demonstration # Simple module for demonstration
class MyModule(torch.nn.Module): class MyModule(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4)) self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5) self.linear = torch.nn.Linear(4, 5)

View File

@ -1012,7 +1012,7 @@ class _PatchedFnSetAttr(_PatchedFn):
class _Patcher: class _Patcher:
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.patches_made: List[_PatchedFn] = [] self.patches_made: List[_PatchedFn] = []
self.visited: Set[int] = set() self.visited: Set[int] = set()

View File

@ -63,7 +63,7 @@ class T(Constraint):
""" """
True True
""" """
def __init__(self): def __init__(self) -> None:
pass pass
def __eq__(self, other): def __eq__(self, other):
@ -76,7 +76,7 @@ class F(Constraint):
""" """
False False
""" """
def __init__(self): def __init__(self) -> None:
pass pass
def __eq__(self, other): def __eq__(self, other):

View File

@ -117,7 +117,7 @@ if HAS_PYDOT:
>>> # xdoctest: +REQUIRES(module:ubelt) >>> # xdoctest: +REQUIRES(module:ubelt)
>>> # define module >>> # define module
>>> class MyModule(torch.nn.Module): >>> class MyModule(torch.nn.Module):
>>> def __init__(self): >>> def __init__(self) -> None:
>>> super().__init__() >>> super().__init__()
>>> self.linear = torch.nn.Linear(4, 5) >>> self.linear = torch.nn.Linear(4, 5)
>>> def forward(self, x): >>> def forward(self, x):

View File

@ -83,7 +83,7 @@ def split_module(
from torch.fx.passes.split_module import split_module from torch.fx.passes.split_module import split_module
class MyModule(torch.nn.Module): class MyModule(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4)) self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5) self.linear = torch.nn.Linear(4, 5)

View File

@ -83,7 +83,7 @@ def split_by_tags(
Given the following module def: Given the following module def:
class SimpleModule(torch.nn.Module): class SimpleModule(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.linear1 = torch.nn.Linear(...) self.linear1 = torch.nn.Linear(...)
self.linear2 = torch.nn.Linear(...) self.linear2 = torch.nn.Linear(...)

View File

@ -38,7 +38,7 @@ class Scope:
return x.transpose(1, 2) return x.transpose(1, 2)
class M(torch.nn.Module): class M(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
self.sub = Sub() self.sub = Sub()
def forward(self, x): def forward(self, x):

View File

@ -118,7 +118,7 @@ def replace_pattern(
from torch.fx import symbolic_trace, subgraph_rewriter from torch.fx import symbolic_trace, subgraph_rewriter
class M(torch.nn.Module): class M(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
def forward(self, x, w1, w2): def forward(self, x, w1, w2):

View File

@ -38,7 +38,7 @@ class _DynType:
""" """
_DynType defines a type which stands for the absence of type information. _DynType defines a type which stands for the absence of type information.
""" """
def __init__(self): def __init__(self) -> None:
self.__name__ = '_DynType' self.__name__ = '_DynType'
def __eq__(self, other): def __eq__(self, other):

View File

@ -219,7 +219,7 @@ def isinstance(obj, target_type):
from typing import Any, Dict, List from typing import Any, Dict, List
class MyModule(torch.nn.Module): class MyModule(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
def forward(self, input: Any): # note the Any type def forward(self, input: Any): # note the Any type
@ -255,7 +255,7 @@ class strict_fusion:
""" """
def __init__(self): def __init__(self) -> None:
if not torch._jit_internal.is_scripting(): if not torch._jit_internal.is_scripting():
warnings.warn("Only works in script mode") warnings.warn("Only works in script mode")
pass pass

View File

@ -73,7 +73,7 @@ def fork(func, *args, **kwargs):
def forward(self, a: Tensor, b : int): def forward(self, a: Tensor, b : int):
return a + b return a + b
class Mod(torch.nn.Module): class Mod(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super(self).__init__() super(self).__init__()
self.mod = AddMod() self.mod = AddMod()
def forward(self, input): def forward(self, input):

View File

@ -39,7 +39,7 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
def fn(self): def fn(self):
return [] return []
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.x: List[int] = [] self.x: List[int] = []

View File

@ -65,7 +65,7 @@ def freeze(
.. testcode:: .. testcode::
import torch import torch
class MyModule2(torch.nn.Module): class MyModule2(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.modified_tensor = torch.tensor(10.) self.modified_tensor = torch.tensor(10.)
self.version = 1 self.version = 1

View File

@ -89,7 +89,7 @@ if _IS_MONKEYTYPE_INSTALLED:
self.traces.append(trace) self.traces.append(trace)
class JitTypeTraceStore(CallTraceStore): class JitTypeTraceStore(CallTraceStore):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
# A dictionary keeping all collected CallTrace # A dictionary keeping all collected CallTrace
# key is fully qualified name of called function # key is fully qualified name of called function
@ -159,15 +159,15 @@ else:
# When MonkeyType is not installed, we provide dummy class definitions # When MonkeyType is not installed, we provide dummy class definitions
# for the below classes. # for the below classes.
class JitTypeTraceStoreLogger: # type: ignore[no-redef] class JitTypeTraceStoreLogger: # type: ignore[no-redef]
def __init__(self): def __init__(self) -> None:
pass pass
class JitTypeTraceStore: # type: ignore[no-redef] class JitTypeTraceStore: # type: ignore[no-redef]
def __init__(self): def __init__(self) -> None:
self.trace_records = None self.trace_records = None
class JitTypeTraceConfig: # type: ignore[no-redef] class JitTypeTraceConfig: # type: ignore[no-redef]
def __init__(self): def __init__(self) -> None:
pass pass
monkeytype_trace = None # type: ignore[assignment] # noqa: F811 monkeytype_trace = None # type: ignore[assignment] # noqa: F811

View File

@ -426,7 +426,7 @@ class ConcreteTypeStore:
type_store: Dict[Type[Module], List[torch._C.ConcreteModuleType]] type_store: Dict[Type[Module], List[torch._C.ConcreteModuleType]]
methods_compiled: Set[torch._C.ConcreteModuleType] methods_compiled: Set[torch._C.ConcreteModuleType]
def __init__(self): def __init__(self) -> None:
# Python module type => List[ConcreteModuleType)] # Python module type => List[ConcreteModuleType)]
self.type_store = {} self.type_store = {}
# ConcreteTypes that have had their methods already compiled # ConcreteTypes that have had their methods already compiled

View File

@ -107,7 +107,7 @@ Attribute.__doc__ = """
from typing import Dict from typing import Dict
class AttributeModule(torch.jit.ScriptModule): class AttributeModule(torch.jit.ScriptModule):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.foo = torch.jit.Attribute(0.1, float) self.foo = torch.jit.Attribute(0.1, float)
@ -138,7 +138,7 @@ Attribute.__doc__ = """
class AttributeModule(torch.nn.Module): class AttributeModule(torch.nn.Module):
names: Dict[str, int] names: Dict[str, int]
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.names = {} self.names = {}
@ -522,7 +522,7 @@ if _enabled:
"original_name", "original_name",
] ]
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
forward: Callable[..., Any] = _CachedForward() # type: ignore[assignment] forward: Callable[..., Any] = _CachedForward() # type: ignore[assignment]
@ -1351,7 +1351,7 @@ def script(
import torch.nn.functional as F import torch.nn.functional as F
class MyModule(nn.Module): class MyModule(nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
# torch.jit.trace produces a ScriptModule's conv1 and conv2 # torch.jit.trace produces a ScriptModule's conv1 and conv2
self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16)) self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
@ -1374,7 +1374,7 @@ def script(
import torch.nn as nn import torch.nn as nn
class MyModule(nn.Module): class MyModule(nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
@torch.jit.export @torch.jit.export
@ -1547,7 +1547,7 @@ def interface(obj):
return x.relu() return x.relu()
class Impl2(torch.nn.Module): class Impl2(torch.nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.val = torch.rand(()) self.val = torch.rand(())
@ -1671,7 +1671,7 @@ class _ScriptProfileTable:
class _ScriptProfile: class _ScriptProfile:
def __init__(self): def __init__(self) -> None:
self.profile = classes.profiling._ScriptProfile() self.profile = classes.profiling._ScriptProfile()
def enable(self): def enable(self):

View File

@ -19,7 +19,7 @@ class EnabledProxy:
This is just a wrapper for a bool, so that we get reference semantics This is just a wrapper for a bool, so that we get reference semantics
""" """
def __init__(self): def __init__(self) -> None:
self.enabled = self.parse_env( self.enabled = self.parse_env(
"PYTORCH_JIT", True, "> Using PyTorch JIT", "> PyTorch JIT DISABLED" "PYTORCH_JIT", True, "> Using PyTorch JIT", "> PyTorch JIT DISABLED"
) )

View File

@ -966,7 +966,7 @@ def trace(
import torch.nn as nn import torch.nn as nn
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.conv = nn.Conv2d(1, 1, 3) self.conv = nn.Conv2d(1, 1, 3)
@ -1182,7 +1182,7 @@ def trace_module(
import torch.nn as nn import torch.nn as nn
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self.conv = nn.Conv2d(1, 1, 3) self.conv = nn.Conv2d(1, 1, 3)

View File

@ -61,7 +61,7 @@ class StorageWeakRef:
class SharedCache(dict): class SharedCache(dict):
"""Dictionary from multiprocessing handles to StorageWeakRef.""" """Dictionary from multiprocessing handles to StorageWeakRef."""
def __init__(self): def __init__(self) -> None:
# free_dead_references() is called if the len exceeds the current # free_dead_references() is called if the len exceeds the current
# limit. The limit scales with the number of remaining live objects. # limit. The limit scales with the number of remaining live objects.
self.limit = 128 self.limit = 128

Some files were not shown because too many files have changed in this diff Show More