Add None return type to init (#132335)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132335
Approved by: https://github.com/albanD
This commit is contained in:
Oguz Ulgen
2024-08-01 00:22:47 -07:00
committed by PyTorch MergeBot
parent 30d7f0b15a
commit 72d2dba992
130 changed files with 295 additions and 295 deletions

View File

@ -119,7 +119,7 @@ inner(torch.randn(20, 20, requires_grad=True) + 1)
backend_name = "relu_compile_error_TESTING_ONLY"
run_code = f"""\
class CpuCudaModule(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.m_x = torch.nn.Linear(20, 20).cuda()
self.m_y = torch.nn.Linear(20, 20)
@ -149,7 +149,7 @@ inner(torch.randn(20, 20).cuda(), torch.randn(20, 20))
res.minifier_module(),
"""\
class Repro(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.G__mod___m_x = Linear(in_features=20, out_features=20, bias=True).cuda()
self.G__mod___m_y = Linear(in_features=20, out_features=20, bias=True)
@ -204,7 +204,7 @@ inner(torch.randn(20, 20))
res.repro_module(),
"""\
class Repro(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
def forward(self, x_19):

View File

@ -122,7 +122,7 @@ inner(torch.randn(20))
res.repro_module(),
"""\
class Repro(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
def forward(self, arg0_1):
@ -138,7 +138,7 @@ class Repro(torch.nn.Module):
res.repro_module(),
"""\
class Repro(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
def forward(self, arg0_1):

View File

@ -19,7 +19,7 @@ class _ClassNamespace(types.ModuleType):
class _Classes(types.ModuleType):
__file__ = "_classes.py"
def __init__(self):
def __init__(self) -> None:
super().__init__("torch.classes")
def __getattr__(self, name):

View File

@ -71,7 +71,7 @@ class PhiloxState:
trace time.
"""
def __init__(self):
def __init__(self) -> None:
self.reset()
def reset(self):

View File

@ -247,7 +247,7 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
# This gives us the appropriately strided outputs here which will reflect runtime strides.
class FakeifyFirstAOTInvocationGuard:
def __init__(self):
def __init__(self) -> None:
self.tc = torch._guards.TracingContext.try_get()
assert self.tc
torch._guards.TracingContext.try_get().fakify_first_call = True

View File

@ -5,7 +5,7 @@ from .utils import ExactWeakKeyDictionary
class CodeContextDict:
def __init__(self):
def __init__(self) -> None:
self.code_context = ExactWeakKeyDictionary()
def has_context(self, code: types.CodeType):

View File

@ -170,7 +170,7 @@ class NNModuleToString:
"""
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
"""
)
@ -491,7 +491,7 @@ _is_leaf_or_default = _mk_defaulter(False)
class NopInputReader:
def __init__(self):
def __init__(self) -> None:
self.total = 0
def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None):

View File

@ -497,7 +497,7 @@ class _TorchDynamoContext:
wrapper function.
>> class CallableClass:
>> def __init__(self):
>> def __init__(self) -> None:
>> super().__init__()
>> self.relu = torch.nn.ReLU()
>>
@ -578,7 +578,7 @@ class OptimizeContext(_TorchDynamoContext):
class RunOnlyContext(_TorchDynamoContext):
def __init__(self):
def __init__(self) -> None:
# cudagraph trees relies on generation increment
def on_enter():
torch._dynamo.mutation_guard.GenerationTracker.generation += 1
@ -590,7 +590,7 @@ class RunOnlyContext(_TorchDynamoContext):
class DisableContext(_TorchDynamoContext):
def __init__(self):
def __init__(self) -> None:
super().__init__(callback=None)
def __call__(self, fn):

View File

@ -74,7 +74,7 @@ class InvalidBackend(TorchDynamoException):
class ResetRequired(TorchDynamoException):
def __init__(self):
def __init__(self) -> None:
super().__init__(
textwrap.dedent(
"""

View File

@ -92,7 +92,7 @@ def print_missing(stack):
class Profiler:
unique_graphs = 0
def __init__(self):
def __init__(self) -> None:
self.prof = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU],
with_stack=should_print_missing(),

View File

@ -70,7 +70,7 @@ class MutableLocal(MutableLocalBase):
state.
"""
def __init__(self):
def __init__(self) -> None:
super().__init__(MutableLocalSource.Local)
def __hash__(self):

View File

@ -274,7 +274,7 @@ class GraphArg:
class BackwardStateGraphArg(GraphArg):
def __init__(self):
def __init__(self) -> None:
super().__init__(
source=None,
_example=BackwardState(),
@ -2646,7 +2646,7 @@ class SourcelessBuilder:
if/else type->VariableTracker trees that were cropping up all over dynamo.
"""
def __init__(self):
def __init__(self) -> None:
raise AssertionError("Use SourcelessBuilder.create()")
@staticmethod

View File

@ -10,7 +10,7 @@ class ClassMethod(torch.nn.Module):
def method(cls, x):
return x + 1
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(4, 2)

View File

@ -26,7 +26,7 @@ class CondBranchClassMethod(torch.nn.Module):
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.subm = MySubModule()

View File

@ -8,7 +8,7 @@ class ModelAttrMutation(torch.nn.Module):
Attribute mutation is not supported.
"""
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.attr_list = [torch.randn(3, 2), torch.randn(3, 2)]

View File

@ -11,7 +11,7 @@ class ScalarOutput(torch.nn.Module):
Returning scalar values from the graph is supported, in addition to Tensor
outputs. Symbolic shapes are captured and rank is specialized.
"""
def __init__(self):
def __init__(self) -> None:
super().__init__()
def forward(self, x):

View File

@ -11,7 +11,7 @@ class SpecializedAttribute(torch.nn.Module):
Model attributes are specialized.
"""
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.a = "moo"
self.b = 4

View File

@ -24,7 +24,7 @@ class ConstantAttrMap(collections.abc.MutableMapping):
if that's the case).
"""
def __init__(self):
def __init__(self) -> None:
# Underlying dict that we use to implement this mapping.
self._constant_attrs: Dict[
Union[int, torch.Tensor, FakeScriptObject], List[Any]

View File

@ -1413,7 +1413,7 @@ class GraphModuleDeserializer(metaclass=Final):
constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]]
example_inputs: Optional[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]]
def __init__(self):
def __init__(self) -> None:
self.serialized_name_to_node: Dict[str, torch.fx.Node] = {}
self.serialized_name_to_meta: Dict[str, MetaType] = {}
self.graph = torch.fx.Graph()

View File

@ -602,7 +602,7 @@ class SubclassMeta:
# Optional field because we don't compute for inference graphs
grad_input_metas: Optional[List[Union[int, SubclassCreationMeta]]] = None
def __init__(self):
def __init__(self) -> None:
# The fields in this class get set after its construction.
pass

View File

@ -878,7 +878,7 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
)
class AOTModule(nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.orig_module = mod

View File

@ -30,7 +30,7 @@ from torch.autograd.forward_ad import _set_fwd_grad_enabled
# We do this by using creating a custom HigherOrderOperator that only functorch
# dispatches specially.
class CustomFunctionHigherOrderOperator(HigherOrderOperator):
def __init__(self):
def __init__(self) -> None:
super().__init__("custom_function_call")
def __call__(self, autograd_function, *args, **kwargs):
@ -713,7 +713,7 @@ def autograd_function_forward_rewritten(original_forward, original_setup_context
class AutogradFunctionApply(HigherOrderOperator):
def __init__(self):
def __init__(self) -> None:
super().__init__("autograd_function_apply")
def __call__(self, fwd, bwd, *fwd_args, **fwd_kwargs):

View File

@ -427,7 +427,7 @@ class ModuleContextCheckpointState:
class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
def __init__(self):
def __init__(self) -> None:
self.nn_modules: Dict[str, Any] = {}
def copy_graphstate(self):
@ -476,7 +476,7 @@ class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
"autocast_cache_enabled",
}
def __init__(self):
def __init__(self) -> None:
self.global_state: Dict[str, Tuple[Callable, ...]] = {}
def copy_graphstate(self):
@ -544,7 +544,7 @@ class GuardsSet:
class GuardsContext(Checkpointable[GuardsCheckpointState]):
def __init__(self):
def __init__(self) -> None:
self.dynamo_guards: GuardsSet = GuardsSet()
self.aotautograd_guards: List[GuardEnvExpr] = []

View File

@ -54,7 +54,7 @@ class AutoFunctionalized(HigherOrderOperator):
underscore is to prevent collisions with kwarg names in **kwargs.
"""
def __init__(self):
def __init__(self) -> None:
super().__init__("auto_functionalized")
def __call__(

View File

@ -55,7 +55,7 @@ class WithEffects(HigherOrderOperator):
per "effect type", which are enumerated in the _EffectType enum.
"""
def __init__(self):
def __init__(self) -> None:
super().__init__("with_effects")
def __call__(

View File

@ -38,7 +38,7 @@ class TransformGetItemToIndex(TorchFunctionMode):
class FlexAttentionHOP(HigherOrderOperator):
def __init__(self):
def __init__(self) -> None:
super().__init__("flex_attention")
def __call__(
@ -74,7 +74,7 @@ flex_attention.__module__ = "torch.ops.higher_order"
class FlexAttentionBackwardHOP(HigherOrderOperator):
def __init__(self):
def __init__(self) -> None:
super().__init__("flex_attention_backward")
def __call__(

View File

@ -45,7 +45,7 @@ class OutDtypeOperator(HigherOrderOperator):
3. Cast the output to `out_dtype`
"""
def __init__(self):
def __init__(self) -> None:
super().__init__("out_dtype")
# TODO(ydwu4): Subclassing HigherOrderOperator causes __module__ to
# become different (torch._higher_order_ops.out_dtype) which will result

View File

@ -519,7 +519,7 @@ def identify_mutated_tensors(kernel, kwargs):
# Used for wrapping a Triton Kernel
class TritonKernelWrapperMutation(HigherOrderOperator):
def __init__(self):
def __init__(self) -> None:
super().__init__("triton_kernel_wrapper_mutation")
@ -528,7 +528,7 @@ triton_kernel_wrapper_mutation = TritonKernelWrapperMutation()
# Used for wrapping a Triton Kernel in a functional manner
class TritonKernelWrapperFunctional(HigherOrderOperator):
def __init__(self):
def __init__(self) -> None:
super().__init__("triton_kernel_wrapper_functional")

View File

@ -18,7 +18,7 @@ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_ten
class WhileLoopOp(HigherOrderOperator):
def __init__(self):
def __init__(self) -> None:
super().__init__("while_loop")
def __call__(

View File

@ -15,7 +15,7 @@ uid = itertools.count(1)
# Used for testing the HigherOrderOperator mechanism
class Wrap(HigherOrderOperator):
def __init__(self):
def __init__(self) -> None:
super().__init__("wrap")
def __call__(self, func, *args, **kwargs):
@ -36,7 +36,7 @@ wrap = Wrap()
class WrapWithSetGradEnabled(HigherOrderOperator):
def __init__(self):
def __init__(self) -> None:
super().__init__("wrap_with_set_grad_enabled")
def __call__(self, enable_grad, wrapped_func, *args, **kwargs):
@ -74,7 +74,7 @@ class WrapActivationCheckpoint(HigherOrderOperator):
partitioners. See TagActivationCheckpoint for more information.
"""
def __init__(self):
def __init__(self) -> None:
super().__init__("wrap_activation_checkpoint")
def __call__(self, function, *args, **kwargs):
@ -113,7 +113,7 @@ class TagActivationCheckpoint(HigherOrderOperator):
the forward and recomputed forward in backward.
"""
def __init__(self):
def __init__(self) -> None:
super().__init__("tag_activation_checkpoint")
@staticmethod

View File

@ -1560,7 +1560,7 @@ class CSE:
class CodeGen:
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.exit_stack = contextlib.ExitStack()

View File

@ -29,7 +29,7 @@ class CppWrapperCuda(CppWrapperCpu):
Generates cpp wrapper for running on GPU and calls CUDA kernels
"""
def __init__(self):
def __init__(self) -> None:
self.device = "cuda"
super().__init__()
self.grid_id = count()

View File

@ -1113,7 +1113,7 @@ class HelperFunctions:
_templates_seen: Dict[str, str] # Template code to function name
finalized_helpers: List[str]
def __init__(self):
def __init__(self) -> None:
self._templates_seen = {}
self.finalized_helpers = []

View File

@ -589,7 +589,7 @@ def canonicalization_prefix():
class FreeUnbackedSymbolsOpsHandler:
symbols: OrderedSet[sympy.Symbol]
def __init__(self):
def __init__(self) -> None:
self.symbols = OrderedSet()
def __getattr__(self, name: str) -> Callable[..., Any]:

View File

@ -65,7 +65,7 @@ class SubgraphLoweringException(RuntimeError):
class InvalidCxxCompiler(RuntimeError):
def __init__(self):
def __init__(self) -> None:
from . import config
super().__init__(

View File

@ -79,7 +79,7 @@ class NumpyCompatNormalization:
inverse_mapping: Dict[str, str]
cache: Dict["torch.fx.graph.Target", Set[str]]
def __init__(self):
def __init__(self) -> None:
self.cache = {} # callable -> tuple of replaceable args e.g. ["axis"]
self.inverse_mapping = {}
for actual_kwarg, numpy_kwargs in self.numpy_compat.items():

View File

@ -1207,7 +1207,7 @@ if torch._C._has_mkldnn:
Combine packed weight nodes with the same inputs to reduce memory usage.
for example:
class Model(nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(32, 32, bias=True)

View File

@ -99,7 +99,7 @@ class CachedMetricsHelper:
apply on a cache hit.
"""
def __init__(self):
def __init__(self) -> None:
self.cached_metrics = {}
for metric in get_metric_fields():
self.cached_metrics[metric] = globals()[metric]

View File

@ -940,7 +940,7 @@ class IndentedBuffer:
class FakeIndentedBuffer(IndentedBuffer):
def __init__(self):
def __init__(self) -> None:
super().__init__()
def __getattribute__(self, name):
@ -1219,7 +1219,7 @@ class DebugDirManager:
counter = itertools.count(0)
prev_debug_name: str
def __init__(self):
def __init__(self) -> None:
self.id = next(DebugDirManager.counter)
def __enter__(self):
@ -1268,7 +1268,7 @@ def get_code(fn, *args, **kwargs):
class DummyModule:
"""This is empty to replace the generated triton module"""
def __init__(self):
def __init__(self) -> None:
pass
def call(self, *args, **kwargs):

View File

@ -7,7 +7,7 @@ from torch._lazy.device_context import get_device_context
class ClosureHandler:
def __init__(self):
def __init__(self) -> None:
pass
def run(self, closure):

View File

@ -42,7 +42,7 @@ class HasStaticMethodFromReal(Protocol):
class FakeClassRegistry:
def __init__(self):
def __init__(self) -> None:
self._registered_class: Dict[str, Any] = {}
def has_impl(self, full_qualname: str) -> bool:

View File

@ -70,7 +70,7 @@ class PythonDispatcher:
]
supported_keys = runtime_keys + alias_keys
def __init__(self):
def __init__(self) -> None:
C._dispatch_check_invariants(self.name) # type: ignore[attr-defined]
self.ref = C._dispatch_library("FRAGMENT", self.namespace, "")
self.ref.def_("foo(Tensor x) -> Tensor")

View File

@ -60,7 +60,7 @@ def clone_inputs(args):
class SchemaCheckMode(TorchDispatchMode):
def __init__(self):
def __init__(self) -> None:
# Information recorded for testing purposes. For example:
# - incorrect schemas
# - overly conservative schemas

View File

@ -36,7 +36,7 @@ class FloatFunctional(torch.nn.Module):
- mul_scalar
"""
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.activation_post_process = torch.nn.Identity()
@ -190,7 +190,7 @@ class QFunctional(torch.nn.Module):
- mul_scalar
"""
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.scale = 1.0
self.zero_point = 0

View File

@ -72,7 +72,7 @@ class QConfigMultiMapping:
"""
def __init__(self):
def __init__(self) -> None:
# initialize this with 1 QConfigMapping to avoid corner cases
self.qconfig_mappings_list: List[QConfigMapping] = [QConfigMapping()]

View File

@ -99,7 +99,7 @@ from torch.ao.pruning._experimental.pruner import SaliencyPruner
# Define model
class Model(nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Linear(700, 500, bias=True),

View File

@ -85,7 +85,7 @@ class FakeQuantizeBase(ABC, Module):
fake_quant_enabled: torch.Tensor
observer_enabled: torch.Tensor
def __init__(self):
def __init__(self) -> None:
"""Set fake_quant_enabled and observer_enabled."""
super().__init__()
# fake_quant_enabled and observer_enabled are buffers to support their

View File

@ -70,7 +70,7 @@ In the following, Ill first have a detailed description for each step, and th
```
class LinearReLUModule(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(5, 10).float()
self.relu = torch.nn.ReLU()

View File

@ -137,7 +137,7 @@ class DetectorBase(ABC):
- Should return a str-based report and dict info in Tuple[str,Dict] format
"""
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.detector_config_info = None

View File

@ -63,7 +63,7 @@ class PrepareCustomConfig:
.set_preserved_attributes(["attr1", "attr2"])
"""
def __init__(self):
def __init__(self) -> None:
self.standalone_module_names: Dict[str, StandaloneModuleConfigEntry] = {}
self.standalone_module_classes: Dict[Type, StandaloneModuleConfigEntry] = {}
self.float_to_observed_mapping: Dict[QuantType, Dict[Type, Type]] = {}
@ -382,7 +382,7 @@ class ConvertCustomConfig:
.set_preserved_attributes(["attr1", "attr2"])
"""
def __init__(self):
def __init__(self) -> None:
self.observed_to_quantized_mapping: Dict[QuantType, Dict[Type, Type]] = {}
self.preserved_attributes: List[str] = []
@ -477,7 +477,7 @@ class FuseCustomConfig:
fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"])
"""
def __init__(self):
def __init__(self) -> None:
self.preserved_attributes: List[str] = []
def __repr__(self):

View File

@ -1568,7 +1568,7 @@ class ReuseInputObserver(ObserverBase):
Note: this is only enabled in FX Graph Mode Quantization
"""
def __init__(self):
def __init__(self) -> None:
super().__init__(torch.quint8, is_dynamic=False)
def forward(self, x):

View File

@ -229,7 +229,7 @@ class QConfigMapping:
"""
def __init__(self):
def __init__(self) -> None:
# In increasing match priority:
self.global_qconfig: QConfigAny = None
self.object_type_qconfigs: OrderedDict[

View File

@ -289,7 +289,7 @@ def prepare_fx(
from torch.ao.quantization.quantize_fx import prepare_fx
class Submodule(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
@ -297,7 +297,7 @@ def prepare_fx(
return x
class M(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(5, 5)
self.sub = Submodule()
@ -427,7 +427,7 @@ def prepare_qat_fx(
from torch.ao.quantization.quantize_fx import prepare_qat_fx
class Submodule(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
@ -435,7 +435,7 @@ def prepare_qat_fx(
return x
class M(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(5, 5)
self.sub = Submodule()

View File

@ -56,7 +56,7 @@ def prepare_pt2e(
)
class M(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(5, 10)
@ -129,7 +129,7 @@ def prepare_qat_pt2e(
)
class M(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(5, 10)

View File

@ -42,7 +42,7 @@ def get_embedding_operators_config() -> OperatorConfig:
class EmbeddingQuantizer(Quantizer):
def __init__(self):
def __init__(self) -> None:
super().__init__()
@classmethod

View File

@ -436,7 +436,7 @@ class X86InductorQuantizer(Quantizer):
supported_config_and_operators = _get_supported_config_and_operators()
module_function_to_aten_operator_type = _map_module_function_to_aten_operator_type()
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.global_config: Optional[QuantizationConfig] = None
self.operator_type_qconfig: Dict[

View File

@ -268,7 +268,7 @@ class XNNPACKQuantizer(Quantizer):
"linear",
]
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.global_config: Optional[QuantizationConfig] = None
self.operator_type_config: Dict[

View File

@ -513,7 +513,7 @@ def _get_path_of_module(
Example::
>> class M(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
return self.linear(x)

View File

@ -645,7 +645,7 @@ class FunctionEvent(FormattedTimesMixin):
class FunctionEventAvg(FormattedTimesMixin):
"""Used to average stats over multiple FunctionEvent objects."""
def __init__(self):
def __init__(self) -> None:
self.key: Optional[str] = None
self.count: int = 0
self.node_id: int = 0

View File

@ -266,7 +266,7 @@ class _Launcher:
or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or \
{expanduser('~')}/.local/lib/ so the LD_PRELOAD environment variable will not be set."
def __init__(self):
def __init__(self) -> None:
self.cpuinfo = _CPUinfo()
def add_lib_preload(self, lib_type):

View File

@ -77,17 +77,17 @@ namespace jit {
*
* So why does debug handle map to DebugInfoTuple = {source range and inlined
* cs}? {debug_handle, source_range_tag, serialized_callstack} Take this
* example: class L(nn.Module): def __init__(self):
* example: class L(nn.Module): def __init__(self) -> None:
* ...
* def forward(self, x):
* return x * 5
* class M(nn.Module):
* def __init__(self):
* def __init__(self) -> None:
* ...
* def forward(self, x):
* return x - 2
* class N(nn.Module):
* def __init__(self):
* def __init__(self) -> None:
* self.m = M()
* def forward(self, x):
* return self.m(x) + 3

View File

@ -328,7 +328,7 @@ For example:
```
class M(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
self.a = torch.rand(2, 3)
self.b = torch.nn.Linear(10, 10)

View File

@ -37,7 +37,7 @@ When making changes to the operators, the first thing to identify is if it's BC/
1. Add a test module in `test/jit/fixtures_srcs/fixtures_src.py`. In `test/jit/fixtures_srcs/generate_models.py`,
```
class TestVersionedLinspaceV7(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex]):
@ -163,7 +163,7 @@ When making changes to the operators, the first thing to identify is if it's BC/
# Step 2. Write down how current module should look like
class MyModuleFloat(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
def forward(self, a, b: float):

View File

@ -25,7 +25,7 @@ namespace onnx {
//
// clang-format off
// class M(torch.nn.Module):
// def __init__(self):
// def __init__(self) -> None:
// super().__init__()
// self.lns = torch.nn.ModuleList([torch.nn.LayerNorm(3, eps = i) for i in range(2)])
// self.celu1 = torch.nn.CELU(1.0)

View File

@ -17,7 +17,7 @@ torch._lazy.ts_backend.init()
class Net(nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)

View File

@ -135,7 +135,7 @@ Here's our model definition:
```python
class Net(nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)

View File

@ -163,7 +163,7 @@ class TensorInfo:
class _TensorsAccessed:
def __init__(self):
def __init__(self) -> None:
self.accesses: Dict[DataPtr, TensorInfo] = {}
def ensure_tensor_exists(self, data_ptr: DataPtr) -> None:
@ -218,7 +218,7 @@ class _TensorsAccessed:
class StreamSynchronizations:
def __init__(self):
def __init__(self) -> None:
self.current_sync_states: Dict[StreamId, Dict[StreamId, SeqNum]] = {}
self.recorded_sync_states: Dict[EventId, Dict[StreamId, SeqNum]] = {}
self.host_sync_state: Dict[StreamId, SeqNum] = {}
@ -338,7 +338,7 @@ class EventHandler:
data race.
"""
def __init__(self):
def __init__(self) -> None:
self.tensors_accessed = _TensorsAccessed()
self.syncs = StreamSynchronizations()
self.seq_num: SeqNum = 0
@ -478,7 +478,7 @@ def zip_arguments(
class ArgumentHandler:
def __init__(self):
def __init__(self) -> None:
self.dataptrs_read: Set[DataPtr] = set()
self.dataptrs_written: Set[DataPtr] = set()
self.tensor_aliases: Dict[DataPtr, List[str]] = {}
@ -527,7 +527,7 @@ class ArgumentHandler:
class CUDASanitizerDispatchMode(TorchDispatchMode):
def __init__(self):
def __init__(self) -> None:
self.event_handler = EventHandler()
torch._C._activate_gpu_trace()
gpu_trace.register_callback_for_event_creation(
@ -596,7 +596,7 @@ class CUDASanitizer:
This approach was deemed more elegant than using the atexit module.
"""
def __init__(self):
def __init__(self) -> None:
self.dispatch = CUDASanitizerDispatchMode()
self.enabled = False

View File

@ -49,7 +49,7 @@ def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
>>> import torch.nn as nn
>>>
>>> class MyModel(nn.Module):
>>> def __init__(self):
>>> def __init__(self) -> None:
>>> super().__init__()
>>> self.l1 = nn.Linear(10, 10)
>>> self.l2 = nn.Linear(10, 10)

View File

@ -47,7 +47,7 @@ def contract(state_cls: Type[_State] = _State):
>>> import torch.nn as nn
>>>
>>> class MyModel(nn.Module):
>>> def __init__(self):
>>> def __init__(self) -> None:
>>> super().__init__()
>>> self.l1 = nn.Linear(10, 10)
>>> self.l2 = nn.Linear(10, 10)

View File

@ -43,7 +43,7 @@ logger = logging.getLogger("torch.distributed._composable.fsdp")
class FSDPStateContext:
"""This has state shared across FSDP states."""
def __init__(self):
def __init__(self) -> None:
# All FSDP states in the root state's module tree
self.all_states: List[FSDPState] = []
# Iteration's forward root runs the once-per-forward logic; this root
@ -71,7 +71,7 @@ def disable_if_config_true(func):
class FSDPState(_State):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self._fsdp_param_group: Optional[FSDPParamGroup] = None
self._is_root: Optional[bool] = None # root set during lazy init

View File

@ -38,7 +38,7 @@ class ShardingPlan:
>>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
>>> class MyModule(nn.Module):
>>> def __init__(self):
>>> def __init__(self) -> None:
>>> super().__init__()
>>> self.fc1 = nn.Linear()
>>> self.gelu = nn.GELU()

View File

@ -117,7 +117,7 @@ import torch.nn as nn
from torch.distributed._tensor import Shard, distribute_tensor, distribute_module, init_device_mesh
class MyModule(nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(8, 8)
self.fc2 = nn.Linear(8, 8)

View File

@ -25,7 +25,7 @@ from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_modul
class SimpleMLP(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.net1 = torch.nn.Linear(5, 128)
self.relu = torch.nn.ReLU()

View File

@ -55,7 +55,7 @@ class Joinable(ABC):
"""
@abstractmethod
def __init__(self):
def __init__(self) -> None:
super().__init__()
self._join_config = _JoinConfig.construct_disabled_join_config()

View File

@ -31,7 +31,7 @@ class InjectedException(Exception):
class Model(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.net1 = nn.Linear(8, 32)
self.net2 = nn.Linear(32, 128)

View File

@ -22,7 +22,7 @@ CHECKPOINT_DIR = f"~/{os.environ['LOGNAME']}/checkpoint"
class Model(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
torch.manual_seed(0)
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())

View File

@ -434,7 +434,7 @@ class _reduce_op:
:class:`~torch.distributed.ReduceOp` is recommended to use instead.
"""
def __init__(self):
def __init__(self) -> None:
# __members__ is a dict storing key-value pairs for enum classes
for k, v in ReduceOp.RedOpType.__members__.items():
setattr(self, k, v)
@ -568,7 +568,7 @@ class _World:
of c10d and is subject to change..
"""
def __init__(self):
def __init__(self) -> None:
self._default_pg = None
self._pg_coalesce_state: Dict[ProcessGroup, List[_CollOp]] = {}
self._pg_default_device: Dict[ProcessGroup, torch.device] = {}
@ -2194,7 +2194,7 @@ class _IllegalWork(Work):
class _CoalescingManager:
def __init__(self):
def __init__(self) -> None:
self.works: List[Work] = []
def append(self, work: Work):

View File

@ -106,7 +106,7 @@ class _FSDPDeviceHandle:
class _UninitializedDeviceHandle(_FSDPDeviceHandle):
def __init__(self):
def __init__(self) -> None:
pass
def __getattribute__(self, __name: str) -> Any:

View File

@ -156,7 +156,7 @@ class _RemoteModule(nn.Module):
created outside of remote modules, rather than as submodules of any remote module (by calling ``add_module``).
Hybrid Example:
>>> class HybridModel(nn.Module):
>>> def __init__(self):
>>> def __init__(self) -> None:
>>> nn.Module.__init__(self)
>>> self.remote_embedding = RemoteModule(...)
>>> self.local_linear = nn.Linear(...)

View File

@ -248,7 +248,7 @@ class ExportGraphSignature:
e.g. If following module is exported::
class CustomModule(nn.Module):
def __init__(self):
def __init__(self) -> None:
super(CustomModule, self).__init__()
# Define a parameter

View File

@ -45,7 +45,7 @@ FXs front-end makes use of the dynamic nature of Python to intercept call-sit
import torch
class MyModule(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(
torch.rand(3, 4))

View File

@ -9,7 +9,7 @@ demonstration of these components in action:
import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)

View File

@ -1012,7 +1012,7 @@ class _PatchedFnSetAttr(_PatchedFn):
class _Patcher:
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.patches_made: List[_PatchedFn] = []
self.visited: Set[int] = set()

View File

@ -63,7 +63,7 @@ class T(Constraint):
"""
True
"""
def __init__(self):
def __init__(self) -> None:
pass
def __eq__(self, other):
@ -76,7 +76,7 @@ class F(Constraint):
"""
False
"""
def __init__(self):
def __init__(self) -> None:
pass
def __eq__(self, other):

View File

@ -117,7 +117,7 @@ if HAS_PYDOT:
>>> # xdoctest: +REQUIRES(module:ubelt)
>>> # define module
>>> class MyModule(torch.nn.Module):
>>> def __init__(self):
>>> def __init__(self) -> None:
>>> super().__init__()
>>> self.linear = torch.nn.Linear(4, 5)
>>> def forward(self, x):

View File

@ -83,7 +83,7 @@ def split_module(
from torch.fx.passes.split_module import split_module
class MyModule(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)

View File

@ -83,7 +83,7 @@ def split_by_tags(
Given the following module def:
class SimpleModule(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.linear1 = torch.nn.Linear(...)
self.linear2 = torch.nn.Linear(...)

View File

@ -38,7 +38,7 @@ class Scope:
return x.transpose(1, 2)
class M(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
self.sub = Sub()
def forward(self, x):

View File

@ -118,7 +118,7 @@ def replace_pattern(
from torch.fx import symbolic_trace, subgraph_rewriter
class M(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
def forward(self, x, w1, w2):

View File

@ -38,7 +38,7 @@ class _DynType:
"""
_DynType defines a type which stands for the absence of type information.
"""
def __init__(self):
def __init__(self) -> None:
self.__name__ = '_DynType'
def __eq__(self, other):

View File

@ -219,7 +219,7 @@ def isinstance(obj, target_type):
from typing import Any, Dict, List
class MyModule(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
def forward(self, input: Any): # note the Any type
@ -255,7 +255,7 @@ class strict_fusion:
"""
def __init__(self):
def __init__(self) -> None:
if not torch._jit_internal.is_scripting():
warnings.warn("Only works in script mode")
pass

View File

@ -73,7 +73,7 @@ def fork(func, *args, **kwargs):
def forward(self, a: Tensor, b : int):
return a + b
class Mod(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super(self).__init__()
self.mod = AddMod()
def forward(self, input):

View File

@ -39,7 +39,7 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
def fn(self):
return []
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.x: List[int] = []

View File

@ -65,7 +65,7 @@ def freeze(
.. testcode::
import torch
class MyModule2(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.modified_tensor = torch.tensor(10.)
self.version = 1

View File

@ -89,7 +89,7 @@ if _IS_MONKEYTYPE_INSTALLED:
self.traces.append(trace)
class JitTypeTraceStore(CallTraceStore):
def __init__(self):
def __init__(self) -> None:
super().__init__()
# A dictionary keeping all collected CallTrace
# key is fully qualified name of called function
@ -159,15 +159,15 @@ else:
# When MonkeyType is not installed, we provide dummy class definitions
# for the below classes.
class JitTypeTraceStoreLogger: # type: ignore[no-redef]
def __init__(self):
def __init__(self) -> None:
pass
class JitTypeTraceStore: # type: ignore[no-redef]
def __init__(self):
def __init__(self) -> None:
self.trace_records = None
class JitTypeTraceConfig: # type: ignore[no-redef]
def __init__(self):
def __init__(self) -> None:
pass
monkeytype_trace = None # type: ignore[assignment] # noqa: F811

View File

@ -426,7 +426,7 @@ class ConcreteTypeStore:
type_store: Dict[Type[Module], List[torch._C.ConcreteModuleType]]
methods_compiled: Set[torch._C.ConcreteModuleType]
def __init__(self):
def __init__(self) -> None:
# Python module type => List[ConcreteModuleType)]
self.type_store = {}
# ConcreteTypes that have had their methods already compiled

View File

@ -107,7 +107,7 @@ Attribute.__doc__ = """
from typing import Dict
class AttributeModule(torch.jit.ScriptModule):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.foo = torch.jit.Attribute(0.1, float)
@ -138,7 +138,7 @@ Attribute.__doc__ = """
class AttributeModule(torch.nn.Module):
names: Dict[str, int]
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.names = {}
@ -522,7 +522,7 @@ if _enabled:
"original_name",
]
def __init__(self):
def __init__(self) -> None:
super().__init__()
forward: Callable[..., Any] = _CachedForward() # type: ignore[assignment]
@ -1351,7 +1351,7 @@ def script(
import torch.nn.functional as F
class MyModule(nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
# torch.jit.trace produces a ScriptModule's conv1 and conv2
self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
@ -1374,7 +1374,7 @@ def script(
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
@torch.jit.export
@ -1547,7 +1547,7 @@ def interface(obj):
return x.relu()
class Impl2(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.val = torch.rand(())
@ -1671,7 +1671,7 @@ class _ScriptProfileTable:
class _ScriptProfile:
def __init__(self):
def __init__(self) -> None:
self.profile = classes.profiling._ScriptProfile()
def enable(self):

View File

@ -19,7 +19,7 @@ class EnabledProxy:
This is just a wrapper for a bool, so that we get reference semantics
"""
def __init__(self):
def __init__(self) -> None:
self.enabled = self.parse_env(
"PYTORCH_JIT", True, "> Using PyTorch JIT", "> PyTorch JIT DISABLED"
)

View File

@ -966,7 +966,7 @@ def trace(
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(1, 1, 3)
@ -1182,7 +1182,7 @@ def trace_module(
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(1, 1, 3)

View File

@ -61,7 +61,7 @@ class StorageWeakRef:
class SharedCache(dict):
"""Dictionary from multiprocessing handles to StorageWeakRef."""
def __init__(self):
def __init__(self) -> None:
# free_dead_references() is called if the len exceeds the current
# limit. The limit scales with the number of remaining live objects.
self.limit = 128

Some files were not shown because too many files have changed in this diff Show More