[BE] typing for decorators - fx/_compatibility (#131568)

See #131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131568
Approved by: https://github.com/justinchuby, https://github.com/oulgen, https://github.com/zou3519
This commit is contained in:
Aaron Orenstein
2024-07-25 08:05:29 -07:00
committed by PyTorch MergeBot
parent 709ddf7a9d
commit 193f62fde9
77 changed files with 232 additions and 249 deletions

View File

@ -336,7 +336,7 @@ class AutogradCompilerInstance:
def bind_tensors_to_proxies(self, tensors, proxies):
if isinstance(proxies, torch.fx.Proxy):
proxies = [proxies[i] for i in range(len(tensors))]
proxies = [proxies[i] for i in range(len(tensors))] # type: ignore[index]
assert len(tensors) == len(proxies)
track_tensor_tree(tensors, proxies, constant=None, tracer=self.fx_tracer)

View File

@ -1069,7 +1069,7 @@ def replay(filename: str) -> None:
record.globals = dict(itertools.chain(record.globals.items(), globals().items()))
try:
_compile(
_compile( # type: ignore[call-arg]
record.code,
record.globals,
record.locals,

View File

@ -1512,7 +1512,7 @@ def export(
# Running graph with interpreter is needed for propagating the stack_trace
def graph_with_interpreter(*args):
with torch.fx.traceback.preserve_node_meta():
return torch.fx.Interpreter(graph).run(*args)
return torch.fx.Interpreter(graph).run(*args) # type: ignore[arg-type]
with maybe_disable_fake_tensor_mode(), enable_python_dispatcher(), (
fake_mode
@ -1536,9 +1536,9 @@ def export(
assert graph is not None
for node in graph.graph.find_nodes(op="get_attr"):
if isinstance(getattr(graph, node.target), torch.Tensor):
if isinstance(getattr(graph, node.target), torch.Tensor): # type: ignore[arg-type]
node.meta["val"] = fake_mode.from_tensor(
getattr(graph, node.target), static_shapes=True
getattr(graph, node.target), static_shapes=True # type: ignore[arg-type]
)
if same_signature:

View File

@ -2039,7 +2039,7 @@ def get_real_value(node, tracer):
return cache[node]
op = node.op
args, kwargs = torch.fx.node.map_arg(
args, kwargs = torch.fx.node.map_arg( # type: ignore[misc]
(node.args, node.kwargs),
lambda n: get_real_value(n, tracer),
)

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import copy
import dataclasses

View File

@ -942,7 +942,7 @@ class ExplainTS2FXGraphConverter(TS2FXGraphConverter):
self.name_to_node,
# Dummy node.
torch.fx.Node(
None,
None, # type: ignore[arg-type]
"mock",
"call_function",
lambda: None,

View File

@ -68,7 +68,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
self.fake_tensor_mode: Optional[FakeTensorMode] = None
self.submodules: Dict[torch.nn.Module, str] = {}
def trace(self) -> None:
def trace(self) -> None: # type: ignore[override]
raise ExportPassBaseError("ExportTracer doesn't support trace().")
def create_arg(self, a: Argument) -> torch.fx.Node:
@ -160,7 +160,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
def placeholder(
self,
target: str,
target: str, # type: ignore[override]
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
) -> ProxyValue:
@ -218,7 +218,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
raise ExportPassBaseError(f"Unsupported target type: {target}")
def get_attr(
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override]
) -> Argument:
return super().get_attr(target, args, kwargs)
@ -231,7 +231,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
raise ExportPassBaseError("call_module is not supported.")
def call_method(
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override]
) -> None:
raise ExportPassBaseError("call_method is not supported.")
@ -394,7 +394,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
)
self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
interpreter = self.ExportInterpreter(self, graph_module)
prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter(
prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( # type: ignore[assignment]
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
)
inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs)

View File

@ -92,7 +92,7 @@ def _replace_with_hop(node: torch.fx.Node):
# Rename the name of getitem nodes to the actual name of its contents
# for passing verifier and better readability, also propagate metadata
for get_item_node in call_func_node.users.keys():
idx: int = get_item_node.args[1]
idx: int = get_item_node.args[1] # type: ignore[assignment]
output_node = output_args[idx]
get_item_node._rename(output_node.name)
get_item_node.meta = output_node.meta

View File

@ -186,7 +186,7 @@ def _extract_graph_with_inputs_outputs(
# joint_graph.nodes).
continue
elif node.op == "placeholder":
env[node] = InvalidNode
env[node] = InvalidNode # type: ignore[assignment]
elif node.op == "call_function":
all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs)
all_args = [
@ -195,7 +195,7 @@ def _extract_graph_with_inputs_outputs(
if isinstance(x, fx.Node)
]
if any(all_args):
env[node] = InvalidNode
env[node] = InvalidNode # type: ignore[assignment]
continue
env[node] = new_graph.node_copy(node, lambda x: env[x])
elif node.op == "get_attr":

View File

@ -769,7 +769,7 @@ def trace_flex_attention_backward(
)
assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
block_mask = block_mask[:-1] + (mask_graph,)
proxy_mode.tracer.root.register_module("fw_graph", fw_graph)
proxy_mode.tracer.root.register_module("fw_graph", fw_graph) # type: ignore[arg-type]
proxy_mode.tracer.root.register_module("joint_graph", joint_graph)
proxy_mode.tracer.root.register_module("mask_graph", mask_graph)
node_args = (

View File

@ -1432,7 +1432,7 @@ class CppVecOverrides(CppOverrides):
], f"{__name__} does not support {dtype}"
node: torch.fx.Node = V.interpreter.current_node
assert node and isinstance(node, torch.fx.Node)
opt_ctx_x = get_opt_ctx(node.args[1])
opt_ctx_x = get_opt_ctx(node.args[1]) # type: ignore[arg-type]
assert opt_ctx_x
assert opt_ctx_x.dtype is not None
assert isinstance(V.kernel, CppVecKernel)

View File

@ -222,7 +222,7 @@ class ConstantFolder(torch.fx.Interpreter):
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
self.node_replacements[node] = tensor
def run(self) -> Any:
def run(self) -> Any: # type: ignore[override]
env: Dict[torch.fx.Node, Any] = {}
self.insert_placerholder_values(env)
return super().run(initial_env=env)

View File

@ -141,7 +141,7 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
kwargs = {}
if hasattr(snode, "get_device"):
kwargs = {"device": snode.get_device()}
fx_node = graph.call_function(node_func, args=(), kwargs=kwargs)
fx_node = graph.call_function(node_func, args=(), kwargs=kwargs) # type: ignore[arg-type]
def in_output(snode):
if isinstance(snode, FusedSchedulerNode):

View File

@ -154,17 +154,17 @@ def binary_folding_init():
return False
if isinstance(other, torch.fx.Node) and other.op == "get_attr":
other_meta_value = other.meta.get("val")
if not other_meta_value.is_floating_point():
if not other_meta_value.is_floating_point(): # type: ignore[union-attr]
return False
if (
torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype)
torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype) # type: ignore[union-attr]
!= weight_meta_value.dtype
):
if not conv_node.meta.get("_allow_conv_mixed_dtype_folding", False):
return False
if (
other_meta_value.dtype != torch.float
other_meta_value.dtype != torch.float # type: ignore[union-attr]
and weight_meta_value.dtype not in (torch.float16, torch.bfloat16)
):
return False

View File

@ -213,7 +213,7 @@ def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs
new_node = graph.create_node(
op="call_function",
target=efficient_conv_bn_eval_decomposed,
args=args,
args=args, # type: ignore[arg-type]
name="efficient_conv_bn_eval",
)
@ -223,7 +223,7 @@ def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs
# take care of the deletion order:
# delete bn_node first, and then conv_node
graph.erase_node(bn_node)
graph.erase_node(conv_node)
graph.erase_node(conv_node) # type: ignore[arg-type]
return
@ -304,7 +304,7 @@ def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwa
new_node = graph.create_node(
op="call_function",
target=efficient_conv_bn_eval_decomposed,
args=args,
args=args, # type: ignore[arg-type]
name="efficient_conv_bn_eval",
)
@ -314,7 +314,7 @@ def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwa
# take care of the deletion order:
# delete bn_node first, and then conv_node
graph.erase_node(bn_node)
graph.erase_node(conv_node)
graph.erase_node(conv_node) # type: ignore[arg-type]
return
@ -373,7 +373,7 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
# Find a pair of conv and bn computation nodes to optimize.
counters["inductor"]["efficient_conv_bn_eval"] += 1
with graph.inserting_before(conv_node):
with graph.inserting_before(conv_node): # type: ignore[arg-type]
# create `get_attr` node to access modules
# note that we directly call `create_node` to fill the `name`
# argument. `graph.get_attr` and
@ -403,4 +403,4 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
# take care of the deletion order:
# delete bn_node first, and then conv_node
graph.erase_node(bn_node)
graph.erase_node(conv_node)
graph.erase_node(conv_node) # type: ignore[arg-type]

View File

@ -223,5 +223,5 @@ def unnecessary_dtype_convert(match: Match, **kwargs):
"""Remove unnecessary dtype conversion op, probably left as a result of Conv-Bn folding"""
graph = match.graph
node = match.output_node()
node.replace_all_uses_with(node.args[0])
node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type]
graph.erase_node(node)

View File

@ -505,7 +505,7 @@ def pointless_view(match: Match, arg, size):
node = match.output_node()
arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr]
if size == arg_size:
node.replace_all_uses_with(node.args[0])
node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type] # type: ignore[arg-type]
match.erase_nodes(graph)

View File

@ -1033,7 +1033,7 @@ def is_index_put_and_requires_h2d_sync_for_cuda_value(node):
# if the value we are putting is a cpu scalar.
# Therefore, when inductor sees an index_put_ with byte tensor indices,
# it should *not* convert the cpu scalar value into a cuda tensor.
args_, kwargs_ = normalize_function(node.target, node.args, node.kwargs)
args_, kwargs_ = normalize_function(node.target, node.args, node.kwargs) # type: ignore[syntax, misc]
any_byte_bool_indices = False
indices = args_[1]
for i in indices:

View File

@ -1885,14 +1885,14 @@ def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float3
graph.erase_node(conv_node)
# Erase the dequant pattern
if dtype == torch.bfloat16:
graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined]
graph.erase_node(dequant_node)
graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined, arg-type]
graph.erase_node(dequant_node) # type: ignore[arg-type]
# Erase the dequant per channel pattern
if clone_node is not None:
graph.erase_node(clone_node)
graph.erase_node(clone_node) # type: ignore[arg-type]
if dtype == torch.bfloat16:
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
graph.erase_node(dequant_per_channel)
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type]
graph.erase_node(dequant_per_channel) # type: ignore[arg-type]
counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len(
match.nodes
@ -2579,8 +2579,8 @@ def quant_lift_up(graph_module: torch.fx.GraphModule):
new_args = map_arg(new_quant_node.args, maybe_replace_node)
new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node)
new_quant_node.args = new_args
new_quant_node.kwargs = new_kwargs
new_quant_node.args = new_args # type: ignore[assignment]
new_quant_node.kwargs = new_kwargs # type: ignore[assignment]
graph_module.graph.erase_node(quant_node)
graph_module.graph.lint()

View File

@ -288,7 +288,7 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
return
src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr]
with graph.inserting_before(src):
with graph.inserting_before(src): # type: ignore[arg-type]
new_node = graph_call_function(
graph,
_generalized_scatter,
@ -311,7 +311,7 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
handle_views(new_src)
src.replace_all_uses_with(new_src) # type: ignore[union-attr]
graph.erase_node(src)
graph.erase_node(src) # type: ignore[arg-type]
for node in graph.nodes:
if _is_view_op(node.target):

View File

@ -188,7 +188,7 @@ def normalize_split_base(
new_split_node = graph.call_function(
torch.split,
args=new_args,
kwargs=new_kwargs,
kwargs=new_kwargs, # type: ignore[arg-type]
)
split_node.replace_all_uses_with(new_split_node)
new_split_node.meta.update(split_node.meta)
@ -373,7 +373,7 @@ def normalize_stack_default(match: Match, *args, **kwargs):
with graph.inserting_after(node):
new_node = graph.call_function(
node.target,
node.target, # type: ignore[arg-type]
args=(tensors,),
kwargs={"dim": dim},
)
@ -502,7 +502,7 @@ def merge_splits(
to_remove = []
with graph.inserting_before(first_split):
with graph.inserting_before(first_split): # type: ignore[arg-type]
# Add the new split node
new_split = graph.call_function(
torch.split,
@ -1431,7 +1431,7 @@ def mutate_cat_node(match: Match, split_sections: List[int], dim: int):
# case 1: the cat uses all getitems from the split
if len(split_sections) == len(cat_user.args[0]): # type: ignore[arg-type]
# replace the users of the cat node to be the input of the split node
cat_user.replace_all_uses_with(split_node.args[0])
cat_user.replace_all_uses_with(split_node.args[0]) # type: ignore[arg-type]
# remove the cat node
graph.erase_node(cat_user)
counters["inductor"]["mutate_cat_pass"] += 1

View File

@ -762,7 +762,7 @@ class GraphLowering(torch.fx.Interpreter):
raise KeyError(f"could not find {buffer_name}")
@dynamo_timed
def run(self, *args: Any) -> Any:
def run(self, *args: Any) -> Any: # type: ignore[override]
return super().run(*args)
def register_operation(self, op: ir.Operation) -> str:
@ -906,9 +906,9 @@ class GraphLowering(torch.fx.Interpreter):
)
def placeholder(
self, target: str, args: Tuple[object], kwargs: Dict[str, object]
self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override]
) -> Union[Expr, TensorBox, None]:
example = super().placeholder(target, args, kwargs)
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
self.graph_input_names.append(target)
if isinstance(example, SymTypes):
expr = example.node.expr
@ -963,7 +963,7 @@ class GraphLowering(torch.fx.Interpreter):
self.aligned_inputs.add(target)
return tensor
def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> Any: # type: ignore[type-arg]
def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> Any: # type: ignore[type-arg, override]
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
return super().call_function(target, args, kwargs)
@ -1040,7 +1040,7 @@ class GraphLowering(torch.fx.Interpreter):
return len(t.shape) == 1 and t.shape[0] <= 8
def get_attr(
self, target: str, args: Tuple[()], kwargs: Dict[str, object]
self, target: str, args: Tuple[()], kwargs: Dict[str, object] # type: ignore[override]
) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
# this is a constant
value = getattr_recursive(self.module, target) # type: ignore[arg-type]
@ -1080,9 +1080,9 @@ class GraphLowering(torch.fx.Interpreter):
raise AssertionError
def output(
self, target: str, args: Tuple[object], kwargs: Dict[str, object]
self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override]
) -> None:
result = super().output(target, args, kwargs)
result = super().output(target, args, kwargs) # type: ignore[arg-type]
if not isinstance(result, (tuple, list)):
# nested subgraphs can have singleton outputs
result = (result,)

View File

@ -6572,7 +6572,7 @@ class InterpreterShim(torch.fx.Interpreter):
self.graph = graph
self.submodules = submodules
self.extra_traceback = False
self.fetch_attr = submodules.__getitem__
self.fetch_attr = submodules.__getitem__ # type: ignore[method-assign]
self.current_node = None
def run_node(self, n: torch.fx.Node) -> Any:

View File

@ -235,7 +235,7 @@ class Match:
if trace_fn is None:
trace_fn = functools.partial(fwd_only, run_dce=run_dce)
replacement = trace_fn(
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"])
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) # type: ignore[arg-type]
)
ReplacementPatternEntry.replace_with_graph(
self,
@ -606,7 +606,7 @@ class _TargetArgsExpr(_TargetExpr):
from torch.fx.operator_schemas import normalize_function
normalized_args_and_kwargs = normalize_function(
node.target, node.args, node.kwargs
node.target, node.args, node.kwargs # type: ignore[arg-type]
)
if normalized_args_and_kwargs is None:
@ -1035,7 +1035,7 @@ class ReplacementPatternEntry(PatternEntry):
if node.op == "call_function":
target = node.target
args, kwargs = self.fetch_args_kwargs_from_env(node)
result = graph.call_function(target, args, kwargs)
result = graph.call_function(target, args, kwargs) # type: ignore[arg-type]
if "val" in node.meta and "val" not in result.meta:
result.meta["val"] = node.meta["val"]
if isinstance(node.meta["val"], torch.Tensor):
@ -1079,7 +1079,7 @@ class ReplacementPatternEntry(PatternEntry):
queue.extend(arg.all_input_nodes)
with graph.inserting_before(last_node):
replacement = Replacer(replacement_graph).run(*args)
replacement = Replacer(replacement_graph).run(*args) # type: ignore[arg-type]
if isinstance(replacement, torch.fx.Node):
replacement = [replacement]
@ -1100,7 +1100,7 @@ class ReplacementPatternEntry(PatternEntry):
return
assert isinstance(old, torch.fx.Node)
if new is None:
old.replace_all_uses_with(None)
old.replace_all_uses_with(None) # type: ignore[arg-type]
graph.erase_node(old)
return
if isinstance(new, torch.fx.Node):
@ -1123,7 +1123,7 @@ class ReplacementPatternEntry(PatternEntry):
graph.erase_node(old)
return
new = typing.cast(Sequence[torch.fx.Node], new)
new = typing.cast(Sequence[torch.fx.Node], new) # type: ignore[redundant-cast]
# `new` is not a node: it's a list of nodes.
#
# This happens when we want to replace a node that has a single
@ -1232,7 +1232,7 @@ def register_replacement(
)
args = list(
torch.fx.map_arg(
torch.fx.map_arg( # type: ignore[arg-type]
[match.kwargs[name] for name in argnames], lambda n: n.meta["val"]
)
)
@ -1665,8 +1665,8 @@ class PatternMatcherPass:
raise RuntimeError(
f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}"
)
if should_compute_mutation_region_ids(graph):
compute_mutation_region_ids(graph)
if should_compute_mutation_region_ids(graph): # type: ignore[arg-type]
compute_mutation_region_ids(graph) # type: ignore[arg-type]
get_mutation_region_id_partial = functools.partial(
get_mutation_region_id, graph
)
@ -1757,7 +1757,7 @@ def fx_to_pattern(
get_attr = _not_implemented
def placeholder(
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any]
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
) -> Union[ExclusiveKeywordArg, KeywordArg]:
n = next(argnum)
if n < len(argnames):
@ -1774,7 +1774,7 @@ def fx_to_pattern(
return KeywordArg(name)
def call_function(
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any]
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
) -> PatternExpr:
args, kwargs = pytree.tree_map(process_arg, (args, kwargs))
if list in ignore_types:
@ -1793,7 +1793,7 @@ def fx_to_pattern(
rv.users = len(n.users)
return rv
pattern = Converter(gm).run()
pattern = Converter(gm).run() # type: ignore[arg-type]
if not isinstance(pattern, PatternExpr):
return MultiOutputPattern(pytree.tree_leaves(pattern))
return pattern

View File

@ -47,7 +47,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
def call_function(
self,
target: Callable[[Any], Any],
target: torch.fx.node.Target,
args: Any,
kwargs: Dict[str, Any],
) -> Any:
@ -70,9 +70,14 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
return lowerings[target](*args, **kwargs)
def output(self, target: str, args: Tuple[Any], kwargs: Dict[str, Any]) -> None:
def output(
self,
target: torch.fx.node.Target,
args: Tuple[torch.fx.node.Argument, ...],
kwargs: Dict[str, Any],
) -> None:
assert len(args) == 1
self.graph_outputs = args[0]
self.graph_outputs = args[0] # type: ignore[assignment]
@dataclass

View File

@ -350,7 +350,7 @@ def gen_gm_and_inputs(target, args, kwargs):
len(target._schema.returns) == 1
and str(target._schema.returns[0].type) == "Tensor"
):
node = (node,)
node = (node,) # type: ignore[assignment]
g.output(node)
gm = torch.fx.GraphModule({}, g)

View File

@ -2156,8 +2156,8 @@ class FakeTensorMode(TorchDispatchMode):
any_constant = any(e.constant is not None for e in flat_arg_fake_tensors)
schema_info = get_schema_info(func)
if any_constant and schema_info.is_mutable():
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type]
)
for k, v in new_kwargs.items():
k = k if (k != "input" or schema_info.has_argument(k)) else "self"

View File

@ -896,7 +896,7 @@ def prepare_n_shadows_model(
tracer = custom_tracer
mt = torch.fx.GraphModule(model, tracer.trace(model))
# this is necessary to ensure logger FQNs get populated
mt._node_name_to_scope = tracer.node_name_to_scope
mt._node_name_to_scope = tracer.node_name_to_scope # type: ignore[assignment]
# run example input propagation, we need this to call prepare_fx on
# individual subgraphs
@ -998,7 +998,7 @@ def _prepare_n_shadows_add_loggers_model(
tracer = quantize_fx.QuantizationTracer([], [])
mt = torch.fx.GraphModule(model, tracer.trace(model))
# this is necessary to ensure logger FQNs get populated
mt._node_name_to_scope = tracer.node_name_to_scope
mt._node_name_to_scope = tracer.node_name_to_scope # type: ignore[assignment]
# run example input propagation, we need this to call prepare_fx on
# individual subgraphs

View File

@ -694,13 +694,13 @@ def _insert_copy_of_node_a_after_input_node_c(
mod_a = getattr_from_fqn(gm_a, node_a.target)
setattr(gm_b, new_mod_copy_name, mod_a)
node_a_shadows_c = graph_c.create_node(
node_a.op, new_mod_copy_name, new_args, new_kwargs, node_a_shadows_c_name
node_a.op, new_mod_copy_name, new_args, new_kwargs, node_a_shadows_c_name # type: ignore[arg-type]
)
return node_a_shadows_c
else:
assert node_a.op in ("call_function", "call_method")
node_a_shadows_c = graph_c.create_node(
node_a.op, node_a.target, new_args, new_kwargs, node_a_shadows_c_name
node_a.op, node_a.target, new_args, new_kwargs, node_a_shadows_c_name # type: ignore[arg-type]
)
return node_a_shadows_c

View File

@ -406,20 +406,20 @@ def create_submodule_from_subgraph(
mod_name = f"mod_{cur_name_idx}"
setattr(gm, mod_name, orig_mod_copy)
cur_name_idx += 1
cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined]
cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined, arg-type]
elif cur_node_orig.op == "call_function":
cur_node_copy = g.call_function(
cur_node_orig.target,
cur_args_copy,
cur_kwargs_copy, # type: ignore[possibly-undefined]
cur_node_orig.target, # type: ignore[arg-type]
cur_args_copy, # type: ignore[arg-type]
cur_kwargs_copy, # type: ignore[possibly-undefined, arg-type]
)
elif cur_node_orig.op == "call_method":
cur_node_copy = g.call_method(
cur_node_orig.target,
cur_args_copy,
cur_kwargs_copy, # type: ignore[possibly-undefined]
cur_node_orig.target, # type: ignore[arg-type]
cur_args_copy, # type: ignore[arg-type]
cur_kwargs_copy, # type: ignore[possibly-undefined, arg-type]
)
else:
@ -582,7 +582,7 @@ def create_one_transformed_and_logged_copy_of_subgraph(
new_args = tuple(new_args) # type: ignore[assignment]
new_node = mt.graph.call_module(attr_name, args=new_args, kwargs=new_kwargs)
new_node = mt.graph.call_module(attr_name, args=new_args, kwargs=new_kwargs) # type: ignore[arg-type]
# add a logger to parent graph to observe the shadow wrapper
logger_mod_orig = _get_logger_for_subgraph(

View File

@ -268,7 +268,7 @@ class BaseStructuredSparsifier(BaseSparsifier):
BiasHook(module.parametrizations.weight[0], prune_bias)
)
def prune(self) -> None:
def prune(self) -> torch.fx.graph_module.GraphModule:
r"""
This function will FX symbolically trace the model and then find instances of the patterns
defined in self.patterns (by default SUPPORTED_STRUCTURED_PRUNING_PATTERNS ).

View File

@ -25,7 +25,7 @@ def _match(
if isinstance(current, type) and issubclass(current, torch.nn.Module):
return (
node.op == "call_module"
and parametrize.type_before_parametrizations(modules[node.target])
and parametrize.type_before_parametrizations(modules[node.target]) # type: ignore[index]
== current
)
elif callable(current):

View File

@ -571,7 +571,7 @@ def _match_static_pattern(
match_key = type(_get_module(ref_node, modules))
else:
expected_op = "call_function"
match_key = ref_node.target
match_key = ref_node.target # type: ignore[assignment]
if ref_node.op != expected_op or match_key not in matching_modules_or_ops:
return SKIP_LOWERING_VALUE
@ -591,7 +591,7 @@ def _match_static_pattern(
if not matched_dequantize:
return SKIP_LOWERING_VALUE
return (q_node, relu_node, ref_node)
return (q_node, relu_node, ref_node) # type: ignore[return-value]
def _match_static_pattern_with_two_inputs(
@ -689,8 +689,8 @@ def _lower_static_weighted_ref_module(
continue
else:
q_class = STATIC_LOWER_MODULE_MAP[ref_class]
output_scale = getattr(model, scale_node.target)
output_zero_point = getattr(model, zero_point_node.target)
output_scale = getattr(model, scale_node.target) # type: ignore[arg-type]
output_zero_point = getattr(model, zero_point_node.target) # type: ignore[arg-type]
q_module = q_class.from_reference(ref_module, output_scale, output_zero_point)
# replace reference module with quantized module
parent_name, module_name = _parent_name(ref_node.target)
@ -700,7 +700,7 @@ def _lower_static_weighted_ref_module(
assert len(ref_node.args) == 1
dq_node = ref_node.args[0]
assert isinstance(dq_node, Node)
ref_node.replace_input_with(dq_node, dq_node.args[0])
ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type]
q_node.replace_all_uses_with(ref_node)
model.graph.erase_node(q_node)
model.graph.erase_node(scale_node)
@ -749,8 +749,8 @@ def _lower_static_weighted_ref_module_with_two_inputs(
continue
else:
continue
output_scale = getattr(model, scale_node.target)
output_zero_point = getattr(model, zero_point_node.target)
output_scale = getattr(model, scale_node.target) # type: ignore[arg-type]
output_zero_point = getattr(model, zero_point_node.target) # type: ignore[arg-type]
q_module = q_class.from_reference(ref_module, output_scale, output_zero_point)
# replace reference module with quantized module
parent_name, module_name = _parent_name(ref_node.target)
@ -763,7 +763,7 @@ def _lower_static_weighted_ref_module_with_two_inputs(
continue
dq_node = arg
assert isinstance(dq_node, Node)
ref_node.replace_input_with(dq_node, dq_node.args[0])
ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type]
q_node.replace_all_uses_with(ref_node)
model.graph.erase_node(q_node)
@ -906,7 +906,7 @@ def _lower_static_weighted_ref_functional(
prepack_args[5], prepack_args[6] = prepack_args[6], prepack_args[5]
else:
raise ValueError(f"Lowering is not supported for op '{func_node.target}'")
with model.graph.inserting_before(output_scale_node):
with model.graph.inserting_before(output_scale_node): # type: ignore[arg-type]
# kwargs of the func node are needed for prepack op (i.e., quantized::linear_prepack)
# They are not needed for compute op (i.e., quantized::linear)
kwargs = func_node.kwargs
@ -1107,7 +1107,7 @@ def _lower_quantized_binary_op(model: GraphModule, qconfig_map: Dict[str, QConfi
dq_node = arg
assert isinstance(dq_node, Node)
dn_input = dq_node.args[0]
bop_node.replace_input_with(dq_node, dn_input)
bop_node.replace_input_with(dq_node, dn_input) # type: ignore[arg-type]
num_dq_nodes += 1
assert num_dq_nodes > 0

View File

@ -821,7 +821,7 @@ def _reroute_tuple_getitem_pattern(graph: Graph):
last_getitem_index = last_getitem.args[1]
new_input = first_tuple.args[0][last_getitem_index] # type: ignore[index]
for user in list(last_getitem.users.keys()):
user.replace_input_with(last_getitem, new_input)
user.replace_input_with(last_getitem, new_input) # type: ignore[arg-type]
def _get_observer_from_activation_post_process(

View File

@ -46,8 +46,8 @@ def _maybe_duplicate_dq(
new_args = map_arg(user.args, maybe_replace_node)
new_kwargs = map_arg(user.kwargs, maybe_replace_node)
user.args = new_args
user.kwargs = new_kwargs
user.args = new_args # type: ignore[assignment]
user.kwargs = new_kwargs # type: ignore[assignment]
class DuplicateDQPass(PassBase):

View File

@ -107,8 +107,8 @@ def _find_q_dq_node_for_user(
q_node = None
if (
dq_node.args[0].op == "call_function"
and dq_node.args[0].target in _QUANTIZE_OPS
dq_node.args[0].op == "call_function" # type: ignore[union-attr]
and dq_node.args[0].target in _QUANTIZE_OPS # type: ignore[union-attr]
):
q_node = dq_node.args[0]
return (q_node, dq_node)
@ -353,7 +353,7 @@ def _get_aten_graph_module_for_pattern(
[x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs]
)
aten_pattern = capture_pre_autograd_graph(
pattern,
pattern, # type: ignore[arg-type]
example_inputs,
kwargs,
)
@ -373,7 +373,7 @@ def _get_aten_graph_module_for_pattern(
aten_pattern.graph.eliminate_dead_code()
aten_pattern.recompile()
return aten_pattern
return aten_pattern # type: ignore[return-value]
def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:

View File

@ -983,13 +983,13 @@ def _annotate_cat(
inputs = cat_node.args[0]
input_qspec_map = {}
input_act0 = inputs[0]
input_act0 = inputs[0] # type: ignore[index]
if isinstance(input_act0, Node):
input_qspec_map[input_act0] = input_act_qspec
shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node))
for input_act in inputs[1:]:
input_qspec_map[input_act] = shared_with_input0_qspec
shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node)) # type: ignore[arg-type]
for input_act in inputs[1:]: # type: ignore[index]
input_qspec_map[input_act] = shared_with_input0_qspec # type: ignore[index]
output_act_qspec = shared_with_input0_qspec

View File

@ -117,7 +117,7 @@ def _to_caller_flattened_graph_module(gm: torch.fx.GraphModule) -> torch.fx.Grap
# pyre-ignore[6]
in_spec=None, # type: ignore[arg-type]
# pyre-ignore[16]
out_spec=gm._graph._codegen.pytree_info.out_spec,
out_spec=gm._graph._codegen.pytree_info.out_spec, # type: ignore[attr-defined]
)
)
gm.recompile()

View File

@ -92,7 +92,7 @@ def clone_subgraph(
with graph.inserting_before(target):
for node in subgraph:
cloned_node = graph.call_function(
node.target, node.args, node.kwargs, node.type
node.target, node.args, node.kwargs, node.type # type: ignore[arg-type]
)
# TODO: there are many flatten/unflatten in IterGraph that
# can be simplified with tree_map. Will simplify this in

View File

@ -366,7 +366,7 @@ class IterGraph(fx.Graph):
actual_target_node = self._lookup_node(target_node, graph)
assert actual_target_node is not None
for actual_node in actual_nodes:
actual_target_node.prepend(actual_node)
actual_target_node.prepend(actual_node) # type: ignore[arg-type]
def move_after(self, nodes: List[fx.Node], target_node: fx.Node) -> None:
for graph in self._all_graphs:
@ -374,7 +374,7 @@ class IterGraph(fx.Graph):
actual_target_node = self._lookup_node(target_node, graph)
for actual_node in actual_nodes:
assert actual_target_node is not None
actual_target_node.append(actual_node)
actual_target_node.append(actual_node) # type: ignore[arg-type]
actual_target_node = actual_node
def call_function(
@ -432,7 +432,7 @@ class IterGraph(fx.Graph):
self.setup_graph.erase_node(setup_node)
super().erase_node(to_erase)
cleanup_node = self._lookup_node(to_erase, self.cleanup_graph)
self.cleanup_graph.erase_node(cleanup_node)
self.cleanup_graph.erase_node(cleanup_node) # type: ignore[arg-type]
def placeholder(
self,
@ -558,7 +558,7 @@ class IterGraph(fx.Graph):
actual_replace_with = self._lookup_node(replace_with, graph)
assert actual_node is not None
ret = actual_node.replace_all_uses_with(
actual_replace_with,
actual_replace_with, # type: ignore[arg-type]
delete_user_cb,
propagate_meta=propagate_meta,
)

View File

@ -45,7 +45,7 @@ def _is_container_node(node: torch.fx.Node) -> bool:
"Malformed graph: a container node is used as input for non-getitem nodes."
"\nNode: {fmt_node}\nUsers: {fmt_users}".format(
fmt_node=node.format_node(),
fmt_users="\n".join(u.format_node() for u in node.users),
fmt_users="\n".join(u.format_node() for u in node.users), # type: ignore[misc]
)
)
return True

View File

@ -472,7 +472,7 @@ def _insert_reshard_gm(
input_node: input_arg,
},
)
node.replace_input_with(input_arg, output_node)
node.replace_input_with(input_arg, output_node) # type: ignore[arg-type]
def _clean_up_graph_metadata(gm: torch.fx.GraphModule) -> None:

View File

@ -96,11 +96,11 @@ class _ExecOrderTracer:
self.exec_info = _ExecutionInfo(root_module)
orig_call_module = tracer.call_module
orig_create_proxy = tracer.create_proxy
tracer.call_module = functools.partial(
tracer.call_module = functools.partial( # type: ignore[method-assign]
self._patched_call_module, orig_call_module, self.exec_info
)
fqn_to_param = dict(root_module.named_parameters())
tracer.create_proxy = functools.partial(
tracer.create_proxy = functools.partial( # type: ignore[method-assign]
self._patched_create_proxy,
orig_create_proxy,
self.exec_info,
@ -109,8 +109,8 @@ class _ExecOrderTracer:
try:
yield
finally:
tracer.call_module = orig_call_module
tracer.create_proxy = orig_create_proxy
tracer.call_module = orig_call_module # type: ignore[method-assign]
tracer.create_proxy = orig_create_proxy # type: ignore[method-assign]
def _patched_call_module(
self,
@ -216,8 +216,8 @@ class _ExecOrderTracer:
isinstance(arg, torch.fx.Proxy)
and arg.node.target in fqn_to_param
):
param = fqn_to_param[arg.node.target]
named_params.append((arg.node.target, param))
param = fqn_to_param[arg.node.target] # type: ignore[index]
named_params.append((arg.node.target, param)) # type: ignore[arg-type]
if param not in exec_info.visited_params:
exec_info.visited_params.add(param)
exec_info.param_forward_order.append(param)

View File

@ -214,7 +214,7 @@ def _insert_stage_symbolic_backward(
input_nodes = list(node.all_input_nodes)
grads_proxy = fx.Proxy(grads)
for i, input_node in enumerate(input_nodes):
assign_or_accumulate_grad(input_node, grads_proxy[i].node)
assign_or_accumulate_grad(input_node, grads_proxy[i].node) # type: ignore[index]
return g
@ -416,15 +416,15 @@ class _LinearNodeList:
def __init__(self, node_list):
self.serialize_node_list = []
for node in node_list:
node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name))
node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name))
node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name)) # type: ignore[arg-type, return-value]
node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name)) # type: ignore[arg-type, return-value]
serialize_node = fx.Node(
graph=None,
graph=None, # type: ignore[arg-type]
name=node.name,
op=node.op,
target=node.target,
args=node_args,
kwargs=node_kwargs,
args=node_args, # type: ignore[arg-type]
kwargs=node_kwargs, # type: ignore[arg-type]
return_type=node.type,
)
serialize_node.meta = copy.copy(node.meta)
@ -447,8 +447,8 @@ class _LinearNodeList:
deser_node = graph.create_node(
op=node.op,
target=node.target,
args=node_args,
kwargs=node_kwargs,
args=node_args, # type: ignore[arg-type]
kwargs=node_kwargs, # type: ignore[arg-type]
name=node.name,
type_expr=node.type,
)
@ -731,7 +731,7 @@ class Pipe(torch.nn.Module):
# TODO: what does split do with module invocations? does it move the modules
# into the submodules?
split = split_module(traced, mod, split_callback)
split = split_module(traced, mod, split_callback) # type: ignore[arg-type]
# a (custom) tracer can produce dead code like orphan get_attr nodes
split.graph.eliminate_dead_code()

View File

@ -86,7 +86,7 @@ class TensorChunkSpec:
"""
args_chunk_spec = map_aggregate(
chunk_dims,
lambda dim: TensorChunkSpec(dim),
lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type, return-value]
)
return args_chunk_spec
@ -104,7 +104,7 @@ class TensorChunkSpec:
"""
kwargs_chunk_spec = map_aggregate(
chunk_dims,
lambda dim: TensorChunkSpec(dim),
lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type, return-value]
)
return kwargs_chunk_spec

View File

@ -1,14 +1,15 @@
# mypy: allow-untyped-defs
from typing import Any, Dict
from typing import Any, Dict, Callable, TypeVar
import textwrap
_BACK_COMPAT_OBJECTS : Dict[Any, None] = {}
_MARKED_WITH_COMPATIBILITY : Dict[Any, None] = {}
def compatibility(is_backward_compatible : bool):
_T = TypeVar("_T")
def compatibility(is_backward_compatible : bool) -> Callable[[_T], _T]:
if is_backward_compatible:
def mark_back_compat(fn):
def mark_back_compat(fn: _T) -> _T:
docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
docstring += """
.. note::
@ -22,7 +23,7 @@ def compatibility(is_backward_compatible : bool):
return mark_back_compat
else:
def mark_not_back_compat(fn):
def mark_not_back_compat(fn: _T) -> _T:
docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
docstring += """
.. warning::

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from contextlib import contextmanager

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import builtins
import copy

View File

@ -165,7 +165,7 @@ class MetaTracer(torch.fx.Tracer):
meta_target = manual_meta_overrides.get(target, target)
meta_out = meta_target(*args_metas, **kwargs_metas)
elif kind == 'call_method':
meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas)
meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas) # type: ignore[index] # type: ignore[index]
elif kind == 'call_module':
assert hasattr(self, 'orig_forward')
self._disable_module_getattr = True
@ -173,7 +173,7 @@ class MetaTracer(torch.fx.Tracer):
mod = self.root.get_submodule(target)
mod_type = type(mod)
if mod_type in manual_meta_overrides:
meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas)
meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas) # type: ignore[misc, arg-type]
else:
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
finally:
@ -237,7 +237,7 @@ class MetaTracer(torch.fx.Tracer):
def proxy(self, node):
return MetaProxy(node, self)
def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None):
def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None): # type: ignore[override]
assert isinstance(meta_args, dict)
self.meta_args = meta_args
@ -263,7 +263,7 @@ def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]],
meta_args : Optional[Dict[str, torch.Tensor]] = None,
concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule:
tracer = MetaTracer()
graph = tracer.trace(root, meta_args, concrete_args)
graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type]
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
gm = torch.fx.GraphModule(tracer.root, graph, name)
return gm

View File

@ -153,10 +153,10 @@ def set_proxy_slot(
if isinstance(obj, Tensor):
# We DO want to clobber proxies whenever we run an inplace operation
# on a tensor, and it affects the metadata on the proxy.
tracer.tensor_tracker[obj] = proxy
tracer.tensor_tracker[obj] = proxy # type: ignore[has-type]
elif isinstance(obj, (_AnyScriptObject)):
# We DO want to clobber proxies, with a similar rationale as for tensors.
tracer.script_object_tracker[obj] = proxy
tracer.script_object_tracker[obj] = proxy # type: ignore[has-type]
else:
# NB: Never clobber pre-existing proxy. Although the proxies
# are in principle equivalent, when we do graph partitioning
@ -165,8 +165,8 @@ def set_proxy_slot(
# THEN later we allocate tangent inputs. Make sure if a SymInt
# is derivable from a primal that we use that.
assert isinstance(obj, py_sym_types), type(obj)
if obj not in tracer.symnode_tracker:
tracer.symnode_tracker[obj] = typing.cast(_PySymProxyType, proxy)
if obj not in tracer.symnode_tracker: # type: ignore[has-type]
tracer.symnode_tracker[obj] = typing.cast(_PySymProxyType, proxy) # type: ignore[has-type]
def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool:
assert isinstance(obj, (Tensor, SymNode)), type(obj)
@ -261,12 +261,12 @@ def get_proxy_slot(
tracker: Any
if isinstance(obj, Tensor):
tracker = tracer.tensor_tracker
tracker = tracer.tensor_tracker # type: ignore[has-type]
elif isinstance(obj, _AnyScriptObject):
tracker = tracer.script_object_tracker
tracker = tracer.script_object_tracker # type: ignore[has-type]
else:
assert isinstance(obj, py_sym_types), type(obj)
tracker = tracer.symnode_tracker
tracker = tracer.symnode_tracker # type: ignore[has-type]
if obj not in tracker:
if isinstance(default, _NoDefault):
@ -372,7 +372,7 @@ def track_tensor(tensor: Tensor, proxy: Proxy, *, constant: Optional[Tensor], tr
)
try_set_proxy_slot(
tensor.storage_offset(),
lambda x: set_meta(tracer.create_proxy('call_function', torch.ops.aten.sym_storage_offset.default, (proxy,)), x)
lambda x: set_meta(tracer.create_proxy('call_function', torch.ops.aten.sym_storage_offset.default, (proxy,)), x) # type: ignore[call-arg]
)
set_proxy_slot(tensor, tracer, _ProxyTensor(proxy, constant))
@ -767,7 +767,7 @@ class PythonKeyTracer(Tracer):
torch_fn_counts: Dict[OpOverload, int]
def __init__(self) -> None:
super().__init__(autowrap_modules=())
super().__init__(autowrap_modules=()) # type: ignore[arg-type]
self.tensor_tracker = WeakTensorKeyDictionary()
self.symnode_tracker = _SymNodeDict()
self.script_object_tracker = WeakIdKeyDictionary(dict=None, ref_type=_WeakHashRef)
@ -803,7 +803,7 @@ class PythonKeyTracer(Tracer):
elif isinstance(a, py_sym_types):
assert a.node.constant is not None
return a.node.constant
return super().create_arg(a)
return super().create_arg(a) # type: ignore[return-value]
@overload
def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]:
@ -867,7 +867,7 @@ def dispatch_trace(
tracer: Tracer,
concrete_args: Optional[Tuple[Any, ...]] = None,
) -> GraphModule:
graph = tracer.trace(root, concrete_args)
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
from torch._inductor.fx_passes.dedupe_symint_uses import dedupe_symints
dedupe_symints(graph)
name = root.__class__.__name__ if isinstance(root, Module) else root.__name__
@ -966,7 +966,7 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
# It's for passing the export verifier which needs to verify the meta['val']
# TODO(tmanlaibaatar): we should systematically couple it with expoert verifier,
# instead of hardcoding it here.
node = self.tracer.create_node("call_function", func, args, {})
node = self.tracer.create_node("call_function", func, args, {}) # type: ignore[arg-type]
if func is torch._C._set_grad_enabled:
node.meta['val'] = None
return node
@ -1090,7 +1090,7 @@ class ProxySymDispatchMode(SymDispatchMode):
# func doesn't have a __torch_function__ that Proxy can interpose, so
# we gotta do it manually
n_out = self.tracer.create_node("call_function", func, n_args, {})
n_out = self.tracer.create_node("call_function", func, n_args, {}) # type: ignore[arg-type]
p_out = fx.Proxy(n_out, self.tracer)
set_meta(p_out, out)
return p_out
@ -1149,7 +1149,7 @@ class DecompositionInterpreter(fx.Interpreter):
decomposition_table: Optional[Mapping[OpOverload, Callable]] = None,
**kwargs: object
) -> None:
super().__init__(module, **kwargs)
super().__init__(module, **kwargs) # type: ignore[arg-type]
self.new_graph = new_graph
self.tracer = _GraphAppendingTracerEx(self.new_graph)
# Blegh
@ -1164,23 +1164,23 @@ class DecompositionInterpreter(fx.Interpreter):
# distinguish between different calls to the same torch function.
self.tracer.torch_fn_counts = {}
def placeholder(self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]) -> object:
out = super().placeholder(target, args, kwargs)
def placeholder(self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]) -> object: # type: ignore[override]
out = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer)
track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
# TODO handle case where the first character of target is '*'
return out
def get_attr(self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]) -> object:
out = super().get_attr(target, args, kwargs)
def get_attr(self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]) -> object: # type: ignore[override]
out = super().get_attr(target, args, kwargs) # type: ignore[arg-type]
proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer)
track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
return out
# call_function, call_method, call_module get traced automatically by the outer mode.
def output(self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]) -> object:
out = super().output(target, args, kwargs)
def output(self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]) -> object: # type: ignore[override]
out = super().output(target, args, kwargs) # type: ignore[arg-type]
def get_proxy_node(x: _ProxyTensor) -> fx.node.Node:
return x.proxy.node
@ -1194,7 +1194,7 @@ class DecompositionInterpreter(fx.Interpreter):
# Should enter the mode at least once for being able to restore it later
# See: https://github.com/pytorch/pytorch/pull/82549#discussion_r934782025
with decompose(self.decomposition_table), self.mode:
return super().run(*args, **kwargs)
return super().run(*args, **kwargs) # type: ignore[arg-type]
def wrapper_and_args_for_make_fx(
@ -1348,7 +1348,7 @@ class _ModuleStackTracer(PythonKeyTracer):
self.attr_proxy_map[attr_val].reset_proxy_mapping(attr_val, attr)
return self.attr_proxy_map[attr_val]
def trace(
def trace( # type: ignore[override]
self,
root: Union[Module, Callable],
concrete_args: Optional[Dict[str, object]]
@ -1438,7 +1438,7 @@ class _ModuleStackTracer(PythonKeyTracer):
Add torch_fn by looking at torch_fn_metadata and torch_fn_counts.
Add stack_trace by filtering out forward() stack frames.
'''
node = super().create_node(*args, **kwargs)
node = super().create_node(*args, **kwargs) # type: ignore[arg-type]
# nn_module_stack
if node.op not in ["placeholder", "output"]:

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from collections import defaultdict
from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import contextlib
import copy

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from .graph_module import GraphModule
from ._lazy_graph_module import _make_graph_module

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# Nodes represent a definition of a value in our graph of operators.
from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set
from ._compatibility import compatibility
@ -769,7 +768,7 @@ def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
if isinstance(a, tuple):
t = tuple(map_aggregate(elem, fn) for elem in a)
# Support NamedTuple (if it has `_fields`) by repacking into original type.
return t if not hasattr(a, '_fields') else type(a)(*t)
return t if not hasattr(a, '_fields') else type(a)(*t) # type: ignore[arg-type]
elif isinstance(a, list):
return immutable_list(map_aggregate(elem, fn) for elem in a)
elif isinstance(a, dict):

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import torch
import inspect

View File

@ -251,8 +251,8 @@ if HAS_PYDOT:
label += f"|target={self._typename(node.target)}" + r"\n"
if self.normalize_args:
try:
args, kwargs = normalize_function(
node.target, node.args, node.kwargs, normalize_to_only_use_kwargs=True
args, kwargs = normalize_function( # type: ignore[misc]
node.target, node.args, node.kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type]
)
except Exception:
# Fallback to not normalizing if there's an exception.

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import Any, Dict, List, NamedTuple, Optional

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import inspect
import logging

View File

@ -302,7 +302,7 @@ class _MinimizerBase:
# Find submodule containing colored nodes
submodule_name: str = ""
for child_name, _ in split_module.named_children():
for child_name, _ in split_module.named_children(): # type: ignore[union-attr]
# Skip submodules we're not interested in at the moment
if "minimize" not in child_name:
continue
@ -319,7 +319,7 @@ class _MinimizerBase:
f"Minimize submodule was not found with nodes {nodes}"
)
return split_module, submodule_name
return split_module, submodule_name # type: ignore[return-value]
def _run_and_compare(
self,
@ -391,10 +391,10 @@ class _MinimizerBase:
report.append(f"Result mismatch for {result_key}")
if self.module_exporter:
self.module_exporter(
a_input, submodule, str(result_key[0]) + "_cpu",
a_input, submodule, str(result_key[0]) + "_cpu", # type: ignore[index]
)
self.module_exporter(
b_input, submodule, str(result_key[0]) + "_acc",
b_input, submodule, str(result_key[0]) + "_acc", # type: ignore[index]
)
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}")

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import abc
import typing as t

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
from torch.fx.graph_module import GraphModule
from typing import Any, Callable, Dict, List, Tuple, Type
import torch

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import logging
import operator

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import inspect
from typing import Any, Callable, Dict, List, Optional, Set
@ -163,14 +162,14 @@ def split_module(
)
if keep_original_node_name:
args = () if default_value is inspect.Signature.empty else (default_value,)
base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type)
base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type) # type: ignore[arg-type]
else:
base_mod_env[node.name] = base_mod_graph.placeholder(
node.target, type_expr=node.type, default_value=default_value
node.target, type_expr=node.type, default_value=default_value # type: ignore[arg-type]
)
base_mod_env[node.name].meta = node.meta.copy()
elif node.op == "get_attr":
base_mod_env[node.name] = base_mod_graph.get_attr(node.target)
base_mod_env[node.name] = base_mod_graph.get_attr(node.target) # type: ignore[arg-type]
base_mod_env[node.name].meta = node.meta.copy()
attr_val = m
for atom in node.target.split("."): # type: ignore[union-attr]

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import copy
from dataclasses import dataclass, field

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import argparse
import copy
@ -860,7 +859,7 @@ class _SplitterBase:
for node in self.module.graph.nodes:
if hasattr(node, "tag"):
del node.tag
return split_module
return split_module # type: ignore[return-value]
def __call__(self) -> torch.fx.GraphModule:
subgraphs = self.put_nodes_into_subgraphs()

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import List, Tuple, Union, Dict, Any, Set, Mapping, Optional
import collections

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import Dict, Tuple

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import copy
from queue import SimpleQueue

View File

@ -31,8 +31,8 @@ def _split_to_graph_and_name_node_map(
name_node_map, Dict
), "Expecting the input graph to have a dict output as the last element"
n.args = (flattened,)
orig_pytree_info = gm._graph._codegen.pytree_info
gm._graph._codegen.pytree_info = _PyTreeInfo(
orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined]
gm._graph._codegen.pytree_info = _PyTreeInfo( # type: ignore[attr-defined]
orig_pytree_info.orig_args, orig_pytree_info.in_spec, out_spec
)
gm.recompile()

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from dataclasses import dataclass, field
from torch.fx.graph import Graph

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
from .graph_module import GraphModule
from .graph import Graph
from .node import Node
@ -319,8 +318,8 @@ def _replace_pattern(
# Hook the output Node of the replacement subgraph in to the
# original Graph at the correct location
assert len(match.returning_nodes) == len(copied_returning_nodes)
for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes):
assert len(match.returning_nodes) == len(copied_returning_nodes) # type: ignore[arg-type]
for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes): # type: ignore[arg-type]
gn.replace_all_uses_with(copied_node)
match_changed_node[gn] = copied_node
# Remove the original nodes

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import traceback
from contextlib import contextmanager

View File

@ -305,7 +305,7 @@ def jagged_torch_function(func, *args, **kwargs):
def _flatten_sig(input, start_dim=0, end_dim=-1):
pass
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
_flatten_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -385,7 +385,7 @@ def tensor_attr_unsupported_getter(func, *args, **kwargs):
def is_contiguous_general(func, *args, **kwargs):
from torch._prims_common import is_contiguous_for_memory_format
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
@ -409,7 +409,7 @@ register_jagged_func(
@register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?")
def linear_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -423,7 +423,7 @@ def linear_default(func, *args, **kwargs):
"self: jt, grad_output: jt, weight: t, output_mask: any",
)
def linear_backward_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -444,7 +444,7 @@ def linear_backward_default(func, *args, **kwargs):
def to_copy_default(func, *args, **kwargs):
from .nested_tensor import _tensor_symint_registry
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -476,7 +476,7 @@ register_jagged_func(torch.ops.aten.detach.default, "self: jt_all")(
"self: jt_all",
)
def like_factory_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -492,7 +492,7 @@ def like_factory_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all")
def zero__default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -505,7 +505,7 @@ def zero__default(func, *args, **kwargs):
torch.ops.aten._softmax.default, "self: jt, dim: any, half_to_float: any"
)
def _softmax_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -521,7 +521,7 @@ def _softmax_default(func, *args, **kwargs):
"grad_output: jt, output: jt, dim: any, input_dtype: any",
)
def _softmax_backward(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
grad_out = new_kwargs.pop("grad_output")
@ -535,7 +535,7 @@ def _softmax_backward(func, *args, **kwargs):
torch.ops.aten.native_dropout.default, "self: jt, float: any, train: any?"
)
def native_dropout_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -552,7 +552,7 @@ def native_dropout_default(func, *args, **kwargs):
"grad_output: jt, mask: jt, scale: any",
)
def native_dropout_backward_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
grad_output = new_kwargs.pop("grad_output")
@ -565,7 +565,7 @@ def native_dropout_backward_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.prod.dim_int, "self: jt, dim: any, keepdim: any?")
def prod_dim_int(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -584,7 +584,7 @@ def prod_dim_int(func, *args, **kwargs):
torch.ops.aten.split.Tensor, "self: jt, split_size: any, dim: any"
)
def split_tensor(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -602,7 +602,7 @@ def split_tensor(func, *args, **kwargs):
torch.ops.aten.split_with_sizes.default, "self: jt, split_sizes: any, dim: any"
)
def split_with_sizes_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -620,7 +620,7 @@ def split_with_sizes_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.chunk.default, "self: jt, chunks: any, dim: any?")
def chunk_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -663,7 +663,7 @@ def chunk_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.unbind.int, "self: jt_all, dim: any?")
def unbind_int(func, *args, **kwargs):
# Note that this specializes on the length of the offsets
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -697,7 +697,7 @@ def unbind_int(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.squeeze.dim, "self: jt, dim: any")
def squeeze_dim(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -710,7 +710,7 @@ def squeeze_dim(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt, dim: any")
def unsqueeze_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -725,7 +725,7 @@ def unsqueeze_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any")
def cat_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -748,7 +748,7 @@ def cat_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.matmul.default, "self: jt, other: any")
def matmul_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -773,7 +773,7 @@ def matmul_default(func, *args, **kwargs):
torch.ops.aten.expand.default, "self: jt, size: any, implicit: any?"
)
def expand_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -790,7 +790,7 @@ def expand_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.expand_as.default, "self: t, other: jt")
def expand_as_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -802,7 +802,7 @@ def expand_as_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.where.self, "condition: jt, self: jt, other: jt")
def where_self(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -820,7 +820,7 @@ def where_self(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten._pin_memory.default, "self: jt, device: any?")
def _pin_memory_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -831,7 +831,7 @@ def _pin_memory_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.is_pinned.default, "self: jt, device: any?")
def is_pinned_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -856,7 +856,7 @@ def sum_dim_IntList(func, *args, **kwargs):
Performs a sum along the provided tensor dimension.
Returns a dense tensor if the ragged dimension is reduced away, else returns a nested tensor.
"""
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
@ -930,7 +930,7 @@ def sum_dim_IntList(func, *args, **kwargs):
torch.ops.aten.transpose.int, "self: jt_all, dim0: any, dim1: any"
)
def transpose_int(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -977,7 +977,7 @@ def transpose_int(func, *args, **kwargs):
"self: jt_all, size: any",
)
def view_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1025,7 +1025,7 @@ def view_default(func, *args, **kwargs):
"input: jt, normalized_shape: any, weight: any?, bias: any?, eps: any",
)
def native_layer_norm_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1047,7 +1047,7 @@ def native_layer_norm_default(func, *args, **kwargs):
"grad_out: jt, input: jt, normalized_shape: any, mean: any, rstd: any, weight: any?, bias: any?, output_mask: any",
)
def native_layer_norm_backward_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
grad_out = new_kwargs.pop("grad_out")
@ -1061,7 +1061,7 @@ def native_layer_norm_backward_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.select.int, "self: jt, dim: any, index: any")
def select_int(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1076,7 +1076,7 @@ def select_int(func, *args, **kwargs):
"self: jt, dim: any?, start: any?, end: any?, step: any?",
)
def slice_tensor(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1092,7 +1092,7 @@ def slice_tensor(func, *args, **kwargs):
"dilation: any, transposed: any, output_padding: any, groups: any",
)
def convolution_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1109,7 +1109,7 @@ def mean_dim(func, *args, **kwargs):
Performs a mean along the provided tensor dimension.
Returns a dense tensor if the ragged dimension is reduced away, else returns a nested tensor.
"""
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1165,7 +1165,7 @@ def mean_dim(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any")
def stack_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1199,7 +1199,7 @@ def stack_default(func, *args, **kwargs):
"weight: t, indices: jt, padding_idx: any?, scale_grad_by_freq: any?, sparse: any?",
)
def embedding_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1220,7 +1220,7 @@ def embedding_default(func, *args, **kwargs):
"self: jt_all",
)
def values_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1236,7 +1236,7 @@ def values_default(func, *args, **kwargs):
"values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?",
)
def _nested_view_from_jagged_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1265,7 +1265,7 @@ def _nested_view_from_jagged_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten._nested_get_offsets.default, "self: jt_all")
def _nested_get_offsets(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1275,7 +1275,7 @@ def _nested_get_offsets(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten._nested_get_lengths.default, "self: jt_all")
def _nested_get_lengths(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1285,7 +1285,7 @@ def _nested_get_lengths(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten._nested_get_ragged_idx.default, "self: jt_all")
def _nested_get_ragged_idx(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1295,7 +1295,7 @@ def _nested_get_ragged_idx(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten._nested_get_min_seqlen.default, "self: jt_all")
def _nested_get_min_seqlen(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1305,7 +1305,7 @@ def _nested_get_min_seqlen(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten._nested_get_max_seqlen.default, "self: jt_all")
def _nested_get_max_seqlen(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)

View File

@ -243,7 +243,7 @@ def _fx_args_to_torch_args(
if isinstance(arg, torch.fx.Node):
fake_tensor = arg.meta.get("val")
if fake_tensor is None and arg.op == "get_attr":
fake_tensor = getattr(fx_graph_module, arg.target) # type: ignore[operator]
fake_tensor = getattr(fx_graph_module, arg.target) # type: ignore[operator, arg-type]
# NOTE: Currently, we are aware of
# FakeTensor/Tensor/SymInt/SymFloat/Symbool/int/float/bool could be in
# arg.meta["val"]/get_attr.
@ -254,8 +254,8 @@ def _fx_args_to_torch_args(
wrapped_args.append(real_tensor)
elif isinstance(fake_tensor, (int, float, bool)):
wrapped_args.append(fake_tensor)
elif symbolic_shapes.has_hint(fake_tensor):
wrapped_args.append(symbolic_shapes.hint_int(fake_tensor))
elif symbolic_shapes.has_hint(fake_tensor): # type: ignore[arg-type]
wrapped_args.append(symbolic_shapes.hint_int(fake_tensor)) # type: ignore[arg-type]
else:
raise ValueError(
f"Unexpected input argument type found inside fx.Node. arg: {arg}; "

View File

@ -61,7 +61,7 @@ def _create_tensor_proto_with_external_data(
tensor_proto = onnx.TensorProto() # type: ignore[attr-defined]
tensor_proto.name = name
tensor_proto.data_type = scalar_type.onnx_type()
tensor_proto.data_type = scalar_type.onnx_type() # type: ignore[assignment]
tensor_proto.dims.extend(tensor.shape)
tensor_proto.data_location = onnx.TensorProto.EXTERNAL # type: ignore[attr-defined]

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import dataclasses
import importlib
@ -647,7 +646,7 @@ Examples::
"""
@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True) # type: ignore[arg-type]
@compatibility(is_backward_compatible=False)
class OrtBackendOptions:
"""Options for constructing an ``OrtBackend``, the ONNX Runtime

View File

@ -611,7 +611,7 @@ def return_and_correct_aliasing(func, args, kwargs, out):
return None
def get_arg_from_alias(output_alias, schema_info, args, kwargs):
new_args, new_kwargs = torch.fx.operator_schemas.normalize_function(
new_args, new_kwargs = torch.fx.operator_schemas.normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs
)