mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] typing for decorators - fx/_compatibility (#131568)
See #131429 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131568 Approved by: https://github.com/justinchuby, https://github.com/oulgen, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
709ddf7a9d
commit
193f62fde9
@ -336,7 +336,7 @@ class AutogradCompilerInstance:
|
||||
|
||||
def bind_tensors_to_proxies(self, tensors, proxies):
|
||||
if isinstance(proxies, torch.fx.Proxy):
|
||||
proxies = [proxies[i] for i in range(len(tensors))]
|
||||
proxies = [proxies[i] for i in range(len(tensors))] # type: ignore[index]
|
||||
assert len(tensors) == len(proxies)
|
||||
track_tensor_tree(tensors, proxies, constant=None, tracer=self.fx_tracer)
|
||||
|
||||
|
||||
@ -1069,7 +1069,7 @@ def replay(filename: str) -> None:
|
||||
record.globals = dict(itertools.chain(record.globals.items(), globals().items()))
|
||||
|
||||
try:
|
||||
_compile(
|
||||
_compile( # type: ignore[call-arg]
|
||||
record.code,
|
||||
record.globals,
|
||||
record.locals,
|
||||
|
||||
@ -1512,7 +1512,7 @@ def export(
|
||||
# Running graph with interpreter is needed for propagating the stack_trace
|
||||
def graph_with_interpreter(*args):
|
||||
with torch.fx.traceback.preserve_node_meta():
|
||||
return torch.fx.Interpreter(graph).run(*args)
|
||||
return torch.fx.Interpreter(graph).run(*args) # type: ignore[arg-type]
|
||||
|
||||
with maybe_disable_fake_tensor_mode(), enable_python_dispatcher(), (
|
||||
fake_mode
|
||||
@ -1536,9 +1536,9 @@ def export(
|
||||
|
||||
assert graph is not None
|
||||
for node in graph.graph.find_nodes(op="get_attr"):
|
||||
if isinstance(getattr(graph, node.target), torch.Tensor):
|
||||
if isinstance(getattr(graph, node.target), torch.Tensor): # type: ignore[arg-type]
|
||||
node.meta["val"] = fake_mode.from_tensor(
|
||||
getattr(graph, node.target), static_shapes=True
|
||||
getattr(graph, node.target), static_shapes=True # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if same_signature:
|
||||
|
||||
@ -2039,7 +2039,7 @@ def get_real_value(node, tracer):
|
||||
return cache[node]
|
||||
|
||||
op = node.op
|
||||
args, kwargs = torch.fx.node.map_arg(
|
||||
args, kwargs = torch.fx.node.map_arg( # type: ignore[misc]
|
||||
(node.args, node.kwargs),
|
||||
lambda n: get_real_value(n, tracer),
|
||||
)
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
import dataclasses
|
||||
|
||||
@ -942,7 +942,7 @@ class ExplainTS2FXGraphConverter(TS2FXGraphConverter):
|
||||
self.name_to_node,
|
||||
# Dummy node.
|
||||
torch.fx.Node(
|
||||
None,
|
||||
None, # type: ignore[arg-type]
|
||||
"mock",
|
||||
"call_function",
|
||||
lambda: None,
|
||||
|
||||
@ -68,7 +68,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
self.fake_tensor_mode: Optional[FakeTensorMode] = None
|
||||
self.submodules: Dict[torch.nn.Module, str] = {}
|
||||
|
||||
def trace(self) -> None:
|
||||
def trace(self) -> None: # type: ignore[override]
|
||||
raise ExportPassBaseError("ExportTracer doesn't support trace().")
|
||||
|
||||
def create_arg(self, a: Argument) -> torch.fx.Node:
|
||||
@ -160,7 +160,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
|
||||
def placeholder(
|
||||
self,
|
||||
target: str,
|
||||
target: str, # type: ignore[override]
|
||||
args: Tuple[Argument, ...],
|
||||
kwargs: Dict[str, Argument],
|
||||
) -> ProxyValue:
|
||||
@ -218,7 +218,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
raise ExportPassBaseError(f"Unsupported target type: {target}")
|
||||
|
||||
def get_attr(
|
||||
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
|
||||
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override]
|
||||
) -> Argument:
|
||||
return super().get_attr(target, args, kwargs)
|
||||
|
||||
@ -231,7 +231,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
raise ExportPassBaseError("call_module is not supported.")
|
||||
|
||||
def call_method(
|
||||
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
|
||||
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override]
|
||||
) -> None:
|
||||
raise ExportPassBaseError("call_method is not supported.")
|
||||
|
||||
@ -394,7 +394,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
)
|
||||
self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
|
||||
interpreter = self.ExportInterpreter(self, graph_module)
|
||||
prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter(
|
||||
prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( # type: ignore[assignment]
|
||||
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
|
||||
)
|
||||
inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs)
|
||||
|
||||
@ -92,7 +92,7 @@ def _replace_with_hop(node: torch.fx.Node):
|
||||
# Rename the name of getitem nodes to the actual name of its contents
|
||||
# for passing verifier and better readability, also propagate metadata
|
||||
for get_item_node in call_func_node.users.keys():
|
||||
idx: int = get_item_node.args[1]
|
||||
idx: int = get_item_node.args[1] # type: ignore[assignment]
|
||||
output_node = output_args[idx]
|
||||
get_item_node._rename(output_node.name)
|
||||
get_item_node.meta = output_node.meta
|
||||
|
||||
@ -186,7 +186,7 @@ def _extract_graph_with_inputs_outputs(
|
||||
# joint_graph.nodes).
|
||||
continue
|
||||
elif node.op == "placeholder":
|
||||
env[node] = InvalidNode
|
||||
env[node] = InvalidNode # type: ignore[assignment]
|
||||
elif node.op == "call_function":
|
||||
all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs)
|
||||
all_args = [
|
||||
@ -195,7 +195,7 @@ def _extract_graph_with_inputs_outputs(
|
||||
if isinstance(x, fx.Node)
|
||||
]
|
||||
if any(all_args):
|
||||
env[node] = InvalidNode
|
||||
env[node] = InvalidNode # type: ignore[assignment]
|
||||
continue
|
||||
env[node] = new_graph.node_copy(node, lambda x: env[x])
|
||||
elif node.op == "get_attr":
|
||||
|
||||
@ -769,7 +769,7 @@ def trace_flex_attention_backward(
|
||||
)
|
||||
assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
|
||||
block_mask = block_mask[:-1] + (mask_graph,)
|
||||
proxy_mode.tracer.root.register_module("fw_graph", fw_graph)
|
||||
proxy_mode.tracer.root.register_module("fw_graph", fw_graph) # type: ignore[arg-type]
|
||||
proxy_mode.tracer.root.register_module("joint_graph", joint_graph)
|
||||
proxy_mode.tracer.root.register_module("mask_graph", mask_graph)
|
||||
node_args = (
|
||||
|
||||
@ -1432,7 +1432,7 @@ class CppVecOverrides(CppOverrides):
|
||||
], f"{__name__} does not support {dtype}"
|
||||
node: torch.fx.Node = V.interpreter.current_node
|
||||
assert node and isinstance(node, torch.fx.Node)
|
||||
opt_ctx_x = get_opt_ctx(node.args[1])
|
||||
opt_ctx_x = get_opt_ctx(node.args[1]) # type: ignore[arg-type]
|
||||
assert opt_ctx_x
|
||||
assert opt_ctx_x.dtype is not None
|
||||
assert isinstance(V.kernel, CppVecKernel)
|
||||
|
||||
@ -222,7 +222,7 @@ class ConstantFolder(torch.fx.Interpreter):
|
||||
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
|
||||
self.node_replacements[node] = tensor
|
||||
|
||||
def run(self) -> Any:
|
||||
def run(self) -> Any: # type: ignore[override]
|
||||
env: Dict[torch.fx.Node, Any] = {}
|
||||
self.insert_placerholder_values(env)
|
||||
return super().run(initial_env=env)
|
||||
|
||||
@ -141,7 +141,7 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
|
||||
kwargs = {}
|
||||
if hasattr(snode, "get_device"):
|
||||
kwargs = {"device": snode.get_device()}
|
||||
fx_node = graph.call_function(node_func, args=(), kwargs=kwargs)
|
||||
fx_node = graph.call_function(node_func, args=(), kwargs=kwargs) # type: ignore[arg-type]
|
||||
|
||||
def in_output(snode):
|
||||
if isinstance(snode, FusedSchedulerNode):
|
||||
|
||||
@ -154,17 +154,17 @@ def binary_folding_init():
|
||||
return False
|
||||
if isinstance(other, torch.fx.Node) and other.op == "get_attr":
|
||||
other_meta_value = other.meta.get("val")
|
||||
if not other_meta_value.is_floating_point():
|
||||
if not other_meta_value.is_floating_point(): # type: ignore[union-attr]
|
||||
return False
|
||||
if (
|
||||
torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype)
|
||||
torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype) # type: ignore[union-attr]
|
||||
!= weight_meta_value.dtype
|
||||
):
|
||||
if not conv_node.meta.get("_allow_conv_mixed_dtype_folding", False):
|
||||
return False
|
||||
|
||||
if (
|
||||
other_meta_value.dtype != torch.float
|
||||
other_meta_value.dtype != torch.float # type: ignore[union-attr]
|
||||
and weight_meta_value.dtype not in (torch.float16, torch.bfloat16)
|
||||
):
|
||||
return False
|
||||
|
||||
@ -213,7 +213,7 @@ def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs
|
||||
new_node = graph.create_node(
|
||||
op="call_function",
|
||||
target=efficient_conv_bn_eval_decomposed,
|
||||
args=args,
|
||||
args=args, # type: ignore[arg-type]
|
||||
name="efficient_conv_bn_eval",
|
||||
)
|
||||
|
||||
@ -223,7 +223,7 @@ def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs
|
||||
# take care of the deletion order:
|
||||
# delete bn_node first, and then conv_node
|
||||
graph.erase_node(bn_node)
|
||||
graph.erase_node(conv_node)
|
||||
graph.erase_node(conv_node) # type: ignore[arg-type]
|
||||
|
||||
return
|
||||
|
||||
@ -304,7 +304,7 @@ def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwa
|
||||
new_node = graph.create_node(
|
||||
op="call_function",
|
||||
target=efficient_conv_bn_eval_decomposed,
|
||||
args=args,
|
||||
args=args, # type: ignore[arg-type]
|
||||
name="efficient_conv_bn_eval",
|
||||
)
|
||||
|
||||
@ -314,7 +314,7 @@ def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwa
|
||||
# take care of the deletion order:
|
||||
# delete bn_node first, and then conv_node
|
||||
graph.erase_node(bn_node)
|
||||
graph.erase_node(conv_node)
|
||||
graph.erase_node(conv_node) # type: ignore[arg-type]
|
||||
|
||||
return
|
||||
|
||||
@ -373,7 +373,7 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
|
||||
# Find a pair of conv and bn computation nodes to optimize.
|
||||
counters["inductor"]["efficient_conv_bn_eval"] += 1
|
||||
|
||||
with graph.inserting_before(conv_node):
|
||||
with graph.inserting_before(conv_node): # type: ignore[arg-type]
|
||||
# create `get_attr` node to access modules
|
||||
# note that we directly call `create_node` to fill the `name`
|
||||
# argument. `graph.get_attr` and
|
||||
@ -403,4 +403,4 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
|
||||
# take care of the deletion order:
|
||||
# delete bn_node first, and then conv_node
|
||||
graph.erase_node(bn_node)
|
||||
graph.erase_node(conv_node)
|
||||
graph.erase_node(conv_node) # type: ignore[arg-type]
|
||||
|
||||
@ -223,5 +223,5 @@ def unnecessary_dtype_convert(match: Match, **kwargs):
|
||||
"""Remove unnecessary dtype conversion op, probably left as a result of Conv-Bn folding"""
|
||||
graph = match.graph
|
||||
node = match.output_node()
|
||||
node.replace_all_uses_with(node.args[0])
|
||||
node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type]
|
||||
graph.erase_node(node)
|
||||
|
||||
@ -505,7 +505,7 @@ def pointless_view(match: Match, arg, size):
|
||||
node = match.output_node()
|
||||
arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr]
|
||||
if size == arg_size:
|
||||
node.replace_all_uses_with(node.args[0])
|
||||
node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type] # type: ignore[arg-type]
|
||||
match.erase_nodes(graph)
|
||||
|
||||
|
||||
|
||||
@ -1033,7 +1033,7 @@ def is_index_put_and_requires_h2d_sync_for_cuda_value(node):
|
||||
# if the value we are putting is a cpu scalar.
|
||||
# Therefore, when inductor sees an index_put_ with byte tensor indices,
|
||||
# it should *not* convert the cpu scalar value into a cuda tensor.
|
||||
args_, kwargs_ = normalize_function(node.target, node.args, node.kwargs)
|
||||
args_, kwargs_ = normalize_function(node.target, node.args, node.kwargs) # type: ignore[syntax, misc]
|
||||
any_byte_bool_indices = False
|
||||
indices = args_[1]
|
||||
for i in indices:
|
||||
|
||||
@ -1885,14 +1885,14 @@ def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float3
|
||||
graph.erase_node(conv_node)
|
||||
# Erase the dequant pattern
|
||||
if dtype == torch.bfloat16:
|
||||
graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined]
|
||||
graph.erase_node(dequant_node)
|
||||
graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined, arg-type]
|
||||
graph.erase_node(dequant_node) # type: ignore[arg-type]
|
||||
# Erase the dequant per channel pattern
|
||||
if clone_node is not None:
|
||||
graph.erase_node(clone_node)
|
||||
graph.erase_node(clone_node) # type: ignore[arg-type]
|
||||
if dtype == torch.bfloat16:
|
||||
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
|
||||
graph.erase_node(dequant_per_channel)
|
||||
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type]
|
||||
graph.erase_node(dequant_per_channel) # type: ignore[arg-type]
|
||||
counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1
|
||||
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len(
|
||||
match.nodes
|
||||
@ -2579,8 +2579,8 @@ def quant_lift_up(graph_module: torch.fx.GraphModule):
|
||||
|
||||
new_args = map_arg(new_quant_node.args, maybe_replace_node)
|
||||
new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node)
|
||||
new_quant_node.args = new_args
|
||||
new_quant_node.kwargs = new_kwargs
|
||||
new_quant_node.args = new_args # type: ignore[assignment]
|
||||
new_quant_node.kwargs = new_kwargs # type: ignore[assignment]
|
||||
graph_module.graph.erase_node(quant_node)
|
||||
|
||||
graph_module.graph.lint()
|
||||
|
||||
@ -288,7 +288,7 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
|
||||
return
|
||||
|
||||
src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr]
|
||||
with graph.inserting_before(src):
|
||||
with graph.inserting_before(src): # type: ignore[arg-type]
|
||||
new_node = graph_call_function(
|
||||
graph,
|
||||
_generalized_scatter,
|
||||
@ -311,7 +311,7 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
|
||||
handle_views(new_src)
|
||||
src.replace_all_uses_with(new_src) # type: ignore[union-attr]
|
||||
|
||||
graph.erase_node(src)
|
||||
graph.erase_node(src) # type: ignore[arg-type]
|
||||
|
||||
for node in graph.nodes:
|
||||
if _is_view_op(node.target):
|
||||
|
||||
@ -188,7 +188,7 @@ def normalize_split_base(
|
||||
new_split_node = graph.call_function(
|
||||
torch.split,
|
||||
args=new_args,
|
||||
kwargs=new_kwargs,
|
||||
kwargs=new_kwargs, # type: ignore[arg-type]
|
||||
)
|
||||
split_node.replace_all_uses_with(new_split_node)
|
||||
new_split_node.meta.update(split_node.meta)
|
||||
@ -373,7 +373,7 @@ def normalize_stack_default(match: Match, *args, **kwargs):
|
||||
|
||||
with graph.inserting_after(node):
|
||||
new_node = graph.call_function(
|
||||
node.target,
|
||||
node.target, # type: ignore[arg-type]
|
||||
args=(tensors,),
|
||||
kwargs={"dim": dim},
|
||||
)
|
||||
@ -502,7 +502,7 @@ def merge_splits(
|
||||
|
||||
to_remove = []
|
||||
|
||||
with graph.inserting_before(first_split):
|
||||
with graph.inserting_before(first_split): # type: ignore[arg-type]
|
||||
# Add the new split node
|
||||
new_split = graph.call_function(
|
||||
torch.split,
|
||||
@ -1431,7 +1431,7 @@ def mutate_cat_node(match: Match, split_sections: List[int], dim: int):
|
||||
# case 1: the cat uses all getitems from the split
|
||||
if len(split_sections) == len(cat_user.args[0]): # type: ignore[arg-type]
|
||||
# replace the users of the cat node to be the input of the split node
|
||||
cat_user.replace_all_uses_with(split_node.args[0])
|
||||
cat_user.replace_all_uses_with(split_node.args[0]) # type: ignore[arg-type]
|
||||
# remove the cat node
|
||||
graph.erase_node(cat_user)
|
||||
counters["inductor"]["mutate_cat_pass"] += 1
|
||||
|
||||
@ -762,7 +762,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
raise KeyError(f"could not find {buffer_name}")
|
||||
|
||||
@dynamo_timed
|
||||
def run(self, *args: Any) -> Any:
|
||||
def run(self, *args: Any) -> Any: # type: ignore[override]
|
||||
return super().run(*args)
|
||||
|
||||
def register_operation(self, op: ir.Operation) -> str:
|
||||
@ -906,9 +906,9 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
)
|
||||
|
||||
def placeholder(
|
||||
self, target: str, args: Tuple[object], kwargs: Dict[str, object]
|
||||
self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override]
|
||||
) -> Union[Expr, TensorBox, None]:
|
||||
example = super().placeholder(target, args, kwargs)
|
||||
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
|
||||
self.graph_input_names.append(target)
|
||||
if isinstance(example, SymTypes):
|
||||
expr = example.node.expr
|
||||
@ -963,7 +963,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.aligned_inputs.add(target)
|
||||
return tensor
|
||||
|
||||
def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> Any: # type: ignore[type-arg]
|
||||
def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> Any: # type: ignore[type-arg, override]
|
||||
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
|
||||
return super().call_function(target, args, kwargs)
|
||||
|
||||
@ -1040,7 +1040,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
return len(t.shape) == 1 and t.shape[0] <= 8
|
||||
|
||||
def get_attr(
|
||||
self, target: str, args: Tuple[()], kwargs: Dict[str, object]
|
||||
self, target: str, args: Tuple[()], kwargs: Dict[str, object] # type: ignore[override]
|
||||
) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
|
||||
# this is a constant
|
||||
value = getattr_recursive(self.module, target) # type: ignore[arg-type]
|
||||
@ -1080,9 +1080,9 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
raise AssertionError
|
||||
|
||||
def output(
|
||||
self, target: str, args: Tuple[object], kwargs: Dict[str, object]
|
||||
self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override]
|
||||
) -> None:
|
||||
result = super().output(target, args, kwargs)
|
||||
result = super().output(target, args, kwargs) # type: ignore[arg-type]
|
||||
if not isinstance(result, (tuple, list)):
|
||||
# nested subgraphs can have singleton outputs
|
||||
result = (result,)
|
||||
|
||||
@ -6572,7 +6572,7 @@ class InterpreterShim(torch.fx.Interpreter):
|
||||
self.graph = graph
|
||||
self.submodules = submodules
|
||||
self.extra_traceback = False
|
||||
self.fetch_attr = submodules.__getitem__
|
||||
self.fetch_attr = submodules.__getitem__ # type: ignore[method-assign]
|
||||
self.current_node = None
|
||||
|
||||
def run_node(self, n: torch.fx.Node) -> Any:
|
||||
|
||||
@ -235,7 +235,7 @@ class Match:
|
||||
if trace_fn is None:
|
||||
trace_fn = functools.partial(fwd_only, run_dce=run_dce)
|
||||
replacement = trace_fn(
|
||||
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"])
|
||||
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) # type: ignore[arg-type]
|
||||
)
|
||||
ReplacementPatternEntry.replace_with_graph(
|
||||
self,
|
||||
@ -606,7 +606,7 @@ class _TargetArgsExpr(_TargetExpr):
|
||||
from torch.fx.operator_schemas import normalize_function
|
||||
|
||||
normalized_args_and_kwargs = normalize_function(
|
||||
node.target, node.args, node.kwargs
|
||||
node.target, node.args, node.kwargs # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if normalized_args_and_kwargs is None:
|
||||
@ -1035,7 +1035,7 @@ class ReplacementPatternEntry(PatternEntry):
|
||||
if node.op == "call_function":
|
||||
target = node.target
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
||||
result = graph.call_function(target, args, kwargs)
|
||||
result = graph.call_function(target, args, kwargs) # type: ignore[arg-type]
|
||||
if "val" in node.meta and "val" not in result.meta:
|
||||
result.meta["val"] = node.meta["val"]
|
||||
if isinstance(node.meta["val"], torch.Tensor):
|
||||
@ -1079,7 +1079,7 @@ class ReplacementPatternEntry(PatternEntry):
|
||||
queue.extend(arg.all_input_nodes)
|
||||
|
||||
with graph.inserting_before(last_node):
|
||||
replacement = Replacer(replacement_graph).run(*args)
|
||||
replacement = Replacer(replacement_graph).run(*args) # type: ignore[arg-type]
|
||||
if isinstance(replacement, torch.fx.Node):
|
||||
replacement = [replacement]
|
||||
|
||||
@ -1100,7 +1100,7 @@ class ReplacementPatternEntry(PatternEntry):
|
||||
return
|
||||
assert isinstance(old, torch.fx.Node)
|
||||
if new is None:
|
||||
old.replace_all_uses_with(None)
|
||||
old.replace_all_uses_with(None) # type: ignore[arg-type]
|
||||
graph.erase_node(old)
|
||||
return
|
||||
if isinstance(new, torch.fx.Node):
|
||||
@ -1123,7 +1123,7 @@ class ReplacementPatternEntry(PatternEntry):
|
||||
graph.erase_node(old)
|
||||
return
|
||||
|
||||
new = typing.cast(Sequence[torch.fx.Node], new)
|
||||
new = typing.cast(Sequence[torch.fx.Node], new) # type: ignore[redundant-cast]
|
||||
# `new` is not a node: it's a list of nodes.
|
||||
#
|
||||
# This happens when we want to replace a node that has a single
|
||||
@ -1232,7 +1232,7 @@ def register_replacement(
|
||||
)
|
||||
|
||||
args = list(
|
||||
torch.fx.map_arg(
|
||||
torch.fx.map_arg( # type: ignore[arg-type]
|
||||
[match.kwargs[name] for name in argnames], lambda n: n.meta["val"]
|
||||
)
|
||||
)
|
||||
@ -1665,8 +1665,8 @@ class PatternMatcherPass:
|
||||
raise RuntimeError(
|
||||
f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}"
|
||||
)
|
||||
if should_compute_mutation_region_ids(graph):
|
||||
compute_mutation_region_ids(graph)
|
||||
if should_compute_mutation_region_ids(graph): # type: ignore[arg-type]
|
||||
compute_mutation_region_ids(graph) # type: ignore[arg-type]
|
||||
get_mutation_region_id_partial = functools.partial(
|
||||
get_mutation_region_id, graph
|
||||
)
|
||||
@ -1757,7 +1757,7 @@ def fx_to_pattern(
|
||||
get_attr = _not_implemented
|
||||
|
||||
def placeholder(
|
||||
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any]
|
||||
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
|
||||
) -> Union[ExclusiveKeywordArg, KeywordArg]:
|
||||
n = next(argnum)
|
||||
if n < len(argnames):
|
||||
@ -1774,7 +1774,7 @@ def fx_to_pattern(
|
||||
return KeywordArg(name)
|
||||
|
||||
def call_function(
|
||||
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any]
|
||||
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
|
||||
) -> PatternExpr:
|
||||
args, kwargs = pytree.tree_map(process_arg, (args, kwargs))
|
||||
if list in ignore_types:
|
||||
@ -1793,7 +1793,7 @@ def fx_to_pattern(
|
||||
rv.users = len(n.users)
|
||||
return rv
|
||||
|
||||
pattern = Converter(gm).run()
|
||||
pattern = Converter(gm).run() # type: ignore[arg-type]
|
||||
if not isinstance(pattern, PatternExpr):
|
||||
return MultiOutputPattern(pytree.tree_leaves(pattern))
|
||||
return pattern
|
||||
|
||||
@ -47,7 +47,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
|
||||
|
||||
def call_function(
|
||||
self,
|
||||
target: Callable[[Any], Any],
|
||||
target: torch.fx.node.Target,
|
||||
args: Any,
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Any:
|
||||
@ -70,9 +70,14 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
|
||||
|
||||
return lowerings[target](*args, **kwargs)
|
||||
|
||||
def output(self, target: str, args: Tuple[Any], kwargs: Dict[str, Any]) -> None:
|
||||
def output(
|
||||
self,
|
||||
target: torch.fx.node.Target,
|
||||
args: Tuple[torch.fx.node.Argument, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
assert len(args) == 1
|
||||
self.graph_outputs = args[0]
|
||||
self.graph_outputs = args[0] # type: ignore[assignment]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -350,7 +350,7 @@ def gen_gm_and_inputs(target, args, kwargs):
|
||||
len(target._schema.returns) == 1
|
||||
and str(target._schema.returns[0].type) == "Tensor"
|
||||
):
|
||||
node = (node,)
|
||||
node = (node,) # type: ignore[assignment]
|
||||
g.output(node)
|
||||
|
||||
gm = torch.fx.GraphModule({}, g)
|
||||
|
||||
@ -2156,8 +2156,8 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
any_constant = any(e.constant is not None for e in flat_arg_fake_tensors)
|
||||
schema_info = get_schema_info(func)
|
||||
if any_constant and schema_info.is_mutable():
|
||||
_, new_kwargs = normalize_function(
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type]
|
||||
)
|
||||
for k, v in new_kwargs.items():
|
||||
k = k if (k != "input" or schema_info.has_argument(k)) else "self"
|
||||
|
||||
@ -896,7 +896,7 @@ def prepare_n_shadows_model(
|
||||
tracer = custom_tracer
|
||||
mt = torch.fx.GraphModule(model, tracer.trace(model))
|
||||
# this is necessary to ensure logger FQNs get populated
|
||||
mt._node_name_to_scope = tracer.node_name_to_scope
|
||||
mt._node_name_to_scope = tracer.node_name_to_scope # type: ignore[assignment]
|
||||
|
||||
# run example input propagation, we need this to call prepare_fx on
|
||||
# individual subgraphs
|
||||
@ -998,7 +998,7 @@ def _prepare_n_shadows_add_loggers_model(
|
||||
tracer = quantize_fx.QuantizationTracer([], [])
|
||||
mt = torch.fx.GraphModule(model, tracer.trace(model))
|
||||
# this is necessary to ensure logger FQNs get populated
|
||||
mt._node_name_to_scope = tracer.node_name_to_scope
|
||||
mt._node_name_to_scope = tracer.node_name_to_scope # type: ignore[assignment]
|
||||
|
||||
# run example input propagation, we need this to call prepare_fx on
|
||||
# individual subgraphs
|
||||
|
||||
@ -694,13 +694,13 @@ def _insert_copy_of_node_a_after_input_node_c(
|
||||
mod_a = getattr_from_fqn(gm_a, node_a.target)
|
||||
setattr(gm_b, new_mod_copy_name, mod_a)
|
||||
node_a_shadows_c = graph_c.create_node(
|
||||
node_a.op, new_mod_copy_name, new_args, new_kwargs, node_a_shadows_c_name
|
||||
node_a.op, new_mod_copy_name, new_args, new_kwargs, node_a_shadows_c_name # type: ignore[arg-type]
|
||||
)
|
||||
return node_a_shadows_c
|
||||
else:
|
||||
assert node_a.op in ("call_function", "call_method")
|
||||
node_a_shadows_c = graph_c.create_node(
|
||||
node_a.op, node_a.target, new_args, new_kwargs, node_a_shadows_c_name
|
||||
node_a.op, node_a.target, new_args, new_kwargs, node_a_shadows_c_name # type: ignore[arg-type]
|
||||
)
|
||||
return node_a_shadows_c
|
||||
|
||||
|
||||
@ -406,20 +406,20 @@ def create_submodule_from_subgraph(
|
||||
mod_name = f"mod_{cur_name_idx}"
|
||||
setattr(gm, mod_name, orig_mod_copy)
|
||||
cur_name_idx += 1
|
||||
cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined]
|
||||
cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined, arg-type]
|
||||
|
||||
elif cur_node_orig.op == "call_function":
|
||||
cur_node_copy = g.call_function(
|
||||
cur_node_orig.target,
|
||||
cur_args_copy,
|
||||
cur_kwargs_copy, # type: ignore[possibly-undefined]
|
||||
cur_node_orig.target, # type: ignore[arg-type]
|
||||
cur_args_copy, # type: ignore[arg-type]
|
||||
cur_kwargs_copy, # type: ignore[possibly-undefined, arg-type]
|
||||
)
|
||||
|
||||
elif cur_node_orig.op == "call_method":
|
||||
cur_node_copy = g.call_method(
|
||||
cur_node_orig.target,
|
||||
cur_args_copy,
|
||||
cur_kwargs_copy, # type: ignore[possibly-undefined]
|
||||
cur_node_orig.target, # type: ignore[arg-type]
|
||||
cur_args_copy, # type: ignore[arg-type]
|
||||
cur_kwargs_copy, # type: ignore[possibly-undefined, arg-type]
|
||||
)
|
||||
|
||||
else:
|
||||
@ -582,7 +582,7 @@ def create_one_transformed_and_logged_copy_of_subgraph(
|
||||
|
||||
new_args = tuple(new_args) # type: ignore[assignment]
|
||||
|
||||
new_node = mt.graph.call_module(attr_name, args=new_args, kwargs=new_kwargs)
|
||||
new_node = mt.graph.call_module(attr_name, args=new_args, kwargs=new_kwargs) # type: ignore[arg-type]
|
||||
|
||||
# add a logger to parent graph to observe the shadow wrapper
|
||||
logger_mod_orig = _get_logger_for_subgraph(
|
||||
|
||||
@ -268,7 +268,7 @@ class BaseStructuredSparsifier(BaseSparsifier):
|
||||
BiasHook(module.parametrizations.weight[0], prune_bias)
|
||||
)
|
||||
|
||||
def prune(self) -> None:
|
||||
def prune(self) -> torch.fx.graph_module.GraphModule:
|
||||
r"""
|
||||
This function will FX symbolically trace the model and then find instances of the patterns
|
||||
defined in self.patterns (by default SUPPORTED_STRUCTURED_PRUNING_PATTERNS ).
|
||||
|
||||
@ -25,7 +25,7 @@ def _match(
|
||||
if isinstance(current, type) and issubclass(current, torch.nn.Module):
|
||||
return (
|
||||
node.op == "call_module"
|
||||
and parametrize.type_before_parametrizations(modules[node.target])
|
||||
and parametrize.type_before_parametrizations(modules[node.target]) # type: ignore[index]
|
||||
== current
|
||||
)
|
||||
elif callable(current):
|
||||
|
||||
@ -571,7 +571,7 @@ def _match_static_pattern(
|
||||
match_key = type(_get_module(ref_node, modules))
|
||||
else:
|
||||
expected_op = "call_function"
|
||||
match_key = ref_node.target
|
||||
match_key = ref_node.target # type: ignore[assignment]
|
||||
if ref_node.op != expected_op or match_key not in matching_modules_or_ops:
|
||||
return SKIP_LOWERING_VALUE
|
||||
|
||||
@ -591,7 +591,7 @@ def _match_static_pattern(
|
||||
if not matched_dequantize:
|
||||
return SKIP_LOWERING_VALUE
|
||||
|
||||
return (q_node, relu_node, ref_node)
|
||||
return (q_node, relu_node, ref_node) # type: ignore[return-value]
|
||||
|
||||
|
||||
def _match_static_pattern_with_two_inputs(
|
||||
@ -689,8 +689,8 @@ def _lower_static_weighted_ref_module(
|
||||
continue
|
||||
else:
|
||||
q_class = STATIC_LOWER_MODULE_MAP[ref_class]
|
||||
output_scale = getattr(model, scale_node.target)
|
||||
output_zero_point = getattr(model, zero_point_node.target)
|
||||
output_scale = getattr(model, scale_node.target) # type: ignore[arg-type]
|
||||
output_zero_point = getattr(model, zero_point_node.target) # type: ignore[arg-type]
|
||||
q_module = q_class.from_reference(ref_module, output_scale, output_zero_point)
|
||||
# replace reference module with quantized module
|
||||
parent_name, module_name = _parent_name(ref_node.target)
|
||||
@ -700,7 +700,7 @@ def _lower_static_weighted_ref_module(
|
||||
assert len(ref_node.args) == 1
|
||||
dq_node = ref_node.args[0]
|
||||
assert isinstance(dq_node, Node)
|
||||
ref_node.replace_input_with(dq_node, dq_node.args[0])
|
||||
ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type]
|
||||
q_node.replace_all_uses_with(ref_node)
|
||||
model.graph.erase_node(q_node)
|
||||
model.graph.erase_node(scale_node)
|
||||
@ -749,8 +749,8 @@ def _lower_static_weighted_ref_module_with_two_inputs(
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
output_scale = getattr(model, scale_node.target)
|
||||
output_zero_point = getattr(model, zero_point_node.target)
|
||||
output_scale = getattr(model, scale_node.target) # type: ignore[arg-type]
|
||||
output_zero_point = getattr(model, zero_point_node.target) # type: ignore[arg-type]
|
||||
q_module = q_class.from_reference(ref_module, output_scale, output_zero_point)
|
||||
# replace reference module with quantized module
|
||||
parent_name, module_name = _parent_name(ref_node.target)
|
||||
@ -763,7 +763,7 @@ def _lower_static_weighted_ref_module_with_two_inputs(
|
||||
continue
|
||||
dq_node = arg
|
||||
assert isinstance(dq_node, Node)
|
||||
ref_node.replace_input_with(dq_node, dq_node.args[0])
|
||||
ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type]
|
||||
|
||||
q_node.replace_all_uses_with(ref_node)
|
||||
model.graph.erase_node(q_node)
|
||||
@ -906,7 +906,7 @@ def _lower_static_weighted_ref_functional(
|
||||
prepack_args[5], prepack_args[6] = prepack_args[6], prepack_args[5]
|
||||
else:
|
||||
raise ValueError(f"Lowering is not supported for op '{func_node.target}'")
|
||||
with model.graph.inserting_before(output_scale_node):
|
||||
with model.graph.inserting_before(output_scale_node): # type: ignore[arg-type]
|
||||
# kwargs of the func node are needed for prepack op (i.e., quantized::linear_prepack)
|
||||
# They are not needed for compute op (i.e., quantized::linear)
|
||||
kwargs = func_node.kwargs
|
||||
@ -1107,7 +1107,7 @@ def _lower_quantized_binary_op(model: GraphModule, qconfig_map: Dict[str, QConfi
|
||||
dq_node = arg
|
||||
assert isinstance(dq_node, Node)
|
||||
dn_input = dq_node.args[0]
|
||||
bop_node.replace_input_with(dq_node, dn_input)
|
||||
bop_node.replace_input_with(dq_node, dn_input) # type: ignore[arg-type]
|
||||
num_dq_nodes += 1
|
||||
assert num_dq_nodes > 0
|
||||
|
||||
|
||||
@ -821,7 +821,7 @@ def _reroute_tuple_getitem_pattern(graph: Graph):
|
||||
last_getitem_index = last_getitem.args[1]
|
||||
new_input = first_tuple.args[0][last_getitem_index] # type: ignore[index]
|
||||
for user in list(last_getitem.users.keys()):
|
||||
user.replace_input_with(last_getitem, new_input)
|
||||
user.replace_input_with(last_getitem, new_input) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _get_observer_from_activation_post_process(
|
||||
|
||||
@ -46,8 +46,8 @@ def _maybe_duplicate_dq(
|
||||
|
||||
new_args = map_arg(user.args, maybe_replace_node)
|
||||
new_kwargs = map_arg(user.kwargs, maybe_replace_node)
|
||||
user.args = new_args
|
||||
user.kwargs = new_kwargs
|
||||
user.args = new_args # type: ignore[assignment]
|
||||
user.kwargs = new_kwargs # type: ignore[assignment]
|
||||
|
||||
|
||||
class DuplicateDQPass(PassBase):
|
||||
|
||||
@ -107,8 +107,8 @@ def _find_q_dq_node_for_user(
|
||||
|
||||
q_node = None
|
||||
if (
|
||||
dq_node.args[0].op == "call_function"
|
||||
and dq_node.args[0].target in _QUANTIZE_OPS
|
||||
dq_node.args[0].op == "call_function" # type: ignore[union-attr]
|
||||
and dq_node.args[0].target in _QUANTIZE_OPS # type: ignore[union-attr]
|
||||
):
|
||||
q_node = dq_node.args[0]
|
||||
return (q_node, dq_node)
|
||||
@ -353,7 +353,7 @@ def _get_aten_graph_module_for_pattern(
|
||||
[x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs]
|
||||
)
|
||||
aten_pattern = capture_pre_autograd_graph(
|
||||
pattern,
|
||||
pattern, # type: ignore[arg-type]
|
||||
example_inputs,
|
||||
kwargs,
|
||||
)
|
||||
@ -373,7 +373,7 @@ def _get_aten_graph_module_for_pattern(
|
||||
aten_pattern.graph.eliminate_dead_code()
|
||||
aten_pattern.recompile()
|
||||
|
||||
return aten_pattern
|
||||
return aten_pattern # type: ignore[return-value]
|
||||
|
||||
|
||||
def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
|
||||
|
||||
@ -983,13 +983,13 @@ def _annotate_cat(
|
||||
inputs = cat_node.args[0]
|
||||
|
||||
input_qspec_map = {}
|
||||
input_act0 = inputs[0]
|
||||
input_act0 = inputs[0] # type: ignore[index]
|
||||
if isinstance(input_act0, Node):
|
||||
input_qspec_map[input_act0] = input_act_qspec
|
||||
|
||||
shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node))
|
||||
for input_act in inputs[1:]:
|
||||
input_qspec_map[input_act] = shared_with_input0_qspec
|
||||
shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node)) # type: ignore[arg-type]
|
||||
for input_act in inputs[1:]: # type: ignore[index]
|
||||
input_qspec_map[input_act] = shared_with_input0_qspec # type: ignore[index]
|
||||
|
||||
output_act_qspec = shared_with_input0_qspec
|
||||
|
||||
|
||||
@ -117,7 +117,7 @@ def _to_caller_flattened_graph_module(gm: torch.fx.GraphModule) -> torch.fx.Grap
|
||||
# pyre-ignore[6]
|
||||
in_spec=None, # type: ignore[arg-type]
|
||||
# pyre-ignore[16]
|
||||
out_spec=gm._graph._codegen.pytree_info.out_spec,
|
||||
out_spec=gm._graph._codegen.pytree_info.out_spec, # type: ignore[attr-defined]
|
||||
)
|
||||
)
|
||||
gm.recompile()
|
||||
|
||||
@ -92,7 +92,7 @@ def clone_subgraph(
|
||||
with graph.inserting_before(target):
|
||||
for node in subgraph:
|
||||
cloned_node = graph.call_function(
|
||||
node.target, node.args, node.kwargs, node.type
|
||||
node.target, node.args, node.kwargs, node.type # type: ignore[arg-type]
|
||||
)
|
||||
# TODO: there are many flatten/unflatten in IterGraph that
|
||||
# can be simplified with tree_map. Will simplify this in
|
||||
|
||||
@ -366,7 +366,7 @@ class IterGraph(fx.Graph):
|
||||
actual_target_node = self._lookup_node(target_node, graph)
|
||||
assert actual_target_node is not None
|
||||
for actual_node in actual_nodes:
|
||||
actual_target_node.prepend(actual_node)
|
||||
actual_target_node.prepend(actual_node) # type: ignore[arg-type]
|
||||
|
||||
def move_after(self, nodes: List[fx.Node], target_node: fx.Node) -> None:
|
||||
for graph in self._all_graphs:
|
||||
@ -374,7 +374,7 @@ class IterGraph(fx.Graph):
|
||||
actual_target_node = self._lookup_node(target_node, graph)
|
||||
for actual_node in actual_nodes:
|
||||
assert actual_target_node is not None
|
||||
actual_target_node.append(actual_node)
|
||||
actual_target_node.append(actual_node) # type: ignore[arg-type]
|
||||
actual_target_node = actual_node
|
||||
|
||||
def call_function(
|
||||
@ -432,7 +432,7 @@ class IterGraph(fx.Graph):
|
||||
self.setup_graph.erase_node(setup_node)
|
||||
super().erase_node(to_erase)
|
||||
cleanup_node = self._lookup_node(to_erase, self.cleanup_graph)
|
||||
self.cleanup_graph.erase_node(cleanup_node)
|
||||
self.cleanup_graph.erase_node(cleanup_node) # type: ignore[arg-type]
|
||||
|
||||
def placeholder(
|
||||
self,
|
||||
@ -558,7 +558,7 @@ class IterGraph(fx.Graph):
|
||||
actual_replace_with = self._lookup_node(replace_with, graph)
|
||||
assert actual_node is not None
|
||||
ret = actual_node.replace_all_uses_with(
|
||||
actual_replace_with,
|
||||
actual_replace_with, # type: ignore[arg-type]
|
||||
delete_user_cb,
|
||||
propagate_meta=propagate_meta,
|
||||
)
|
||||
|
||||
@ -45,7 +45,7 @@ def _is_container_node(node: torch.fx.Node) -> bool:
|
||||
"Malformed graph: a container node is used as input for non-getitem nodes."
|
||||
"\nNode: {fmt_node}\nUsers: {fmt_users}".format(
|
||||
fmt_node=node.format_node(),
|
||||
fmt_users="\n".join(u.format_node() for u in node.users),
|
||||
fmt_users="\n".join(u.format_node() for u in node.users), # type: ignore[misc]
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
@ -472,7 +472,7 @@ def _insert_reshard_gm(
|
||||
input_node: input_arg,
|
||||
},
|
||||
)
|
||||
node.replace_input_with(input_arg, output_node)
|
||||
node.replace_input_with(input_arg, output_node) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _clean_up_graph_metadata(gm: torch.fx.GraphModule) -> None:
|
||||
|
||||
@ -96,11 +96,11 @@ class _ExecOrderTracer:
|
||||
self.exec_info = _ExecutionInfo(root_module)
|
||||
orig_call_module = tracer.call_module
|
||||
orig_create_proxy = tracer.create_proxy
|
||||
tracer.call_module = functools.partial(
|
||||
tracer.call_module = functools.partial( # type: ignore[method-assign]
|
||||
self._patched_call_module, orig_call_module, self.exec_info
|
||||
)
|
||||
fqn_to_param = dict(root_module.named_parameters())
|
||||
tracer.create_proxy = functools.partial(
|
||||
tracer.create_proxy = functools.partial( # type: ignore[method-assign]
|
||||
self._patched_create_proxy,
|
||||
orig_create_proxy,
|
||||
self.exec_info,
|
||||
@ -109,8 +109,8 @@ class _ExecOrderTracer:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
tracer.call_module = orig_call_module
|
||||
tracer.create_proxy = orig_create_proxy
|
||||
tracer.call_module = orig_call_module # type: ignore[method-assign]
|
||||
tracer.create_proxy = orig_create_proxy # type: ignore[method-assign]
|
||||
|
||||
def _patched_call_module(
|
||||
self,
|
||||
@ -216,8 +216,8 @@ class _ExecOrderTracer:
|
||||
isinstance(arg, torch.fx.Proxy)
|
||||
and arg.node.target in fqn_to_param
|
||||
):
|
||||
param = fqn_to_param[arg.node.target]
|
||||
named_params.append((arg.node.target, param))
|
||||
param = fqn_to_param[arg.node.target] # type: ignore[index]
|
||||
named_params.append((arg.node.target, param)) # type: ignore[arg-type]
|
||||
if param not in exec_info.visited_params:
|
||||
exec_info.visited_params.add(param)
|
||||
exec_info.param_forward_order.append(param)
|
||||
|
||||
@ -214,7 +214,7 @@ def _insert_stage_symbolic_backward(
|
||||
input_nodes = list(node.all_input_nodes)
|
||||
grads_proxy = fx.Proxy(grads)
|
||||
for i, input_node in enumerate(input_nodes):
|
||||
assign_or_accumulate_grad(input_node, grads_proxy[i].node)
|
||||
assign_or_accumulate_grad(input_node, grads_proxy[i].node) # type: ignore[index]
|
||||
|
||||
return g
|
||||
|
||||
@ -416,15 +416,15 @@ class _LinearNodeList:
|
||||
def __init__(self, node_list):
|
||||
self.serialize_node_list = []
|
||||
for node in node_list:
|
||||
node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name))
|
||||
node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name))
|
||||
node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name)) # type: ignore[arg-type, return-value]
|
||||
node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name)) # type: ignore[arg-type, return-value]
|
||||
serialize_node = fx.Node(
|
||||
graph=None,
|
||||
graph=None, # type: ignore[arg-type]
|
||||
name=node.name,
|
||||
op=node.op,
|
||||
target=node.target,
|
||||
args=node_args,
|
||||
kwargs=node_kwargs,
|
||||
args=node_args, # type: ignore[arg-type]
|
||||
kwargs=node_kwargs, # type: ignore[arg-type]
|
||||
return_type=node.type,
|
||||
)
|
||||
serialize_node.meta = copy.copy(node.meta)
|
||||
@ -447,8 +447,8 @@ class _LinearNodeList:
|
||||
deser_node = graph.create_node(
|
||||
op=node.op,
|
||||
target=node.target,
|
||||
args=node_args,
|
||||
kwargs=node_kwargs,
|
||||
args=node_args, # type: ignore[arg-type]
|
||||
kwargs=node_kwargs, # type: ignore[arg-type]
|
||||
name=node.name,
|
||||
type_expr=node.type,
|
||||
)
|
||||
@ -731,7 +731,7 @@ class Pipe(torch.nn.Module):
|
||||
|
||||
# TODO: what does split do with module invocations? does it move the modules
|
||||
# into the submodules?
|
||||
split = split_module(traced, mod, split_callback)
|
||||
split = split_module(traced, mod, split_callback) # type: ignore[arg-type]
|
||||
# a (custom) tracer can produce dead code like orphan get_attr nodes
|
||||
split.graph.eliminate_dead_code()
|
||||
|
||||
|
||||
@ -86,7 +86,7 @@ class TensorChunkSpec:
|
||||
"""
|
||||
args_chunk_spec = map_aggregate(
|
||||
chunk_dims,
|
||||
lambda dim: TensorChunkSpec(dim),
|
||||
lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type, return-value]
|
||||
)
|
||||
return args_chunk_spec
|
||||
|
||||
@ -104,7 +104,7 @@ class TensorChunkSpec:
|
||||
"""
|
||||
kwargs_chunk_spec = map_aggregate(
|
||||
chunk_dims,
|
||||
lambda dim: TensorChunkSpec(dim),
|
||||
lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type, return-value]
|
||||
)
|
||||
return kwargs_chunk_spec
|
||||
|
||||
|
||||
@ -1,14 +1,15 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Callable, TypeVar
|
||||
import textwrap
|
||||
|
||||
_BACK_COMPAT_OBJECTS : Dict[Any, None] = {}
|
||||
_MARKED_WITH_COMPATIBILITY : Dict[Any, None] = {}
|
||||
|
||||
def compatibility(is_backward_compatible : bool):
|
||||
_T = TypeVar("_T")
|
||||
|
||||
def compatibility(is_backward_compatible : bool) -> Callable[[_T], _T]:
|
||||
if is_backward_compatible:
|
||||
|
||||
def mark_back_compat(fn):
|
||||
def mark_back_compat(fn: _T) -> _T:
|
||||
docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
|
||||
docstring += """
|
||||
.. note::
|
||||
@ -22,7 +23,7 @@ def compatibility(is_backward_compatible : bool):
|
||||
return mark_back_compat
|
||||
else:
|
||||
|
||||
def mark_not_back_compat(fn):
|
||||
def mark_not_back_compat(fn: _T) -> _T:
|
||||
docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
|
||||
docstring += """
|
||||
.. warning::
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import builtins
|
||||
import copy
|
||||
|
||||
@ -165,7 +165,7 @@ class MetaTracer(torch.fx.Tracer):
|
||||
meta_target = manual_meta_overrides.get(target, target)
|
||||
meta_out = meta_target(*args_metas, **kwargs_metas)
|
||||
elif kind == 'call_method':
|
||||
meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas)
|
||||
meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas) # type: ignore[index] # type: ignore[index]
|
||||
elif kind == 'call_module':
|
||||
assert hasattr(self, 'orig_forward')
|
||||
self._disable_module_getattr = True
|
||||
@ -173,7 +173,7 @@ class MetaTracer(torch.fx.Tracer):
|
||||
mod = self.root.get_submodule(target)
|
||||
mod_type = type(mod)
|
||||
if mod_type in manual_meta_overrides:
|
||||
meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas)
|
||||
meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas) # type: ignore[misc, arg-type]
|
||||
else:
|
||||
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
|
||||
finally:
|
||||
@ -237,7 +237,7 @@ class MetaTracer(torch.fx.Tracer):
|
||||
def proxy(self, node):
|
||||
return MetaProxy(node, self)
|
||||
|
||||
def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None):
|
||||
def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None): # type: ignore[override]
|
||||
assert isinstance(meta_args, dict)
|
||||
self.meta_args = meta_args
|
||||
|
||||
@ -263,7 +263,7 @@ def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]],
|
||||
meta_args : Optional[Dict[str, torch.Tensor]] = None,
|
||||
concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule:
|
||||
tracer = MetaTracer()
|
||||
graph = tracer.trace(root, meta_args, concrete_args)
|
||||
graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type]
|
||||
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||
gm = torch.fx.GraphModule(tracer.root, graph, name)
|
||||
return gm
|
||||
|
||||
@ -153,10 +153,10 @@ def set_proxy_slot(
|
||||
if isinstance(obj, Tensor):
|
||||
# We DO want to clobber proxies whenever we run an inplace operation
|
||||
# on a tensor, and it affects the metadata on the proxy.
|
||||
tracer.tensor_tracker[obj] = proxy
|
||||
tracer.tensor_tracker[obj] = proxy # type: ignore[has-type]
|
||||
elif isinstance(obj, (_AnyScriptObject)):
|
||||
# We DO want to clobber proxies, with a similar rationale as for tensors.
|
||||
tracer.script_object_tracker[obj] = proxy
|
||||
tracer.script_object_tracker[obj] = proxy # type: ignore[has-type]
|
||||
else:
|
||||
# NB: Never clobber pre-existing proxy. Although the proxies
|
||||
# are in principle equivalent, when we do graph partitioning
|
||||
@ -165,8 +165,8 @@ def set_proxy_slot(
|
||||
# THEN later we allocate tangent inputs. Make sure if a SymInt
|
||||
# is derivable from a primal that we use that.
|
||||
assert isinstance(obj, py_sym_types), type(obj)
|
||||
if obj not in tracer.symnode_tracker:
|
||||
tracer.symnode_tracker[obj] = typing.cast(_PySymProxyType, proxy)
|
||||
if obj not in tracer.symnode_tracker: # type: ignore[has-type]
|
||||
tracer.symnode_tracker[obj] = typing.cast(_PySymProxyType, proxy) # type: ignore[has-type]
|
||||
|
||||
def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool:
|
||||
assert isinstance(obj, (Tensor, SymNode)), type(obj)
|
||||
@ -261,12 +261,12 @@ def get_proxy_slot(
|
||||
|
||||
tracker: Any
|
||||
if isinstance(obj, Tensor):
|
||||
tracker = tracer.tensor_tracker
|
||||
tracker = tracer.tensor_tracker # type: ignore[has-type]
|
||||
elif isinstance(obj, _AnyScriptObject):
|
||||
tracker = tracer.script_object_tracker
|
||||
tracker = tracer.script_object_tracker # type: ignore[has-type]
|
||||
else:
|
||||
assert isinstance(obj, py_sym_types), type(obj)
|
||||
tracker = tracer.symnode_tracker
|
||||
tracker = tracer.symnode_tracker # type: ignore[has-type]
|
||||
|
||||
if obj not in tracker:
|
||||
if isinstance(default, _NoDefault):
|
||||
@ -372,7 +372,7 @@ def track_tensor(tensor: Tensor, proxy: Proxy, *, constant: Optional[Tensor], tr
|
||||
)
|
||||
try_set_proxy_slot(
|
||||
tensor.storage_offset(),
|
||||
lambda x: set_meta(tracer.create_proxy('call_function', torch.ops.aten.sym_storage_offset.default, (proxy,)), x)
|
||||
lambda x: set_meta(tracer.create_proxy('call_function', torch.ops.aten.sym_storage_offset.default, (proxy,)), x) # type: ignore[call-arg]
|
||||
)
|
||||
set_proxy_slot(tensor, tracer, _ProxyTensor(proxy, constant))
|
||||
|
||||
@ -767,7 +767,7 @@ class PythonKeyTracer(Tracer):
|
||||
torch_fn_counts: Dict[OpOverload, int]
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(autowrap_modules=())
|
||||
super().__init__(autowrap_modules=()) # type: ignore[arg-type]
|
||||
self.tensor_tracker = WeakTensorKeyDictionary()
|
||||
self.symnode_tracker = _SymNodeDict()
|
||||
self.script_object_tracker = WeakIdKeyDictionary(dict=None, ref_type=_WeakHashRef)
|
||||
@ -803,7 +803,7 @@ class PythonKeyTracer(Tracer):
|
||||
elif isinstance(a, py_sym_types):
|
||||
assert a.node.constant is not None
|
||||
return a.node.constant
|
||||
return super().create_arg(a)
|
||||
return super().create_arg(a) # type: ignore[return-value]
|
||||
|
||||
@overload
|
||||
def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]:
|
||||
@ -867,7 +867,7 @@ def dispatch_trace(
|
||||
tracer: Tracer,
|
||||
concrete_args: Optional[Tuple[Any, ...]] = None,
|
||||
) -> GraphModule:
|
||||
graph = tracer.trace(root, concrete_args)
|
||||
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
|
||||
from torch._inductor.fx_passes.dedupe_symint_uses import dedupe_symints
|
||||
dedupe_symints(graph)
|
||||
name = root.__class__.__name__ if isinstance(root, Module) else root.__name__
|
||||
@ -966,7 +966,7 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
|
||||
# It's for passing the export verifier which needs to verify the meta['val']
|
||||
# TODO(tmanlaibaatar): we should systematically couple it with expoert verifier,
|
||||
# instead of hardcoding it here.
|
||||
node = self.tracer.create_node("call_function", func, args, {})
|
||||
node = self.tracer.create_node("call_function", func, args, {}) # type: ignore[arg-type]
|
||||
if func is torch._C._set_grad_enabled:
|
||||
node.meta['val'] = None
|
||||
return node
|
||||
@ -1090,7 +1090,7 @@ class ProxySymDispatchMode(SymDispatchMode):
|
||||
|
||||
# func doesn't have a __torch_function__ that Proxy can interpose, so
|
||||
# we gotta do it manually
|
||||
n_out = self.tracer.create_node("call_function", func, n_args, {})
|
||||
n_out = self.tracer.create_node("call_function", func, n_args, {}) # type: ignore[arg-type]
|
||||
p_out = fx.Proxy(n_out, self.tracer)
|
||||
set_meta(p_out, out)
|
||||
return p_out
|
||||
@ -1149,7 +1149,7 @@ class DecompositionInterpreter(fx.Interpreter):
|
||||
decomposition_table: Optional[Mapping[OpOverload, Callable]] = None,
|
||||
**kwargs: object
|
||||
) -> None:
|
||||
super().__init__(module, **kwargs)
|
||||
super().__init__(module, **kwargs) # type: ignore[arg-type]
|
||||
self.new_graph = new_graph
|
||||
self.tracer = _GraphAppendingTracerEx(self.new_graph)
|
||||
# Blegh
|
||||
@ -1164,23 +1164,23 @@ class DecompositionInterpreter(fx.Interpreter):
|
||||
# distinguish between different calls to the same torch function.
|
||||
self.tracer.torch_fn_counts = {}
|
||||
|
||||
def placeholder(self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]) -> object:
|
||||
out = super().placeholder(target, args, kwargs)
|
||||
def placeholder(self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]) -> object: # type: ignore[override]
|
||||
out = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
|
||||
proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer)
|
||||
track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
|
||||
# TODO handle case where the first character of target is '*'
|
||||
return out
|
||||
|
||||
def get_attr(self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]) -> object:
|
||||
out = super().get_attr(target, args, kwargs)
|
||||
def get_attr(self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]) -> object: # type: ignore[override]
|
||||
out = super().get_attr(target, args, kwargs) # type: ignore[arg-type]
|
||||
proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer)
|
||||
track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
|
||||
return out
|
||||
|
||||
# call_function, call_method, call_module get traced automatically by the outer mode.
|
||||
|
||||
def output(self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]) -> object:
|
||||
out = super().output(target, args, kwargs)
|
||||
def output(self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]) -> object: # type: ignore[override]
|
||||
out = super().output(target, args, kwargs) # type: ignore[arg-type]
|
||||
|
||||
def get_proxy_node(x: _ProxyTensor) -> fx.node.Node:
|
||||
return x.proxy.node
|
||||
@ -1194,7 +1194,7 @@ class DecompositionInterpreter(fx.Interpreter):
|
||||
# Should enter the mode at least once for being able to restore it later
|
||||
# See: https://github.com/pytorch/pytorch/pull/82549#discussion_r934782025
|
||||
with decompose(self.decomposition_table), self.mode:
|
||||
return super().run(*args, **kwargs)
|
||||
return super().run(*args, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def wrapper_and_args_for_make_fx(
|
||||
@ -1348,7 +1348,7 @@ class _ModuleStackTracer(PythonKeyTracer):
|
||||
self.attr_proxy_map[attr_val].reset_proxy_mapping(attr_val, attr)
|
||||
return self.attr_proxy_map[attr_val]
|
||||
|
||||
def trace(
|
||||
def trace( # type: ignore[override]
|
||||
self,
|
||||
root: Union[Module, Callable],
|
||||
concrete_args: Optional[Dict[str, object]]
|
||||
@ -1438,7 +1438,7 @@ class _ModuleStackTracer(PythonKeyTracer):
|
||||
Add torch_fn by looking at torch_fn_metadata and torch_fn_counts.
|
||||
Add stack_trace by filtering out forward() stack frames.
|
||||
'''
|
||||
node = super().create_node(*args, **kwargs)
|
||||
node = super().create_node(*args, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
# nn_module_stack
|
||||
if node.op not in ["placeholder", "output"]:
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from collections import defaultdict
|
||||
from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import copy
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from .graph_module import GraphModule
|
||||
from ._lazy_graph_module import _make_graph_module
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# Nodes represent a definition of a value in our graph of operators.
|
||||
from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set
|
||||
from ._compatibility import compatibility
|
||||
@ -769,7 +768,7 @@ def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
|
||||
if isinstance(a, tuple):
|
||||
t = tuple(map_aggregate(elem, fn) for elem in a)
|
||||
# Support NamedTuple (if it has `_fields`) by repacking into original type.
|
||||
return t if not hasattr(a, '_fields') else type(a)(*t)
|
||||
return t if not hasattr(a, '_fields') else type(a)(*t) # type: ignore[arg-type]
|
||||
elif isinstance(a, list):
|
||||
return immutable_list(map_aggregate(elem, fn) for elem in a)
|
||||
elif isinstance(a, dict):
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
import inspect
|
||||
|
||||
@ -251,8 +251,8 @@ if HAS_PYDOT:
|
||||
label += f"|target={self._typename(node.target)}" + r"\n"
|
||||
if self.normalize_args:
|
||||
try:
|
||||
args, kwargs = normalize_function(
|
||||
node.target, node.args, node.kwargs, normalize_to_only_use_kwargs=True
|
||||
args, kwargs = normalize_function( # type: ignore[misc]
|
||||
node.target, node.args, node.kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type]
|
||||
)
|
||||
except Exception:
|
||||
# Fallback to not normalizing if there's an exception.
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Dict, List, NamedTuple, Optional
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
@ -302,7 +302,7 @@ class _MinimizerBase:
|
||||
|
||||
# Find submodule containing colored nodes
|
||||
submodule_name: str = ""
|
||||
for child_name, _ in split_module.named_children():
|
||||
for child_name, _ in split_module.named_children(): # type: ignore[union-attr]
|
||||
# Skip submodules we're not interested in at the moment
|
||||
if "minimize" not in child_name:
|
||||
continue
|
||||
@ -319,7 +319,7 @@ class _MinimizerBase:
|
||||
f"Minimize submodule was not found with nodes {nodes}"
|
||||
)
|
||||
|
||||
return split_module, submodule_name
|
||||
return split_module, submodule_name # type: ignore[return-value]
|
||||
|
||||
def _run_and_compare(
|
||||
self,
|
||||
@ -391,10 +391,10 @@ class _MinimizerBase:
|
||||
report.append(f"Result mismatch for {result_key}")
|
||||
if self.module_exporter:
|
||||
self.module_exporter(
|
||||
a_input, submodule, str(result_key[0]) + "_cpu",
|
||||
a_input, submodule, str(result_key[0]) + "_cpu", # type: ignore[index]
|
||||
)
|
||||
self.module_exporter(
|
||||
b_input, submodule, str(result_key[0]) + "_acc",
|
||||
b_input, submodule, str(result_key[0]) + "_acc", # type: ignore[index]
|
||||
)
|
||||
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}")
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import abc
|
||||
import typing as t
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||
import torch
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
import operator
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
@ -163,14 +162,14 @@ def split_module(
|
||||
)
|
||||
if keep_original_node_name:
|
||||
args = () if default_value is inspect.Signature.empty else (default_value,)
|
||||
base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type)
|
||||
base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type) # type: ignore[arg-type]
|
||||
else:
|
||||
base_mod_env[node.name] = base_mod_graph.placeholder(
|
||||
node.target, type_expr=node.type, default_value=default_value
|
||||
node.target, type_expr=node.type, default_value=default_value # type: ignore[arg-type]
|
||||
)
|
||||
base_mod_env[node.name].meta = node.meta.copy()
|
||||
elif node.op == "get_attr":
|
||||
base_mod_env[node.name] = base_mod_graph.get_attr(node.target)
|
||||
base_mod_env[node.name] = base_mod_graph.get_attr(node.target) # type: ignore[arg-type]
|
||||
base_mod_env[node.name].meta = node.meta.copy()
|
||||
attr_val = m
|
||||
for atom in node.target.split("."): # type: ignore[union-attr]
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import argparse
|
||||
import copy
|
||||
@ -860,7 +859,7 @@ class _SplitterBase:
|
||||
for node in self.module.graph.nodes:
|
||||
if hasattr(node, "tag"):
|
||||
del node.tag
|
||||
return split_module
|
||||
return split_module # type: ignore[return-value]
|
||||
|
||||
def __call__(self) -> torch.fx.GraphModule:
|
||||
subgraphs = self.put_nodes_into_subgraphs()
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import List, Tuple, Union, Dict, Any, Set, Mapping, Optional
|
||||
import collections
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Dict, Tuple
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
from queue import SimpleQueue
|
||||
|
||||
@ -31,8 +31,8 @@ def _split_to_graph_and_name_node_map(
|
||||
name_node_map, Dict
|
||||
), "Expecting the input graph to have a dict output as the last element"
|
||||
n.args = (flattened,)
|
||||
orig_pytree_info = gm._graph._codegen.pytree_info
|
||||
gm._graph._codegen.pytree_info = _PyTreeInfo(
|
||||
orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined]
|
||||
gm._graph._codegen.pytree_info = _PyTreeInfo( # type: ignore[attr-defined]
|
||||
orig_pytree_info.orig_args, orig_pytree_info.in_spec, out_spec
|
||||
)
|
||||
gm.recompile()
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from dataclasses import dataclass, field
|
||||
from torch.fx.graph import Graph
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
from .graph_module import GraphModule
|
||||
from .graph import Graph
|
||||
from .node import Node
|
||||
@ -319,8 +318,8 @@ def _replace_pattern(
|
||||
|
||||
# Hook the output Node of the replacement subgraph in to the
|
||||
# original Graph at the correct location
|
||||
assert len(match.returning_nodes) == len(copied_returning_nodes)
|
||||
for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes):
|
||||
assert len(match.returning_nodes) == len(copied_returning_nodes) # type: ignore[arg-type]
|
||||
for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes): # type: ignore[arg-type]
|
||||
gn.replace_all_uses_with(copied_node)
|
||||
match_changed_node[gn] = copied_node
|
||||
# Remove the original nodes
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import traceback
|
||||
from contextlib import contextmanager
|
||||
|
||||
@ -305,7 +305,7 @@ def jagged_torch_function(func, *args, **kwargs):
|
||||
def _flatten_sig(input, start_dim=0, end_dim=-1):
|
||||
pass
|
||||
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
_flatten_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -385,7 +385,7 @@ def tensor_attr_unsupported_getter(func, *args, **kwargs):
|
||||
def is_contiguous_general(func, *args, **kwargs):
|
||||
from torch._prims_common import is_contiguous_for_memory_format
|
||||
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
inp = new_kwargs.pop("input")
|
||||
@ -409,7 +409,7 @@ register_jagged_func(
|
||||
|
||||
@register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?")
|
||||
def linear_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -423,7 +423,7 @@ def linear_default(func, *args, **kwargs):
|
||||
"self: jt, grad_output: jt, weight: t, output_mask: any",
|
||||
)
|
||||
def linear_backward_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -444,7 +444,7 @@ def linear_backward_default(func, *args, **kwargs):
|
||||
def to_copy_default(func, *args, **kwargs):
|
||||
from .nested_tensor import _tensor_symint_registry
|
||||
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -476,7 +476,7 @@ register_jagged_func(torch.ops.aten.detach.default, "self: jt_all")(
|
||||
"self: jt_all",
|
||||
)
|
||||
def like_factory_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -492,7 +492,7 @@ def like_factory_default(func, *args, **kwargs):
|
||||
|
||||
@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all")
|
||||
def zero__default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -505,7 +505,7 @@ def zero__default(func, *args, **kwargs):
|
||||
torch.ops.aten._softmax.default, "self: jt, dim: any, half_to_float: any"
|
||||
)
|
||||
def _softmax_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -521,7 +521,7 @@ def _softmax_default(func, *args, **kwargs):
|
||||
"grad_output: jt, output: jt, dim: any, input_dtype: any",
|
||||
)
|
||||
def _softmax_backward(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
grad_out = new_kwargs.pop("grad_output")
|
||||
@ -535,7 +535,7 @@ def _softmax_backward(func, *args, **kwargs):
|
||||
torch.ops.aten.native_dropout.default, "self: jt, float: any, train: any?"
|
||||
)
|
||||
def native_dropout_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -552,7 +552,7 @@ def native_dropout_default(func, *args, **kwargs):
|
||||
"grad_output: jt, mask: jt, scale: any",
|
||||
)
|
||||
def native_dropout_backward_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
grad_output = new_kwargs.pop("grad_output")
|
||||
@ -565,7 +565,7 @@ def native_dropout_backward_default(func, *args, **kwargs):
|
||||
|
||||
@register_jagged_func(torch.ops.aten.prod.dim_int, "self: jt, dim: any, keepdim: any?")
|
||||
def prod_dim_int(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -584,7 +584,7 @@ def prod_dim_int(func, *args, **kwargs):
|
||||
torch.ops.aten.split.Tensor, "self: jt, split_size: any, dim: any"
|
||||
)
|
||||
def split_tensor(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -602,7 +602,7 @@ def split_tensor(func, *args, **kwargs):
|
||||
torch.ops.aten.split_with_sizes.default, "self: jt, split_sizes: any, dim: any"
|
||||
)
|
||||
def split_with_sizes_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -620,7 +620,7 @@ def split_with_sizes_default(func, *args, **kwargs):
|
||||
|
||||
@register_jagged_func(torch.ops.aten.chunk.default, "self: jt, chunks: any, dim: any?")
|
||||
def chunk_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -663,7 +663,7 @@ def chunk_default(func, *args, **kwargs):
|
||||
@register_jagged_func(torch.ops.aten.unbind.int, "self: jt_all, dim: any?")
|
||||
def unbind_int(func, *args, **kwargs):
|
||||
# Note that this specializes on the length of the offsets
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -697,7 +697,7 @@ def unbind_int(func, *args, **kwargs):
|
||||
|
||||
@register_jagged_func(torch.ops.aten.squeeze.dim, "self: jt, dim: any")
|
||||
def squeeze_dim(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -710,7 +710,7 @@ def squeeze_dim(func, *args, **kwargs):
|
||||
|
||||
@register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt, dim: any")
|
||||
def unsqueeze_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -725,7 +725,7 @@ def unsqueeze_default(func, *args, **kwargs):
|
||||
|
||||
@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any")
|
||||
def cat_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -748,7 +748,7 @@ def cat_default(func, *args, **kwargs):
|
||||
|
||||
@register_jagged_func(torch.ops.aten.matmul.default, "self: jt, other: any")
|
||||
def matmul_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -773,7 +773,7 @@ def matmul_default(func, *args, **kwargs):
|
||||
torch.ops.aten.expand.default, "self: jt, size: any, implicit: any?"
|
||||
)
|
||||
def expand_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -790,7 +790,7 @@ def expand_default(func, *args, **kwargs):
|
||||
|
||||
@register_jagged_func(torch.ops.aten.expand_as.default, "self: t, other: jt")
|
||||
def expand_as_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -802,7 +802,7 @@ def expand_as_default(func, *args, **kwargs):
|
||||
|
||||
@register_jagged_func(torch.ops.aten.where.self, "condition: jt, self: jt, other: jt")
|
||||
def where_self(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -820,7 +820,7 @@ def where_self(func, *args, **kwargs):
|
||||
|
||||
@register_jagged_func(torch.ops.aten._pin_memory.default, "self: jt, device: any?")
|
||||
def _pin_memory_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -831,7 +831,7 @@ def _pin_memory_default(func, *args, **kwargs):
|
||||
|
||||
@register_jagged_func(torch.ops.aten.is_pinned.default, "self: jt, device: any?")
|
||||
def is_pinned_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -856,7 +856,7 @@ def sum_dim_IntList(func, *args, **kwargs):
|
||||
Performs a sum along the provided tensor dimension.
|
||||
Returns a dense tensor if the ragged dimension is reduced away, else returns a nested tensor.
|
||||
"""
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
inp = new_kwargs.pop("input")
|
||||
@ -930,7 +930,7 @@ def sum_dim_IntList(func, *args, **kwargs):
|
||||
torch.ops.aten.transpose.int, "self: jt_all, dim0: any, dim1: any"
|
||||
)
|
||||
def transpose_int(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -977,7 +977,7 @@ def transpose_int(func, *args, **kwargs):
|
||||
"self: jt_all, size: any",
|
||||
)
|
||||
def view_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -1025,7 +1025,7 @@ def view_default(func, *args, **kwargs):
|
||||
"input: jt, normalized_shape: any, weight: any?, bias: any?, eps: any",
|
||||
)
|
||||
def native_layer_norm_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -1047,7 +1047,7 @@ def native_layer_norm_default(func, *args, **kwargs):
|
||||
"grad_out: jt, input: jt, normalized_shape: any, mean: any, rstd: any, weight: any?, bias: any?, output_mask: any",
|
||||
)
|
||||
def native_layer_norm_backward_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
grad_out = new_kwargs.pop("grad_out")
|
||||
@ -1061,7 +1061,7 @@ def native_layer_norm_backward_default(func, *args, **kwargs):
|
||||
|
||||
@register_jagged_func(torch.ops.aten.select.int, "self: jt, dim: any, index: any")
|
||||
def select_int(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -1076,7 +1076,7 @@ def select_int(func, *args, **kwargs):
|
||||
"self: jt, dim: any?, start: any?, end: any?, step: any?",
|
||||
)
|
||||
def slice_tensor(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -1092,7 +1092,7 @@ def slice_tensor(func, *args, **kwargs):
|
||||
"dilation: any, transposed: any, output_padding: any, groups: any",
|
||||
)
|
||||
def convolution_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -1109,7 +1109,7 @@ def mean_dim(func, *args, **kwargs):
|
||||
Performs a mean along the provided tensor dimension.
|
||||
Returns a dense tensor if the ragged dimension is reduced away, else returns a nested tensor.
|
||||
"""
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -1165,7 +1165,7 @@ def mean_dim(func, *args, **kwargs):
|
||||
|
||||
@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any")
|
||||
def stack_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -1199,7 +1199,7 @@ def stack_default(func, *args, **kwargs):
|
||||
"weight: t, indices: jt, padding_idx: any?, scale_grad_by_freq: any?, sparse: any?",
|
||||
)
|
||||
def embedding_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -1220,7 +1220,7 @@ def embedding_default(func, *args, **kwargs):
|
||||
"self: jt_all",
|
||||
)
|
||||
def values_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -1236,7 +1236,7 @@ def values_default(func, *args, **kwargs):
|
||||
"values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?",
|
||||
)
|
||||
def _nested_view_from_jagged_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -1265,7 +1265,7 @@ def _nested_view_from_jagged_default(func, *args, **kwargs):
|
||||
|
||||
@register_jagged_func(torch.ops.aten._nested_get_offsets.default, "self: jt_all")
|
||||
def _nested_get_offsets(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -1275,7 +1275,7 @@ def _nested_get_offsets(func, *args, **kwargs):
|
||||
|
||||
@register_jagged_func(torch.ops.aten._nested_get_lengths.default, "self: jt_all")
|
||||
def _nested_get_lengths(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -1285,7 +1285,7 @@ def _nested_get_lengths(func, *args, **kwargs):
|
||||
|
||||
@register_jagged_func(torch.ops.aten._nested_get_ragged_idx.default, "self: jt_all")
|
||||
def _nested_get_ragged_idx(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -1295,7 +1295,7 @@ def _nested_get_ragged_idx(func, *args, **kwargs):
|
||||
|
||||
@register_jagged_func(torch.ops.aten._nested_get_min_seqlen.default, "self: jt_all")
|
||||
def _nested_get_min_seqlen(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
@ -1305,7 +1305,7 @@ def _nested_get_min_seqlen(func, *args, **kwargs):
|
||||
|
||||
@register_jagged_func(torch.ops.aten._nested_get_max_seqlen.default, "self: jt_all")
|
||||
def _nested_get_max_seqlen(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
|
||||
@ -243,7 +243,7 @@ def _fx_args_to_torch_args(
|
||||
if isinstance(arg, torch.fx.Node):
|
||||
fake_tensor = arg.meta.get("val")
|
||||
if fake_tensor is None and arg.op == "get_attr":
|
||||
fake_tensor = getattr(fx_graph_module, arg.target) # type: ignore[operator]
|
||||
fake_tensor = getattr(fx_graph_module, arg.target) # type: ignore[operator, arg-type]
|
||||
# NOTE: Currently, we are aware of
|
||||
# FakeTensor/Tensor/SymInt/SymFloat/Symbool/int/float/bool could be in
|
||||
# arg.meta["val"]/get_attr.
|
||||
@ -254,8 +254,8 @@ def _fx_args_to_torch_args(
|
||||
wrapped_args.append(real_tensor)
|
||||
elif isinstance(fake_tensor, (int, float, bool)):
|
||||
wrapped_args.append(fake_tensor)
|
||||
elif symbolic_shapes.has_hint(fake_tensor):
|
||||
wrapped_args.append(symbolic_shapes.hint_int(fake_tensor))
|
||||
elif symbolic_shapes.has_hint(fake_tensor): # type: ignore[arg-type]
|
||||
wrapped_args.append(symbolic_shapes.hint_int(fake_tensor)) # type: ignore[arg-type]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected input argument type found inside fx.Node. arg: {arg}; "
|
||||
|
||||
@ -61,7 +61,7 @@ def _create_tensor_proto_with_external_data(
|
||||
|
||||
tensor_proto = onnx.TensorProto() # type: ignore[attr-defined]
|
||||
tensor_proto.name = name
|
||||
tensor_proto.data_type = scalar_type.onnx_type()
|
||||
tensor_proto.data_type = scalar_type.onnx_type() # type: ignore[assignment]
|
||||
|
||||
tensor_proto.dims.extend(tensor.shape)
|
||||
tensor_proto.data_location = onnx.TensorProto.EXTERNAL # type: ignore[attr-defined]
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import dataclasses
|
||||
import importlib
|
||||
@ -647,7 +646,7 @@ Examples::
|
||||
"""
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@dataclasses.dataclass(frozen=True) # type: ignore[arg-type]
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class OrtBackendOptions:
|
||||
"""Options for constructing an ``OrtBackend``, the ONNX Runtime
|
||||
|
||||
@ -611,7 +611,7 @@ def return_and_correct_aliasing(func, args, kwargs, out):
|
||||
return None
|
||||
|
||||
def get_arg_from_alias(output_alias, schema_info, args, kwargs):
|
||||
new_args, new_kwargs = torch.fx.operator_schemas.normalize_function(
|
||||
new_args, new_kwargs = torch.fx.operator_schemas.normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user