mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add None return type to init (#132335)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132335 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
30d7f0b15a
commit
72d2dba992
@ -119,7 +119,7 @@ inner(torch.randn(20, 20, requires_grad=True) + 1)
|
|||||||
backend_name = "relu_compile_error_TESTING_ONLY"
|
backend_name = "relu_compile_error_TESTING_ONLY"
|
||||||
run_code = f"""\
|
run_code = f"""\
|
||||||
class CpuCudaModule(torch.nn.Module):
|
class CpuCudaModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.m_x = torch.nn.Linear(20, 20).cuda()
|
self.m_x = torch.nn.Linear(20, 20).cuda()
|
||||||
self.m_y = torch.nn.Linear(20, 20)
|
self.m_y = torch.nn.Linear(20, 20)
|
||||||
@ -149,7 +149,7 @@ inner(torch.randn(20, 20).cuda(), torch.randn(20, 20))
|
|||||||
res.minifier_module(),
|
res.minifier_module(),
|
||||||
"""\
|
"""\
|
||||||
class Repro(torch.nn.Module):
|
class Repro(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.G__mod___m_x = Linear(in_features=20, out_features=20, bias=True).cuda()
|
self.G__mod___m_x = Linear(in_features=20, out_features=20, bias=True).cuda()
|
||||||
self.G__mod___m_y = Linear(in_features=20, out_features=20, bias=True)
|
self.G__mod___m_y = Linear(in_features=20, out_features=20, bias=True)
|
||||||
@ -204,7 +204,7 @@ inner(torch.randn(20, 20))
|
|||||||
res.repro_module(),
|
res.repro_module(),
|
||||||
"""\
|
"""\
|
||||||
class Repro(torch.nn.Module):
|
class Repro(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x_19):
|
def forward(self, x_19):
|
||||||
|
@ -122,7 +122,7 @@ inner(torch.randn(20))
|
|||||||
res.repro_module(),
|
res.repro_module(),
|
||||||
"""\
|
"""\
|
||||||
class Repro(torch.nn.Module):
|
class Repro(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, arg0_1):
|
def forward(self, arg0_1):
|
||||||
@ -138,7 +138,7 @@ class Repro(torch.nn.Module):
|
|||||||
res.repro_module(),
|
res.repro_module(),
|
||||||
"""\
|
"""\
|
||||||
class Repro(torch.nn.Module):
|
class Repro(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, arg0_1):
|
def forward(self, arg0_1):
|
||||||
|
@ -19,7 +19,7 @@ class _ClassNamespace(types.ModuleType):
|
|||||||
class _Classes(types.ModuleType):
|
class _Classes(types.ModuleType):
|
||||||
__file__ = "_classes.py"
|
__file__ = "_classes.py"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__("torch.classes")
|
super().__init__("torch.classes")
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
|
@ -71,7 +71,7 @@ class PhiloxState:
|
|||||||
trace time.
|
trace time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
@ -247,7 +247,7 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
|
|||||||
# This gives us the appropriately strided outputs here which will reflect runtime strides.
|
# This gives us the appropriately strided outputs here which will reflect runtime strides.
|
||||||
|
|
||||||
class FakeifyFirstAOTInvocationGuard:
|
class FakeifyFirstAOTInvocationGuard:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.tc = torch._guards.TracingContext.try_get()
|
self.tc = torch._guards.TracingContext.try_get()
|
||||||
assert self.tc
|
assert self.tc
|
||||||
torch._guards.TracingContext.try_get().fakify_first_call = True
|
torch._guards.TracingContext.try_get().fakify_first_call = True
|
||||||
|
@ -5,7 +5,7 @@ from .utils import ExactWeakKeyDictionary
|
|||||||
|
|
||||||
|
|
||||||
class CodeContextDict:
|
class CodeContextDict:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.code_context = ExactWeakKeyDictionary()
|
self.code_context = ExactWeakKeyDictionary()
|
||||||
|
|
||||||
def has_context(self, code: types.CodeType):
|
def has_context(self, code: types.CodeType):
|
||||||
|
@ -170,7 +170,7 @@ class NNModuleToString:
|
|||||||
"""
|
"""
|
||||||
from torch.nn import *
|
from torch.nn import *
|
||||||
class Repro(torch.nn.Module):
|
class Repro(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
@ -491,7 +491,7 @@ _is_leaf_or_default = _mk_defaulter(False)
|
|||||||
|
|
||||||
|
|
||||||
class NopInputReader:
|
class NopInputReader:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.total = 0
|
self.total = 0
|
||||||
|
|
||||||
def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None):
|
def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None):
|
||||||
|
@ -497,7 +497,7 @@ class _TorchDynamoContext:
|
|||||||
wrapper function.
|
wrapper function.
|
||||||
|
|
||||||
>> class CallableClass:
|
>> class CallableClass:
|
||||||
>> def __init__(self):
|
>> def __init__(self) -> None:
|
||||||
>> super().__init__()
|
>> super().__init__()
|
||||||
>> self.relu = torch.nn.ReLU()
|
>> self.relu = torch.nn.ReLU()
|
||||||
>>
|
>>
|
||||||
@ -578,7 +578,7 @@ class OptimizeContext(_TorchDynamoContext):
|
|||||||
|
|
||||||
|
|
||||||
class RunOnlyContext(_TorchDynamoContext):
|
class RunOnlyContext(_TorchDynamoContext):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
# cudagraph trees relies on generation increment
|
# cudagraph trees relies on generation increment
|
||||||
def on_enter():
|
def on_enter():
|
||||||
torch._dynamo.mutation_guard.GenerationTracker.generation += 1
|
torch._dynamo.mutation_guard.GenerationTracker.generation += 1
|
||||||
@ -590,7 +590,7 @@ class RunOnlyContext(_TorchDynamoContext):
|
|||||||
|
|
||||||
|
|
||||||
class DisableContext(_TorchDynamoContext):
|
class DisableContext(_TorchDynamoContext):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(callback=None)
|
super().__init__(callback=None)
|
||||||
|
|
||||||
def __call__(self, fn):
|
def __call__(self, fn):
|
||||||
|
@ -74,7 +74,7 @@ class InvalidBackend(TorchDynamoException):
|
|||||||
|
|
||||||
|
|
||||||
class ResetRequired(TorchDynamoException):
|
class ResetRequired(TorchDynamoException):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
textwrap.dedent(
|
textwrap.dedent(
|
||||||
"""
|
"""
|
||||||
|
@ -92,7 +92,7 @@ def print_missing(stack):
|
|||||||
class Profiler:
|
class Profiler:
|
||||||
unique_graphs = 0
|
unique_graphs = 0
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.prof = torch.profiler.profile(
|
self.prof = torch.profiler.profile(
|
||||||
activities=[torch.profiler.ProfilerActivity.CPU],
|
activities=[torch.profiler.ProfilerActivity.CPU],
|
||||||
with_stack=should_print_missing(),
|
with_stack=should_print_missing(),
|
||||||
|
@ -70,7 +70,7 @@ class MutableLocal(MutableLocalBase):
|
|||||||
state.
|
state.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(MutableLocalSource.Local)
|
super().__init__(MutableLocalSource.Local)
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
|
@ -274,7 +274,7 @@ class GraphArg:
|
|||||||
|
|
||||||
|
|
||||||
class BackwardStateGraphArg(GraphArg):
|
class BackwardStateGraphArg(GraphArg):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
source=None,
|
source=None,
|
||||||
_example=BackwardState(),
|
_example=BackwardState(),
|
||||||
@ -2646,7 +2646,7 @@ class SourcelessBuilder:
|
|||||||
if/else type->VariableTracker trees that were cropping up all over dynamo.
|
if/else type->VariableTracker trees that were cropping up all over dynamo.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
raise AssertionError("Use SourcelessBuilder.create()")
|
raise AssertionError("Use SourcelessBuilder.create()")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -10,7 +10,7 @@ class ClassMethod(torch.nn.Module):
|
|||||||
def method(cls, x):
|
def method(cls, x):
|
||||||
return x + 1
|
return x + 1
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.linear = torch.nn.Linear(4, 2)
|
self.linear = torch.nn.Linear(4, 2)
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ class CondBranchClassMethod(torch.nn.Module):
|
|||||||
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
|
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.subm = MySubModule()
|
self.subm = MySubModule()
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ class ModelAttrMutation(torch.nn.Module):
|
|||||||
Attribute mutation is not supported.
|
Attribute mutation is not supported.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attr_list = [torch.randn(3, 2), torch.randn(3, 2)]
|
self.attr_list = [torch.randn(3, 2), torch.randn(3, 2)]
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ class ScalarOutput(torch.nn.Module):
|
|||||||
Returning scalar values from the graph is supported, in addition to Tensor
|
Returning scalar values from the graph is supported, in addition to Tensor
|
||||||
outputs. Symbolic shapes are captured and rank is specialized.
|
outputs. Symbolic shapes are captured and rank is specialized.
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -11,7 +11,7 @@ class SpecializedAttribute(torch.nn.Module):
|
|||||||
Model attributes are specialized.
|
Model attributes are specialized.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.a = "moo"
|
self.a = "moo"
|
||||||
self.b = 4
|
self.b = 4
|
||||||
|
@ -24,7 +24,7 @@ class ConstantAttrMap(collections.abc.MutableMapping):
|
|||||||
if that's the case).
|
if that's the case).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
# Underlying dict that we use to implement this mapping.
|
# Underlying dict that we use to implement this mapping.
|
||||||
self._constant_attrs: Dict[
|
self._constant_attrs: Dict[
|
||||||
Union[int, torch.Tensor, FakeScriptObject], List[Any]
|
Union[int, torch.Tensor, FakeScriptObject], List[Any]
|
||||||
|
@ -1413,7 +1413,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
|||||||
constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]]
|
constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]]
|
||||||
example_inputs: Optional[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]]
|
example_inputs: Optional[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.serialized_name_to_node: Dict[str, torch.fx.Node] = {}
|
self.serialized_name_to_node: Dict[str, torch.fx.Node] = {}
|
||||||
self.serialized_name_to_meta: Dict[str, MetaType] = {}
|
self.serialized_name_to_meta: Dict[str, MetaType] = {}
|
||||||
self.graph = torch.fx.Graph()
|
self.graph = torch.fx.Graph()
|
||||||
|
@ -602,7 +602,7 @@ class SubclassMeta:
|
|||||||
# Optional field because we don't compute for inference graphs
|
# Optional field because we don't compute for inference graphs
|
||||||
grad_input_metas: Optional[List[Union[int, SubclassCreationMeta]]] = None
|
grad_input_metas: Optional[List[Union[int, SubclassCreationMeta]]] = None
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
# The fields in this class get set after its construction.
|
# The fields in this class get set after its construction.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -878,7 +878,7 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
|
|||||||
)
|
)
|
||||||
|
|
||||||
class AOTModule(nn.Module):
|
class AOTModule(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.orig_module = mod
|
self.orig_module = mod
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ from torch.autograd.forward_ad import _set_fwd_grad_enabled
|
|||||||
# We do this by using creating a custom HigherOrderOperator that only functorch
|
# We do this by using creating a custom HigherOrderOperator that only functorch
|
||||||
# dispatches specially.
|
# dispatches specially.
|
||||||
class CustomFunctionHigherOrderOperator(HigherOrderOperator):
|
class CustomFunctionHigherOrderOperator(HigherOrderOperator):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__("custom_function_call")
|
super().__init__("custom_function_call")
|
||||||
|
|
||||||
def __call__(self, autograd_function, *args, **kwargs):
|
def __call__(self, autograd_function, *args, **kwargs):
|
||||||
@ -713,7 +713,7 @@ def autograd_function_forward_rewritten(original_forward, original_setup_context
|
|||||||
|
|
||||||
|
|
||||||
class AutogradFunctionApply(HigherOrderOperator):
|
class AutogradFunctionApply(HigherOrderOperator):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__("autograd_function_apply")
|
super().__init__("autograd_function_apply")
|
||||||
|
|
||||||
def __call__(self, fwd, bwd, *fwd_args, **fwd_kwargs):
|
def __call__(self, fwd, bwd, *fwd_args, **fwd_kwargs):
|
||||||
|
@ -427,7 +427,7 @@ class ModuleContextCheckpointState:
|
|||||||
|
|
||||||
|
|
||||||
class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
|
class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.nn_modules: Dict[str, Any] = {}
|
self.nn_modules: Dict[str, Any] = {}
|
||||||
|
|
||||||
def copy_graphstate(self):
|
def copy_graphstate(self):
|
||||||
@ -476,7 +476,7 @@ class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
|
|||||||
"autocast_cache_enabled",
|
"autocast_cache_enabled",
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.global_state: Dict[str, Tuple[Callable, ...]] = {}
|
self.global_state: Dict[str, Tuple[Callable, ...]] = {}
|
||||||
|
|
||||||
def copy_graphstate(self):
|
def copy_graphstate(self):
|
||||||
@ -544,7 +544,7 @@ class GuardsSet:
|
|||||||
|
|
||||||
|
|
||||||
class GuardsContext(Checkpointable[GuardsCheckpointState]):
|
class GuardsContext(Checkpointable[GuardsCheckpointState]):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.dynamo_guards: GuardsSet = GuardsSet()
|
self.dynamo_guards: GuardsSet = GuardsSet()
|
||||||
self.aotautograd_guards: List[GuardEnvExpr] = []
|
self.aotautograd_guards: List[GuardEnvExpr] = []
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ class AutoFunctionalized(HigherOrderOperator):
|
|||||||
underscore is to prevent collisions with kwarg names in **kwargs.
|
underscore is to prevent collisions with kwarg names in **kwargs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__("auto_functionalized")
|
super().__init__("auto_functionalized")
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -55,7 +55,7 @@ class WithEffects(HigherOrderOperator):
|
|||||||
per "effect type", which are enumerated in the _EffectType enum.
|
per "effect type", which are enumerated in the _EffectType enum.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__("with_effects")
|
super().__init__("with_effects")
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -38,7 +38,7 @@ class TransformGetItemToIndex(TorchFunctionMode):
|
|||||||
|
|
||||||
|
|
||||||
class FlexAttentionHOP(HigherOrderOperator):
|
class FlexAttentionHOP(HigherOrderOperator):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__("flex_attention")
|
super().__init__("flex_attention")
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@ -74,7 +74,7 @@ flex_attention.__module__ = "torch.ops.higher_order"
|
|||||||
|
|
||||||
|
|
||||||
class FlexAttentionBackwardHOP(HigherOrderOperator):
|
class FlexAttentionBackwardHOP(HigherOrderOperator):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__("flex_attention_backward")
|
super().__init__("flex_attention_backward")
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -45,7 +45,7 @@ class OutDtypeOperator(HigherOrderOperator):
|
|||||||
3. Cast the output to `out_dtype`
|
3. Cast the output to `out_dtype`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__("out_dtype")
|
super().__init__("out_dtype")
|
||||||
# TODO(ydwu4): Subclassing HigherOrderOperator causes __module__ to
|
# TODO(ydwu4): Subclassing HigherOrderOperator causes __module__ to
|
||||||
# become different (torch._higher_order_ops.out_dtype) which will result
|
# become different (torch._higher_order_ops.out_dtype) which will result
|
||||||
|
@ -519,7 +519,7 @@ def identify_mutated_tensors(kernel, kwargs):
|
|||||||
|
|
||||||
# Used for wrapping a Triton Kernel
|
# Used for wrapping a Triton Kernel
|
||||||
class TritonKernelWrapperMutation(HigherOrderOperator):
|
class TritonKernelWrapperMutation(HigherOrderOperator):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__("triton_kernel_wrapper_mutation")
|
super().__init__("triton_kernel_wrapper_mutation")
|
||||||
|
|
||||||
|
|
||||||
@ -528,7 +528,7 @@ triton_kernel_wrapper_mutation = TritonKernelWrapperMutation()
|
|||||||
|
|
||||||
# Used for wrapping a Triton Kernel in a functional manner
|
# Used for wrapping a Triton Kernel in a functional manner
|
||||||
class TritonKernelWrapperFunctional(HigherOrderOperator):
|
class TritonKernelWrapperFunctional(HigherOrderOperator):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__("triton_kernel_wrapper_functional")
|
super().__init__("triton_kernel_wrapper_functional")
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_ten
|
|||||||
|
|
||||||
|
|
||||||
class WhileLoopOp(HigherOrderOperator):
|
class WhileLoopOp(HigherOrderOperator):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__("while_loop")
|
super().__init__("while_loop")
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -15,7 +15,7 @@ uid = itertools.count(1)
|
|||||||
|
|
||||||
# Used for testing the HigherOrderOperator mechanism
|
# Used for testing the HigherOrderOperator mechanism
|
||||||
class Wrap(HigherOrderOperator):
|
class Wrap(HigherOrderOperator):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__("wrap")
|
super().__init__("wrap")
|
||||||
|
|
||||||
def __call__(self, func, *args, **kwargs):
|
def __call__(self, func, *args, **kwargs):
|
||||||
@ -36,7 +36,7 @@ wrap = Wrap()
|
|||||||
|
|
||||||
|
|
||||||
class WrapWithSetGradEnabled(HigherOrderOperator):
|
class WrapWithSetGradEnabled(HigherOrderOperator):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__("wrap_with_set_grad_enabled")
|
super().__init__("wrap_with_set_grad_enabled")
|
||||||
|
|
||||||
def __call__(self, enable_grad, wrapped_func, *args, **kwargs):
|
def __call__(self, enable_grad, wrapped_func, *args, **kwargs):
|
||||||
@ -74,7 +74,7 @@ class WrapActivationCheckpoint(HigherOrderOperator):
|
|||||||
partitioners. See TagActivationCheckpoint for more information.
|
partitioners. See TagActivationCheckpoint for more information.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__("wrap_activation_checkpoint")
|
super().__init__("wrap_activation_checkpoint")
|
||||||
|
|
||||||
def __call__(self, function, *args, **kwargs):
|
def __call__(self, function, *args, **kwargs):
|
||||||
@ -113,7 +113,7 @@ class TagActivationCheckpoint(HigherOrderOperator):
|
|||||||
the forward and recomputed forward in backward.
|
the forward and recomputed forward in backward.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__("tag_activation_checkpoint")
|
super().__init__("tag_activation_checkpoint")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -1560,7 +1560,7 @@ class CSE:
|
|||||||
|
|
||||||
|
|
||||||
class CodeGen:
|
class CodeGen:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.exit_stack = contextlib.ExitStack()
|
self.exit_stack = contextlib.ExitStack()
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ class CppWrapperCuda(CppWrapperCpu):
|
|||||||
Generates cpp wrapper for running on GPU and calls CUDA kernels
|
Generates cpp wrapper for running on GPU and calls CUDA kernels
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.device = "cuda"
|
self.device = "cuda"
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.grid_id = count()
|
self.grid_id = count()
|
||||||
|
@ -1113,7 +1113,7 @@ class HelperFunctions:
|
|||||||
_templates_seen: Dict[str, str] # Template code to function name
|
_templates_seen: Dict[str, str] # Template code to function name
|
||||||
finalized_helpers: List[str]
|
finalized_helpers: List[str]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self._templates_seen = {}
|
self._templates_seen = {}
|
||||||
self.finalized_helpers = []
|
self.finalized_helpers = []
|
||||||
|
|
||||||
|
@ -589,7 +589,7 @@ def canonicalization_prefix():
|
|||||||
class FreeUnbackedSymbolsOpsHandler:
|
class FreeUnbackedSymbolsOpsHandler:
|
||||||
symbols: OrderedSet[sympy.Symbol]
|
symbols: OrderedSet[sympy.Symbol]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.symbols = OrderedSet()
|
self.symbols = OrderedSet()
|
||||||
|
|
||||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||||
|
@ -65,7 +65,7 @@ class SubgraphLoweringException(RuntimeError):
|
|||||||
|
|
||||||
|
|
||||||
class InvalidCxxCompiler(RuntimeError):
|
class InvalidCxxCompiler(RuntimeError):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
from . import config
|
from . import config
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@ -79,7 +79,7 @@ class NumpyCompatNormalization:
|
|||||||
inverse_mapping: Dict[str, str]
|
inverse_mapping: Dict[str, str]
|
||||||
cache: Dict["torch.fx.graph.Target", Set[str]]
|
cache: Dict["torch.fx.graph.Target", Set[str]]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.cache = {} # callable -> tuple of replaceable args e.g. ["axis"]
|
self.cache = {} # callable -> tuple of replaceable args e.g. ["axis"]
|
||||||
self.inverse_mapping = {}
|
self.inverse_mapping = {}
|
||||||
for actual_kwarg, numpy_kwargs in self.numpy_compat.items():
|
for actual_kwarg, numpy_kwargs in self.numpy_compat.items():
|
||||||
|
@ -1207,7 +1207,7 @@ if torch._C._has_mkldnn:
|
|||||||
Combine packed weight nodes with the same inputs to reduce memory usage.
|
Combine packed weight nodes with the same inputs to reduce memory usage.
|
||||||
for example:
|
for example:
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.linear = nn.Linear(32, 32, bias=True)
|
self.linear = nn.Linear(32, 32, bias=True)
|
||||||
|
|
||||||
|
@ -99,7 +99,7 @@ class CachedMetricsHelper:
|
|||||||
apply on a cache hit.
|
apply on a cache hit.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.cached_metrics = {}
|
self.cached_metrics = {}
|
||||||
for metric in get_metric_fields():
|
for metric in get_metric_fields():
|
||||||
self.cached_metrics[metric] = globals()[metric]
|
self.cached_metrics[metric] = globals()[metric]
|
||||||
|
@ -940,7 +940,7 @@ class IndentedBuffer:
|
|||||||
|
|
||||||
|
|
||||||
class FakeIndentedBuffer(IndentedBuffer):
|
class FakeIndentedBuffer(IndentedBuffer):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def __getattribute__(self, name):
|
def __getattribute__(self, name):
|
||||||
@ -1219,7 +1219,7 @@ class DebugDirManager:
|
|||||||
counter = itertools.count(0)
|
counter = itertools.count(0)
|
||||||
prev_debug_name: str
|
prev_debug_name: str
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.id = next(DebugDirManager.counter)
|
self.id = next(DebugDirManager.counter)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
@ -1268,7 +1268,7 @@ def get_code(fn, *args, **kwargs):
|
|||||||
class DummyModule:
|
class DummyModule:
|
||||||
"""This is empty to replace the generated triton module"""
|
"""This is empty to replace the generated triton module"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def call(self, *args, **kwargs):
|
def call(self, *args, **kwargs):
|
||||||
|
@ -7,7 +7,7 @@ from torch._lazy.device_context import get_device_context
|
|||||||
|
|
||||||
|
|
||||||
class ClosureHandler:
|
class ClosureHandler:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def run(self, closure):
|
def run(self, closure):
|
||||||
|
@ -42,7 +42,7 @@ class HasStaticMethodFromReal(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
class FakeClassRegistry:
|
class FakeClassRegistry:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self._registered_class: Dict[str, Any] = {}
|
self._registered_class: Dict[str, Any] = {}
|
||||||
|
|
||||||
def has_impl(self, full_qualname: str) -> bool:
|
def has_impl(self, full_qualname: str) -> bool:
|
||||||
|
@ -70,7 +70,7 @@ class PythonDispatcher:
|
|||||||
]
|
]
|
||||||
supported_keys = runtime_keys + alias_keys
|
supported_keys = runtime_keys + alias_keys
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
C._dispatch_check_invariants(self.name) # type: ignore[attr-defined]
|
C._dispatch_check_invariants(self.name) # type: ignore[attr-defined]
|
||||||
self.ref = C._dispatch_library("FRAGMENT", self.namespace, "")
|
self.ref = C._dispatch_library("FRAGMENT", self.namespace, "")
|
||||||
self.ref.def_("foo(Tensor x) -> Tensor")
|
self.ref.def_("foo(Tensor x) -> Tensor")
|
||||||
|
@ -60,7 +60,7 @@ def clone_inputs(args):
|
|||||||
|
|
||||||
|
|
||||||
class SchemaCheckMode(TorchDispatchMode):
|
class SchemaCheckMode(TorchDispatchMode):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
# Information recorded for testing purposes. For example:
|
# Information recorded for testing purposes. For example:
|
||||||
# - incorrect schemas
|
# - incorrect schemas
|
||||||
# - overly conservative schemas
|
# - overly conservative schemas
|
||||||
|
@ -36,7 +36,7 @@ class FloatFunctional(torch.nn.Module):
|
|||||||
- mul_scalar
|
- mul_scalar
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.activation_post_process = torch.nn.Identity()
|
self.activation_post_process = torch.nn.Identity()
|
||||||
|
|
||||||
@ -190,7 +190,7 @@ class QFunctional(torch.nn.Module):
|
|||||||
- mul_scalar
|
- mul_scalar
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = 1.0
|
self.scale = 1.0
|
||||||
self.zero_point = 0
|
self.zero_point = 0
|
||||||
|
@ -72,7 +72,7 @@ class QConfigMultiMapping:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
# initialize this with 1 QConfigMapping to avoid corner cases
|
# initialize this with 1 QConfigMapping to avoid corner cases
|
||||||
self.qconfig_mappings_list: List[QConfigMapping] = [QConfigMapping()]
|
self.qconfig_mappings_list: List[QConfigMapping] = [QConfigMapping()]
|
||||||
|
|
||||||
|
@ -99,7 +99,7 @@ from torch.ao.pruning._experimental.pruner import SaliencyPruner
|
|||||||
|
|
||||||
# Define model
|
# Define model
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.seq = nn.Sequential(
|
self.seq = nn.Sequential(
|
||||||
nn.Linear(700, 500, bias=True),
|
nn.Linear(700, 500, bias=True),
|
||||||
|
@ -85,7 +85,7 @@ class FakeQuantizeBase(ABC, Module):
|
|||||||
fake_quant_enabled: torch.Tensor
|
fake_quant_enabled: torch.Tensor
|
||||||
observer_enabled: torch.Tensor
|
observer_enabled: torch.Tensor
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
"""Set fake_quant_enabled and observer_enabled."""
|
"""Set fake_quant_enabled and observer_enabled."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# fake_quant_enabled and observer_enabled are buffers to support their
|
# fake_quant_enabled and observer_enabled are buffers to support their
|
||||||
|
@ -70,7 +70,7 @@ In the following, I’ll first have a detailed description for each step, and th
|
|||||||
|
|
||||||
```
|
```
|
||||||
class LinearReLUModule(torch.nn.Module):
|
class LinearReLUModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.linear = torch.nn.Linear(5, 10).float()
|
self.linear = torch.nn.Linear(5, 10).float()
|
||||||
self.relu = torch.nn.ReLU()
|
self.relu = torch.nn.ReLU()
|
||||||
|
@ -137,7 +137,7 @@ class DetectorBase(ABC):
|
|||||||
- Should return a str-based report and dict info in Tuple[str,Dict] format
|
- Should return a str-based report and dict info in Tuple[str,Dict] format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.detector_config_info = None
|
self.detector_config_info = None
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ class PrepareCustomConfig:
|
|||||||
.set_preserved_attributes(["attr1", "attr2"])
|
.set_preserved_attributes(["attr1", "attr2"])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.standalone_module_names: Dict[str, StandaloneModuleConfigEntry] = {}
|
self.standalone_module_names: Dict[str, StandaloneModuleConfigEntry] = {}
|
||||||
self.standalone_module_classes: Dict[Type, StandaloneModuleConfigEntry] = {}
|
self.standalone_module_classes: Dict[Type, StandaloneModuleConfigEntry] = {}
|
||||||
self.float_to_observed_mapping: Dict[QuantType, Dict[Type, Type]] = {}
|
self.float_to_observed_mapping: Dict[QuantType, Dict[Type, Type]] = {}
|
||||||
@ -382,7 +382,7 @@ class ConvertCustomConfig:
|
|||||||
.set_preserved_attributes(["attr1", "attr2"])
|
.set_preserved_attributes(["attr1", "attr2"])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.observed_to_quantized_mapping: Dict[QuantType, Dict[Type, Type]] = {}
|
self.observed_to_quantized_mapping: Dict[QuantType, Dict[Type, Type]] = {}
|
||||||
self.preserved_attributes: List[str] = []
|
self.preserved_attributes: List[str] = []
|
||||||
|
|
||||||
@ -477,7 +477,7 @@ class FuseCustomConfig:
|
|||||||
fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"])
|
fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.preserved_attributes: List[str] = []
|
self.preserved_attributes: List[str] = []
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
@ -1568,7 +1568,7 @@ class ReuseInputObserver(ObserverBase):
|
|||||||
Note: this is only enabled in FX Graph Mode Quantization
|
Note: this is only enabled in FX Graph Mode Quantization
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(torch.quint8, is_dynamic=False)
|
super().__init__(torch.quint8, is_dynamic=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -229,7 +229,7 @@ class QConfigMapping:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
# In increasing match priority:
|
# In increasing match priority:
|
||||||
self.global_qconfig: QConfigAny = None
|
self.global_qconfig: QConfigAny = None
|
||||||
self.object_type_qconfigs: OrderedDict[
|
self.object_type_qconfigs: OrderedDict[
|
||||||
|
@ -289,7 +289,7 @@ def prepare_fx(
|
|||||||
from torch.ao.quantization.quantize_fx import prepare_fx
|
from torch.ao.quantization.quantize_fx import prepare_fx
|
||||||
|
|
||||||
class Submodule(torch.nn.Module):
|
class Submodule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.linear = torch.nn.Linear(5, 5)
|
self.linear = torch.nn.Linear(5, 5)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -297,7 +297,7 @@ def prepare_fx(
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.linear = torch.nn.Linear(5, 5)
|
self.linear = torch.nn.Linear(5, 5)
|
||||||
self.sub = Submodule()
|
self.sub = Submodule()
|
||||||
@ -427,7 +427,7 @@ def prepare_qat_fx(
|
|||||||
from torch.ao.quantization.quantize_fx import prepare_qat_fx
|
from torch.ao.quantization.quantize_fx import prepare_qat_fx
|
||||||
|
|
||||||
class Submodule(torch.nn.Module):
|
class Submodule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.linear = torch.nn.Linear(5, 5)
|
self.linear = torch.nn.Linear(5, 5)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -435,7 +435,7 @@ def prepare_qat_fx(
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.linear = torch.nn.Linear(5, 5)
|
self.linear = torch.nn.Linear(5, 5)
|
||||||
self.sub = Submodule()
|
self.sub = Submodule()
|
||||||
|
@ -56,7 +56,7 @@ def prepare_pt2e(
|
|||||||
)
|
)
|
||||||
|
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.linear = torch.nn.Linear(5, 10)
|
self.linear = torch.nn.Linear(5, 10)
|
||||||
|
|
||||||
@ -129,7 +129,7 @@ def prepare_qat_pt2e(
|
|||||||
)
|
)
|
||||||
|
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.linear = torch.nn.Linear(5, 10)
|
self.linear = torch.nn.Linear(5, 10)
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ def get_embedding_operators_config() -> OperatorConfig:
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingQuantizer(Quantizer):
|
class EmbeddingQuantizer(Quantizer):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -436,7 +436,7 @@ class X86InductorQuantizer(Quantizer):
|
|||||||
supported_config_and_operators = _get_supported_config_and_operators()
|
supported_config_and_operators = _get_supported_config_and_operators()
|
||||||
module_function_to_aten_operator_type = _map_module_function_to_aten_operator_type()
|
module_function_to_aten_operator_type = _map_module_function_to_aten_operator_type()
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.global_config: Optional[QuantizationConfig] = None
|
self.global_config: Optional[QuantizationConfig] = None
|
||||||
self.operator_type_qconfig: Dict[
|
self.operator_type_qconfig: Dict[
|
||||||
|
@ -268,7 +268,7 @@ class XNNPACKQuantizer(Quantizer):
|
|||||||
"linear",
|
"linear",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.global_config: Optional[QuantizationConfig] = None
|
self.global_config: Optional[QuantizationConfig] = None
|
||||||
self.operator_type_config: Dict[
|
self.operator_type_config: Dict[
|
||||||
|
@ -513,7 +513,7 @@ def _get_path_of_module(
|
|||||||
Example::
|
Example::
|
||||||
|
|
||||||
>> class M(torch.nn.Module):
|
>> class M(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.linear = torch.nn.Linear(5, 5)
|
self.linear = torch.nn.Linear(5, 5)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.linear(x)
|
return self.linear(x)
|
||||||
|
@ -645,7 +645,7 @@ class FunctionEvent(FormattedTimesMixin):
|
|||||||
class FunctionEventAvg(FormattedTimesMixin):
|
class FunctionEventAvg(FormattedTimesMixin):
|
||||||
"""Used to average stats over multiple FunctionEvent objects."""
|
"""Used to average stats over multiple FunctionEvent objects."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.key: Optional[str] = None
|
self.key: Optional[str] = None
|
||||||
self.count: int = 0
|
self.count: int = 0
|
||||||
self.node_id: int = 0
|
self.node_id: int = 0
|
||||||
|
@ -266,7 +266,7 @@ class _Launcher:
|
|||||||
or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or \
|
or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or \
|
||||||
{expanduser('~')}/.local/lib/ so the LD_PRELOAD environment variable will not be set."
|
{expanduser('~')}/.local/lib/ so the LD_PRELOAD environment variable will not be set."
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.cpuinfo = _CPUinfo()
|
self.cpuinfo = _CPUinfo()
|
||||||
|
|
||||||
def add_lib_preload(self, lib_type):
|
def add_lib_preload(self, lib_type):
|
||||||
|
@ -77,17 +77,17 @@ namespace jit {
|
|||||||
*
|
*
|
||||||
* So why does debug handle map to DebugInfoTuple = {source range and inlined
|
* So why does debug handle map to DebugInfoTuple = {source range and inlined
|
||||||
* cs}? {debug_handle, source_range_tag, serialized_callstack} Take this
|
* cs}? {debug_handle, source_range_tag, serialized_callstack} Take this
|
||||||
* example: class L(nn.Module): def __init__(self):
|
* example: class L(nn.Module): def __init__(self) -> None:
|
||||||
* ...
|
* ...
|
||||||
* def forward(self, x):
|
* def forward(self, x):
|
||||||
* return x * 5
|
* return x * 5
|
||||||
* class M(nn.Module):
|
* class M(nn.Module):
|
||||||
* def __init__(self):
|
* def __init__(self) -> None:
|
||||||
* ...
|
* ...
|
||||||
* def forward(self, x):
|
* def forward(self, x):
|
||||||
* return x - 2
|
* return x - 2
|
||||||
* class N(nn.Module):
|
* class N(nn.Module):
|
||||||
* def __init__(self):
|
* def __init__(self) -> None:
|
||||||
* self.m = M()
|
* self.m = M()
|
||||||
* def forward(self, x):
|
* def forward(self, x):
|
||||||
* return self.m(x) + 3
|
* return self.m(x) + 3
|
||||||
|
@ -328,7 +328,7 @@ For example:
|
|||||||
|
|
||||||
```
|
```
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.a = torch.rand(2, 3)
|
self.a = torch.rand(2, 3)
|
||||||
self.b = torch.nn.Linear(10, 10)
|
self.b = torch.nn.Linear(10, 10)
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ When making changes to the operators, the first thing to identify is if it's BC/
|
|||||||
1. Add a test module in `test/jit/fixtures_srcs/fixtures_src.py`. In `test/jit/fixtures_srcs/generate_models.py`,
|
1. Add a test module in `test/jit/fixtures_srcs/fixtures_src.py`. In `test/jit/fixtures_srcs/generate_models.py`,
|
||||||
```
|
```
|
||||||
class TestVersionedLinspaceV7(torch.nn.Module):
|
class TestVersionedLinspaceV7(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex]):
|
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex]):
|
||||||
@ -163,7 +163,7 @@ When making changes to the operators, the first thing to identify is if it's BC/
|
|||||||
|
|
||||||
# Step 2. Write down how current module should look like
|
# Step 2. Write down how current module should look like
|
||||||
class MyModuleFloat(torch.nn.Module):
|
class MyModuleFloat(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, a, b: float):
|
def forward(self, a, b: float):
|
||||||
|
@ -25,7 +25,7 @@ namespace onnx {
|
|||||||
//
|
//
|
||||||
// clang-format off
|
// clang-format off
|
||||||
// class M(torch.nn.Module):
|
// class M(torch.nn.Module):
|
||||||
// def __init__(self):
|
// def __init__(self) -> None:
|
||||||
// super().__init__()
|
// super().__init__()
|
||||||
// self.lns = torch.nn.ModuleList([torch.nn.LayerNorm(3, eps = i) for i in range(2)])
|
// self.lns = torch.nn.ModuleList([torch.nn.LayerNorm(3, eps = i) for i in range(2)])
|
||||||
// self.celu1 = torch.nn.CELU(1.0)
|
// self.celu1 = torch.nn.CELU(1.0)
|
||||||
|
@ -17,7 +17,7 @@ torch._lazy.ts_backend.init()
|
|||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = nn.Conv2d(1, 32, 3, 1)
|
self.conv1 = nn.Conv2d(1, 32, 3, 1)
|
||||||
self.conv2 = nn.Conv2d(32, 64, 3, 1)
|
self.conv2 = nn.Conv2d(32, 64, 3, 1)
|
||||||
|
@ -135,7 +135,7 @@ Here's our model definition:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = nn.Conv2d(1, 32, 3, 1)
|
self.conv1 = nn.Conv2d(1, 32, 3, 1)
|
||||||
self.conv2 = nn.Conv2d(32, 64, 3, 1)
|
self.conv2 = nn.Conv2d(32, 64, 3, 1)
|
||||||
|
@ -163,7 +163,7 @@ class TensorInfo:
|
|||||||
|
|
||||||
|
|
||||||
class _TensorsAccessed:
|
class _TensorsAccessed:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.accesses: Dict[DataPtr, TensorInfo] = {}
|
self.accesses: Dict[DataPtr, TensorInfo] = {}
|
||||||
|
|
||||||
def ensure_tensor_exists(self, data_ptr: DataPtr) -> None:
|
def ensure_tensor_exists(self, data_ptr: DataPtr) -> None:
|
||||||
@ -218,7 +218,7 @@ class _TensorsAccessed:
|
|||||||
|
|
||||||
|
|
||||||
class StreamSynchronizations:
|
class StreamSynchronizations:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.current_sync_states: Dict[StreamId, Dict[StreamId, SeqNum]] = {}
|
self.current_sync_states: Dict[StreamId, Dict[StreamId, SeqNum]] = {}
|
||||||
self.recorded_sync_states: Dict[EventId, Dict[StreamId, SeqNum]] = {}
|
self.recorded_sync_states: Dict[EventId, Dict[StreamId, SeqNum]] = {}
|
||||||
self.host_sync_state: Dict[StreamId, SeqNum] = {}
|
self.host_sync_state: Dict[StreamId, SeqNum] = {}
|
||||||
@ -338,7 +338,7 @@ class EventHandler:
|
|||||||
data race.
|
data race.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.tensors_accessed = _TensorsAccessed()
|
self.tensors_accessed = _TensorsAccessed()
|
||||||
self.syncs = StreamSynchronizations()
|
self.syncs = StreamSynchronizations()
|
||||||
self.seq_num: SeqNum = 0
|
self.seq_num: SeqNum = 0
|
||||||
@ -478,7 +478,7 @@ def zip_arguments(
|
|||||||
|
|
||||||
|
|
||||||
class ArgumentHandler:
|
class ArgumentHandler:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.dataptrs_read: Set[DataPtr] = set()
|
self.dataptrs_read: Set[DataPtr] = set()
|
||||||
self.dataptrs_written: Set[DataPtr] = set()
|
self.dataptrs_written: Set[DataPtr] = set()
|
||||||
self.tensor_aliases: Dict[DataPtr, List[str]] = {}
|
self.tensor_aliases: Dict[DataPtr, List[str]] = {}
|
||||||
@ -527,7 +527,7 @@ class ArgumentHandler:
|
|||||||
|
|
||||||
|
|
||||||
class CUDASanitizerDispatchMode(TorchDispatchMode):
|
class CUDASanitizerDispatchMode(TorchDispatchMode):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.event_handler = EventHandler()
|
self.event_handler = EventHandler()
|
||||||
torch._C._activate_gpu_trace()
|
torch._C._activate_gpu_trace()
|
||||||
gpu_trace.register_callback_for_event_creation(
|
gpu_trace.register_callback_for_event_creation(
|
||||||
@ -596,7 +596,7 @@ class CUDASanitizer:
|
|||||||
This approach was deemed more elegant than using the atexit module.
|
This approach was deemed more elegant than using the atexit module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.dispatch = CUDASanitizerDispatchMode()
|
self.dispatch = CUDASanitizerDispatchMode()
|
||||||
self.enabled = False
|
self.enabled = False
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
|
|||||||
>>> import torch.nn as nn
|
>>> import torch.nn as nn
|
||||||
>>>
|
>>>
|
||||||
>>> class MyModel(nn.Module):
|
>>> class MyModel(nn.Module):
|
||||||
>>> def __init__(self):
|
>>> def __init__(self) -> None:
|
||||||
>>> super().__init__()
|
>>> super().__init__()
|
||||||
>>> self.l1 = nn.Linear(10, 10)
|
>>> self.l1 = nn.Linear(10, 10)
|
||||||
>>> self.l2 = nn.Linear(10, 10)
|
>>> self.l2 = nn.Linear(10, 10)
|
||||||
|
@ -47,7 +47,7 @@ def contract(state_cls: Type[_State] = _State):
|
|||||||
>>> import torch.nn as nn
|
>>> import torch.nn as nn
|
||||||
>>>
|
>>>
|
||||||
>>> class MyModel(nn.Module):
|
>>> class MyModel(nn.Module):
|
||||||
>>> def __init__(self):
|
>>> def __init__(self) -> None:
|
||||||
>>> super().__init__()
|
>>> super().__init__()
|
||||||
>>> self.l1 = nn.Linear(10, 10)
|
>>> self.l1 = nn.Linear(10, 10)
|
||||||
>>> self.l2 = nn.Linear(10, 10)
|
>>> self.l2 = nn.Linear(10, 10)
|
||||||
|
@ -43,7 +43,7 @@ logger = logging.getLogger("torch.distributed._composable.fsdp")
|
|||||||
class FSDPStateContext:
|
class FSDPStateContext:
|
||||||
"""This has state shared across FSDP states."""
|
"""This has state shared across FSDP states."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
# All FSDP states in the root state's module tree
|
# All FSDP states in the root state's module tree
|
||||||
self.all_states: List[FSDPState] = []
|
self.all_states: List[FSDPState] = []
|
||||||
# Iteration's forward root runs the once-per-forward logic; this root
|
# Iteration's forward root runs the once-per-forward logic; this root
|
||||||
@ -71,7 +71,7 @@ def disable_if_config_true(func):
|
|||||||
|
|
||||||
|
|
||||||
class FSDPState(_State):
|
class FSDPState(_State):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._fsdp_param_group: Optional[FSDPParamGroup] = None
|
self._fsdp_param_group: Optional[FSDPParamGroup] = None
|
||||||
self._is_root: Optional[bool] = None # root set during lazy init
|
self._is_root: Optional[bool] = None # root set during lazy init
|
||||||
|
@ -38,7 +38,7 @@ class ShardingPlan:
|
|||||||
|
|
||||||
>>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
|
>>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
|
||||||
>>> class MyModule(nn.Module):
|
>>> class MyModule(nn.Module):
|
||||||
>>> def __init__(self):
|
>>> def __init__(self) -> None:
|
||||||
>>> super().__init__()
|
>>> super().__init__()
|
||||||
>>> self.fc1 = nn.Linear()
|
>>> self.fc1 = nn.Linear()
|
||||||
>>> self.gelu = nn.GELU()
|
>>> self.gelu = nn.GELU()
|
||||||
|
@ -117,7 +117,7 @@ import torch.nn as nn
|
|||||||
from torch.distributed._tensor import Shard, distribute_tensor, distribute_module, init_device_mesh
|
from torch.distributed._tensor import Shard, distribute_tensor, distribute_module, init_device_mesh
|
||||||
|
|
||||||
class MyModule(nn.Module):
|
class MyModule(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fc1 = nn.Linear(8, 8)
|
self.fc1 = nn.Linear(8, 8)
|
||||||
self.fc2 = nn.Linear(8, 8)
|
self.fc2 = nn.Linear(8, 8)
|
||||||
|
@ -25,7 +25,7 @@ from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_modul
|
|||||||
|
|
||||||
|
|
||||||
class SimpleMLP(torch.nn.Module):
|
class SimpleMLP(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.net1 = torch.nn.Linear(5, 128)
|
self.net1 = torch.nn.Linear(5, 128)
|
||||||
self.relu = torch.nn.ReLU()
|
self.relu = torch.nn.ReLU()
|
||||||
|
@ -55,7 +55,7 @@ class Joinable(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._join_config = _JoinConfig.construct_disabled_join_config()
|
self._join_config = _JoinConfig.construct_disabled_join_config()
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ class InjectedException(Exception):
|
|||||||
|
|
||||||
|
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.net1 = nn.Linear(8, 32)
|
self.net1 = nn.Linear(8, 32)
|
||||||
self.net2 = nn.Linear(32, 128)
|
self.net2 = nn.Linear(32, 128)
|
||||||
|
@ -22,7 +22,7 @@ CHECKPOINT_DIR = f"~/{os.environ['LOGNAME']}/checkpoint"
|
|||||||
|
|
||||||
|
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
|
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
|
||||||
|
@ -434,7 +434,7 @@ class _reduce_op:
|
|||||||
:class:`~torch.distributed.ReduceOp` is recommended to use instead.
|
:class:`~torch.distributed.ReduceOp` is recommended to use instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
# __members__ is a dict storing key-value pairs for enum classes
|
# __members__ is a dict storing key-value pairs for enum classes
|
||||||
for k, v in ReduceOp.RedOpType.__members__.items():
|
for k, v in ReduceOp.RedOpType.__members__.items():
|
||||||
setattr(self, k, v)
|
setattr(self, k, v)
|
||||||
@ -568,7 +568,7 @@ class _World:
|
|||||||
of c10d and is subject to change..
|
of c10d and is subject to change..
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self._default_pg = None
|
self._default_pg = None
|
||||||
self._pg_coalesce_state: Dict[ProcessGroup, List[_CollOp]] = {}
|
self._pg_coalesce_state: Dict[ProcessGroup, List[_CollOp]] = {}
|
||||||
self._pg_default_device: Dict[ProcessGroup, torch.device] = {}
|
self._pg_default_device: Dict[ProcessGroup, torch.device] = {}
|
||||||
@ -2194,7 +2194,7 @@ class _IllegalWork(Work):
|
|||||||
|
|
||||||
|
|
||||||
class _CoalescingManager:
|
class _CoalescingManager:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.works: List[Work] = []
|
self.works: List[Work] = []
|
||||||
|
|
||||||
def append(self, work: Work):
|
def append(self, work: Work):
|
||||||
|
@ -106,7 +106,7 @@ class _FSDPDeviceHandle:
|
|||||||
|
|
||||||
|
|
||||||
class _UninitializedDeviceHandle(_FSDPDeviceHandle):
|
class _UninitializedDeviceHandle(_FSDPDeviceHandle):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __getattribute__(self, __name: str) -> Any:
|
def __getattribute__(self, __name: str) -> Any:
|
||||||
|
@ -156,7 +156,7 @@ class _RemoteModule(nn.Module):
|
|||||||
created outside of remote modules, rather than as submodules of any remote module (by calling ``add_module``).
|
created outside of remote modules, rather than as submodules of any remote module (by calling ``add_module``).
|
||||||
Hybrid Example:
|
Hybrid Example:
|
||||||
>>> class HybridModel(nn.Module):
|
>>> class HybridModel(nn.Module):
|
||||||
>>> def __init__(self):
|
>>> def __init__(self) -> None:
|
||||||
>>> nn.Module.__init__(self)
|
>>> nn.Module.__init__(self)
|
||||||
>>> self.remote_embedding = RemoteModule(...)
|
>>> self.remote_embedding = RemoteModule(...)
|
||||||
>>> self.local_linear = nn.Linear(...)
|
>>> self.local_linear = nn.Linear(...)
|
||||||
|
@ -248,7 +248,7 @@ class ExportGraphSignature:
|
|||||||
e.g. If following module is exported::
|
e.g. If following module is exported::
|
||||||
|
|
||||||
class CustomModule(nn.Module):
|
class CustomModule(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super(CustomModule, self).__init__()
|
super(CustomModule, self).__init__()
|
||||||
|
|
||||||
# Define a parameter
|
# Define a parameter
|
||||||
|
@ -45,7 +45,7 @@ FX’s front-end makes use of the dynamic nature of Python to intercept call-sit
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
class MyModule(torch.nn.Module):
|
class MyModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.param = torch.nn.Parameter(
|
self.param = torch.nn.Parameter(
|
||||||
torch.rand(3, 4))
|
torch.rand(3, 4))
|
||||||
|
@ -9,7 +9,7 @@ demonstration of these components in action:
|
|||||||
import torch
|
import torch
|
||||||
# Simple module for demonstration
|
# Simple module for demonstration
|
||||||
class MyModule(torch.nn.Module):
|
class MyModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
||||||
self.linear = torch.nn.Linear(4, 5)
|
self.linear = torch.nn.Linear(4, 5)
|
||||||
|
@ -1012,7 +1012,7 @@ class _PatchedFnSetAttr(_PatchedFn):
|
|||||||
|
|
||||||
|
|
||||||
class _Patcher:
|
class _Patcher:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.patches_made: List[_PatchedFn] = []
|
self.patches_made: List[_PatchedFn] = []
|
||||||
self.visited: Set[int] = set()
|
self.visited: Set[int] = set()
|
||||||
|
@ -63,7 +63,7 @@ class T(Constraint):
|
|||||||
"""
|
"""
|
||||||
True
|
True
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
@ -76,7 +76,7 @@ class F(Constraint):
|
|||||||
"""
|
"""
|
||||||
False
|
False
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
@ -117,7 +117,7 @@ if HAS_PYDOT:
|
|||||||
>>> # xdoctest: +REQUIRES(module:ubelt)
|
>>> # xdoctest: +REQUIRES(module:ubelt)
|
||||||
>>> # define module
|
>>> # define module
|
||||||
>>> class MyModule(torch.nn.Module):
|
>>> class MyModule(torch.nn.Module):
|
||||||
>>> def __init__(self):
|
>>> def __init__(self) -> None:
|
||||||
>>> super().__init__()
|
>>> super().__init__()
|
||||||
>>> self.linear = torch.nn.Linear(4, 5)
|
>>> self.linear = torch.nn.Linear(4, 5)
|
||||||
>>> def forward(self, x):
|
>>> def forward(self, x):
|
||||||
|
@ -83,7 +83,7 @@ def split_module(
|
|||||||
from torch.fx.passes.split_module import split_module
|
from torch.fx.passes.split_module import split_module
|
||||||
|
|
||||||
class MyModule(torch.nn.Module):
|
class MyModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
||||||
self.linear = torch.nn.Linear(4, 5)
|
self.linear = torch.nn.Linear(4, 5)
|
||||||
|
@ -83,7 +83,7 @@ def split_by_tags(
|
|||||||
Given the following module def:
|
Given the following module def:
|
||||||
|
|
||||||
class SimpleModule(torch.nn.Module):
|
class SimpleModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.linear1 = torch.nn.Linear(...)
|
self.linear1 = torch.nn.Linear(...)
|
||||||
self.linear2 = torch.nn.Linear(...)
|
self.linear2 = torch.nn.Linear(...)
|
||||||
|
@ -38,7 +38,7 @@ class Scope:
|
|||||||
return x.transpose(1, 2)
|
return x.transpose(1, 2)
|
||||||
|
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.sub = Sub()
|
self.sub = Sub()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -118,7 +118,7 @@ def replace_pattern(
|
|||||||
from torch.fx import symbolic_trace, subgraph_rewriter
|
from torch.fx import symbolic_trace, subgraph_rewriter
|
||||||
|
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x, w1, w2):
|
def forward(self, x, w1, w2):
|
||||||
|
@ -38,7 +38,7 @@ class _DynType:
|
|||||||
"""
|
"""
|
||||||
_DynType defines a type which stands for the absence of type information.
|
_DynType defines a type which stands for the absence of type information.
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.__name__ = '_DynType'
|
self.__name__ = '_DynType'
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
@ -219,7 +219,7 @@ def isinstance(obj, target_type):
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
class MyModule(torch.nn.Module):
|
class MyModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, input: Any): # note the Any type
|
def forward(self, input: Any): # note the Any type
|
||||||
@ -255,7 +255,7 @@ class strict_fusion:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
if not torch._jit_internal.is_scripting():
|
if not torch._jit_internal.is_scripting():
|
||||||
warnings.warn("Only works in script mode")
|
warnings.warn("Only works in script mode")
|
||||||
pass
|
pass
|
||||||
|
@ -73,7 +73,7 @@ def fork(func, *args, **kwargs):
|
|||||||
def forward(self, a: Tensor, b : int):
|
def forward(self, a: Tensor, b : int):
|
||||||
return a + b
|
return a + b
|
||||||
class Mod(torch.nn.Module):
|
class Mod(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super(self).__init__()
|
super(self).__init__()
|
||||||
self.mod = AddMod()
|
self.mod = AddMod()
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
@ -39,7 +39,7 @@ class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
|
|||||||
def fn(self):
|
def fn(self):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.x: List[int] = []
|
self.x: List[int] = []
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ def freeze(
|
|||||||
.. testcode::
|
.. testcode::
|
||||||
import torch
|
import torch
|
||||||
class MyModule2(torch.nn.Module):
|
class MyModule2(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.modified_tensor = torch.tensor(10.)
|
self.modified_tensor = torch.tensor(10.)
|
||||||
self.version = 1
|
self.version = 1
|
||||||
|
@ -89,7 +89,7 @@ if _IS_MONKEYTYPE_INSTALLED:
|
|||||||
self.traces.append(trace)
|
self.traces.append(trace)
|
||||||
|
|
||||||
class JitTypeTraceStore(CallTraceStore):
|
class JitTypeTraceStore(CallTraceStore):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# A dictionary keeping all collected CallTrace
|
# A dictionary keeping all collected CallTrace
|
||||||
# key is fully qualified name of called function
|
# key is fully qualified name of called function
|
||||||
@ -159,15 +159,15 @@ else:
|
|||||||
# When MonkeyType is not installed, we provide dummy class definitions
|
# When MonkeyType is not installed, we provide dummy class definitions
|
||||||
# for the below classes.
|
# for the below classes.
|
||||||
class JitTypeTraceStoreLogger: # type: ignore[no-redef]
|
class JitTypeTraceStoreLogger: # type: ignore[no-redef]
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class JitTypeTraceStore: # type: ignore[no-redef]
|
class JitTypeTraceStore: # type: ignore[no-redef]
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.trace_records = None
|
self.trace_records = None
|
||||||
|
|
||||||
class JitTypeTraceConfig: # type: ignore[no-redef]
|
class JitTypeTraceConfig: # type: ignore[no-redef]
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
monkeytype_trace = None # type: ignore[assignment] # noqa: F811
|
monkeytype_trace = None # type: ignore[assignment] # noqa: F811
|
||||||
|
@ -426,7 +426,7 @@ class ConcreteTypeStore:
|
|||||||
type_store: Dict[Type[Module], List[torch._C.ConcreteModuleType]]
|
type_store: Dict[Type[Module], List[torch._C.ConcreteModuleType]]
|
||||||
methods_compiled: Set[torch._C.ConcreteModuleType]
|
methods_compiled: Set[torch._C.ConcreteModuleType]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
# Python module type => List[ConcreteModuleType)]
|
# Python module type => List[ConcreteModuleType)]
|
||||||
self.type_store = {}
|
self.type_store = {}
|
||||||
# ConcreteTypes that have had their methods already compiled
|
# ConcreteTypes that have had their methods already compiled
|
||||||
|
@ -107,7 +107,7 @@ Attribute.__doc__ = """
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
class AttributeModule(torch.jit.ScriptModule):
|
class AttributeModule(torch.jit.ScriptModule):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.foo = torch.jit.Attribute(0.1, float)
|
self.foo = torch.jit.Attribute(0.1, float)
|
||||||
|
|
||||||
@ -138,7 +138,7 @@ Attribute.__doc__ = """
|
|||||||
class AttributeModule(torch.nn.Module):
|
class AttributeModule(torch.nn.Module):
|
||||||
names: Dict[str, int]
|
names: Dict[str, int]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.names = {}
|
self.names = {}
|
||||||
|
|
||||||
@ -522,7 +522,7 @@ if _enabled:
|
|||||||
"original_name",
|
"original_name",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
forward: Callable[..., Any] = _CachedForward() # type: ignore[assignment]
|
forward: Callable[..., Any] = _CachedForward() # type: ignore[assignment]
|
||||||
@ -1351,7 +1351,7 @@ def script(
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
class MyModule(nn.Module):
|
class MyModule(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# torch.jit.trace produces a ScriptModule's conv1 and conv2
|
# torch.jit.trace produces a ScriptModule's conv1 and conv2
|
||||||
self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
|
self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
|
||||||
@ -1374,7 +1374,7 @@ def script(
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
class MyModule(nn.Module):
|
class MyModule(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@torch.jit.export
|
@torch.jit.export
|
||||||
@ -1547,7 +1547,7 @@ def interface(obj):
|
|||||||
return x.relu()
|
return x.relu()
|
||||||
|
|
||||||
class Impl2(torch.nn.Module):
|
class Impl2(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.val = torch.rand(())
|
self.val = torch.rand(())
|
||||||
|
|
||||||
@ -1671,7 +1671,7 @@ class _ScriptProfileTable:
|
|||||||
|
|
||||||
|
|
||||||
class _ScriptProfile:
|
class _ScriptProfile:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.profile = classes.profiling._ScriptProfile()
|
self.profile = classes.profiling._ScriptProfile()
|
||||||
|
|
||||||
def enable(self):
|
def enable(self):
|
||||||
|
@ -19,7 +19,7 @@ class EnabledProxy:
|
|||||||
This is just a wrapper for a bool, so that we get reference semantics
|
This is just a wrapper for a bool, so that we get reference semantics
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.enabled = self.parse_env(
|
self.enabled = self.parse_env(
|
||||||
"PYTORCH_JIT", True, "> Using PyTorch JIT", "> PyTorch JIT DISABLED"
|
"PYTORCH_JIT", True, "> Using PyTorch JIT", "> PyTorch JIT DISABLED"
|
||||||
)
|
)
|
||||||
|
@ -966,7 +966,7 @@ def trace(
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Conv2d(1, 1, 3)
|
self.conv = nn.Conv2d(1, 1, 3)
|
||||||
|
|
||||||
@ -1182,7 +1182,7 @@ def trace_module(
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Conv2d(1, 1, 3)
|
self.conv = nn.Conv2d(1, 1, 3)
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ class StorageWeakRef:
|
|||||||
class SharedCache(dict):
|
class SharedCache(dict):
|
||||||
"""Dictionary from multiprocessing handles to StorageWeakRef."""
|
"""Dictionary from multiprocessing handles to StorageWeakRef."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
# free_dead_references() is called if the len exceeds the current
|
# free_dead_references() is called if the len exceeds the current
|
||||||
# limit. The limit scales with the number of remaining live objects.
|
# limit. The limit scales with the number of remaining live objects.
|
||||||
self.limit = 128
|
self.limit = 128
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user