[BE] typing for decorators - fx/_compatibility (part 1) (#134202)

Part of #134054.

This corresponds to the pytorch mypy changes from D61493706. Updating takes so
long and touches so many files that it's impossible to land as a whole without conflicting with some other intermediate change.
So landing these 'type: ignore' for pytorch in advance of them actually being needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134202
Approved by: https://github.com/Skylion007
This commit is contained in:
Aaron Orenstein
2024-08-22 09:42:18 -07:00
committed by PyTorch MergeBot
parent 44fa9f991c
commit d95aedf5fd
52 changed files with 206 additions and 207 deletions

View File

@ -445,7 +445,7 @@ class AutogradCompilerInstance:
def bind_tensors_to_proxies(self, tensors, proxies):
if isinstance(proxies, torch.fx.Proxy):
proxies = [proxies[i] for i in range(len(tensors))]
proxies = [proxies[i] for i in range(len(tensors))] # type: ignore[index]
assert len(tensors) == len(proxies)
track_tensor_tree(tensors, proxies, constant=None, tracer=self.fx_tracer)

View File

@ -1539,7 +1539,7 @@ def export(
# Running graph with interpreter is needed for propagating the stack_trace
def graph_with_interpreter(*args):
with torch.fx.traceback.preserve_node_meta():
return torch.fx.Interpreter(graph).run(*args)
return torch.fx.Interpreter(graph).run(*args) # type: ignore[arg-type]
with unset_fake_temporarily(), enable_python_dispatcher(), fake_mode:
try:
@ -1561,9 +1561,9 @@ def export(
assert graph is not None
for node in graph.graph.find_nodes(op="get_attr"):
if isinstance(getattr(graph, node.target), torch.Tensor):
if isinstance(getattr(graph, node.target), torch.Tensor): # type: ignore[arg-type]
node.meta["val"] = fake_mode.from_tensor(
getattr(graph, node.target), static_shapes=True
getattr(graph, node.target), static_shapes=True # type: ignore[arg-type]
)
if same_signature:

View File

@ -2231,7 +2231,7 @@ def get_real_value(node, tracer):
return cache[node]
op = node.op
args, kwargs = torch.fx.node.map_arg(
args, kwargs = torch.fx.node.map_arg( # type: ignore[misc]
(node.args, node.kwargs),
lambda n: get_real_value(n, tracer),
)

View File

@ -1343,7 +1343,7 @@ class ExplainTS2FXGraphConverter(TS2FXGraphConverter):
self.name_to_node,
# Dummy node.
torch.fx.Node(
None,
None, # type: ignore[arg-type]
"mock",
"call_function",
lambda: None,

View File

@ -68,7 +68,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
self.fake_tensor_mode: Optional[FakeTensorMode] = None
self.submodules: Dict[torch.nn.Module, str] = {}
def trace(self) -> None:
def trace(self) -> None: # type: ignore[override]
raise ExportPassBaseError("ExportTracer doesn't support trace().")
def create_arg(self, a: Argument) -> torch.fx.Node:
@ -160,7 +160,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
def placeholder(
self,
target: str,
target: str, # type: ignore[override]
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
) -> ProxyValue:
@ -218,7 +218,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
raise ExportPassBaseError(f"Unsupported target type: {target}")
def get_attr(
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override]
) -> Argument:
return super().get_attr(target, args, kwargs)
@ -231,7 +231,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
raise ExportPassBaseError("call_module is not supported.")
def call_method(
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override]
) -> None:
raise ExportPassBaseError("call_method is not supported.")
@ -394,7 +394,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
)
self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
interpreter = self.ExportInterpreter(self, graph_module)
prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter(
prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( # type: ignore[assignment]
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
)
inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs)

View File

@ -64,7 +64,7 @@ def _replace_with_hop_helper(
# Rename the name of getitem nodes to the actual name of its contents
# for passing verifier and better readability, also propagate metadata
for get_item_node in call_func_node.users.keys():
idx: int = get_item_node.args[1]
idx: int = get_item_node.args[1] # type: ignore[assignment]
output_node = output_args[idx]
get_item_node._rename(output_node.name)
get_item_node.meta = output_node.meta

View File

@ -177,7 +177,7 @@ def _extract_graph_with_inputs_outputs(
for node in joint_graph.nodes:
if _must_be_in_backward(node) and subgraph != "backward":
env[node] = InvalidNode
env[node] = InvalidNode # type: ignore[assignment]
continue
if node in env:
@ -186,7 +186,7 @@ def _extract_graph_with_inputs_outputs(
# joint_graph.nodes).
continue
elif node.op == "placeholder":
env[node] = InvalidNode
env[node] = InvalidNode # type: ignore[assignment]
elif node.op == "call_function":
all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs)
all_args = [
@ -195,7 +195,7 @@ def _extract_graph_with_inputs_outputs(
if isinstance(x, fx.Node)
]
if any(all_args):
env[node] = InvalidNode
env[node] = InvalidNode # type: ignore[assignment]
continue
env[node] = new_graph.node_copy(node, lambda x: env[x])
elif node.op == "get_attr":

View File

@ -816,7 +816,7 @@ def trace_flex_attention_backward(
)
assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
block_mask = block_mask[:-1] + (mask_graph,)
proxy_mode.tracer.root.register_module("fw_graph", fw_graph)
proxy_mode.tracer.root.register_module("fw_graph", fw_graph) # type: ignore[arg-type]
proxy_mode.tracer.root.register_module("joint_graph", joint_graph)
proxy_mode.tracer.root.register_module("mask_graph", mask_graph)
node_args = (

View File

@ -222,7 +222,7 @@ class ConstantFolder(torch.fx.Interpreter):
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
self.node_replacements[node] = tensor
def run(self) -> Any:
def run(self) -> Any: # type: ignore[override]
env: Dict[torch.fx.Node, Any] = {}
self.insert_placerholder_values(env)
return super().run(initial_env=env)

View File

@ -144,7 +144,7 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
kwargs = {}
if hasattr(snode, "get_device"):
kwargs = {"device": snode.get_device()}
fx_node = graph.call_function(node_func, args=(), kwargs=kwargs)
fx_node = graph.call_function(node_func, args=(), kwargs=kwargs) # type: ignore[arg-type]
def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool:
if isinstance(snode, FusedSchedulerNode):

View File

@ -154,17 +154,17 @@ def binary_folding_init():
return False
if isinstance(other, torch.fx.Node) and other.op == "get_attr":
other_meta_value = other.meta.get("val")
if not other_meta_value.is_floating_point():
if not other_meta_value.is_floating_point(): # type: ignore[union-attr]
return False
if (
torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype)
torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype) # type: ignore[union-attr]
!= weight_meta_value.dtype
):
if not conv_node.meta.get("_allow_conv_mixed_dtype_folding", False):
return False
if (
other_meta_value.dtype != torch.float
other_meta_value.dtype != torch.float # type: ignore[union-attr]
and weight_meta_value.dtype not in (torch.float16, torch.bfloat16)
):
return False

View File

@ -213,7 +213,7 @@ def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs
new_node = graph.create_node(
op="call_function",
target=efficient_conv_bn_eval_decomposed,
args=args,
args=args, # type: ignore[arg-type]
name="efficient_conv_bn_eval",
)
@ -223,7 +223,7 @@ def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs
# take care of the deletion order:
# delete bn_node first, and then conv_node
graph.erase_node(bn_node)
graph.erase_node(conv_node)
graph.erase_node(conv_node) # type: ignore[arg-type]
return
@ -304,7 +304,7 @@ def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwa
new_node = graph.create_node(
op="call_function",
target=efficient_conv_bn_eval_decomposed,
args=args,
args=args, # type: ignore[arg-type]
name="efficient_conv_bn_eval",
)
@ -314,7 +314,7 @@ def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwa
# take care of the deletion order:
# delete bn_node first, and then conv_node
graph.erase_node(bn_node)
graph.erase_node(conv_node)
graph.erase_node(conv_node) # type: ignore[arg-type]
return
@ -373,7 +373,7 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
# Find a pair of conv and bn computation nodes to optimize.
counters["inductor"]["efficient_conv_bn_eval"] += 1
with graph.inserting_before(conv_node):
with graph.inserting_before(conv_node): # type: ignore[arg-type]
# create `get_attr` node to access modules
# note that we directly call `create_node` to fill the `name`
# argument. `graph.get_attr` and
@ -403,4 +403,4 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
# take care of the deletion order:
# delete bn_node first, and then conv_node
graph.erase_node(bn_node)
graph.erase_node(conv_node)
graph.erase_node(conv_node) # type: ignore[arg-type]

View File

@ -223,5 +223,5 @@ def unnecessary_dtype_convert(match: Match, **kwargs):
"""Remove unnecessary dtype conversion op, probably left as a result of Conv-Bn folding"""
graph = match.graph
node = match.output_node()
node.replace_all_uses_with(node.args[0])
node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type]
graph.erase_node(node)

View File

@ -511,7 +511,7 @@ def pointless_view(match: Match, arg, size):
node = match.output_node()
arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr]
if size == arg_size:
node.replace_all_uses_with(node.args[0])
node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type]
match.erase_nodes(graph)

View File

@ -1033,7 +1033,7 @@ def is_index_put_and_requires_h2d_sync_for_gpu_value(node):
# if the value we are putting is a cpu scalar.
# Therefore, when inductor sees an index_put_ with byte tensor indices,
# it should *not* convert the cpu scalar value into a gpu tensor.
args_, kwargs_ = normalize_function(node.target, node.args, node.kwargs)
args_, kwargs_ = normalize_function(node.target, node.args, node.kwargs) # type: ignore[misc]
any_byte_bool_indices = False
indices = args_[1]
for i in indices:

View File

@ -1887,14 +1887,14 @@ def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float3
graph.erase_node(conv_node)
# Erase the dequant pattern
if dtype == torch.bfloat16:
graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined]
graph.erase_node(dequant_node)
graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined, arg-type]
graph.erase_node(dequant_node) # type: ignore[arg-type]
# Erase the dequant per channel pattern
if clone_node is not None:
graph.erase_node(clone_node)
graph.erase_node(clone_node) # type: ignore[arg-type]
if dtype == torch.bfloat16:
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
graph.erase_node(dequant_per_channel)
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type]
graph.erase_node(dequant_per_channel) # type: ignore[arg-type]
counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len(
match.nodes
@ -2581,8 +2581,8 @@ def quant_lift_up(graph_module: torch.fx.GraphModule):
new_args = map_arg(new_quant_node.args, maybe_replace_node)
new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node)
new_quant_node.args = new_args
new_quant_node.kwargs = new_kwargs
new_quant_node.args = new_args # type: ignore[assignment]
new_quant_node.kwargs = new_kwargs # type: ignore[assignment]
graph_module.graph.erase_node(quant_node)
graph_module.graph.lint()

View File

@ -293,7 +293,7 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
return
src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr]
with graph.inserting_before(src):
with graph.inserting_before(src): # type: ignore[arg-type]
new_node = graph_call_function(
graph,
_generalized_scatter,
@ -316,7 +316,7 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
handle_views(new_src)
src.replace_all_uses_with(new_src) # type: ignore[union-attr]
graph.erase_node(src)
graph.erase_node(src) # type: ignore[arg-type]
for node in graph.nodes:
if _is_view_op(node.target):

View File

@ -193,7 +193,7 @@ def normalize_split_base(
new_split_node = graph.call_function(
torch.split,
args=new_args,
kwargs=new_kwargs,
kwargs=new_kwargs, # type: ignore[arg-type]
)
split_node.replace_all_uses_with(new_split_node)
new_split_node.meta.update(split_node.meta)
@ -378,7 +378,7 @@ def normalize_stack_default(match: Match, *args, **kwargs):
with graph.inserting_after(node):
new_node = graph.call_function(
node.target,
node.target, # type: ignore[arg-type]
args=(tensors,),
kwargs={"dim": dim},
)
@ -529,7 +529,7 @@ def merge_splits(
to_remove = []
with graph.inserting_before(first_split):
with graph.inserting_before(first_split): # type: ignore[arg-type]
# Add the new split node
new_split = graph.call_function(
torch.split,
@ -1535,7 +1535,7 @@ def mutate_cat_node(match: Match, split_sections: List[int], dim: int):
# case 1: the cat uses all getitems from the split
if len(split_sections) == len(cat_user.args[0]): # type: ignore[arg-type]
# replace the users of the cat node to be the input of the split node
cat_user.replace_all_uses_with(split_node.args[0])
cat_user.replace_all_uses_with(split_node.args[0]) # type: ignore[arg-type]
# remove the cat node
graph.erase_node(cat_user)
counters["inductor"]["mutate_cat_pass"] += 1
@ -1947,7 +1947,7 @@ def remove_split_unbind_children(graph: torch.fx.Graph, inputs: List[torch.fx.No
# check the split node to remove if it has no users
for node in nodes:
if len(node.users.keys()) == 0: # type: ignore[union-attr]
graph.erase_node(node)
graph.erase_node(node) # type: ignore[arg-type]
# ############pattern to be optimized is#########

View File

@ -766,7 +766,7 @@ class GraphLowering(torch.fx.Interpreter):
return self.graph_inputs[buffer_name].get_numel()
raise KeyError(f"could not find {buffer_name}")
def run(self, *args: Any) -> Any:
def run(self, *args: Any) -> Any: # type: ignore[override]
with dynamo_timed("GraphLowering.run"):
return super().run(*args)
@ -911,9 +911,9 @@ class GraphLowering(torch.fx.Interpreter):
)
def placeholder(
self, target: str, args: Tuple[object], kwargs: Dict[str, object]
self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override]
) -> Union[Expr, TensorBox, None]:
example = super().placeholder(target, args, kwargs)
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
self.graph_input_names.append(target)
if isinstance(example, SymTypes):
expr = example.node.expr
@ -968,7 +968,7 @@ class GraphLowering(torch.fx.Interpreter):
self.aligned_inputs.add(target)
return tensor
def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> Any: # type: ignore[type-arg]
def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> Any: # type: ignore[type-arg, override]
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
return super().call_function(target, args, kwargs)
@ -1023,7 +1023,7 @@ class GraphLowering(torch.fx.Interpreter):
return len(t.shape) == 1 and t.shape[0] <= 8
def get_attr(
self, target: str, args: Tuple[()], kwargs: Dict[str, object]
self, target: str, args: Tuple[()], kwargs: Dict[str, object] # type: ignore[override]
) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
# this is a constant
value = getattr_recursive(self.module, target) # type: ignore[arg-type]
@ -1063,9 +1063,9 @@ class GraphLowering(torch.fx.Interpreter):
raise AssertionError
def output(
self, target: str, args: Tuple[object], kwargs: Dict[str, object]
self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override]
) -> None:
result = super().output(target, args, kwargs)
result = super().output(target, args, kwargs) # type: ignore[arg-type]
if not isinstance(result, (tuple, list)):
# nested subgraphs can have singleton outputs
result = (result,)

View File

@ -6668,7 +6668,7 @@ class InterpreterShim(torch.fx.Interpreter):
self.graph = graph
self.submodules = submodules
self.extra_traceback = False
self.fetch_attr = submodules.__getitem__
self.fetch_attr = submodules.__getitem__ # type: ignore[method-assign]
self.current_node = None
def run_node(self, n: torch.fx.Node) -> Any:

View File

@ -236,7 +236,7 @@ class Match:
if trace_fn is None:
trace_fn = functools.partial(fwd_only, run_dce=run_dce)
replacement = trace_fn(
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"])
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) # type: ignore[arg-type]
)
ReplacementPatternEntry.replace_with_graph(
self,
@ -607,7 +607,7 @@ class _TargetArgsExpr(_TargetExpr):
from torch.fx.operator_schemas import normalize_function
normalized_args_and_kwargs = normalize_function(
node.target, node.args, node.kwargs
node.target, node.args, node.kwargs # type: ignore[arg-type]
)
if normalized_args_and_kwargs is None:
@ -1036,7 +1036,7 @@ class ReplacementPatternEntry(PatternEntry):
if node.op == "call_function":
target = node.target
args, kwargs = self.fetch_args_kwargs_from_env(node)
result = graph.call_function(target, args, kwargs)
result = graph.call_function(target, args, kwargs) # type: ignore[arg-type]
if "val" in node.meta and "val" not in result.meta:
result.meta["val"] = node.meta["val"]
if isinstance(node.meta["val"], torch.Tensor):
@ -1080,7 +1080,7 @@ class ReplacementPatternEntry(PatternEntry):
queue.extend(arg.all_input_nodes)
with graph.inserting_before(last_node):
replacement = Replacer(replacement_graph).run(*args)
replacement = Replacer(replacement_graph).run(*args) # type: ignore[arg-type]
if isinstance(replacement, torch.fx.Node):
replacement = [replacement]
@ -1101,7 +1101,7 @@ class ReplacementPatternEntry(PatternEntry):
return
assert isinstance(old, torch.fx.Node)
if new is None:
old.replace_all_uses_with(None)
old.replace_all_uses_with(None) # type: ignore[arg-type]
graph.erase_node(old)
return
if isinstance(new, torch.fx.Node):
@ -1124,7 +1124,6 @@ class ReplacementPatternEntry(PatternEntry):
graph.erase_node(old)
return
new = typing.cast(Sequence[torch.fx.Node], new)
# `new` is not a node: it's a list of nodes.
#
# This happens when we want to replace a node that has a single
@ -1157,7 +1156,7 @@ class ReplacementPatternEntry(PatternEntry):
idx = maybe_getitem(user)
if idx is None:
raise AssertionError("can't handle")
replace(user, new[idx])
replace(user, new[idx]) # type: ignore[index]
graph.erase_node(old)
if len(output_nodes) == len(replacement):
@ -1233,7 +1232,7 @@ def register_replacement(
)
args = list(
torch.fx.map_arg(
torch.fx.map_arg( # type: ignore[arg-type]
[match.kwargs[name] for name in argnames], lambda n: n.meta["val"]
)
)
@ -1672,8 +1671,8 @@ class PatternMatcherPass:
raise RuntimeError(
f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}"
)
if should_compute_mutation_region_ids(graph):
compute_mutation_region_ids(graph)
if should_compute_mutation_region_ids(graph): # type: ignore[arg-type]
compute_mutation_region_ids(graph) # type: ignore[arg-type]
get_mutation_region_id_partial = functools.partial(
get_mutation_region_id, graph
)
@ -1764,7 +1763,7 @@ def fx_to_pattern(
get_attr = _not_implemented
def placeholder(
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any]
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
) -> Union[ExclusiveKeywordArg, KeywordArg]:
n = next(argnum)
if n < len(argnames):
@ -1781,7 +1780,7 @@ def fx_to_pattern(
return KeywordArg(name)
def call_function(
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any]
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
) -> PatternExpr:
args, kwargs = pytree.tree_map(process_arg, (args, kwargs))
if list in ignore_types:
@ -1800,7 +1799,7 @@ def fx_to_pattern(
rv.users = len(n.users)
return rv
pattern = Converter(gm).run()
pattern = Converter(gm).run() # type: ignore[arg-type]
if not isinstance(pattern, PatternExpr):
return MultiOutputPattern(pytree.tree_leaves(pattern))
return pattern

View File

@ -47,7 +47,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
def call_function(
self,
target: Callable[[Any], Any],
target: Callable[[Any], Any], # type: ignore[override]
args: Any,
kwargs: Dict[str, Any],
) -> Any:
@ -70,7 +70,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
return lowerings[target](*args, **kwargs)
def output(self, target: str, args: Tuple[Any], kwargs: Dict[str, Any]) -> None:
def output(self, target: str, args: Tuple[Any], kwargs: Dict[str, Any]) -> None: # type: ignore[override]
assert len(args) == 1
self.graph_outputs = args[0]

View File

@ -371,7 +371,7 @@ def gen_gm_and_inputs(target, args, kwargs):
len(target._schema.returns) == 1
and str(target._schema.returns[0].type) == "Tensor"
):
node = (node,)
node = (node,) # type: ignore[assignment]
g.output(node)
gm = torch.fx.GraphModule({}, g)

View File

@ -2216,8 +2216,8 @@ class FakeTensorMode(TorchDispatchMode):
any_constant = any(e.constant is not None for e in flat_arg_fake_tensors)
schema_info = get_schema_info(func)
if any_constant and schema_info.is_mutable():
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type]
)
for k, v in new_kwargs.items():
k = k if (k != "input" or schema_info.has_argument(k)) else "self"

View File

@ -896,7 +896,7 @@ def prepare_n_shadows_model(
tracer = custom_tracer
mt = torch.fx.GraphModule(model, tracer.trace(model))
# this is necessary to ensure logger FQNs get populated
mt._node_name_to_scope = tracer.node_name_to_scope
mt._node_name_to_scope = tracer.node_name_to_scope # type: ignore[assignment]
# run example input propagation, we need this to call prepare_fx on
# individual subgraphs
@ -998,7 +998,7 @@ def _prepare_n_shadows_add_loggers_model(
tracer = quantize_fx.QuantizationTracer([], [])
mt = torch.fx.GraphModule(model, tracer.trace(model))
# this is necessary to ensure logger FQNs get populated
mt._node_name_to_scope = tracer.node_name_to_scope
mt._node_name_to_scope = tracer.node_name_to_scope # type: ignore[assignment]
# run example input propagation, we need this to call prepare_fx on
# individual subgraphs

View File

@ -694,13 +694,13 @@ def _insert_copy_of_node_a_after_input_node_c(
mod_a = getattr_from_fqn(gm_a, node_a.target)
setattr(gm_b, new_mod_copy_name, mod_a)
node_a_shadows_c = graph_c.create_node(
node_a.op, new_mod_copy_name, new_args, new_kwargs, node_a_shadows_c_name
node_a.op, new_mod_copy_name, new_args, new_kwargs, node_a_shadows_c_name # type: ignore[arg-type]
)
return node_a_shadows_c
else:
assert node_a.op in ("call_function", "call_method")
node_a_shadows_c = graph_c.create_node(
node_a.op, node_a.target, new_args, new_kwargs, node_a_shadows_c_name
node_a.op, node_a.target, new_args, new_kwargs, node_a_shadows_c_name # type: ignore[arg-type]
)
return node_a_shadows_c

View File

@ -406,19 +406,19 @@ def create_submodule_from_subgraph(
mod_name = f"mod_{cur_name_idx}"
setattr(gm, mod_name, orig_mod_copy)
cur_name_idx += 1
cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined]
cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined,arg-type]
elif cur_node_orig.op == "call_function":
cur_node_copy = g.call_function(
cur_node_orig.target,
cur_args_copy,
cur_node_orig.target, # type: ignore[arg-type]
cur_args_copy, # type: ignore[arg-type]
cur_kwargs_copy, # type: ignore[possibly-undefined]
)
elif cur_node_orig.op == "call_method":
cur_node_copy = g.call_method(
cur_node_orig.target,
cur_args_copy,
cur_node_orig.target, # type: ignore[arg-type]
cur_args_copy, # type: ignore[arg-type]
cur_kwargs_copy, # type: ignore[possibly-undefined]
)
@ -582,7 +582,7 @@ def create_one_transformed_and_logged_copy_of_subgraph(
new_args = tuple(new_args) # type: ignore[assignment]
new_node = mt.graph.call_module(attr_name, args=new_args, kwargs=new_kwargs)
new_node = mt.graph.call_module(attr_name, args=new_args, kwargs=new_kwargs) # type: ignore[arg-type]
# add a logger to parent graph to observe the shadow wrapper
logger_mod_orig = _get_logger_for_subgraph(

View File

@ -311,4 +311,4 @@ class BaseStructuredSparsifier(BaseSparsifier):
self.traced.graph.lint()
self.traced.recompile()
return self.traced
return self.traced # type: ignore[return-value]

View File

@ -25,7 +25,7 @@ def _match(
if isinstance(current, type) and issubclass(current, torch.nn.Module):
return (
node.op == "call_module"
and parametrize.type_before_parametrizations(modules[node.target])
and parametrize.type_before_parametrizations(modules[node.target]) # type: ignore[index]
== current
)
elif callable(current):

View File

@ -569,7 +569,7 @@ def _match_static_pattern(
match_key = type(_get_module(ref_node, modules))
else:
expected_op = "call_function"
match_key = ref_node.target
match_key = ref_node.target # type: ignore[assignment]
if ref_node.op != expected_op or match_key not in matching_modules_or_ops:
return SKIP_LOWERING_VALUE
@ -589,7 +589,7 @@ def _match_static_pattern(
if not matched_dequantize:
return SKIP_LOWERING_VALUE
return (q_node, relu_node, ref_node)
return (q_node, relu_node, ref_node) # type: ignore[return-value]
def _match_static_pattern_with_two_inputs(
@ -687,8 +687,8 @@ def _lower_static_weighted_ref_module(
continue
else:
q_class = STATIC_LOWER_MODULE_MAP[ref_class]
output_scale = getattr(model, scale_node.target)
output_zero_point = getattr(model, zero_point_node.target)
output_scale = getattr(model, scale_node.target) # type: ignore[arg-type]
output_zero_point = getattr(model, zero_point_node.target) # type: ignore[arg-type]
q_module = q_class.from_reference(ref_module, output_scale, output_zero_point)
# replace reference module with quantized module
parent_name, module_name = _parent_name(ref_node.target)
@ -698,7 +698,7 @@ def _lower_static_weighted_ref_module(
assert len(ref_node.args) == 1
dq_node = ref_node.args[0]
assert isinstance(dq_node, Node)
ref_node.replace_input_with(dq_node, dq_node.args[0])
ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type]
q_node.replace_all_uses_with(ref_node)
model.graph.erase_node(q_node)
model.graph.erase_node(scale_node)
@ -747,8 +747,8 @@ def _lower_static_weighted_ref_module_with_two_inputs(
continue
else:
continue
output_scale = getattr(model, scale_node.target)
output_zero_point = getattr(model, zero_point_node.target)
output_scale = getattr(model, scale_node.target) # type: ignore[arg-type]
output_zero_point = getattr(model, zero_point_node.target) # type: ignore[arg-type]
q_module = q_class.from_reference(ref_module, output_scale, output_zero_point)
# replace reference module with quantized module
parent_name, module_name = _parent_name(ref_node.target)
@ -761,7 +761,7 @@ def _lower_static_weighted_ref_module_with_two_inputs(
continue
dq_node = arg
assert isinstance(dq_node, Node)
ref_node.replace_input_with(dq_node, dq_node.args[0])
ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type]
q_node.replace_all_uses_with(ref_node)
model.graph.erase_node(q_node)
@ -904,7 +904,7 @@ def _lower_static_weighted_ref_functional(
prepack_args[5], prepack_args[6] = prepack_args[6], prepack_args[5]
else:
raise ValueError(f"Lowering is not supported for op '{func_node.target}'")
with model.graph.inserting_before(output_scale_node):
with model.graph.inserting_before(output_scale_node): # type: ignore[arg-type]
# kwargs of the func node are needed for prepack op (i.e., quantized::linear_prepack)
# They are not needed for compute op (i.e., quantized::linear)
kwargs = func_node.kwargs
@ -1105,7 +1105,7 @@ def _lower_quantized_binary_op(model: GraphModule, qconfig_map: Dict[str, QConfi
dq_node = arg
assert isinstance(dq_node, Node)
dn_input = dq_node.args[0]
bop_node.replace_input_with(dq_node, dn_input)
bop_node.replace_input_with(dq_node, dn_input) # type: ignore[arg-type]
num_dq_nodes += 1
assert num_dq_nodes > 0

View File

@ -821,7 +821,7 @@ def _reroute_tuple_getitem_pattern(graph: Graph):
last_getitem_index = last_getitem.args[1]
new_input = first_tuple.args[0][last_getitem_index] # type: ignore[index]
for user in list(last_getitem.users.keys()):
user.replace_input_with(last_getitem, new_input)
user.replace_input_with(last_getitem, new_input) # type: ignore[arg-type]
def _get_observer_from_activation_post_process(

View File

@ -46,8 +46,8 @@ def _maybe_duplicate_dq(
new_args = map_arg(user.args, maybe_replace_node)
new_kwargs = map_arg(user.kwargs, maybe_replace_node)
user.args = new_args
user.kwargs = new_kwargs
user.args = new_args # type: ignore[assignment]
user.kwargs = new_kwargs # type: ignore[assignment]
class DuplicateDQPass(PassBase):

View File

@ -107,8 +107,8 @@ def _find_q_dq_node_for_user(
q_node = None
if (
dq_node.args[0].op == "call_function"
and dq_node.args[0].target in _QUANTIZE_OPS
dq_node.args[0].op == "call_function" # type: ignore[union-attr]
and dq_node.args[0].target in _QUANTIZE_OPS # type: ignore[union-attr]
):
q_node = dq_node.args[0]
return (q_node, dq_node)
@ -362,7 +362,7 @@ def _get_aten_graph_module_for_pattern(
[x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs]
)
aten_pattern = capture_pre_autograd_graph(
pattern,
pattern, # type: ignore[arg-type]
example_inputs,
kwargs,
)
@ -382,7 +382,7 @@ def _get_aten_graph_module_for_pattern(
aten_pattern.graph.eliminate_dead_code()
aten_pattern.recompile()
return aten_pattern
return aten_pattern # type: ignore[return-value]
def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:

View File

@ -979,13 +979,13 @@ def _annotate_cat(
inputs = cat_node.args[0]
input_qspec_map = {}
input_act0 = inputs[0]
input_act0 = inputs[0] # type: ignore[index]
if isinstance(input_act0, Node):
input_qspec_map[input_act0] = input_act_qspec
shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node))
for input_act in inputs[1:]:
input_qspec_map[input_act] = shared_with_input0_qspec
shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node)) # type: ignore[arg-type]
for input_act in inputs[1:]: # type: ignore[index]
input_qspec_map[input_act] = shared_with_input0_qspec # type: ignore[index]
output_act_qspec = shared_with_input0_qspec

View File

@ -474,7 +474,7 @@ def _insert_reshard_gm(
input_node: input_arg,
},
)
node.replace_input_with(input_arg, output_node)
node.replace_input_with(input_arg, output_node) # type: ignore[arg-type]
def _clean_up_graph_metadata(gm: torch.fx.GraphModule) -> None:

View File

@ -213,7 +213,7 @@ def normalize_sizes(sizes: Union[Shape, Tuple[Shape]]) -> Shape:
if isinstance(sizes[0], int):
return cast(Shape, sizes)
elif len(sizes) == 1:
return cast(Shape, sizes[0]) # type: ignore[redundant-cast]
return sizes[0]
else:
raise RuntimeError("Size must be int... or tuple")

View File

@ -96,11 +96,11 @@ class _ExecOrderTracer:
self.exec_info = _ExecutionInfo(root_module)
orig_call_module = tracer.call_module
orig_create_proxy = tracer.create_proxy
tracer.call_module = functools.partial(
tracer.call_module = functools.partial( # type: ignore[method-assign]
self._patched_call_module, orig_call_module, self.exec_info
)
fqn_to_param = dict(root_module.named_parameters())
tracer.create_proxy = functools.partial(
tracer.create_proxy = functools.partial( # type: ignore[method-assign]
self._patched_create_proxy,
orig_create_proxy,
self.exec_info,
@ -109,8 +109,8 @@ class _ExecOrderTracer:
try:
yield
finally:
tracer.call_module = orig_call_module
tracer.create_proxy = orig_create_proxy
tracer.call_module = orig_call_module # type: ignore[method-assign]
tracer.create_proxy = orig_create_proxy # type: ignore[method-assign]
def _patched_call_module(
self,
@ -216,8 +216,8 @@ class _ExecOrderTracer:
isinstance(arg, torch.fx.Proxy)
and arg.node.target in fqn_to_param
):
param = fqn_to_param[arg.node.target]
named_params.append((arg.node.target, param))
param = fqn_to_param[arg.node.target] # type: ignore[index]
named_params.append((arg.node.target, param)) # type: ignore[arg-type]
if param not in exec_info.visited_params:
exec_info.visited_params.add(param)
exec_info.param_forward_order.append(param)

View File

@ -214,7 +214,7 @@ def _insert_stage_symbolic_backward(
input_nodes = list(node.all_input_nodes)
grads_proxy = fx.Proxy(grads)
for i, input_node in enumerate(input_nodes):
assign_or_accumulate_grad(input_node, grads_proxy[i].node)
assign_or_accumulate_grad(input_node, grads_proxy[i].node) # type: ignore[index]
return g
@ -416,15 +416,15 @@ class _LinearNodeList:
def __init__(self, node_list):
self.serialize_node_list = []
for node in node_list:
node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name))
node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name))
node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value]
node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value]
serialize_node = fx.Node(
graph=None,
graph=None, # type: ignore[arg-type]
name=node.name,
op=node.op,
target=node.target,
args=node_args,
kwargs=node_kwargs,
args=node_args, # type: ignore[arg-type]
kwargs=node_kwargs, # type: ignore[arg-type]
return_type=node.type,
)
serialize_node.meta = copy.copy(node.meta)
@ -447,8 +447,8 @@ class _LinearNodeList:
deser_node = graph.create_node(
op=node.op,
target=node.target,
args=node_args,
kwargs=node_kwargs,
args=node_args, # type: ignore[arg-type]
kwargs=node_kwargs, # type: ignore[arg-type]
name=node.name,
type_expr=node.type,
)
@ -731,7 +731,7 @@ class Pipe(torch.nn.Module):
# TODO: what does split do with module invocations? does it move the modules
# into the submodules?
split = split_module(traced, mod, split_callback)
split = split_module(traced, mod, split_callback) # type: ignore[arg-type]
# a (custom) tracer can produce dead code like orphan get_attr nodes
split.graph.eliminate_dead_code()

View File

@ -86,7 +86,7 @@ class TensorChunkSpec:
"""
args_chunk_spec = map_aggregate(
chunk_dims,
lambda dim: TensorChunkSpec(dim),
lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
)
return args_chunk_spec
@ -104,7 +104,7 @@ class TensorChunkSpec:
"""
kwargs_chunk_spec = map_aggregate(
chunk_dims,
lambda dim: TensorChunkSpec(dim),
lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
)
return kwargs_chunk_spec

View File

@ -681,7 +681,7 @@ def _export_to_aten_ir(
if fake_mode:
insert_deferred_runtime_asserts(
gm,
fake_mode.shape_env,
fake_mode.shape_env, # type: ignore[arg-type]
f"exported program: {first_call_function_nn_module_stack(gm.graph)}",
export=True,
)

View File

@ -165,7 +165,7 @@ class MetaTracer(torch.fx.Tracer):
meta_target = manual_meta_overrides.get(target, target)
meta_out = meta_target(*args_metas, **kwargs_metas)
elif kind == 'call_method':
meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas)
meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas) # type: ignore[index]
elif kind == 'call_module':
assert hasattr(self, 'orig_forward')
self._disable_module_getattr = True
@ -173,7 +173,7 @@ class MetaTracer(torch.fx.Tracer):
mod = self.root.get_submodule(target)
mod_type = type(mod)
if mod_type in manual_meta_overrides:
meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas)
meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas) # type: ignore[misc, arg-type]
else:
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
finally:
@ -237,7 +237,7 @@ class MetaTracer(torch.fx.Tracer):
def proxy(self, node):
return MetaProxy(node, self)
def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None):
def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None): # type: ignore[override]
assert isinstance(meta_args, dict)
self.meta_args = meta_args
@ -263,7 +263,7 @@ def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]],
meta_args : Optional[Dict[str, torch.Tensor]] = None,
concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule:
tracer = MetaTracer()
graph = tracer.trace(root, meta_args, concrete_args)
graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type]
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
gm = torch.fx.GraphModule(tracer.root, graph, name)
return gm

View File

@ -990,7 +990,7 @@ class PythonKeyTracer(Tracer):
torch_fn_counts: Dict[OpOverload, int]
def __init__(self) -> None:
super().__init__(autowrap_modules=())
super().__init__(autowrap_modules=()) # type: ignore[arg-type]
self.tensor_tracker = WeakTensorKeyDictionary()
self.symnode_tracker = _SymNodeDict()
self.script_object_tracker = WeakIdKeyDictionary(
@ -1036,7 +1036,7 @@ class PythonKeyTracer(Tracer):
elif isinstance(a, py_sym_types):
assert a.node.constant is not None
return a.node.constant
return super().create_arg(a)
return super().create_arg(a) # type: ignore[return-value]
@overload
def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]:
@ -1103,7 +1103,7 @@ def dispatch_trace(
tracer: Tracer,
concrete_args: Optional[Tuple[Any, ...]] = None,
) -> GraphModule:
graph = tracer.trace(root, concrete_args)
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
# NB: be careful not to DCE .item() calls
def impure_pred(n: fx.Node) -> bool:
@ -1228,7 +1228,7 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
# It's for passing the export verifier which needs to verify the meta['val']
# TODO(tmanlaibaatar): we should systematically couple it with expoert verifier,
# instead of hardcoding it here.
node = self.tracer.create_node("call_function", func, args, {})
node = self.tracer.create_node("call_function", func, args, {}) # type: ignore[arg-type]
if func is torch._C._set_grad_enabled:
node.meta["val"] = None
return node
@ -1323,7 +1323,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
# func doesn't have a __torch_function__ that Proxy can interpose, so
# we gotta do it manually
n_out = self.tracer.create_node("call_function", func, n_args, {})
n_out = self.tracer.create_node("call_function", func, n_args, {}) # type: ignore[arg-type]
p_out = fx.Proxy(n_out, self.tracer)
set_meta(p_out, out)
return p_out
@ -1382,7 +1382,7 @@ class DecompositionInterpreter(fx.Interpreter):
decomposition_table: Optional[Mapping[OpOverload, Callable]] = None,
**kwargs: object,
) -> None:
super().__init__(module, **kwargs)
super().__init__(module, **kwargs) # type: ignore[arg-type]
self.new_graph = new_graph
self.tracer = _GraphAppendingTracerEx(self.new_graph)
# Blegh
@ -1399,18 +1399,18 @@ class DecompositionInterpreter(fx.Interpreter):
self.tracer.torch_fn_counts = {}
def placeholder(
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override]
) -> object:
out = super().placeholder(target, args, kwargs)
out = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer)
track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
# TODO handle case where the first character of target is '*'
return out
def get_attr(
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override]
) -> object:
out = super().get_attr(target, args, kwargs)
out = super().get_attr(target, args, kwargs) # type: ignore[arg-type]
proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer)
track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
return out
@ -1418,9 +1418,9 @@ class DecompositionInterpreter(fx.Interpreter):
# call_function, call_method, call_module get traced automatically by the outer mode.
def output(
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override]
) -> object:
out = super().output(target, args, kwargs)
out = super().output(target, args, kwargs) # type: ignore[arg-type]
def get_proxy_node(x: _ProxyTensor) -> fx.node.Node:
return x.proxy.node
@ -1435,7 +1435,7 @@ class DecompositionInterpreter(fx.Interpreter):
# Should enter the mode at least once for being able to restore it later
# See: https://github.com/pytorch/pytorch/pull/82549#discussion_r934782025
with decompose(self.decomposition_table), self.mode:
return super().run(*args, **kwargs)
return super().run(*args, **kwargs) # type: ignore[arg-type]
def wrapper_and_args_for_make_fx(
@ -1596,7 +1596,7 @@ class _ModuleStackTracer(PythonKeyTracer):
self.attr_proxy_map[attr_val].reset_proxy_mapping(attr_val, attr)
return self.attr_proxy_map[attr_val]
def trace(
def trace( # type: ignore[override]
self, root: Union[Module, Callable], concrete_args: Optional[Dict[str, object]]
) -> fx.Graph:
res = super().trace(root, concrete_args)
@ -1690,7 +1690,7 @@ class _ModuleStackTracer(PythonKeyTracer):
Add torch_fn by looking at torch_fn_metadata and torch_fn_counts.
Add stack_trace by filtering out forward() stack frames.
"""
node = super().create_node(*args, **kwargs)
node = super().create_node(*args, **kwargs) # type: ignore[arg-type]
# nn_module_stack
if node.op not in ["placeholder", "output"]:

View File

@ -768,7 +768,7 @@ def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
if isinstance(a, tuple):
t = tuple(map_aggregate(elem, fn) for elem in a)
# Support NamedTuple (if it has `_fields`) by repacking into original type.
return t if not hasattr(a, '_fields') else type(a)(*t)
return t if not hasattr(a, '_fields') else type(a)(*t) # type: ignore[arg-type]
elif isinstance(a, list):
return immutable_list(map_aggregate(elem, fn) for elem in a)
elif isinstance(a, dict):

View File

@ -251,8 +251,8 @@ if HAS_PYDOT:
label += f"|target={self._typename(node.target)}" + r"\n"
if self.normalize_args:
try:
args, kwargs = normalize_function(
node.target, node.args, node.kwargs, normalize_to_only_use_kwargs=True
args, kwargs = normalize_function( # type: ignore[misc]
node.target, node.args, node.kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type]
)
except Exception:
# Fallback to not normalizing if there's an exception.

View File

@ -299,7 +299,7 @@ class _MinimizerBase:
# Find submodule containing colored nodes
submodule_name: str = ""
for child_name, _ in split_module.named_children():
for child_name, _ in split_module.named_children(): # type: ignore[union-attr]
# Skip submodules we're not interested in at the moment
if "minimize" not in child_name:
continue
@ -316,7 +316,7 @@ class _MinimizerBase:
f"Minimize submodule was not found with nodes {nodes}"
)
return split_module, submodule_name
return split_module, submodule_name # type: ignore[return-value]
def _run_and_compare(
self,
@ -388,10 +388,10 @@ class _MinimizerBase:
report.append(f"Result mismatch for {result_key}")
if self.module_exporter:
self.module_exporter(
a_input, submodule, str(result_key[0]) + "_cpu",
a_input, submodule, str(result_key[0]) + "_cpu", # type: ignore[index]
)
self.module_exporter(
b_input, submodule, str(result_key[0]) + "_acc",
b_input, submodule, str(result_key[0]) + "_acc", # type: ignore[index]
)
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}")

View File

@ -163,14 +163,14 @@ def split_module(
)
if keep_original_node_name:
args = () if default_value is inspect.Signature.empty else (default_value,)
base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type)
base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type) # type: ignore[arg-type]
else:
base_mod_env[node.name] = base_mod_graph.placeholder(
node.target, type_expr=node.type, default_value=default_value
node.target, type_expr=node.type, default_value=default_value # type: ignore[arg-type]
)
base_mod_env[node.name].meta = node.meta.copy()
elif node.op == "get_attr":
base_mod_env[node.name] = base_mod_graph.get_attr(node.target)
base_mod_env[node.name] = base_mod_graph.get_attr(node.target) # type: ignore[arg-type]
base_mod_env[node.name].meta = node.meta.copy()
attr_val = m
for atom in node.target.split("."): # type: ignore[union-attr]

View File

@ -869,7 +869,7 @@ class _SplitterBase:
for node in self.module.graph.nodes:
if hasattr(node, "tag"):
del node.tag
return split_module
return split_module # type: ignore[return-value]
def __call__(self) -> torch.fx.GraphModule:
subgraphs = self.put_nodes_into_subgraphs()

View File

@ -32,8 +32,8 @@ def _split_to_graph_and_name_node_map(
name_node_map, Dict
), "Expecting the input graph to have a dict output as the last element"
n.args = (flattened,)
orig_pytree_info = gm._graph._codegen.pytree_info
gm._graph._codegen.pytree_info = _PyTreeInfo(
orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined]
gm._graph._codegen.pytree_info = _PyTreeInfo( # type: ignore[attr-defined]
orig_pytree_info.orig_args, orig_pytree_info.in_spec, out_spec
)
gm.recompile()

View File

@ -319,8 +319,8 @@ def _replace_pattern(
# Hook the output Node of the replacement subgraph in to the
# original Graph at the correct location
assert len(match.returning_nodes) == len(copied_returning_nodes)
for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes):
assert len(match.returning_nodes) == len(copied_returning_nodes) # type: ignore[arg-type]
for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes): # type: ignore[arg-type]
gn.replace_all_uses_with(copied_node)
match_changed_node[gn] = copied_node
# Remove the original nodes

View File

@ -324,7 +324,7 @@ def jagged_torch_function(func, *args, **kwargs):
def _flatten_sig(input, start_dim=0, end_dim=-1):
pass
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
_flatten_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -404,7 +404,7 @@ def tensor_attr_unsupported_getter(func, *args, **kwargs):
def is_contiguous_general(func, *args, **kwargs):
from torch._prims_common import is_contiguous_for_memory_format
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
@ -459,7 +459,7 @@ def clone_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?")
def linear_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -473,7 +473,7 @@ def linear_default(func, *args, **kwargs):
"self: jt, grad_output: jt, weight: t, output_mask: any",
)
def linear_backward_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -505,7 +505,7 @@ def to_dtype(func, *args, **kwargs):
def to_copy_default(func, *args, **kwargs):
from .nested_tensor import _tensor_symint_registry
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -539,7 +539,7 @@ def to_copy_default(func, *args, **kwargs):
torch.ops.aten.copy_.default, "self: jt_all, src: jt_all, non_blocking: any?"
)
def copy_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
@ -567,7 +567,7 @@ register_jagged_func(torch.ops.aten.detach.default, "self: jt_all")(
"self: jt_all",
)
def like_factory_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -583,7 +583,7 @@ def like_factory_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all")
def zero__default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -596,7 +596,7 @@ def zero__default(func, *args, **kwargs):
torch.ops.aten._softmax.default, "self: jt_all, dim: any, half_to_float: any"
)
def _softmax_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -672,7 +672,7 @@ def _softmax_default(func, *args, **kwargs):
"grad_output: jt, output: jt, dim: any, input_dtype: any",
)
def _softmax_backward(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
grad_out = new_kwargs.pop("grad_output")
@ -686,7 +686,7 @@ def _softmax_backward(func, *args, **kwargs):
torch.ops.aten.native_dropout.default, "self: jt, float: any, train: any?"
)
def native_dropout_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -703,7 +703,7 @@ def native_dropout_default(func, *args, **kwargs):
"grad_output: jt, mask: jt, scale: any",
)
def native_dropout_backward_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
grad_output = new_kwargs.pop("grad_output")
@ -716,7 +716,7 @@ def native_dropout_backward_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.prod.dim_int, "self: jt, dim: any, keepdim: any?")
def prod_dim_int(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -735,7 +735,7 @@ def prod_dim_int(func, *args, **kwargs):
torch.ops.aten.split.Tensor, "self: jt, split_size: any, dim: any"
)
def split_tensor(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -753,7 +753,7 @@ def split_tensor(func, *args, **kwargs):
torch.ops.aten.split_with_sizes.default, "self: jt, split_sizes: any, dim: any"
)
def split_with_sizes_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -773,7 +773,7 @@ def split_with_sizes_default(func, *args, **kwargs):
torch.ops.aten.narrow.default, "self: jt, dim: any, start: any, length: any"
)
def narrow(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
@ -790,7 +790,7 @@ def narrow(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.chunk.default, "self: jt, chunks: any, dim: any?")
def chunk_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -833,7 +833,7 @@ def chunk_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.unbind.int, "self: jt_all, dim: any?")
def unbind_int(func, *args, **kwargs):
# Note that this specializes on the length of the offsets
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -867,7 +867,7 @@ def unbind_int(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.squeeze.dim, "self: jt, dim: any")
def squeeze_dim(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -880,7 +880,7 @@ def squeeze_dim(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt, dim: any")
def unsqueeze_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -895,7 +895,7 @@ def unsqueeze_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any")
def cat_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -918,7 +918,7 @@ def cat_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.matmul.default, "self: jt, other: any")
def matmul_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -943,7 +943,7 @@ def matmul_default(func, *args, **kwargs):
torch.ops.aten.expand.default, "self: jt, size: any, implicit: any?"
)
def expand_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -960,7 +960,7 @@ def expand_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.expand_as.default, "self: t, other: jt")
def expand_as_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -972,7 +972,7 @@ def expand_as_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.where.self, "condition: jt, self: jt, other: jt")
def where_self(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -990,7 +990,7 @@ def where_self(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten._pin_memory.default, "self: jt, device: any?")
def _pin_memory_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1001,7 +1001,7 @@ def _pin_memory_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.is_pinned.default, "self: jt, device: any?")
def is_pinned_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1026,7 +1026,7 @@ def sum_dim_IntList(func, *args, **kwargs):
Performs a sum along the provided tensor dimension.
Returns a dense tensor if the ragged dimension is reduced away, else returns a nested tensor.
"""
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
@ -1117,7 +1117,7 @@ def sum_dim_IntList(func, *args, **kwargs):
torch.ops.aten.transpose.int, "self: jt_all, dim0: any, dim1: any"
)
def transpose_int(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1164,7 +1164,7 @@ def transpose_int(func, *args, **kwargs):
"self: jt_all, size: any",
)
def view_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1212,7 +1212,7 @@ def view_default(func, *args, **kwargs):
"input: jt_all, normalized_shape: any, weight: any?, bias: any?, eps: any",
)
def native_layer_norm_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1313,7 +1313,7 @@ def native_layer_norm_default(func, *args, **kwargs):
"grad_out: jt, input: jt, normalized_shape: any, mean: any, rstd: any, weight: any?, bias: any?, output_mask: any",
)
def native_layer_norm_backward_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
grad_out = new_kwargs.pop("grad_out")
@ -1327,7 +1327,7 @@ def native_layer_norm_backward_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.select.int, "self: jt, dim: any, index: any")
def select_int(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1349,7 +1349,7 @@ def select_int(func, *args, **kwargs):
"self: jt, dim: any?, start: any?, end: any?, step: any?",
)
def slice_tensor(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1365,7 +1365,7 @@ def slice_tensor(func, *args, **kwargs):
"dilation: any, transposed: any, output_padding: any, groups: any",
)
def convolution_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1382,7 +1382,7 @@ def mean_dim(func, *args, **kwargs):
Performs a mean along the provided tensor dimension.
Returns a dense tensor if the ragged dimension is reduced away, else returns a nested tensor.
"""
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1438,7 +1438,7 @@ def mean_dim(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any")
def stack_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1472,7 +1472,7 @@ def stack_default(func, *args, **kwargs):
"weight: t, indices: jt, padding_idx: any?, scale_grad_by_freq: any?, sparse: any?",
)
def embedding_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1493,7 +1493,7 @@ def embedding_default(func, *args, **kwargs):
"self: jt_all",
)
def values_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1520,7 +1520,7 @@ def all_default(func, *args, **kwargs):
"values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?",
)
def _nested_view_from_jagged_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1549,7 +1549,7 @@ def _nested_view_from_jagged_default(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten._nested_get_offsets.default, "self: jt_all")
def _nested_get_offsets(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1559,7 +1559,7 @@ def _nested_get_offsets(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten._nested_get_lengths.default, "self: jt_all")
def _nested_get_lengths(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1569,7 +1569,7 @@ def _nested_get_lengths(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten._nested_get_ragged_idx.default, "self: jt_all")
def _nested_get_ragged_idx(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1579,7 +1579,7 @@ def _nested_get_ragged_idx(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten._nested_get_min_seqlen.default, "self: jt_all")
def _nested_get_min_seqlen(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1589,7 +1589,7 @@ def _nested_get_min_seqlen(func, *args, **kwargs):
@register_jagged_func(torch.ops.aten._nested_get_max_seqlen.default, "self: jt_all")
def _nested_get_max_seqlen(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
@ -1600,7 +1600,7 @@ def _nested_get_max_seqlen(func, *args, **kwargs):
# If a section of the Nested Tensor is fully masked out we still retain the section with a length of 0
@register_jagged_func(torch.ops.aten.masked_select.default, "self: jt, mask: any")
def masked_select_default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")

View File

@ -242,7 +242,7 @@ def _fx_args_to_torch_args(
if isinstance(arg, torch.fx.Node):
fake_tensor = arg.meta.get("val")
if fake_tensor is None and arg.op == "get_attr":
fake_tensor = getattr(fx_graph_module, arg.target) # type: ignore[operator]
fake_tensor = getattr(fx_graph_module, arg.target) # type: ignore[operator, arg-type]
# NOTE: Currently, we are aware of
# FakeTensor/Tensor/SymInt/SymFloat/Symbool/int/float/bool could be in
# arg.meta["val"]/get_attr.
@ -253,8 +253,8 @@ def _fx_args_to_torch_args(
wrapped_args.append(real_tensor)
elif isinstance(fake_tensor, (int, float, bool)):
wrapped_args.append(fake_tensor)
elif symbolic_shapes.has_hint(fake_tensor):
wrapped_args.append(symbolic_shapes.hint_int(fake_tensor))
elif symbolic_shapes.has_hint(fake_tensor): # type: ignore[arg-type]
wrapped_args.append(symbolic_shapes.hint_int(fake_tensor)) # type: ignore[arg-type]
else:
raise ValueError(
f"Unexpected input argument type found inside fx.Node. arg: {arg}; "

View File

@ -615,7 +615,7 @@ def return_and_correct_aliasing(func, args, kwargs, out):
return None
def get_arg_from_alias(output_alias, schema_info, args, kwargs):
new_args, new_kwargs = torch.fx.operator_schemas.normalize_function(
new_args, new_kwargs = torch.fx.operator_schemas.normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs
)