Add basic mypy annotations to dynamo (#132415)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132415
Approved by: https://github.com/XuehaiPan, https://github.com/jamesjwu
This commit is contained in:
Oguz Ulgen
2024-08-01 08:53:32 -07:00
committed by PyTorch MergeBot
parent 3558a8cf4a
commit 6e79932543
35 changed files with 178 additions and 166 deletions

View File

@ -16,7 +16,7 @@ log = logging.getLogger(__name__)
class AotAutograd:
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
self.__name__ = "compiler_fn"
self.kwargs = kwargs

View File

@ -203,7 +203,7 @@ class ExplainOutput:
out_guards: Optional[List[_guards.Guard]] = None
compile_times: Optional[str] = None
def __str__(self):
def __str__(self) -> str:
output = f"Graph Count: {self.graph_count}\n"
output += f"Graph Break Count: {self.graph_break_count}\n"
output += f"Op Count: {self.op_count}\n"
@ -289,7 +289,7 @@ class ExplainWithBackend:
print(eb.output())
"""
def __init__(self, backend):
def __init__(self, backend) -> None:
from .registry import lookup_backend
self.backend = lookup_backend(backend)

View File

@ -131,7 +131,7 @@ def has_higher_order_op(gm):
# compile each of the partitioned submodules using the user-provided compiler
class SubmodCompiler(torch.fx.interpreter.Interpreter):
def __init__(self, module, compiler, fake_mode):
def __init__(self, module, compiler, fake_mode) -> None:
super().__init__(module)
self.compiler = compiler
self.fake_mode = fake_mode
@ -145,7 +145,7 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
assert len(kwargs) == 0, "We assume only args for these modules"
class WrapperModule(torch.nn.Module):
def __init__(self, submod, unwrap_singleton_tuple):
def __init__(self, submod, unwrap_singleton_tuple) -> None:
super().__init__()
self.submod = submod
self.unwrap_singleton_tuple = unwrap_singleton_tuple
@ -252,7 +252,7 @@ class SubmodCompiler(torch.fx.interpreter.Interpreter):
assert self.tc
torch._guards.TracingContext.try_get().fakify_first_call = True
def __del__(self):
def __del__(self) -> None:
self.tc.fakify_first_call = False
# For aot_eager and other backends, tracing context is not set
@ -362,7 +362,7 @@ class DDPOptimizer:
bucket_bytes_cap: int,
backend_compile_fn,
first_bucket_cap: Optional[int] = None,
):
) -> None:
if first_bucket_cap is not None:
self.first_bucket_cap = first_bucket_cap
elif torch.distributed.is_available():

View File

@ -51,7 +51,7 @@ class PyCodegen:
root: Optional[torch.nn.Module] = None,
graph_output_var: Optional[str] = None,
tempvars=None,
):
) -> None:
self.root = root
self.top_of_stack: Optional[VariableTracker] = None
self.uses: Counter[VariableTracker] = collections.Counter()

View File

@ -32,7 +32,7 @@ class ComptimeVar:
actual data in the Tensor is.)
"""
def __init__(self, v):
def __init__(self, v) -> None:
self.__variable = v
def as_proxy(self):
@ -128,7 +128,7 @@ class ComptimeVar:
"""
return self.__variable
def __repr__(self):
def __repr__(self) -> str:
return self.__variable.debug_repr()
# TODO: API for adding a custom guard
@ -141,7 +141,7 @@ class ComptimeContext:
file a feature request at https://github.com/pytorch/pytorch/
"""
def __init__(self, tx):
def __init__(self, tx) -> None:
self.__tx = tx
def get_local(self, name: str, *, stacklevel=0) -> ComptimeVar:

View File

@ -130,7 +130,9 @@ class DeviceGuard:
The device is switched using the provided device interface.
"""
def __init__(self, device_interface: Type[DeviceInterface], index: Optional[int]):
def __init__(
self, device_interface: Type[DeviceInterface], index: Optional[int]
) -> None:
self.device_interface = device_interface
self.idx = index
self.prev_idx = -1

View File

@ -158,7 +158,7 @@ class OptimizedModule(torch.nn.Module):
"named_children_walk",
}
def __init__(self, mod: torch.nn.Module, dynamo_ctx):
def __init__(self, mod: torch.nn.Module, dynamo_ctx) -> None:
super().__init__()
# Installs the params/buffer
self._orig_mod = mod
@ -218,7 +218,7 @@ class OptimizedModule(torch.nn.Module):
return self._modules["_orig_mod"]
return getattr(self._orig_mod, name)
def __setattr__(self, name, val):
def __setattr__(self, name, val) -> None:
# Allow patching over class attributes
if hasattr(type(self), name):
return super().__setattr__(name, val)
@ -304,7 +304,7 @@ class _TorchDynamoContext:
export=False,
dynamic=None,
compiler_config=None,
):
) -> None:
super().__init__()
assert callable(callback) or callback is False or callback is None
self.callback: DynamoCallback = callback
@ -539,7 +539,7 @@ class OptimizeContext(_TorchDynamoContext):
rebuild_ctx: Optional[
Callable[[], Union[OptimizeContext, _NullDecorator]]
] = None,
):
) -> None:
def on_enter():
install_generation_tagging_init()
@ -879,7 +879,7 @@ class FlattenInputOutputSignature(torch.fx.interpreter.Transformer):
example_fake_inputs: List[torch.Tensor],
flat_args_dynamic_dims: List[Set[int]],
fake_mode: Optional[fake_tensor.FakeTensorMode] = None,
):
) -> None:
super().__init__(m)
assert len(flat_args_dynamic_dims) == len(flat_args)

View File

@ -41,7 +41,7 @@ class InternalTorchDynamoError(TorchDynamoException):
class RestartAnalysis(TorchDynamoException):
restart_reason: str
def __init__(self, *args, restart_reason=None):
def __init__(self, *args, restart_reason=None) -> None:
self.restart_reason = restart_reason
super().__init__(*args)
@ -67,7 +67,7 @@ class TorchRuntimeError(TorchDynamoException):
class InvalidBackend(TorchDynamoException):
def __init__(self, name):
def __init__(self, name) -> None:
super().__init__(
f"Invalid backend: {name!r}, see `torch._dynamo.list_backends()` for available backends."
)
@ -86,7 +86,7 @@ class ResetRequired(TorchDynamoException):
class BackendCompilerFailed(TorchDynamoException):
def __init__(self, backend_fn, inner_exception):
def __init__(self, backend_fn, inner_exception) -> None:
self.backend_name = getattr(backend_fn, "__name__", "?")
self.inner_exception = inner_exception
msg = f"backend={self.backend_name!r} raised:\n{type(inner_exception).__name__}: {inner_exception}"
@ -94,7 +94,7 @@ class BackendCompilerFailed(TorchDynamoException):
class Unsupported(TorchDynamoException):
def __init__(self, msg, *, case_name=None):
def __init__(self, msg, *, case_name=None) -> None:
super().__init__(msg)
self.real_stack = torch._guards.TracingContext.extract_stack()
self.msg = msg
@ -118,12 +118,12 @@ class RecompileError(TorchDynamoException):
class ArgsMismatchError(Unsupported):
def __init__(self, msg):
def __init__(self, msg) -> None:
super().__init__(msg)
class AttributeMutationError(Unsupported):
def __init__(self, msg):
def __init__(self, msg) -> None:
super().__init__(msg)
@ -132,7 +132,7 @@ class CondOpArgsMismatchError(ArgsMismatchError):
Internal error from cond() due to arguments mismatch.
"""
def __init__(self, msg):
def __init__(self, msg) -> None:
super().__init__(msg)
@ -147,7 +147,7 @@ class UserErrorType(Enum):
class UserError(Unsupported):
def __init__(self, error_type: UserErrorType, msg, case_name=None):
def __init__(self, error_type: UserErrorType, msg, case_name=None) -> None:
"""
Type of errors that would be valid in Eager, but not supported in TorchDynamo.
The error message should tell user about next actions.
@ -191,7 +191,7 @@ class ObservedUserStopIteration(ObservedException):
# Reference `StopIteration_init` in CPython
# https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L568-L584
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__("unhandled `raise StopIteration`")
if len(args) > 0:
self.value = args[0]
@ -291,10 +291,10 @@ def warning(msg: str) -> None:
# KeyError has special handling for its args
# see https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L2534 for details
class KeyErrorMsg:
def __init__(self, value):
def __init__(self, value) -> None:
self.value = value
def __str__(self):
def __str__(self) -> str:
return str(self.value)
def __repr__(self) -> str:

View File

@ -79,7 +79,7 @@ class FakeBackwardCFunction:
self,
real: torch.autograd.function.BackwardCFunction,
saved_tensors: List[torch.Tensor],
):
) -> None:
self.real = real
self.saved_tensors = saved_tensors

View File

@ -38,7 +38,7 @@ class ProfileMetrics:
self.fusions / max(1, other.fusions),
)
def __str__(self):
def __str__(self) -> str:
return f"{self.operators:4.0%} ops {self.microseconds:4.0%} time"
def tocsv(self):
@ -46,7 +46,7 @@ class ProfileMetrics:
class ProfileResult:
def __init__(self, captured, total, unique_graphs):
def __init__(self, captured, total, unique_graphs) -> None:
self.captured: ProfileMetrics = captured or ProfileMetrics()
self.total: ProfileMetrics = total or ProfileMetrics()
self.unique_graphs: int = unique_graphs
@ -60,7 +60,7 @@ class ProfileResult:
def percent(self):
return self.captured / self.total
def __str__(self):
def __str__(self) -> str:
return (
f"{self.unique_graphs:2} graphs {self.captured.graphs:2} graph calls "
f"{self.captured.operators:4}/{self.total.operators:4} = "

View File

@ -629,7 +629,7 @@ def repro_analyze(options, mod, load_args):
assert not new_args
class WriterInterp(fx.Interpreter):
def __init__(self, mod, subdir):
def __init__(self, mod, subdir) -> None:
super().__init__(mod)
self.subdir = subdir

View File

@ -58,7 +58,7 @@ def _accuracy_fails(gm, example_inputs, compiler_fn):
class WrapBackendDebug:
def __init__(self, unconfigured_compiler_fn, compiler_name: str):
def __init__(self, unconfigured_compiler_fn, compiler_name: str) -> None:
functools.wraps(unconfigured_compiler_fn)(self)
self._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined]
self._compiler_name = compiler_name

View File

@ -684,7 +684,7 @@ def break_graph_if_unsupported(*, push):
class BytecodeDistpatchTableMeta(type):
"""Installs a `cls.dispatch_table` on every subclass to speed up calls to self.OPCODE()"""
def __init__(cls, name, bases, dct):
def __init__(cls, name, bases, dct) -> None:
super().__init__(name, bases, dct)
def _missing(opname, *args):
@ -2515,7 +2515,7 @@ class InstructionTranslatorBase(
inline_depth: int,
speculation_log: SpeculationLog,
distributed_state: Optional[DistributedState],
):
) -> None:
super().__init__()
self.speculation_log = speculation_log
self.distributed_state = distributed_state
@ -2621,7 +2621,7 @@ class InstructionTranslator(InstructionTranslatorBase):
frame_state,
speculation_log: SpeculationLog,
distributed_state: Optional[DistributedState],
):
) -> None:
_step_logger()(
logging.INFO,
f"torchdynamo start tracing {f_code.co_name} {code_options['co_filename']}:{code_options['co_firstlineno']}",
@ -3095,7 +3095,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
symbolic_globals: Dict[str, VariableTracker],
closure_cells: Dict[str, VariableTracker],
funcvar: BaseUserFunctionVariable,
):
) -> None:
f_globals = funcvar.get_globals() # type: ignore[attr-defined]
f_builtins = f_globals["__builtins__"]
if not isinstance(f_builtins, dict):
@ -3264,7 +3264,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
generated_items: List[VariableTracker]
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.generated_items = []

View File

@ -2929,7 +2929,9 @@ class FunctionIdSet:
function_ids: Optional[Set[int]] = None
function_names: Optional[Dict[int, str]] = None
def __init__(self, lazy_initializer: Callable[[], Union[Dict[int, str], Set[int]]]):
def __init__(
self, lazy_initializer: Callable[[], Union[Dict[int, str], Set[int]]]
) -> None:
self.lazy_initializer = lazy_initializer
def __call__(self):
@ -2957,7 +2959,7 @@ class FunctionIdSet:
if idx in function_ids:
function_ids.remove(idx)
def __contains__(self, idx: int):
def __contains__(self, idx: int) -> bool:
return idx in self()

View File

@ -31,7 +31,7 @@ class MutableLocalBase:
Base class for Variable.mutable_local
"""
def __init__(self, typ: MutableLocalSource):
def __init__(self, typ: MutableLocalSource) -> None:
# In HigherOrderOperator tracing, we need to distinguish
# between MutableLocals inside the HigherOrderOperator and
# ones outside it. For example, it is not safe to mutate
@ -110,7 +110,7 @@ class VariableTrackerMeta(type):
instance = instance.realize()
return type.__instancecheck__(cls, instance)
def __init__(cls, name, bases, attrs):
def __init__(cls, name, bases, attrs) -> None:
super().__init__(name, bases, attrs)
VariableTrackerMeta.all_subclasses.append(cls)
@ -173,7 +173,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
for subvalue in value.values():
cls.visit(fn, subvalue, cache)
def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
def debug_repr(self):
@ -365,7 +365,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
*,
source: Source = None,
mutable_local: MutableLocal = None,
):
) -> None:
super().__init__()
self.source = source
self.mutable_local = mutable_local

View File

@ -308,7 +308,7 @@ class VariableBuilder:
self,
tx,
source: Source,
):
) -> None:
assert (
source is not None
), "Consider SourcelessBuilder for ephemeral objects, usually objects created locally."

View File

@ -637,11 +637,11 @@ class BuiltinVariable(VariableTracker):
def can_insert_in_graph(self):
return self.fn in self._fx_graph_functions()
def __init__(self, fn, **kwargs):
def __init__(self, fn, **kwargs) -> None:
super().__init__(**kwargs)
self.fn = fn
def __str__(self):
def __str__(self) -> str:
if self.fn is None:
name = "None"
else:

View File

@ -64,7 +64,7 @@ class ConstantVariable(VariableTracker):
return ConstantVariable(value, **kwargs)
def __init__(self, value, **kwargs):
def __init__(self, value, **kwargs) -> None:
super().__init__(**kwargs)
if not ConstantVariable.is_literal(value):
for disallowed_type, reason in _type_to_assert_reason.items():
@ -81,7 +81,7 @@ class ConstantVariable(VariableTracker):
def as_proxy(self):
return self.value
def __str__(self):
def __str__(self) -> str:
return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})"
def python_type(self):
@ -211,7 +211,7 @@ class ConstantVariable(VariableTracker):
class EnumVariable(VariableTracker):
def __init__(self, value, **kwargs):
def __init__(self, value, **kwargs) -> None:
super().__init__(**kwargs)
self.value = value
@ -226,7 +226,7 @@ class EnumVariable(VariableTracker):
def as_proxy(self):
return self.value
def __str__(self):
def __str__(self) -> str:
return f"EnumVariable({type(self.value)})"
def python_type(self):

View File

@ -63,7 +63,9 @@ class ContextWrappingVariable(VariableTracker):
*VariableTracker._nonvar_fields,
}
def __init__(self, target_values, initial_values=None, *, state=None, **kwargs):
def __init__(
self, target_values, initial_values=None, *, state=None, **kwargs
) -> None:
super().__init__(**kwargs)
self.target_values = target_values
self.initial_values = initial_values
@ -127,7 +129,7 @@ class ContextWrappingVariable(VariableTracker):
class GenericContextWrappingVariable(UserDefinedObjectVariable):
# Some methods in ContextWrappingVariable assumes the arguments are
# python contants. Which might not always be the case here.
def __init__(self, cm_obj, **kwargs):
def __init__(self, cm_obj, **kwargs) -> None:
assert cm_obj is not None
super().__init__(
value=cm_obj,
@ -389,7 +391,7 @@ class CatchWarningsCtxManagerVariable(ContextWrappingVariable):
initial_values=None,
)
def __init__(self, catch_warnings_args, **kwargs):
def __init__(self, catch_warnings_args, **kwargs) -> None:
assert isinstance(catch_warnings_args, dict), catch_warnings_args
super().__init__(**kwargs)
self.catch_warnings_args = catch_warnings_args
@ -465,7 +467,9 @@ class GradModeVariable(ContextWrappingVariable):
var._call_func(tx, var.target_values)
return var
def __init__(self, target_values, initial_values=None, initialized=True, **kwargs):
def __init__(
self, target_values, initial_values=None, initialized=True, **kwargs
) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
@ -518,7 +522,7 @@ class InferenceModeVariable(ContextWrappingVariable):
target_values,
initial_values=None,
**kwargs,
):
) -> None:
if initial_values is None:
# This must be called here since function defaults are evaluated at import time
initial_values = torch.is_inference_mode_enabled()
@ -572,7 +576,7 @@ class TorchFunctionDisableVariable(ContextWrappingVariable):
var.set_cleanup_hook(tx)
return var
def __init__(self, target_values, initial_values=None, **kwargs):
def __init__(self, target_values, initial_values=None, **kwargs) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
@ -604,7 +608,7 @@ class DeterministicAlgorithmsVariable(ContextWrappingVariable):
var.set_cleanup_hook(tx)
return var
def __init__(self, target_values, initial_values=None, **kwargs):
def __init__(self, target_values, initial_values=None, **kwargs) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
@ -644,7 +648,7 @@ class DisabledSavedTensorsHooksVariable(ContextWrappingVariable):
var.set_cleanup_hook(tx)
return var
def __init__(self, target_values, initial_values=None, **kwargs):
def __init__(self, target_values, initial_values=None, **kwargs) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
@ -713,7 +717,7 @@ class AutocastModeVariable(ContextWrappingVariable):
var = AutocastModeVariable(target_values, initial_values=None, **kwargs)
return var
def __init__(self, target_values, initial_values=None, **kwargs):
def __init__(self, target_values, initial_values=None, **kwargs) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
@ -746,7 +750,7 @@ class NullContextVariable(ContextWrappingVariable):
support yet, e.g, torch.autograd.profiler.record_function.
"""
def __init__(self, target_values=None, **kwargs):
def __init__(self, target_values=None, **kwargs) -> None:
super().__init__(target_values=target_values, **kwargs)
def enter(self, tx):
@ -787,7 +791,7 @@ class StreamContextVariable(ContextWrappingVariable):
**kwargs,
)
def __init__(self, target_values, device, initial_values=None, **kwargs):
def __init__(self, target_values, device, initial_values=None, **kwargs) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
@ -840,7 +844,7 @@ class PreserveVersionContextVariable(ContextWrappingVariable):
)
)
def __init__(self, tensor, prev_version, **kwargs):
def __init__(self, tensor, prev_version, **kwargs) -> None:
kwargs.setdefault("target_values", None)
super().__init__(**kwargs)
self.tensor = tensor
@ -875,7 +879,9 @@ class FSDPParamGroupUseTrainingStateVariable(ContextWrappingVariable):
)
return var
def __init__(self, param_group_var, target_values, initial_values=None, **kwargs):
def __init__(
self, param_group_var, target_values, initial_values=None, **kwargs
) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
@ -922,7 +928,7 @@ class FSDPParamGroupUseTrainingStateVariable(ContextWrappingVariable):
class StreamVariable(VariableTracker):
def __init__(self, proxy, value, device, **kwargs):
def __init__(self, proxy, value, device, **kwargs) -> None:
if proxy is not None and "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == value
assert (
@ -995,7 +1001,7 @@ class StreamVariable(VariableTracker):
class EventVariable(VariableTracker):
def __init__(self, proxy, value, **kwargs):
def __init__(self, proxy, value, **kwargs) -> None:
if proxy is not None and "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == value
super().__init__(**kwargs)
@ -1043,7 +1049,7 @@ class WithExitFunctionVariable(VariableTracker):
ctx: Union[ContextWrappingVariable, GenericContextWrappingVariable],
target,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
assert isinstance(
ctx, (ContextWrappingVariable, GenericContextWrappingVariable)

View File

@ -72,7 +72,7 @@ class ConstDictVariable(VariableTracker):
Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing
"""
def __init__(self, vt):
def __init__(self, vt) -> None:
# We specialize SymNodes
vt = specialize_symnode(vt)
# TODO Temorarily remove to figure out what keys are we breaking on
@ -129,7 +129,7 @@ class ConstDictVariable(VariableTracker):
def __init__(
self, items: Dict[VariableTracker, VariableTracker], user_cls=dict, **kwargs
):
) -> None:
super().__init__(**kwargs)
Hashable = ConstDictVariable._HashableTracker
@ -171,7 +171,7 @@ class ConstDictVariable(VariableTracker):
def python_type(self):
return self.user_cls
def __contains__(self, vt):
def __contains__(self, vt) -> bool:
assert isinstance(vt, VariableTracker)
Hashable = ConstDictVariable._HashableTracker
return (
@ -344,7 +344,7 @@ class ConstDictVariable(VariableTracker):
class DefaultDictVariable(ConstDictVariable):
def __init__(self, items, user_cls, default_factory=None, **kwargs):
def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None:
super().__init__(items, user_cls, **kwargs)
assert user_cls is collections.defaultdict
self.default_factory = default_factory
@ -400,7 +400,7 @@ class SetVariable(ConstDictVariable):
self,
items: List[VariableTracker],
**kwargs,
):
) -> None:
items = dict.fromkeys(items, SetVariable._default_value())
super().__init__(items, **kwargs)
@ -511,7 +511,7 @@ class DictView(VariableTracker):
kv: Optional[str] = None
def __init__(self, dv_dict: ConstDictVariable, **kwargs):
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
super().__init__(**kwargs)
assert self.kv in ("keys", "values")
assert isinstance(dv_dict, ConstDictVariable)
@ -745,7 +745,7 @@ class CustomizedDictVariable(ConstDictVariable):
items[key] = var
return cls(items, user_cls)
def __init__(self, items, user_cls, **options):
def __init__(self, items, user_cls, **options) -> None:
super().__init__(items, user_cls, **options)
assert self.is_matching_cls(user_cls)
@ -872,7 +872,7 @@ class HFPretrainedConfigVariable(VariableTracker):
def is_matching_object(cls, obj):
return cls.is_matching_cls(type(obj))
def __init__(self, obj, **kwargs):
def __init__(self, obj, **kwargs) -> None:
super().__init__(**kwargs)
self.obj = obj
assert self.is_matching_cls(type(obj))

View File

@ -32,7 +32,7 @@ class DistributedVariable(VariableTracker):
and hold the tracking value for the corresponding distributed object.
"""
def __init__(self, value, **kwargs):
def __init__(self, value, **kwargs) -> None:
super().__init__(**kwargs)
if not DistributedVariable.is_available():
unimplemented("torch.distributed package is not available!")
@ -362,7 +362,7 @@ class BackwardHookVariable(VariableTracker):
user_hooks: VariableTracker,
user_pre_hooks: VariableTracker,
**options,
):
) -> None:
super().__init__(**options)
self.proxy = proxy
self.module = module

View File

@ -139,7 +139,7 @@ class UserFunctionVariable(BaseUserFunctionVariable):
source=source,
)
def __init__(self, fn, is_constant=False, **kwargs):
def __init__(self, fn, is_constant=False, **kwargs) -> None:
super().__init__(**kwargs)
if getattr(fn, "_dynamo_marked_constant", False):
# This method should be treated as a constant for the purposes of compilation
@ -325,11 +325,11 @@ class UserFunctionVariable(BaseUserFunctionVariable):
class UserMethodVariable(UserFunctionVariable):
"""Some unsupported user-defined method"""
def __init__(self, fn, obj, **kwargs):
def __init__(self, fn, obj, **kwargs) -> None:
super().__init__(fn=fn, **kwargs)
self.obj = obj
def __str__(self):
def __str__(self) -> str:
return f"{self.__class__.__name__}({self.fn}, {self.obj})"
def self_args(self):
@ -387,7 +387,7 @@ class UserMethodVariable(UserFunctionVariable):
class WrappedUserMethodVariable(UserMethodVariable):
def __init__(self, wrapped, context, **kwargs):
def __init__(self, wrapped, context, **kwargs) -> None:
kwargs.pop("fn", None)
kwargs.pop("obj", None)
super().__init__(wrapped.fn, wrapped.obj, **kwargs)
@ -407,7 +407,7 @@ class WrappedUserMethodVariable(UserMethodVariable):
class WrappedUserFunctionVariable(UserFunctionVariable):
def __init__(self, wrapped, context, **kwargs):
def __init__(self, wrapped, context, **kwargs) -> None:
kwargs.pop("fn", None)
kwargs.pop("obj", None)
super().__init__(wrapped.fn, **kwargs)
@ -461,7 +461,7 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
closure_scope,
wrapped_reconstructible=None,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
assert isinstance(fn_name.as_python_constant(), str)
assert isinstance(code.as_python_constant(), types.CodeType)
@ -619,7 +619,7 @@ class SkipFunctionVariable(VariableTracker):
*VariableTracker._nonvar_fields,
}
def __init__(self, value, reason=None, **kwargs):
def __init__(self, value, reason=None, **kwargs) -> None:
super().__init__(**kwargs)
self.value = value
self.reason = reason
@ -800,7 +800,7 @@ class CollectiveFunctionRewriteVariable(UserFunctionVariable):
than status-quo as we currently graph-break on all distributed.* collectives.
"""
def __init__(self, fn, *, replacement_var, **kwargs):
def __init__(self, fn, *, replacement_var, **kwargs) -> None:
super().__init__(fn, **kwargs)
assert isinstance(replacement_var, UserFunctionVariable)
self.replacement_var = replacement_var
@ -869,7 +869,7 @@ class CollectiveFunctionRewriteVariable(UserFunctionVariable):
class FunctoolsPartialVariable(VariableTracker):
def __init__(self, func: VariableTracker, args, keywords, **kwargs):
def __init__(self, func: VariableTracker, args, keywords, **kwargs) -> None:
super().__init__(**kwargs)
self.func = func
assert isinstance(args, list)
@ -1006,7 +1006,7 @@ dynamo_triton_hopifier_singleton = DynamoTritonHOPifier()
class TritonKernelVariable(VariableTracker):
def __init__(self, kernel, kernel_idx, grid, **kwargs):
def __init__(self, kernel, kernel_idx, grid, **kwargs) -> None:
super().__init__(**kwargs)
dynamo_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid)

View File

@ -559,7 +559,7 @@ def add_subgraph(tx: "InstructionTranslator", name, gm):
class TorchHigherOrderOperatorVariable(VariableTracker):
def __init__(
self, value: HigherOrderOperator, source: Optional[Source] = None, **kwargs
):
) -> None:
super().__init__(**kwargs)
self.value = value
self.source = source
@ -810,7 +810,7 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
class CallTorchbindHigherOrderVariable(TorchHigherOrderOperatorVariable):
def __init__(self, hop, source, script_obj_var, method_name):
def __init__(self, hop, source, script_obj_var, method_name) -> None:
super().__init__(hop, source)
self.script_obj_var = script_obj_var
self.method_name = method_name
@ -1750,7 +1750,7 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
class AutogradFunctionApplyVariable(VariableTracker):
def __init__(self, fwd_graph, bwd_graph, parent_source, **kwargs):
def __init__(self, fwd_graph, bwd_graph, parent_source, **kwargs) -> None:
super().__init__(**kwargs)
self.fwd_graph = fwd_graph
self.bwd_graph = bwd_graph

View File

@ -26,11 +26,11 @@ MAX_CYCLE = 3000
class ItertoolsVariable(VariableTracker):
def __init__(self, value, **kwargs):
def __init__(self, value, **kwargs) -> None:
super().__init__(**kwargs)
self.value = value
def __repr__(self):
def __repr__(self) -> str:
return f"ItertoolsVariable({self.value})"
def python_type(self):
@ -206,7 +206,7 @@ class ItertoolsVariable(VariableTracker):
class IteratorVariable(VariableTracker):
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
def next_variable(self, tx):
@ -233,7 +233,7 @@ class IteratorVariable(VariableTracker):
class RepeatIteratorVariable(IteratorVariable):
def __init__(self, item: VariableTracker, **kwargs):
def __init__(self, item: VariableTracker, **kwargs) -> None:
super().__init__(**kwargs)
self.item = item
@ -255,7 +255,7 @@ class RepeatIteratorVariable(IteratorVariable):
class CountIteratorVariable(IteratorVariable):
def __init__(self, item: int = 0, step: int = 1, **kwargs):
def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None:
super().__init__(**kwargs)
if not isinstance(item, VariableTracker):
item = ConstantVariable.create(item)
@ -293,7 +293,7 @@ class CycleIteratorVariable(IteratorVariable):
saved_index: int = 0,
item: Optional[VariableTracker] = None,
**kwargs,
):
) -> None:
if saved is None:
saved = []
super().__init__(**kwargs)
@ -346,7 +346,7 @@ class ZipVariable(IteratorVariable):
iterables: List[Union[List[VariableTracker], VariableTracker]],
strict: bool = False,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
assert isinstance(iterables, list)
# can be list[Variable] or VariableTracker (with next_variable implemented)
@ -459,7 +459,7 @@ class MapVariable(ZipVariable):
fn: VariableTracker,
iterables: List[Union[List[VariableTracker], VariableTracker]],
**kwargs,
):
) -> None:
super().__init__(iterables, **kwargs)
self.fn = fn
@ -493,7 +493,7 @@ class EnumerateVariable(ZipVariable):
iterable: Union[List[VariableTracker], VariableTracker],
start: int = 0,
**kwargs,
):
) -> None:
super().__init__(
[CountIteratorVariable(start, mutable_local=MutableLocal()), iterable],
**kwargs,

View File

@ -10,7 +10,7 @@ from .tensor import SymNodeVariable
class LazyCache:
"""Container to cache the real VariableTracker"""
def __init__(self, value, source):
def __init__(self, value, source) -> None:
if not isinstance(value, LazySymNodeFormatString):
assert source
self.value = value
@ -52,7 +52,7 @@ class LazyVariableTracker(VariableTracker):
def create(value, source, **options):
return LazyVariableTracker(LazyCache(value, source), source=source, **options)
def __init__(self, _cache, **kwargs):
def __init__(self, _cache, **kwargs) -> None:
assert isinstance(_cache, LazyCache)
super().__init__(**kwargs)
self._cache = _cache
@ -79,7 +79,7 @@ class LazyVariableTracker(VariableTracker):
self.realize()
return VariableTracker.clone(self.unwrap(), **kwargs)
def __str__(self):
def __str__(self) -> str:
if self.is_realized():
return self.unwrap().__str__()
return VariableTracker.__str__(self.unwrap())
@ -135,7 +135,7 @@ class LazyVariableTracker(VariableTracker):
class LazySymNodeFormatString:
def __init__(
self, sym_node_variable: SymNodeVariable, fmt_spec_var: VariableTracker
):
) -> None:
from .constant import ConstantVariable
self.sym_node_var = sym_node_variable
@ -143,7 +143,7 @@ class LazySymNodeFormatString:
"{:" + fmt_spec_var.as_python_constant() + "}"
)
def __str__(self):
def __str__(self) -> str:
return str.format(
self.fmt_var.as_python_constant(),
str(self.sym_node_var.evaluate_expr()),

View File

@ -61,7 +61,7 @@ class BaseListVariable(VariableTracker):
self,
items: List[VariableTracker],
**kwargs,
):
) -> None:
super().__init__(**kwargs)
assert isinstance(items, list)
assert all(isinstance(x, VariableTracker) for x in items)
@ -157,7 +157,7 @@ class BaseListVariable(VariableTracker):
class RangeVariable(BaseListVariable):
def __init__(self, items, **kwargs):
def __init__(self, items, **kwargs) -> None:
items_to_map = items
start = variables.ConstantVariable.create(0)
stop = None
@ -401,7 +401,7 @@ class ListVariable(CommonListMethodsVariable):
def python_type(self):
return list
def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={len(self.items)})"
def debug_repr(self):
@ -558,7 +558,7 @@ class SizeVariable(TupleVariable):
items: List[VariableTracker],
proxy: Optional[torch.fx.Proxy] = None,
**kwargs,
):
) -> None:
self.proxy = proxy
super().__init__(items, **kwargs)
@ -694,7 +694,7 @@ class NamedTupleVariable(TupleVariable):
*TupleVariable._nonvar_fields,
}
def __init__(self, items, tuple_cls, **kwargs):
def __init__(self, items, tuple_cls, **kwargs) -> None:
super().__init__(items, **kwargs)
self.tuple_cls = tuple_cls
@ -753,7 +753,7 @@ class NamedTupleVariable(TupleVariable):
class SliceVariable(BaseListVariable):
def __init__(self, items, **kwargs):
def __init__(self, items, **kwargs) -> None:
items_to_map = items
start, stop, step = [variables.ConstantVariable.create(None)] * 3
@ -802,7 +802,7 @@ class ListIteratorVariable(IteratorVariable):
*IteratorVariable._nonvar_fields,
}
def __init__(self, items, index: int = 0, **kwargs):
def __init__(self, items, index: int = 0, **kwargs) -> None:
super().__init__(**kwargs)
assert isinstance(items, list)
# Removing this check as it slows things down too much
@ -812,7 +812,7 @@ class ListIteratorVariable(IteratorVariable):
self.items = items
self.index = index
def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})"
def next_variable(self, tx):
@ -916,7 +916,9 @@ class RestrictedListSubclassVariable(ListVariable):
def is_matching_cls(cls, user_cls: type):
return cls._is_non_conflicting_subclass(user_cls, list)
def __init__(self, items, *, user_cls: type, user_cls_source: Source, **kwargs):
def __init__(
self, items, *, user_cls: type, user_cls_source: Source, **kwargs
) -> None:
super().__init__(items=items, **kwargs)
self.user_cls = user_cls
self.user_cls_source = user_cls_source

View File

@ -48,7 +48,7 @@ class SuperVariable(VariableTracker):
*VariableTracker._nonvar_fields,
}
def __init__(self, typevar, objvar=None, specialized=False, **kwargs):
def __init__(self, typevar, objvar=None, specialized=False, **kwargs) -> None:
super().__init__(**kwargs)
# typevar is the fist argument to super(). In the case where no argument
# is provided to super(), it is the __class__ object where
@ -209,7 +209,7 @@ class SuperVariable(VariableTracker):
class ExceptionVariable(VariableTracker):
def __init__(self, exc_type, args, **kwargs):
def __init__(self, exc_type, args, **kwargs) -> None:
super().__init__(**kwargs)
self.exc_type = exc_type
self.args = args
@ -301,7 +301,7 @@ class ClosureVariable(UnknownVariable):
*UnknownVariable._nonvar_fields,
}
def __init__(self, name, **kwargs):
def __init__(self, name, **kwargs) -> None:
super().__init__(**kwargs)
self.name = name
@ -316,7 +316,7 @@ class InlinedClosureVariable(UnknownVariable):
*UnknownVariable._nonvar_fields,
}
def __init__(self, name, **kwargs):
def __init__(self, name, **kwargs) -> None:
super().__init__(**kwargs)
self.name = name
@ -325,12 +325,12 @@ class InlinedClosureVariable(UnknownVariable):
class NewCellVariable(VariableTracker):
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
class NewGlobalVariable(VariableTracker):
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
@ -350,7 +350,7 @@ class InspectSignatureVariable(VariableTracker):
callable, mutable_local=variables.base.MutableLocal()
)
def __init__(self, inspected: VariableTracker, **kwargs):
def __init__(self, inspected: VariableTracker, **kwargs) -> None:
super().__init__(**kwargs)
self.inspected = inspected
@ -574,7 +574,7 @@ class AutogradFunctionVariable(VariableTracker):
*VariableTracker._nonvar_fields,
}
def __init__(self, fn_cls, **kwargs):
def __init__(self, fn_cls, **kwargs) -> None:
super().__init__(**kwargs)
self.fn_cls = fn_cls
@ -767,7 +767,7 @@ class AutogradFunctionContextVariable(UserDefinedObjectVariable):
saved_tensors=None,
needs_input_grad=None,
**kwargs,
):
) -> None:
super().__init__(value=value, value_type=value_type, **kwargs)
self.inference = inference
self.proxy = proxy
@ -862,7 +862,7 @@ class AutogradEngineVariable(UserDefinedObjectVariable):
value,
value_type=None,
**kwargs,
):
) -> None:
super().__init__(value=value, value_type=value_type, **kwargs)
def call_method(
@ -894,7 +894,7 @@ class AutogradEngineVariable(UserDefinedObjectVariable):
class LambdaVariable(VariableTracker):
def __init__(self, fn, **kwargs):
def __init__(self, fn, **kwargs) -> None:
super().__init__(**kwargs)
self.fn = fn
@ -913,14 +913,14 @@ class GetAttrVariable(VariableTracker):
*VariableTracker._nonvar_fields,
}
def __init__(self, obj, name, **kwargs):
def __init__(self, obj, name, **kwargs) -> None:
super().__init__(**kwargs)
assert isinstance(obj, VariableTracker)
assert isinstance(name, str)
self.obj = obj
self.name = name
def __str__(self):
def __str__(self) -> str:
return f"{self.__class__.__name__}({self.obj}, {self.name})"
@staticmethod
@ -1013,7 +1013,7 @@ class GetAttrVariable(VariableTracker):
class MethodWrapperVariable(VariableTracker):
def __init__(self, method_wrapper, **kwargs):
def __init__(self, method_wrapper, **kwargs) -> None:
super().__init__(**kwargs)
self.method_wrapper = method_wrapper
@ -1040,7 +1040,7 @@ class MethodWrapperVariable(VariableTracker):
class GetSetDescriptorVariable(VariableTracker):
def __init__(self, desc, **kwargs):
def __init__(self, desc, **kwargs) -> None:
super().__init__(**kwargs)
self.desc = desc
@ -1068,7 +1068,7 @@ class PythonModuleVariable(VariableTracker):
*VariableTracker._nonvar_fields,
}
def __init__(self, value: types.ModuleType, **kwargs):
def __init__(self, value: types.ModuleType, **kwargs) -> None:
super().__init__(**kwargs)
self.value = value
self.is_torch = self.value is torch or self.value.__name__.startswith("torch.")
@ -1079,7 +1079,7 @@ class PythonModuleVariable(VariableTracker):
def as_python_constant(self):
return self.value
def __repr__(self):
def __repr__(self) -> str:
return f"PythonModuleVariable({self.value})"
def call_hasattr(self, tx: "InstructionTranslator", name):
@ -1104,7 +1104,7 @@ class PythonModuleVariable(VariableTracker):
class TypingVariable(VariableTracker):
def __init__(self, value, **kwargs):
def __init__(self, value, **kwargs) -> None:
super().__init__(**kwargs)
self.value = value
@ -1152,7 +1152,7 @@ class NumpyVariable(VariableTracker):
constant_fold_functions = (tnp.issubdtype,)
def __init__(self, value, **kwargs):
def __init__(self, value, **kwargs) -> None:
super().__init__(**kwargs)
self.value = value
@ -1259,10 +1259,10 @@ class NumpyVariable(VariableTracker):
# Used to keep track of NULLs pushed on the stack for Python 3.11 function calls
class NullVariable(VariableTracker):
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
def __str__(self):
def __str__(self) -> str:
return "NullVariable"
def reconstruct(self, codegen):
@ -1296,14 +1296,14 @@ class StringFormatVariable(VariableTracker):
)
return cls(format_string, list(sym_args), dict(sym_kwargs))
def __init__(self, format_string, sym_args, sym_kwargs, **kwargs):
def __init__(self, format_string, sym_args, sym_kwargs, **kwargs) -> None:
super().__init__(**kwargs)
assert isinstance(format_string, str)
self.format_string = format_string
self.sym_args = sym_args
self.sym_kwargs = sym_kwargs
def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})"
def reconstruct(self, codegen):
@ -1330,7 +1330,7 @@ class DebuggingVariable(VariableTracker):
registered to config.reorderable_logging_functions.
"""
def __init__(self, value, **kwargs):
def __init__(self, value, **kwargs) -> None:
super().__init__(**kwargs)
self.value = value
@ -1384,7 +1384,7 @@ class LoggingLoggerVariable(VariableTracker):
Represents a call to any of logging.Logger methods
"""
def __init__(self, value, **kwargs):
def __init__(self, value, **kwargs) -> None:
super().__init__(**kwargs)
def call_method(
@ -1414,7 +1414,7 @@ class ConstantLikeVariable(VariableTracker):
np_floating = type("invalid_type", (), {})
np_dtype = type("invalid_type", (), {})
def __init__(self, value, **kwargs):
def __init__(self, value, **kwargs) -> None:
super().__init__(**kwargs)
self.value = value
@ -1472,7 +1472,7 @@ class ConstantRegexMatchVariable(ConstantLikeVariable):
class TorchVersionVariable(ConstantLikeVariable):
_error_prefix = "torch.__version__"
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
kwargs.setdefault("value", torch.__version__)
assert kwargs["value"] is torch.__version__
super().__init__(**kwargs)

View File

@ -133,7 +133,7 @@ class NNModuleVariable(VariableTracker):
def __init__(
self, module_type: type, module_key: str, module: torch.nn.Module, **kwargs
):
) -> None:
super().__init__(**kwargs)
self.module_type = module_type
self.module_key = module_key
@ -776,7 +776,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
Giving one graph per module class.
"""
def __init__(self, value, **kwargs):
def __init__(self, value, **kwargs) -> None:
if type(value) is torch.jit._script.RecursiveScriptModule:
raise Unsupported(
"ScriptModules aren't supported in UnspecializedNNModuleVariable"
@ -1160,7 +1160,7 @@ class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable):
compilation.
"""
def __init__(self, value, **kwargs):
def __init__(self, value, **kwargs) -> None:
source = kwargs.get("source", None)
assert (
source is not None

View File

@ -51,7 +51,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
static_tensor_names=None,
tensor_to_source=None,
**kwargs,
):
) -> None:
super().__init__(value, **kwargs)
self.grad_to_source = grad_to_source or {}
self.tensor_to_source = tensor_to_source or {}

View File

@ -35,7 +35,7 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
def create(proxy, value, **options):
return TorchScriptObjectVariable(proxy, value, **options)
def __init__(self, proxy, value, source, **kwargs):
def __init__(self, proxy, value, source, **kwargs) -> None:
super().__init__(value, **kwargs)
self.proxy = proxy
self.proxy.node.meta["example_value"] = value

View File

@ -46,7 +46,7 @@ class SDPAParamsVariable(VariableTracker):
tx, param_vars, {}
)
def __init__(self, proxy, param_vars, **kwargs):
def __init__(self, proxy, param_vars, **kwargs) -> None:
self.proxy = proxy
self.param_vars = param_vars
super().__init__(**kwargs)

View File

@ -137,7 +137,7 @@ class TensorVariable(VariableTracker):
is_contiguous=None,
_is_name_set=None,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.proxy = proxy
self.dtype = dtype
@ -1073,7 +1073,7 @@ class SymNodeVariable(VariableTracker):
return SymNodeVariable(proxy, sym_num, **options)
def __init__(self, proxy, sym_num, **kwargs):
def __init__(self, proxy, sym_num, **kwargs) -> None:
super().__init__(**kwargs)
self.proxy = proxy
# TODO: Should we allow non SymTypes here? Today it is allowed
@ -1252,7 +1252,7 @@ class UnspecializedPythonVariable(TensorVariable):
def __init__(
self, proxy: torch.fx.Proxy, *, raw_value=None, need_unwrap=True, **kwargs
):
) -> None:
super().__init__(proxy, **kwargs)
self.raw_value = raw_value
self.need_unwrap = need_unwrap
@ -1276,7 +1276,7 @@ class FakeItemVariable(TensorVariable):
*TensorVariable._nonvar_fields,
}
def __init__(self, proxy: torch.fx.Proxy, **kwargs):
def __init__(self, proxy: torch.fx.Proxy, **kwargs) -> None:
need_unwrap = kwargs.pop("need_unwrap", False)
super().__init__(proxy, **kwargs)
self.need_unwrap = need_unwrap
@ -1287,7 +1287,7 @@ class FakeItemVariable(TensorVariable):
class TensorSubclassVariable(VariableTracker):
def __init__(self, value, *args, **kwargs):
def __init__(self, value, *args, **kwargs) -> None:
self.value = value
super().__init__(*args, **kwargs)
@ -1329,7 +1329,7 @@ class UntypedStorageVariable(VariableTracker):
from_tensor: TensorVariable,
example_value: torch.UntypedStorage,
**kwargs,
):
) -> None:
super().__init__(**kwargs),
self.from_tensor = from_tensor
# Example_value will always have device="meta"

View File

@ -154,7 +154,7 @@ class BaseTorchVariable(VariableTracker):
source=source,
)
def __init__(self, value, **kwargs):
def __init__(self, value, **kwargs) -> None:
super().__init__(**kwargs)
self.value = value
@ -190,7 +190,7 @@ class BaseTorchVariable(VariableTracker):
class TorchCtxManagerClassVariable(BaseTorchVariable):
"""Points to a context manager class in torch.* that dynamo has implementations"""
def __repr__(self):
def __repr__(self) -> str:
return f"TorchCtxManagerClassVariable({self.value})"
@staticmethod
@ -331,7 +331,7 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
class TorchInGraphFunctionVariable(BaseTorchVariable):
"""Points to a torch function/method that should be put in FX graph"""
def __repr__(self):
def __repr__(self) -> str:
return f"TorchInGraphFunctionVariable({self.value})"
def get_function(self):

View File

@ -175,7 +175,7 @@ class TensorWithTFOverrideVariable(TensorVariable):
Represents a tensor subclass instance with a __torch_function__ override.
"""
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
self.torch_function_fn = kwargs.pop("torch_function_fn")
super().__init__(*args, **kwargs)

View File

@ -97,7 +97,7 @@ class UserDefinedVariable(VariableTracker):
class UserDefinedClassVariable(UserDefinedVariable):
def __init__(self, value, **kwargs):
def __init__(self, value, **kwargs) -> None:
super().__init__(**kwargs)
self.value = value
@ -110,7 +110,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
def as_proxy(self):
return self.value
def __str__(self):
def __str__(self) -> str:
return f"UserDefinedClassVariable({self.value})"
@staticmethod
@ -523,13 +523,13 @@ class UserDefinedObjectVariable(UserDefinedVariable):
_nonvar_fields = {"value", "value_type", *UserDefinedVariable._nonvar_fields}
def __init__(self, value, value_type=None, **kwargs):
def __init__(self, value, value_type=None, **kwargs) -> None:
super().__init__(**kwargs)
self.value = value
self.value_type = value_type or type(value)
assert type(value) is self.value_type
def __str__(self):
def __str__(self) -> str:
inner = self.value_type.__name__
if inner in [
"builtin_function_or_method",
@ -540,7 +540,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
inner = str(getattr(self.value, "__name__", None))
return f"{self.__class__.__name__}({inner})"
def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.value_type.__name__})"
def python_type(self):
@ -1111,7 +1111,7 @@ class SourcelessGraphModuleVariable(UserDefinedObjectVariable):
self,
value,
**kwargs,
):
) -> None:
super().__init__(value, **kwargs)
def call_method(
@ -1133,7 +1133,7 @@ class SourcelessGraphModuleVariable(UserDefinedObjectVariable):
class WeakRefVariable(UserDefinedObjectVariable):
_nonvar_fields = UserDefinedObjectVariable._nonvar_fields
def __init__(self, value, **kwargs):
def __init__(self, value, **kwargs) -> None:
super().__init__(value, **kwargs)
def call_function(
@ -1162,7 +1162,7 @@ class KeyedJaggedTensorVariable(UserDefinedObjectVariable):
mod = sys.modules.get("torchrec.sparse.jagged_tensor")
return mod is not None and type(obj) is mod.KeyedJaggedTensor
def __init__(self, value, **kwargs):
def __init__(self, value, **kwargs) -> None:
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
assert type(value) is KeyedJaggedTensor
@ -1194,7 +1194,7 @@ class RemovableHandleVariable(VariableTracker):
# index of the registration in the side_effects owned register_hook/handle list, used during removal.
idx=None,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.mutable_local = mutable_local
self.idx = idx