mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] typing for decorators - fx/_compatibility (part 1) (#134202)
Part of #134054. This corresponds to the pytorch mypy changes from D61493706. Updating takes so long and touches so many files that it's impossible to land as a whole without conflicting with some other intermediate change. So landing these 'type: ignore' for pytorch in advance of them actually being needed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134202 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
44fa9f991c
commit
d95aedf5fd
@ -445,7 +445,7 @@ class AutogradCompilerInstance:
|
|||||||
|
|
||||||
def bind_tensors_to_proxies(self, tensors, proxies):
|
def bind_tensors_to_proxies(self, tensors, proxies):
|
||||||
if isinstance(proxies, torch.fx.Proxy):
|
if isinstance(proxies, torch.fx.Proxy):
|
||||||
proxies = [proxies[i] for i in range(len(tensors))]
|
proxies = [proxies[i] for i in range(len(tensors))] # type: ignore[index]
|
||||||
assert len(tensors) == len(proxies)
|
assert len(tensors) == len(proxies)
|
||||||
track_tensor_tree(tensors, proxies, constant=None, tracer=self.fx_tracer)
|
track_tensor_tree(tensors, proxies, constant=None, tracer=self.fx_tracer)
|
||||||
|
|
||||||
|
@ -1539,7 +1539,7 @@ def export(
|
|||||||
# Running graph with interpreter is needed for propagating the stack_trace
|
# Running graph with interpreter is needed for propagating the stack_trace
|
||||||
def graph_with_interpreter(*args):
|
def graph_with_interpreter(*args):
|
||||||
with torch.fx.traceback.preserve_node_meta():
|
with torch.fx.traceback.preserve_node_meta():
|
||||||
return torch.fx.Interpreter(graph).run(*args)
|
return torch.fx.Interpreter(graph).run(*args) # type: ignore[arg-type]
|
||||||
|
|
||||||
with unset_fake_temporarily(), enable_python_dispatcher(), fake_mode:
|
with unset_fake_temporarily(), enable_python_dispatcher(), fake_mode:
|
||||||
try:
|
try:
|
||||||
@ -1561,9 +1561,9 @@ def export(
|
|||||||
|
|
||||||
assert graph is not None
|
assert graph is not None
|
||||||
for node in graph.graph.find_nodes(op="get_attr"):
|
for node in graph.graph.find_nodes(op="get_attr"):
|
||||||
if isinstance(getattr(graph, node.target), torch.Tensor):
|
if isinstance(getattr(graph, node.target), torch.Tensor): # type: ignore[arg-type]
|
||||||
node.meta["val"] = fake_mode.from_tensor(
|
node.meta["val"] = fake_mode.from_tensor(
|
||||||
getattr(graph, node.target), static_shapes=True
|
getattr(graph, node.target), static_shapes=True # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
if same_signature:
|
if same_signature:
|
||||||
|
@ -2231,7 +2231,7 @@ def get_real_value(node, tracer):
|
|||||||
return cache[node]
|
return cache[node]
|
||||||
|
|
||||||
op = node.op
|
op = node.op
|
||||||
args, kwargs = torch.fx.node.map_arg(
|
args, kwargs = torch.fx.node.map_arg( # type: ignore[misc]
|
||||||
(node.args, node.kwargs),
|
(node.args, node.kwargs),
|
||||||
lambda n: get_real_value(n, tracer),
|
lambda n: get_real_value(n, tracer),
|
||||||
)
|
)
|
||||||
|
@ -1343,7 +1343,7 @@ class ExplainTS2FXGraphConverter(TS2FXGraphConverter):
|
|||||||
self.name_to_node,
|
self.name_to_node,
|
||||||
# Dummy node.
|
# Dummy node.
|
||||||
torch.fx.Node(
|
torch.fx.Node(
|
||||||
None,
|
None, # type: ignore[arg-type]
|
||||||
"mock",
|
"mock",
|
||||||
"call_function",
|
"call_function",
|
||||||
lambda: None,
|
lambda: None,
|
||||||
|
@ -68,7 +68,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
|||||||
self.fake_tensor_mode: Optional[FakeTensorMode] = None
|
self.fake_tensor_mode: Optional[FakeTensorMode] = None
|
||||||
self.submodules: Dict[torch.nn.Module, str] = {}
|
self.submodules: Dict[torch.nn.Module, str] = {}
|
||||||
|
|
||||||
def trace(self) -> None:
|
def trace(self) -> None: # type: ignore[override]
|
||||||
raise ExportPassBaseError("ExportTracer doesn't support trace().")
|
raise ExportPassBaseError("ExportTracer doesn't support trace().")
|
||||||
|
|
||||||
def create_arg(self, a: Argument) -> torch.fx.Node:
|
def create_arg(self, a: Argument) -> torch.fx.Node:
|
||||||
@ -160,7 +160,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
|||||||
|
|
||||||
def placeholder(
|
def placeholder(
|
||||||
self,
|
self,
|
||||||
target: str,
|
target: str, # type: ignore[override]
|
||||||
args: Tuple[Argument, ...],
|
args: Tuple[Argument, ...],
|
||||||
kwargs: Dict[str, Argument],
|
kwargs: Dict[str, Argument],
|
||||||
) -> ProxyValue:
|
) -> ProxyValue:
|
||||||
@ -218,7 +218,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
|||||||
raise ExportPassBaseError(f"Unsupported target type: {target}")
|
raise ExportPassBaseError(f"Unsupported target type: {target}")
|
||||||
|
|
||||||
def get_attr(
|
def get_attr(
|
||||||
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
|
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override]
|
||||||
) -> Argument:
|
) -> Argument:
|
||||||
return super().get_attr(target, args, kwargs)
|
return super().get_attr(target, args, kwargs)
|
||||||
|
|
||||||
@ -231,7 +231,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
|||||||
raise ExportPassBaseError("call_module is not supported.")
|
raise ExportPassBaseError("call_module is not supported.")
|
||||||
|
|
||||||
def call_method(
|
def call_method(
|
||||||
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
|
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override]
|
||||||
) -> None:
|
) -> None:
|
||||||
raise ExportPassBaseError("call_method is not supported.")
|
raise ExportPassBaseError("call_method is not supported.")
|
||||||
|
|
||||||
@ -394,7 +394,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
|||||||
)
|
)
|
||||||
self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
|
self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
|
||||||
interpreter = self.ExportInterpreter(self, graph_module)
|
interpreter = self.ExportInterpreter(self, graph_module)
|
||||||
prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter(
|
prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( # type: ignore[assignment]
|
||||||
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
|
torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
|
||||||
)
|
)
|
||||||
inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs)
|
inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs)
|
||||||
|
@ -64,7 +64,7 @@ def _replace_with_hop_helper(
|
|||||||
# Rename the name of getitem nodes to the actual name of its contents
|
# Rename the name of getitem nodes to the actual name of its contents
|
||||||
# for passing verifier and better readability, also propagate metadata
|
# for passing verifier and better readability, also propagate metadata
|
||||||
for get_item_node in call_func_node.users.keys():
|
for get_item_node in call_func_node.users.keys():
|
||||||
idx: int = get_item_node.args[1]
|
idx: int = get_item_node.args[1] # type: ignore[assignment]
|
||||||
output_node = output_args[idx]
|
output_node = output_args[idx]
|
||||||
get_item_node._rename(output_node.name)
|
get_item_node._rename(output_node.name)
|
||||||
get_item_node.meta = output_node.meta
|
get_item_node.meta = output_node.meta
|
||||||
|
@ -177,7 +177,7 @@ def _extract_graph_with_inputs_outputs(
|
|||||||
|
|
||||||
for node in joint_graph.nodes:
|
for node in joint_graph.nodes:
|
||||||
if _must_be_in_backward(node) and subgraph != "backward":
|
if _must_be_in_backward(node) and subgraph != "backward":
|
||||||
env[node] = InvalidNode
|
env[node] = InvalidNode # type: ignore[assignment]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if node in env:
|
if node in env:
|
||||||
@ -186,7 +186,7 @@ def _extract_graph_with_inputs_outputs(
|
|||||||
# joint_graph.nodes).
|
# joint_graph.nodes).
|
||||||
continue
|
continue
|
||||||
elif node.op == "placeholder":
|
elif node.op == "placeholder":
|
||||||
env[node] = InvalidNode
|
env[node] = InvalidNode # type: ignore[assignment]
|
||||||
elif node.op == "call_function":
|
elif node.op == "call_function":
|
||||||
all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs)
|
all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs)
|
||||||
all_args = [
|
all_args = [
|
||||||
@ -195,7 +195,7 @@ def _extract_graph_with_inputs_outputs(
|
|||||||
if isinstance(x, fx.Node)
|
if isinstance(x, fx.Node)
|
||||||
]
|
]
|
||||||
if any(all_args):
|
if any(all_args):
|
||||||
env[node] = InvalidNode
|
env[node] = InvalidNode # type: ignore[assignment]
|
||||||
continue
|
continue
|
||||||
env[node] = new_graph.node_copy(node, lambda x: env[x])
|
env[node] = new_graph.node_copy(node, lambda x: env[x])
|
||||||
elif node.op == "get_attr":
|
elif node.op == "get_attr":
|
||||||
|
@ -816,7 +816,7 @@ def trace_flex_attention_backward(
|
|||||||
)
|
)
|
||||||
assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
|
assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
|
||||||
block_mask = block_mask[:-1] + (mask_graph,)
|
block_mask = block_mask[:-1] + (mask_graph,)
|
||||||
proxy_mode.tracer.root.register_module("fw_graph", fw_graph)
|
proxy_mode.tracer.root.register_module("fw_graph", fw_graph) # type: ignore[arg-type]
|
||||||
proxy_mode.tracer.root.register_module("joint_graph", joint_graph)
|
proxy_mode.tracer.root.register_module("joint_graph", joint_graph)
|
||||||
proxy_mode.tracer.root.register_module("mask_graph", mask_graph)
|
proxy_mode.tracer.root.register_module("mask_graph", mask_graph)
|
||||||
node_args = (
|
node_args = (
|
||||||
|
@ -222,7 +222,7 @@ class ConstantFolder(torch.fx.Interpreter):
|
|||||||
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
|
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
|
||||||
self.node_replacements[node] = tensor
|
self.node_replacements[node] = tensor
|
||||||
|
|
||||||
def run(self) -> Any:
|
def run(self) -> Any: # type: ignore[override]
|
||||||
env: Dict[torch.fx.Node, Any] = {}
|
env: Dict[torch.fx.Node, Any] = {}
|
||||||
self.insert_placerholder_values(env)
|
self.insert_placerholder_values(env)
|
||||||
return super().run(initial_env=env)
|
return super().run(initial_env=env)
|
||||||
|
@ -144,7 +144,7 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
|
|||||||
kwargs = {}
|
kwargs = {}
|
||||||
if hasattr(snode, "get_device"):
|
if hasattr(snode, "get_device"):
|
||||||
kwargs = {"device": snode.get_device()}
|
kwargs = {"device": snode.get_device()}
|
||||||
fx_node = graph.call_function(node_func, args=(), kwargs=kwargs)
|
fx_node = graph.call_function(node_func, args=(), kwargs=kwargs) # type: ignore[arg-type]
|
||||||
|
|
||||||
def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool:
|
def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool:
|
||||||
if isinstance(snode, FusedSchedulerNode):
|
if isinstance(snode, FusedSchedulerNode):
|
||||||
|
@ -154,17 +154,17 @@ def binary_folding_init():
|
|||||||
return False
|
return False
|
||||||
if isinstance(other, torch.fx.Node) and other.op == "get_attr":
|
if isinstance(other, torch.fx.Node) and other.op == "get_attr":
|
||||||
other_meta_value = other.meta.get("val")
|
other_meta_value = other.meta.get("val")
|
||||||
if not other_meta_value.is_floating_point():
|
if not other_meta_value.is_floating_point(): # type: ignore[union-attr]
|
||||||
return False
|
return False
|
||||||
if (
|
if (
|
||||||
torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype)
|
torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype) # type: ignore[union-attr]
|
||||||
!= weight_meta_value.dtype
|
!= weight_meta_value.dtype
|
||||||
):
|
):
|
||||||
if not conv_node.meta.get("_allow_conv_mixed_dtype_folding", False):
|
if not conv_node.meta.get("_allow_conv_mixed_dtype_folding", False):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if (
|
if (
|
||||||
other_meta_value.dtype != torch.float
|
other_meta_value.dtype != torch.float # type: ignore[union-attr]
|
||||||
and weight_meta_value.dtype not in (torch.float16, torch.bfloat16)
|
and weight_meta_value.dtype not in (torch.float16, torch.bfloat16)
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
|
@ -213,7 +213,7 @@ def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs
|
|||||||
new_node = graph.create_node(
|
new_node = graph.create_node(
|
||||||
op="call_function",
|
op="call_function",
|
||||||
target=efficient_conv_bn_eval_decomposed,
|
target=efficient_conv_bn_eval_decomposed,
|
||||||
args=args,
|
args=args, # type: ignore[arg-type]
|
||||||
name="efficient_conv_bn_eval",
|
name="efficient_conv_bn_eval",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -223,7 +223,7 @@ def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs
|
|||||||
# take care of the deletion order:
|
# take care of the deletion order:
|
||||||
# delete bn_node first, and then conv_node
|
# delete bn_node first, and then conv_node
|
||||||
graph.erase_node(bn_node)
|
graph.erase_node(bn_node)
|
||||||
graph.erase_node(conv_node)
|
graph.erase_node(conv_node) # type: ignore[arg-type]
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -304,7 +304,7 @@ def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwa
|
|||||||
new_node = graph.create_node(
|
new_node = graph.create_node(
|
||||||
op="call_function",
|
op="call_function",
|
||||||
target=efficient_conv_bn_eval_decomposed,
|
target=efficient_conv_bn_eval_decomposed,
|
||||||
args=args,
|
args=args, # type: ignore[arg-type]
|
||||||
name="efficient_conv_bn_eval",
|
name="efficient_conv_bn_eval",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -314,7 +314,7 @@ def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwa
|
|||||||
# take care of the deletion order:
|
# take care of the deletion order:
|
||||||
# delete bn_node first, and then conv_node
|
# delete bn_node first, and then conv_node
|
||||||
graph.erase_node(bn_node)
|
graph.erase_node(bn_node)
|
||||||
graph.erase_node(conv_node)
|
graph.erase_node(conv_node) # type: ignore[arg-type]
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -373,7 +373,7 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
|
|||||||
# Find a pair of conv and bn computation nodes to optimize.
|
# Find a pair of conv and bn computation nodes to optimize.
|
||||||
counters["inductor"]["efficient_conv_bn_eval"] += 1
|
counters["inductor"]["efficient_conv_bn_eval"] += 1
|
||||||
|
|
||||||
with graph.inserting_before(conv_node):
|
with graph.inserting_before(conv_node): # type: ignore[arg-type]
|
||||||
# create `get_attr` node to access modules
|
# create `get_attr` node to access modules
|
||||||
# note that we directly call `create_node` to fill the `name`
|
# note that we directly call `create_node` to fill the `name`
|
||||||
# argument. `graph.get_attr` and
|
# argument. `graph.get_attr` and
|
||||||
@ -403,4 +403,4 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
|
|||||||
# take care of the deletion order:
|
# take care of the deletion order:
|
||||||
# delete bn_node first, and then conv_node
|
# delete bn_node first, and then conv_node
|
||||||
graph.erase_node(bn_node)
|
graph.erase_node(bn_node)
|
||||||
graph.erase_node(conv_node)
|
graph.erase_node(conv_node) # type: ignore[arg-type]
|
||||||
|
@ -223,5 +223,5 @@ def unnecessary_dtype_convert(match: Match, **kwargs):
|
|||||||
"""Remove unnecessary dtype conversion op, probably left as a result of Conv-Bn folding"""
|
"""Remove unnecessary dtype conversion op, probably left as a result of Conv-Bn folding"""
|
||||||
graph = match.graph
|
graph = match.graph
|
||||||
node = match.output_node()
|
node = match.output_node()
|
||||||
node.replace_all_uses_with(node.args[0])
|
node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type]
|
||||||
graph.erase_node(node)
|
graph.erase_node(node)
|
||||||
|
@ -511,7 +511,7 @@ def pointless_view(match: Match, arg, size):
|
|||||||
node = match.output_node()
|
node = match.output_node()
|
||||||
arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr]
|
arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr]
|
||||||
if size == arg_size:
|
if size == arg_size:
|
||||||
node.replace_all_uses_with(node.args[0])
|
node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type]
|
||||||
match.erase_nodes(graph)
|
match.erase_nodes(graph)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1033,7 +1033,7 @@ def is_index_put_and_requires_h2d_sync_for_gpu_value(node):
|
|||||||
# if the value we are putting is a cpu scalar.
|
# if the value we are putting is a cpu scalar.
|
||||||
# Therefore, when inductor sees an index_put_ with byte tensor indices,
|
# Therefore, when inductor sees an index_put_ with byte tensor indices,
|
||||||
# it should *not* convert the cpu scalar value into a gpu tensor.
|
# it should *not* convert the cpu scalar value into a gpu tensor.
|
||||||
args_, kwargs_ = normalize_function(node.target, node.args, node.kwargs)
|
args_, kwargs_ = normalize_function(node.target, node.args, node.kwargs) # type: ignore[misc]
|
||||||
any_byte_bool_indices = False
|
any_byte_bool_indices = False
|
||||||
indices = args_[1]
|
indices = args_[1]
|
||||||
for i in indices:
|
for i in indices:
|
||||||
|
@ -1887,14 +1887,14 @@ def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float3
|
|||||||
graph.erase_node(conv_node)
|
graph.erase_node(conv_node)
|
||||||
# Erase the dequant pattern
|
# Erase the dequant pattern
|
||||||
if dtype == torch.bfloat16:
|
if dtype == torch.bfloat16:
|
||||||
graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined]
|
graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined, arg-type]
|
||||||
graph.erase_node(dequant_node)
|
graph.erase_node(dequant_node) # type: ignore[arg-type]
|
||||||
# Erase the dequant per channel pattern
|
# Erase the dequant per channel pattern
|
||||||
if clone_node is not None:
|
if clone_node is not None:
|
||||||
graph.erase_node(clone_node)
|
graph.erase_node(clone_node) # type: ignore[arg-type]
|
||||||
if dtype == torch.bfloat16:
|
if dtype == torch.bfloat16:
|
||||||
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
|
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type]
|
||||||
graph.erase_node(dequant_per_channel)
|
graph.erase_node(dequant_per_channel) # type: ignore[arg-type]
|
||||||
counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1
|
counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1
|
||||||
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len(
|
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len(
|
||||||
match.nodes
|
match.nodes
|
||||||
@ -2581,8 +2581,8 @@ def quant_lift_up(graph_module: torch.fx.GraphModule):
|
|||||||
|
|
||||||
new_args = map_arg(new_quant_node.args, maybe_replace_node)
|
new_args = map_arg(new_quant_node.args, maybe_replace_node)
|
||||||
new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node)
|
new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node)
|
||||||
new_quant_node.args = new_args
|
new_quant_node.args = new_args # type: ignore[assignment]
|
||||||
new_quant_node.kwargs = new_kwargs
|
new_quant_node.kwargs = new_kwargs # type: ignore[assignment]
|
||||||
graph_module.graph.erase_node(quant_node)
|
graph_module.graph.erase_node(quant_node)
|
||||||
|
|
||||||
graph_module.graph.lint()
|
graph_module.graph.lint()
|
||||||
|
@ -293,7 +293,7 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr]
|
src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr]
|
||||||
with graph.inserting_before(src):
|
with graph.inserting_before(src): # type: ignore[arg-type]
|
||||||
new_node = graph_call_function(
|
new_node = graph_call_function(
|
||||||
graph,
|
graph,
|
||||||
_generalized_scatter,
|
_generalized_scatter,
|
||||||
@ -316,7 +316,7 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
|
|||||||
handle_views(new_src)
|
handle_views(new_src)
|
||||||
src.replace_all_uses_with(new_src) # type: ignore[union-attr]
|
src.replace_all_uses_with(new_src) # type: ignore[union-attr]
|
||||||
|
|
||||||
graph.erase_node(src)
|
graph.erase_node(src) # type: ignore[arg-type]
|
||||||
|
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
if _is_view_op(node.target):
|
if _is_view_op(node.target):
|
||||||
|
@ -193,7 +193,7 @@ def normalize_split_base(
|
|||||||
new_split_node = graph.call_function(
|
new_split_node = graph.call_function(
|
||||||
torch.split,
|
torch.split,
|
||||||
args=new_args,
|
args=new_args,
|
||||||
kwargs=new_kwargs,
|
kwargs=new_kwargs, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
split_node.replace_all_uses_with(new_split_node)
|
split_node.replace_all_uses_with(new_split_node)
|
||||||
new_split_node.meta.update(split_node.meta)
|
new_split_node.meta.update(split_node.meta)
|
||||||
@ -378,7 +378,7 @@ def normalize_stack_default(match: Match, *args, **kwargs):
|
|||||||
|
|
||||||
with graph.inserting_after(node):
|
with graph.inserting_after(node):
|
||||||
new_node = graph.call_function(
|
new_node = graph.call_function(
|
||||||
node.target,
|
node.target, # type: ignore[arg-type]
|
||||||
args=(tensors,),
|
args=(tensors,),
|
||||||
kwargs={"dim": dim},
|
kwargs={"dim": dim},
|
||||||
)
|
)
|
||||||
@ -529,7 +529,7 @@ def merge_splits(
|
|||||||
|
|
||||||
to_remove = []
|
to_remove = []
|
||||||
|
|
||||||
with graph.inserting_before(first_split):
|
with graph.inserting_before(first_split): # type: ignore[arg-type]
|
||||||
# Add the new split node
|
# Add the new split node
|
||||||
new_split = graph.call_function(
|
new_split = graph.call_function(
|
||||||
torch.split,
|
torch.split,
|
||||||
@ -1535,7 +1535,7 @@ def mutate_cat_node(match: Match, split_sections: List[int], dim: int):
|
|||||||
# case 1: the cat uses all getitems from the split
|
# case 1: the cat uses all getitems from the split
|
||||||
if len(split_sections) == len(cat_user.args[0]): # type: ignore[arg-type]
|
if len(split_sections) == len(cat_user.args[0]): # type: ignore[arg-type]
|
||||||
# replace the users of the cat node to be the input of the split node
|
# replace the users of the cat node to be the input of the split node
|
||||||
cat_user.replace_all_uses_with(split_node.args[0])
|
cat_user.replace_all_uses_with(split_node.args[0]) # type: ignore[arg-type]
|
||||||
# remove the cat node
|
# remove the cat node
|
||||||
graph.erase_node(cat_user)
|
graph.erase_node(cat_user)
|
||||||
counters["inductor"]["mutate_cat_pass"] += 1
|
counters["inductor"]["mutate_cat_pass"] += 1
|
||||||
@ -1947,7 +1947,7 @@ def remove_split_unbind_children(graph: torch.fx.Graph, inputs: List[torch.fx.No
|
|||||||
# check the split node to remove if it has no users
|
# check the split node to remove if it has no users
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
if len(node.users.keys()) == 0: # type: ignore[union-attr]
|
if len(node.users.keys()) == 0: # type: ignore[union-attr]
|
||||||
graph.erase_node(node)
|
graph.erase_node(node) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
# ############pattern to be optimized is#########
|
# ############pattern to be optimized is#########
|
||||||
|
@ -766,7 +766,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
return self.graph_inputs[buffer_name].get_numel()
|
return self.graph_inputs[buffer_name].get_numel()
|
||||||
raise KeyError(f"could not find {buffer_name}")
|
raise KeyError(f"could not find {buffer_name}")
|
||||||
|
|
||||||
def run(self, *args: Any) -> Any:
|
def run(self, *args: Any) -> Any: # type: ignore[override]
|
||||||
with dynamo_timed("GraphLowering.run"):
|
with dynamo_timed("GraphLowering.run"):
|
||||||
return super().run(*args)
|
return super().run(*args)
|
||||||
|
|
||||||
@ -911,9 +911,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def placeholder(
|
def placeholder(
|
||||||
self, target: str, args: Tuple[object], kwargs: Dict[str, object]
|
self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override]
|
||||||
) -> Union[Expr, TensorBox, None]:
|
) -> Union[Expr, TensorBox, None]:
|
||||||
example = super().placeholder(target, args, kwargs)
|
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
|
||||||
self.graph_input_names.append(target)
|
self.graph_input_names.append(target)
|
||||||
if isinstance(example, SymTypes):
|
if isinstance(example, SymTypes):
|
||||||
expr = example.node.expr
|
expr = example.node.expr
|
||||||
@ -968,7 +968,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
self.aligned_inputs.add(target)
|
self.aligned_inputs.add(target)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> Any: # type: ignore[type-arg]
|
def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> Any: # type: ignore[type-arg, override]
|
||||||
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
|
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
|
||||||
return super().call_function(target, args, kwargs)
|
return super().call_function(target, args, kwargs)
|
||||||
|
|
||||||
@ -1023,7 +1023,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
return len(t.shape) == 1 and t.shape[0] <= 8
|
return len(t.shape) == 1 and t.shape[0] <= 8
|
||||||
|
|
||||||
def get_attr(
|
def get_attr(
|
||||||
self, target: str, args: Tuple[()], kwargs: Dict[str, object]
|
self, target: str, args: Tuple[()], kwargs: Dict[str, object] # type: ignore[override]
|
||||||
) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
|
) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
|
||||||
# this is a constant
|
# this is a constant
|
||||||
value = getattr_recursive(self.module, target) # type: ignore[arg-type]
|
value = getattr_recursive(self.module, target) # type: ignore[arg-type]
|
||||||
@ -1063,9 +1063,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
raise AssertionError
|
raise AssertionError
|
||||||
|
|
||||||
def output(
|
def output(
|
||||||
self, target: str, args: Tuple[object], kwargs: Dict[str, object]
|
self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override]
|
||||||
) -> None:
|
) -> None:
|
||||||
result = super().output(target, args, kwargs)
|
result = super().output(target, args, kwargs) # type: ignore[arg-type]
|
||||||
if not isinstance(result, (tuple, list)):
|
if not isinstance(result, (tuple, list)):
|
||||||
# nested subgraphs can have singleton outputs
|
# nested subgraphs can have singleton outputs
|
||||||
result = (result,)
|
result = (result,)
|
||||||
|
@ -6668,7 +6668,7 @@ class InterpreterShim(torch.fx.Interpreter):
|
|||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.submodules = submodules
|
self.submodules = submodules
|
||||||
self.extra_traceback = False
|
self.extra_traceback = False
|
||||||
self.fetch_attr = submodules.__getitem__
|
self.fetch_attr = submodules.__getitem__ # type: ignore[method-assign]
|
||||||
self.current_node = None
|
self.current_node = None
|
||||||
|
|
||||||
def run_node(self, n: torch.fx.Node) -> Any:
|
def run_node(self, n: torch.fx.Node) -> Any:
|
||||||
|
@ -236,7 +236,7 @@ class Match:
|
|||||||
if trace_fn is None:
|
if trace_fn is None:
|
||||||
trace_fn = functools.partial(fwd_only, run_dce=run_dce)
|
trace_fn = functools.partial(fwd_only, run_dce=run_dce)
|
||||||
replacement = trace_fn(
|
replacement = trace_fn(
|
||||||
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"])
|
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
ReplacementPatternEntry.replace_with_graph(
|
ReplacementPatternEntry.replace_with_graph(
|
||||||
self,
|
self,
|
||||||
@ -607,7 +607,7 @@ class _TargetArgsExpr(_TargetExpr):
|
|||||||
from torch.fx.operator_schemas import normalize_function
|
from torch.fx.operator_schemas import normalize_function
|
||||||
|
|
||||||
normalized_args_and_kwargs = normalize_function(
|
normalized_args_and_kwargs = normalize_function(
|
||||||
node.target, node.args, node.kwargs
|
node.target, node.args, node.kwargs # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
if normalized_args_and_kwargs is None:
|
if normalized_args_and_kwargs is None:
|
||||||
@ -1036,7 +1036,7 @@ class ReplacementPatternEntry(PatternEntry):
|
|||||||
if node.op == "call_function":
|
if node.op == "call_function":
|
||||||
target = node.target
|
target = node.target
|
||||||
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
||||||
result = graph.call_function(target, args, kwargs)
|
result = graph.call_function(target, args, kwargs) # type: ignore[arg-type]
|
||||||
if "val" in node.meta and "val" not in result.meta:
|
if "val" in node.meta and "val" not in result.meta:
|
||||||
result.meta["val"] = node.meta["val"]
|
result.meta["val"] = node.meta["val"]
|
||||||
if isinstance(node.meta["val"], torch.Tensor):
|
if isinstance(node.meta["val"], torch.Tensor):
|
||||||
@ -1080,7 +1080,7 @@ class ReplacementPatternEntry(PatternEntry):
|
|||||||
queue.extend(arg.all_input_nodes)
|
queue.extend(arg.all_input_nodes)
|
||||||
|
|
||||||
with graph.inserting_before(last_node):
|
with graph.inserting_before(last_node):
|
||||||
replacement = Replacer(replacement_graph).run(*args)
|
replacement = Replacer(replacement_graph).run(*args) # type: ignore[arg-type]
|
||||||
if isinstance(replacement, torch.fx.Node):
|
if isinstance(replacement, torch.fx.Node):
|
||||||
replacement = [replacement]
|
replacement = [replacement]
|
||||||
|
|
||||||
@ -1101,7 +1101,7 @@ class ReplacementPatternEntry(PatternEntry):
|
|||||||
return
|
return
|
||||||
assert isinstance(old, torch.fx.Node)
|
assert isinstance(old, torch.fx.Node)
|
||||||
if new is None:
|
if new is None:
|
||||||
old.replace_all_uses_with(None)
|
old.replace_all_uses_with(None) # type: ignore[arg-type]
|
||||||
graph.erase_node(old)
|
graph.erase_node(old)
|
||||||
return
|
return
|
||||||
if isinstance(new, torch.fx.Node):
|
if isinstance(new, torch.fx.Node):
|
||||||
@ -1124,7 +1124,6 @@ class ReplacementPatternEntry(PatternEntry):
|
|||||||
graph.erase_node(old)
|
graph.erase_node(old)
|
||||||
return
|
return
|
||||||
|
|
||||||
new = typing.cast(Sequence[torch.fx.Node], new)
|
|
||||||
# `new` is not a node: it's a list of nodes.
|
# `new` is not a node: it's a list of nodes.
|
||||||
#
|
#
|
||||||
# This happens when we want to replace a node that has a single
|
# This happens when we want to replace a node that has a single
|
||||||
@ -1157,7 +1156,7 @@ class ReplacementPatternEntry(PatternEntry):
|
|||||||
idx = maybe_getitem(user)
|
idx = maybe_getitem(user)
|
||||||
if idx is None:
|
if idx is None:
|
||||||
raise AssertionError("can't handle")
|
raise AssertionError("can't handle")
|
||||||
replace(user, new[idx])
|
replace(user, new[idx]) # type: ignore[index]
|
||||||
graph.erase_node(old)
|
graph.erase_node(old)
|
||||||
|
|
||||||
if len(output_nodes) == len(replacement):
|
if len(output_nodes) == len(replacement):
|
||||||
@ -1233,7 +1232,7 @@ def register_replacement(
|
|||||||
)
|
)
|
||||||
|
|
||||||
args = list(
|
args = list(
|
||||||
torch.fx.map_arg(
|
torch.fx.map_arg( # type: ignore[arg-type]
|
||||||
[match.kwargs[name] for name in argnames], lambda n: n.meta["val"]
|
[match.kwargs[name] for name in argnames], lambda n: n.meta["val"]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -1672,8 +1671,8 @@ class PatternMatcherPass:
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}"
|
f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}"
|
||||||
)
|
)
|
||||||
if should_compute_mutation_region_ids(graph):
|
if should_compute_mutation_region_ids(graph): # type: ignore[arg-type]
|
||||||
compute_mutation_region_ids(graph)
|
compute_mutation_region_ids(graph) # type: ignore[arg-type]
|
||||||
get_mutation_region_id_partial = functools.partial(
|
get_mutation_region_id_partial = functools.partial(
|
||||||
get_mutation_region_id, graph
|
get_mutation_region_id, graph
|
||||||
)
|
)
|
||||||
@ -1764,7 +1763,7 @@ def fx_to_pattern(
|
|||||||
get_attr = _not_implemented
|
get_attr = _not_implemented
|
||||||
|
|
||||||
def placeholder(
|
def placeholder(
|
||||||
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any]
|
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
|
||||||
) -> Union[ExclusiveKeywordArg, KeywordArg]:
|
) -> Union[ExclusiveKeywordArg, KeywordArg]:
|
||||||
n = next(argnum)
|
n = next(argnum)
|
||||||
if n < len(argnames):
|
if n < len(argnames):
|
||||||
@ -1781,7 +1780,7 @@ def fx_to_pattern(
|
|||||||
return KeywordArg(name)
|
return KeywordArg(name)
|
||||||
|
|
||||||
def call_function(
|
def call_function(
|
||||||
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any]
|
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
|
||||||
) -> PatternExpr:
|
) -> PatternExpr:
|
||||||
args, kwargs = pytree.tree_map(process_arg, (args, kwargs))
|
args, kwargs = pytree.tree_map(process_arg, (args, kwargs))
|
||||||
if list in ignore_types:
|
if list in ignore_types:
|
||||||
@ -1800,7 +1799,7 @@ def fx_to_pattern(
|
|||||||
rv.users = len(n.users)
|
rv.users = len(n.users)
|
||||||
return rv
|
return rv
|
||||||
|
|
||||||
pattern = Converter(gm).run()
|
pattern = Converter(gm).run() # type: ignore[arg-type]
|
||||||
if not isinstance(pattern, PatternExpr):
|
if not isinstance(pattern, PatternExpr):
|
||||||
return MultiOutputPattern(pytree.tree_leaves(pattern))
|
return MultiOutputPattern(pytree.tree_leaves(pattern))
|
||||||
return pattern
|
return pattern
|
||||||
|
@ -47,7 +47,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
|
|||||||
|
|
||||||
def call_function(
|
def call_function(
|
||||||
self,
|
self,
|
||||||
target: Callable[[Any], Any],
|
target: Callable[[Any], Any], # type: ignore[override]
|
||||||
args: Any,
|
args: Any,
|
||||||
kwargs: Dict[str, Any],
|
kwargs: Dict[str, Any],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -70,7 +70,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
|
|||||||
|
|
||||||
return lowerings[target](*args, **kwargs)
|
return lowerings[target](*args, **kwargs)
|
||||||
|
|
||||||
def output(self, target: str, args: Tuple[Any], kwargs: Dict[str, Any]) -> None:
|
def output(self, target: str, args: Tuple[Any], kwargs: Dict[str, Any]) -> None: # type: ignore[override]
|
||||||
assert len(args) == 1
|
assert len(args) == 1
|
||||||
self.graph_outputs = args[0]
|
self.graph_outputs = args[0]
|
||||||
|
|
||||||
|
@ -371,7 +371,7 @@ def gen_gm_and_inputs(target, args, kwargs):
|
|||||||
len(target._schema.returns) == 1
|
len(target._schema.returns) == 1
|
||||||
and str(target._schema.returns[0].type) == "Tensor"
|
and str(target._schema.returns[0].type) == "Tensor"
|
||||||
):
|
):
|
||||||
node = (node,)
|
node = (node,) # type: ignore[assignment]
|
||||||
g.output(node)
|
g.output(node)
|
||||||
|
|
||||||
gm = torch.fx.GraphModule({}, g)
|
gm = torch.fx.GraphModule({}, g)
|
||||||
|
@ -2216,8 +2216,8 @@ class FakeTensorMode(TorchDispatchMode):
|
|||||||
any_constant = any(e.constant is not None for e in flat_arg_fake_tensors)
|
any_constant = any(e.constant is not None for e in flat_arg_fake_tensors)
|
||||||
schema_info = get_schema_info(func)
|
schema_info = get_schema_info(func)
|
||||||
if any_constant and schema_info.is_mutable():
|
if any_constant and schema_info.is_mutable():
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
for k, v in new_kwargs.items():
|
for k, v in new_kwargs.items():
|
||||||
k = k if (k != "input" or schema_info.has_argument(k)) else "self"
|
k = k if (k != "input" or schema_info.has_argument(k)) else "self"
|
||||||
|
@ -896,7 +896,7 @@ def prepare_n_shadows_model(
|
|||||||
tracer = custom_tracer
|
tracer = custom_tracer
|
||||||
mt = torch.fx.GraphModule(model, tracer.trace(model))
|
mt = torch.fx.GraphModule(model, tracer.trace(model))
|
||||||
# this is necessary to ensure logger FQNs get populated
|
# this is necessary to ensure logger FQNs get populated
|
||||||
mt._node_name_to_scope = tracer.node_name_to_scope
|
mt._node_name_to_scope = tracer.node_name_to_scope # type: ignore[assignment]
|
||||||
|
|
||||||
# run example input propagation, we need this to call prepare_fx on
|
# run example input propagation, we need this to call prepare_fx on
|
||||||
# individual subgraphs
|
# individual subgraphs
|
||||||
@ -998,7 +998,7 @@ def _prepare_n_shadows_add_loggers_model(
|
|||||||
tracer = quantize_fx.QuantizationTracer([], [])
|
tracer = quantize_fx.QuantizationTracer([], [])
|
||||||
mt = torch.fx.GraphModule(model, tracer.trace(model))
|
mt = torch.fx.GraphModule(model, tracer.trace(model))
|
||||||
# this is necessary to ensure logger FQNs get populated
|
# this is necessary to ensure logger FQNs get populated
|
||||||
mt._node_name_to_scope = tracer.node_name_to_scope
|
mt._node_name_to_scope = tracer.node_name_to_scope # type: ignore[assignment]
|
||||||
|
|
||||||
# run example input propagation, we need this to call prepare_fx on
|
# run example input propagation, we need this to call prepare_fx on
|
||||||
# individual subgraphs
|
# individual subgraphs
|
||||||
|
@ -694,13 +694,13 @@ def _insert_copy_of_node_a_after_input_node_c(
|
|||||||
mod_a = getattr_from_fqn(gm_a, node_a.target)
|
mod_a = getattr_from_fqn(gm_a, node_a.target)
|
||||||
setattr(gm_b, new_mod_copy_name, mod_a)
|
setattr(gm_b, new_mod_copy_name, mod_a)
|
||||||
node_a_shadows_c = graph_c.create_node(
|
node_a_shadows_c = graph_c.create_node(
|
||||||
node_a.op, new_mod_copy_name, new_args, new_kwargs, node_a_shadows_c_name
|
node_a.op, new_mod_copy_name, new_args, new_kwargs, node_a_shadows_c_name # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
return node_a_shadows_c
|
return node_a_shadows_c
|
||||||
else:
|
else:
|
||||||
assert node_a.op in ("call_function", "call_method")
|
assert node_a.op in ("call_function", "call_method")
|
||||||
node_a_shadows_c = graph_c.create_node(
|
node_a_shadows_c = graph_c.create_node(
|
||||||
node_a.op, node_a.target, new_args, new_kwargs, node_a_shadows_c_name
|
node_a.op, node_a.target, new_args, new_kwargs, node_a_shadows_c_name # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
return node_a_shadows_c
|
return node_a_shadows_c
|
||||||
|
|
||||||
|
@ -406,19 +406,19 @@ def create_submodule_from_subgraph(
|
|||||||
mod_name = f"mod_{cur_name_idx}"
|
mod_name = f"mod_{cur_name_idx}"
|
||||||
setattr(gm, mod_name, orig_mod_copy)
|
setattr(gm, mod_name, orig_mod_copy)
|
||||||
cur_name_idx += 1
|
cur_name_idx += 1
|
||||||
cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined]
|
cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined,arg-type]
|
||||||
|
|
||||||
elif cur_node_orig.op == "call_function":
|
elif cur_node_orig.op == "call_function":
|
||||||
cur_node_copy = g.call_function(
|
cur_node_copy = g.call_function(
|
||||||
cur_node_orig.target,
|
cur_node_orig.target, # type: ignore[arg-type]
|
||||||
cur_args_copy,
|
cur_args_copy, # type: ignore[arg-type]
|
||||||
cur_kwargs_copy, # type: ignore[possibly-undefined]
|
cur_kwargs_copy, # type: ignore[possibly-undefined]
|
||||||
)
|
)
|
||||||
|
|
||||||
elif cur_node_orig.op == "call_method":
|
elif cur_node_orig.op == "call_method":
|
||||||
cur_node_copy = g.call_method(
|
cur_node_copy = g.call_method(
|
||||||
cur_node_orig.target,
|
cur_node_orig.target, # type: ignore[arg-type]
|
||||||
cur_args_copy,
|
cur_args_copy, # type: ignore[arg-type]
|
||||||
cur_kwargs_copy, # type: ignore[possibly-undefined]
|
cur_kwargs_copy, # type: ignore[possibly-undefined]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -582,7 +582,7 @@ def create_one_transformed_and_logged_copy_of_subgraph(
|
|||||||
|
|
||||||
new_args = tuple(new_args) # type: ignore[assignment]
|
new_args = tuple(new_args) # type: ignore[assignment]
|
||||||
|
|
||||||
new_node = mt.graph.call_module(attr_name, args=new_args, kwargs=new_kwargs)
|
new_node = mt.graph.call_module(attr_name, args=new_args, kwargs=new_kwargs) # type: ignore[arg-type]
|
||||||
|
|
||||||
# add a logger to parent graph to observe the shadow wrapper
|
# add a logger to parent graph to observe the shadow wrapper
|
||||||
logger_mod_orig = _get_logger_for_subgraph(
|
logger_mod_orig = _get_logger_for_subgraph(
|
||||||
|
@ -311,4 +311,4 @@ class BaseStructuredSparsifier(BaseSparsifier):
|
|||||||
|
|
||||||
self.traced.graph.lint()
|
self.traced.graph.lint()
|
||||||
self.traced.recompile()
|
self.traced.recompile()
|
||||||
return self.traced
|
return self.traced # type: ignore[return-value]
|
||||||
|
@ -25,7 +25,7 @@ def _match(
|
|||||||
if isinstance(current, type) and issubclass(current, torch.nn.Module):
|
if isinstance(current, type) and issubclass(current, torch.nn.Module):
|
||||||
return (
|
return (
|
||||||
node.op == "call_module"
|
node.op == "call_module"
|
||||||
and parametrize.type_before_parametrizations(modules[node.target])
|
and parametrize.type_before_parametrizations(modules[node.target]) # type: ignore[index]
|
||||||
== current
|
== current
|
||||||
)
|
)
|
||||||
elif callable(current):
|
elif callable(current):
|
||||||
|
@ -569,7 +569,7 @@ def _match_static_pattern(
|
|||||||
match_key = type(_get_module(ref_node, modules))
|
match_key = type(_get_module(ref_node, modules))
|
||||||
else:
|
else:
|
||||||
expected_op = "call_function"
|
expected_op = "call_function"
|
||||||
match_key = ref_node.target
|
match_key = ref_node.target # type: ignore[assignment]
|
||||||
if ref_node.op != expected_op or match_key not in matching_modules_or_ops:
|
if ref_node.op != expected_op or match_key not in matching_modules_or_ops:
|
||||||
return SKIP_LOWERING_VALUE
|
return SKIP_LOWERING_VALUE
|
||||||
|
|
||||||
@ -589,7 +589,7 @@ def _match_static_pattern(
|
|||||||
if not matched_dequantize:
|
if not matched_dequantize:
|
||||||
return SKIP_LOWERING_VALUE
|
return SKIP_LOWERING_VALUE
|
||||||
|
|
||||||
return (q_node, relu_node, ref_node)
|
return (q_node, relu_node, ref_node) # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
def _match_static_pattern_with_two_inputs(
|
def _match_static_pattern_with_two_inputs(
|
||||||
@ -687,8 +687,8 @@ def _lower_static_weighted_ref_module(
|
|||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
q_class = STATIC_LOWER_MODULE_MAP[ref_class]
|
q_class = STATIC_LOWER_MODULE_MAP[ref_class]
|
||||||
output_scale = getattr(model, scale_node.target)
|
output_scale = getattr(model, scale_node.target) # type: ignore[arg-type]
|
||||||
output_zero_point = getattr(model, zero_point_node.target)
|
output_zero_point = getattr(model, zero_point_node.target) # type: ignore[arg-type]
|
||||||
q_module = q_class.from_reference(ref_module, output_scale, output_zero_point)
|
q_module = q_class.from_reference(ref_module, output_scale, output_zero_point)
|
||||||
# replace reference module with quantized module
|
# replace reference module with quantized module
|
||||||
parent_name, module_name = _parent_name(ref_node.target)
|
parent_name, module_name = _parent_name(ref_node.target)
|
||||||
@ -698,7 +698,7 @@ def _lower_static_weighted_ref_module(
|
|||||||
assert len(ref_node.args) == 1
|
assert len(ref_node.args) == 1
|
||||||
dq_node = ref_node.args[0]
|
dq_node = ref_node.args[0]
|
||||||
assert isinstance(dq_node, Node)
|
assert isinstance(dq_node, Node)
|
||||||
ref_node.replace_input_with(dq_node, dq_node.args[0])
|
ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type]
|
||||||
q_node.replace_all_uses_with(ref_node)
|
q_node.replace_all_uses_with(ref_node)
|
||||||
model.graph.erase_node(q_node)
|
model.graph.erase_node(q_node)
|
||||||
model.graph.erase_node(scale_node)
|
model.graph.erase_node(scale_node)
|
||||||
@ -747,8 +747,8 @@ def _lower_static_weighted_ref_module_with_two_inputs(
|
|||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
output_scale = getattr(model, scale_node.target)
|
output_scale = getattr(model, scale_node.target) # type: ignore[arg-type]
|
||||||
output_zero_point = getattr(model, zero_point_node.target)
|
output_zero_point = getattr(model, zero_point_node.target) # type: ignore[arg-type]
|
||||||
q_module = q_class.from_reference(ref_module, output_scale, output_zero_point)
|
q_module = q_class.from_reference(ref_module, output_scale, output_zero_point)
|
||||||
# replace reference module with quantized module
|
# replace reference module with quantized module
|
||||||
parent_name, module_name = _parent_name(ref_node.target)
|
parent_name, module_name = _parent_name(ref_node.target)
|
||||||
@ -761,7 +761,7 @@ def _lower_static_weighted_ref_module_with_two_inputs(
|
|||||||
continue
|
continue
|
||||||
dq_node = arg
|
dq_node = arg
|
||||||
assert isinstance(dq_node, Node)
|
assert isinstance(dq_node, Node)
|
||||||
ref_node.replace_input_with(dq_node, dq_node.args[0])
|
ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type]
|
||||||
|
|
||||||
q_node.replace_all_uses_with(ref_node)
|
q_node.replace_all_uses_with(ref_node)
|
||||||
model.graph.erase_node(q_node)
|
model.graph.erase_node(q_node)
|
||||||
@ -904,7 +904,7 @@ def _lower_static_weighted_ref_functional(
|
|||||||
prepack_args[5], prepack_args[6] = prepack_args[6], prepack_args[5]
|
prepack_args[5], prepack_args[6] = prepack_args[6], prepack_args[5]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Lowering is not supported for op '{func_node.target}'")
|
raise ValueError(f"Lowering is not supported for op '{func_node.target}'")
|
||||||
with model.graph.inserting_before(output_scale_node):
|
with model.graph.inserting_before(output_scale_node): # type: ignore[arg-type]
|
||||||
# kwargs of the func node are needed for prepack op (i.e., quantized::linear_prepack)
|
# kwargs of the func node are needed for prepack op (i.e., quantized::linear_prepack)
|
||||||
# They are not needed for compute op (i.e., quantized::linear)
|
# They are not needed for compute op (i.e., quantized::linear)
|
||||||
kwargs = func_node.kwargs
|
kwargs = func_node.kwargs
|
||||||
@ -1105,7 +1105,7 @@ def _lower_quantized_binary_op(model: GraphModule, qconfig_map: Dict[str, QConfi
|
|||||||
dq_node = arg
|
dq_node = arg
|
||||||
assert isinstance(dq_node, Node)
|
assert isinstance(dq_node, Node)
|
||||||
dn_input = dq_node.args[0]
|
dn_input = dq_node.args[0]
|
||||||
bop_node.replace_input_with(dq_node, dn_input)
|
bop_node.replace_input_with(dq_node, dn_input) # type: ignore[arg-type]
|
||||||
num_dq_nodes += 1
|
num_dq_nodes += 1
|
||||||
assert num_dq_nodes > 0
|
assert num_dq_nodes > 0
|
||||||
|
|
||||||
|
@ -821,7 +821,7 @@ def _reroute_tuple_getitem_pattern(graph: Graph):
|
|||||||
last_getitem_index = last_getitem.args[1]
|
last_getitem_index = last_getitem.args[1]
|
||||||
new_input = first_tuple.args[0][last_getitem_index] # type: ignore[index]
|
new_input = first_tuple.args[0][last_getitem_index] # type: ignore[index]
|
||||||
for user in list(last_getitem.users.keys()):
|
for user in list(last_getitem.users.keys()):
|
||||||
user.replace_input_with(last_getitem, new_input)
|
user.replace_input_with(last_getitem, new_input) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
def _get_observer_from_activation_post_process(
|
def _get_observer_from_activation_post_process(
|
||||||
|
@ -46,8 +46,8 @@ def _maybe_duplicate_dq(
|
|||||||
|
|
||||||
new_args = map_arg(user.args, maybe_replace_node)
|
new_args = map_arg(user.args, maybe_replace_node)
|
||||||
new_kwargs = map_arg(user.kwargs, maybe_replace_node)
|
new_kwargs = map_arg(user.kwargs, maybe_replace_node)
|
||||||
user.args = new_args
|
user.args = new_args # type: ignore[assignment]
|
||||||
user.kwargs = new_kwargs
|
user.kwargs = new_kwargs # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
class DuplicateDQPass(PassBase):
|
class DuplicateDQPass(PassBase):
|
||||||
|
@ -107,8 +107,8 @@ def _find_q_dq_node_for_user(
|
|||||||
|
|
||||||
q_node = None
|
q_node = None
|
||||||
if (
|
if (
|
||||||
dq_node.args[0].op == "call_function"
|
dq_node.args[0].op == "call_function" # type: ignore[union-attr]
|
||||||
and dq_node.args[0].target in _QUANTIZE_OPS
|
and dq_node.args[0].target in _QUANTIZE_OPS # type: ignore[union-attr]
|
||||||
):
|
):
|
||||||
q_node = dq_node.args[0]
|
q_node = dq_node.args[0]
|
||||||
return (q_node, dq_node)
|
return (q_node, dq_node)
|
||||||
@ -362,7 +362,7 @@ def _get_aten_graph_module_for_pattern(
|
|||||||
[x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs]
|
[x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs]
|
||||||
)
|
)
|
||||||
aten_pattern = capture_pre_autograd_graph(
|
aten_pattern = capture_pre_autograd_graph(
|
||||||
pattern,
|
pattern, # type: ignore[arg-type]
|
||||||
example_inputs,
|
example_inputs,
|
||||||
kwargs,
|
kwargs,
|
||||||
)
|
)
|
||||||
@ -382,7 +382,7 @@ def _get_aten_graph_module_for_pattern(
|
|||||||
aten_pattern.graph.eliminate_dead_code()
|
aten_pattern.graph.eliminate_dead_code()
|
||||||
aten_pattern.recompile()
|
aten_pattern.recompile()
|
||||||
|
|
||||||
return aten_pattern
|
return aten_pattern # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
|
def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
|
||||||
|
@ -979,13 +979,13 @@ def _annotate_cat(
|
|||||||
inputs = cat_node.args[0]
|
inputs = cat_node.args[0]
|
||||||
|
|
||||||
input_qspec_map = {}
|
input_qspec_map = {}
|
||||||
input_act0 = inputs[0]
|
input_act0 = inputs[0] # type: ignore[index]
|
||||||
if isinstance(input_act0, Node):
|
if isinstance(input_act0, Node):
|
||||||
input_qspec_map[input_act0] = input_act_qspec
|
input_qspec_map[input_act0] = input_act_qspec
|
||||||
|
|
||||||
shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node))
|
shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node)) # type: ignore[arg-type]
|
||||||
for input_act in inputs[1:]:
|
for input_act in inputs[1:]: # type: ignore[index]
|
||||||
input_qspec_map[input_act] = shared_with_input0_qspec
|
input_qspec_map[input_act] = shared_with_input0_qspec # type: ignore[index]
|
||||||
|
|
||||||
output_act_qspec = shared_with_input0_qspec
|
output_act_qspec = shared_with_input0_qspec
|
||||||
|
|
||||||
|
@ -474,7 +474,7 @@ def _insert_reshard_gm(
|
|||||||
input_node: input_arg,
|
input_node: input_arg,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
node.replace_input_with(input_arg, output_node)
|
node.replace_input_with(input_arg, output_node) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
def _clean_up_graph_metadata(gm: torch.fx.GraphModule) -> None:
|
def _clean_up_graph_metadata(gm: torch.fx.GraphModule) -> None:
|
||||||
|
@ -213,7 +213,7 @@ def normalize_sizes(sizes: Union[Shape, Tuple[Shape]]) -> Shape:
|
|||||||
if isinstance(sizes[0], int):
|
if isinstance(sizes[0], int):
|
||||||
return cast(Shape, sizes)
|
return cast(Shape, sizes)
|
||||||
elif len(sizes) == 1:
|
elif len(sizes) == 1:
|
||||||
return cast(Shape, sizes[0]) # type: ignore[redundant-cast]
|
return sizes[0]
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Size must be int... or tuple")
|
raise RuntimeError("Size must be int... or tuple")
|
||||||
|
|
||||||
|
@ -96,11 +96,11 @@ class _ExecOrderTracer:
|
|||||||
self.exec_info = _ExecutionInfo(root_module)
|
self.exec_info = _ExecutionInfo(root_module)
|
||||||
orig_call_module = tracer.call_module
|
orig_call_module = tracer.call_module
|
||||||
orig_create_proxy = tracer.create_proxy
|
orig_create_proxy = tracer.create_proxy
|
||||||
tracer.call_module = functools.partial(
|
tracer.call_module = functools.partial( # type: ignore[method-assign]
|
||||||
self._patched_call_module, orig_call_module, self.exec_info
|
self._patched_call_module, orig_call_module, self.exec_info
|
||||||
)
|
)
|
||||||
fqn_to_param = dict(root_module.named_parameters())
|
fqn_to_param = dict(root_module.named_parameters())
|
||||||
tracer.create_proxy = functools.partial(
|
tracer.create_proxy = functools.partial( # type: ignore[method-assign]
|
||||||
self._patched_create_proxy,
|
self._patched_create_proxy,
|
||||||
orig_create_proxy,
|
orig_create_proxy,
|
||||||
self.exec_info,
|
self.exec_info,
|
||||||
@ -109,8 +109,8 @@ class _ExecOrderTracer:
|
|||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
tracer.call_module = orig_call_module
|
tracer.call_module = orig_call_module # type: ignore[method-assign]
|
||||||
tracer.create_proxy = orig_create_proxy
|
tracer.create_proxy = orig_create_proxy # type: ignore[method-assign]
|
||||||
|
|
||||||
def _patched_call_module(
|
def _patched_call_module(
|
||||||
self,
|
self,
|
||||||
@ -216,8 +216,8 @@ class _ExecOrderTracer:
|
|||||||
isinstance(arg, torch.fx.Proxy)
|
isinstance(arg, torch.fx.Proxy)
|
||||||
and arg.node.target in fqn_to_param
|
and arg.node.target in fqn_to_param
|
||||||
):
|
):
|
||||||
param = fqn_to_param[arg.node.target]
|
param = fqn_to_param[arg.node.target] # type: ignore[index]
|
||||||
named_params.append((arg.node.target, param))
|
named_params.append((arg.node.target, param)) # type: ignore[arg-type]
|
||||||
if param not in exec_info.visited_params:
|
if param not in exec_info.visited_params:
|
||||||
exec_info.visited_params.add(param)
|
exec_info.visited_params.add(param)
|
||||||
exec_info.param_forward_order.append(param)
|
exec_info.param_forward_order.append(param)
|
||||||
|
@ -214,7 +214,7 @@ def _insert_stage_symbolic_backward(
|
|||||||
input_nodes = list(node.all_input_nodes)
|
input_nodes = list(node.all_input_nodes)
|
||||||
grads_proxy = fx.Proxy(grads)
|
grads_proxy = fx.Proxy(grads)
|
||||||
for i, input_node in enumerate(input_nodes):
|
for i, input_node in enumerate(input_nodes):
|
||||||
assign_or_accumulate_grad(input_node, grads_proxy[i].node)
|
assign_or_accumulate_grad(input_node, grads_proxy[i].node) # type: ignore[index]
|
||||||
|
|
||||||
return g
|
return g
|
||||||
|
|
||||||
@ -416,15 +416,15 @@ class _LinearNodeList:
|
|||||||
def __init__(self, node_list):
|
def __init__(self, node_list):
|
||||||
self.serialize_node_list = []
|
self.serialize_node_list = []
|
||||||
for node in node_list:
|
for node in node_list:
|
||||||
node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name))
|
node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value]
|
||||||
node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name))
|
node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value]
|
||||||
serialize_node = fx.Node(
|
serialize_node = fx.Node(
|
||||||
graph=None,
|
graph=None, # type: ignore[arg-type]
|
||||||
name=node.name,
|
name=node.name,
|
||||||
op=node.op,
|
op=node.op,
|
||||||
target=node.target,
|
target=node.target,
|
||||||
args=node_args,
|
args=node_args, # type: ignore[arg-type]
|
||||||
kwargs=node_kwargs,
|
kwargs=node_kwargs, # type: ignore[arg-type]
|
||||||
return_type=node.type,
|
return_type=node.type,
|
||||||
)
|
)
|
||||||
serialize_node.meta = copy.copy(node.meta)
|
serialize_node.meta = copy.copy(node.meta)
|
||||||
@ -447,8 +447,8 @@ class _LinearNodeList:
|
|||||||
deser_node = graph.create_node(
|
deser_node = graph.create_node(
|
||||||
op=node.op,
|
op=node.op,
|
||||||
target=node.target,
|
target=node.target,
|
||||||
args=node_args,
|
args=node_args, # type: ignore[arg-type]
|
||||||
kwargs=node_kwargs,
|
kwargs=node_kwargs, # type: ignore[arg-type]
|
||||||
name=node.name,
|
name=node.name,
|
||||||
type_expr=node.type,
|
type_expr=node.type,
|
||||||
)
|
)
|
||||||
@ -731,7 +731,7 @@ class Pipe(torch.nn.Module):
|
|||||||
|
|
||||||
# TODO: what does split do with module invocations? does it move the modules
|
# TODO: what does split do with module invocations? does it move the modules
|
||||||
# into the submodules?
|
# into the submodules?
|
||||||
split = split_module(traced, mod, split_callback)
|
split = split_module(traced, mod, split_callback) # type: ignore[arg-type]
|
||||||
# a (custom) tracer can produce dead code like orphan get_attr nodes
|
# a (custom) tracer can produce dead code like orphan get_attr nodes
|
||||||
split.graph.eliminate_dead_code()
|
split.graph.eliminate_dead_code()
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ class TensorChunkSpec:
|
|||||||
"""
|
"""
|
||||||
args_chunk_spec = map_aggregate(
|
args_chunk_spec = map_aggregate(
|
||||||
chunk_dims,
|
chunk_dims,
|
||||||
lambda dim: TensorChunkSpec(dim),
|
lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
|
||||||
)
|
)
|
||||||
return args_chunk_spec
|
return args_chunk_spec
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ class TensorChunkSpec:
|
|||||||
"""
|
"""
|
||||||
kwargs_chunk_spec = map_aggregate(
|
kwargs_chunk_spec = map_aggregate(
|
||||||
chunk_dims,
|
chunk_dims,
|
||||||
lambda dim: TensorChunkSpec(dim),
|
lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
|
||||||
)
|
)
|
||||||
return kwargs_chunk_spec
|
return kwargs_chunk_spec
|
||||||
|
|
||||||
|
@ -681,7 +681,7 @@ def _export_to_aten_ir(
|
|||||||
if fake_mode:
|
if fake_mode:
|
||||||
insert_deferred_runtime_asserts(
|
insert_deferred_runtime_asserts(
|
||||||
gm,
|
gm,
|
||||||
fake_mode.shape_env,
|
fake_mode.shape_env, # type: ignore[arg-type]
|
||||||
f"exported program: {first_call_function_nn_module_stack(gm.graph)}",
|
f"exported program: {first_call_function_nn_module_stack(gm.graph)}",
|
||||||
export=True,
|
export=True,
|
||||||
)
|
)
|
||||||
|
@ -165,7 +165,7 @@ class MetaTracer(torch.fx.Tracer):
|
|||||||
meta_target = manual_meta_overrides.get(target, target)
|
meta_target = manual_meta_overrides.get(target, target)
|
||||||
meta_out = meta_target(*args_metas, **kwargs_metas)
|
meta_out = meta_target(*args_metas, **kwargs_metas)
|
||||||
elif kind == 'call_method':
|
elif kind == 'call_method':
|
||||||
meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas)
|
meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas) # type: ignore[index]
|
||||||
elif kind == 'call_module':
|
elif kind == 'call_module':
|
||||||
assert hasattr(self, 'orig_forward')
|
assert hasattr(self, 'orig_forward')
|
||||||
self._disable_module_getattr = True
|
self._disable_module_getattr = True
|
||||||
@ -173,7 +173,7 @@ class MetaTracer(torch.fx.Tracer):
|
|||||||
mod = self.root.get_submodule(target)
|
mod = self.root.get_submodule(target)
|
||||||
mod_type = type(mod)
|
mod_type = type(mod)
|
||||||
if mod_type in manual_meta_overrides:
|
if mod_type in manual_meta_overrides:
|
||||||
meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas)
|
meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas) # type: ignore[misc, arg-type]
|
||||||
else:
|
else:
|
||||||
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
|
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
|
||||||
finally:
|
finally:
|
||||||
@ -237,7 +237,7 @@ class MetaTracer(torch.fx.Tracer):
|
|||||||
def proxy(self, node):
|
def proxy(self, node):
|
||||||
return MetaProxy(node, self)
|
return MetaProxy(node, self)
|
||||||
|
|
||||||
def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None):
|
def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None): # type: ignore[override]
|
||||||
assert isinstance(meta_args, dict)
|
assert isinstance(meta_args, dict)
|
||||||
self.meta_args = meta_args
|
self.meta_args = meta_args
|
||||||
|
|
||||||
@ -263,7 +263,7 @@ def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]],
|
|||||||
meta_args : Optional[Dict[str, torch.Tensor]] = None,
|
meta_args : Optional[Dict[str, torch.Tensor]] = None,
|
||||||
concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule:
|
concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule:
|
||||||
tracer = MetaTracer()
|
tracer = MetaTracer()
|
||||||
graph = tracer.trace(root, meta_args, concrete_args)
|
graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type]
|
||||||
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||||
gm = torch.fx.GraphModule(tracer.root, graph, name)
|
gm = torch.fx.GraphModule(tracer.root, graph, name)
|
||||||
return gm
|
return gm
|
||||||
|
@ -990,7 +990,7 @@ class PythonKeyTracer(Tracer):
|
|||||||
torch_fn_counts: Dict[OpOverload, int]
|
torch_fn_counts: Dict[OpOverload, int]
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__(autowrap_modules=())
|
super().__init__(autowrap_modules=()) # type: ignore[arg-type]
|
||||||
self.tensor_tracker = WeakTensorKeyDictionary()
|
self.tensor_tracker = WeakTensorKeyDictionary()
|
||||||
self.symnode_tracker = _SymNodeDict()
|
self.symnode_tracker = _SymNodeDict()
|
||||||
self.script_object_tracker = WeakIdKeyDictionary(
|
self.script_object_tracker = WeakIdKeyDictionary(
|
||||||
@ -1036,7 +1036,7 @@ class PythonKeyTracer(Tracer):
|
|||||||
elif isinstance(a, py_sym_types):
|
elif isinstance(a, py_sym_types):
|
||||||
assert a.node.constant is not None
|
assert a.node.constant is not None
|
||||||
return a.node.constant
|
return a.node.constant
|
||||||
return super().create_arg(a)
|
return super().create_arg(a) # type: ignore[return-value]
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]:
|
def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]:
|
||||||
@ -1103,7 +1103,7 @@ def dispatch_trace(
|
|||||||
tracer: Tracer,
|
tracer: Tracer,
|
||||||
concrete_args: Optional[Tuple[Any, ...]] = None,
|
concrete_args: Optional[Tuple[Any, ...]] = None,
|
||||||
) -> GraphModule:
|
) -> GraphModule:
|
||||||
graph = tracer.trace(root, concrete_args)
|
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
|
||||||
|
|
||||||
# NB: be careful not to DCE .item() calls
|
# NB: be careful not to DCE .item() calls
|
||||||
def impure_pred(n: fx.Node) -> bool:
|
def impure_pred(n: fx.Node) -> bool:
|
||||||
@ -1228,7 +1228,7 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
|
|||||||
# It's for passing the export verifier which needs to verify the meta['val']
|
# It's for passing the export verifier which needs to verify the meta['val']
|
||||||
# TODO(tmanlaibaatar): we should systematically couple it with expoert verifier,
|
# TODO(tmanlaibaatar): we should systematically couple it with expoert verifier,
|
||||||
# instead of hardcoding it here.
|
# instead of hardcoding it here.
|
||||||
node = self.tracer.create_node("call_function", func, args, {})
|
node = self.tracer.create_node("call_function", func, args, {}) # type: ignore[arg-type]
|
||||||
if func is torch._C._set_grad_enabled:
|
if func is torch._C._set_grad_enabled:
|
||||||
node.meta["val"] = None
|
node.meta["val"] = None
|
||||||
return node
|
return node
|
||||||
@ -1323,7 +1323,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||||||
|
|
||||||
# func doesn't have a __torch_function__ that Proxy can interpose, so
|
# func doesn't have a __torch_function__ that Proxy can interpose, so
|
||||||
# we gotta do it manually
|
# we gotta do it manually
|
||||||
n_out = self.tracer.create_node("call_function", func, n_args, {})
|
n_out = self.tracer.create_node("call_function", func, n_args, {}) # type: ignore[arg-type]
|
||||||
p_out = fx.Proxy(n_out, self.tracer)
|
p_out = fx.Proxy(n_out, self.tracer)
|
||||||
set_meta(p_out, out)
|
set_meta(p_out, out)
|
||||||
return p_out
|
return p_out
|
||||||
@ -1382,7 +1382,7 @@ class DecompositionInterpreter(fx.Interpreter):
|
|||||||
decomposition_table: Optional[Mapping[OpOverload, Callable]] = None,
|
decomposition_table: Optional[Mapping[OpOverload, Callable]] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(module, **kwargs)
|
super().__init__(module, **kwargs) # type: ignore[arg-type]
|
||||||
self.new_graph = new_graph
|
self.new_graph = new_graph
|
||||||
self.tracer = _GraphAppendingTracerEx(self.new_graph)
|
self.tracer = _GraphAppendingTracerEx(self.new_graph)
|
||||||
# Blegh
|
# Blegh
|
||||||
@ -1399,18 +1399,18 @@ class DecompositionInterpreter(fx.Interpreter):
|
|||||||
self.tracer.torch_fn_counts = {}
|
self.tracer.torch_fn_counts = {}
|
||||||
|
|
||||||
def placeholder(
|
def placeholder(
|
||||||
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]
|
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override]
|
||||||
) -> object:
|
) -> object:
|
||||||
out = super().placeholder(target, args, kwargs)
|
out = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
|
||||||
proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer)
|
proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer)
|
||||||
track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
|
track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
|
||||||
# TODO handle case where the first character of target is '*'
|
# TODO handle case where the first character of target is '*'
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def get_attr(
|
def get_attr(
|
||||||
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]
|
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override]
|
||||||
) -> object:
|
) -> object:
|
||||||
out = super().get_attr(target, args, kwargs)
|
out = super().get_attr(target, args, kwargs) # type: ignore[arg-type]
|
||||||
proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer)
|
proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer)
|
||||||
track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
|
track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
|
||||||
return out
|
return out
|
||||||
@ -1418,9 +1418,9 @@ class DecompositionInterpreter(fx.Interpreter):
|
|||||||
# call_function, call_method, call_module get traced automatically by the outer mode.
|
# call_function, call_method, call_module get traced automatically by the outer mode.
|
||||||
|
|
||||||
def output(
|
def output(
|
||||||
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object]
|
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override]
|
||||||
) -> object:
|
) -> object:
|
||||||
out = super().output(target, args, kwargs)
|
out = super().output(target, args, kwargs) # type: ignore[arg-type]
|
||||||
|
|
||||||
def get_proxy_node(x: _ProxyTensor) -> fx.node.Node:
|
def get_proxy_node(x: _ProxyTensor) -> fx.node.Node:
|
||||||
return x.proxy.node
|
return x.proxy.node
|
||||||
@ -1435,7 +1435,7 @@ class DecompositionInterpreter(fx.Interpreter):
|
|||||||
# Should enter the mode at least once for being able to restore it later
|
# Should enter the mode at least once for being able to restore it later
|
||||||
# See: https://github.com/pytorch/pytorch/pull/82549#discussion_r934782025
|
# See: https://github.com/pytorch/pytorch/pull/82549#discussion_r934782025
|
||||||
with decompose(self.decomposition_table), self.mode:
|
with decompose(self.decomposition_table), self.mode:
|
||||||
return super().run(*args, **kwargs)
|
return super().run(*args, **kwargs) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
def wrapper_and_args_for_make_fx(
|
def wrapper_and_args_for_make_fx(
|
||||||
@ -1596,7 +1596,7 @@ class _ModuleStackTracer(PythonKeyTracer):
|
|||||||
self.attr_proxy_map[attr_val].reset_proxy_mapping(attr_val, attr)
|
self.attr_proxy_map[attr_val].reset_proxy_mapping(attr_val, attr)
|
||||||
return self.attr_proxy_map[attr_val]
|
return self.attr_proxy_map[attr_val]
|
||||||
|
|
||||||
def trace(
|
def trace( # type: ignore[override]
|
||||||
self, root: Union[Module, Callable], concrete_args: Optional[Dict[str, object]]
|
self, root: Union[Module, Callable], concrete_args: Optional[Dict[str, object]]
|
||||||
) -> fx.Graph:
|
) -> fx.Graph:
|
||||||
res = super().trace(root, concrete_args)
|
res = super().trace(root, concrete_args)
|
||||||
@ -1690,7 +1690,7 @@ class _ModuleStackTracer(PythonKeyTracer):
|
|||||||
Add torch_fn by looking at torch_fn_metadata and torch_fn_counts.
|
Add torch_fn by looking at torch_fn_metadata and torch_fn_counts.
|
||||||
Add stack_trace by filtering out forward() stack frames.
|
Add stack_trace by filtering out forward() stack frames.
|
||||||
"""
|
"""
|
||||||
node = super().create_node(*args, **kwargs)
|
node = super().create_node(*args, **kwargs) # type: ignore[arg-type]
|
||||||
|
|
||||||
# nn_module_stack
|
# nn_module_stack
|
||||||
if node.op not in ["placeholder", "output"]:
|
if node.op not in ["placeholder", "output"]:
|
||||||
|
@ -768,7 +768,7 @@ def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
|
|||||||
if isinstance(a, tuple):
|
if isinstance(a, tuple):
|
||||||
t = tuple(map_aggregate(elem, fn) for elem in a)
|
t = tuple(map_aggregate(elem, fn) for elem in a)
|
||||||
# Support NamedTuple (if it has `_fields`) by repacking into original type.
|
# Support NamedTuple (if it has `_fields`) by repacking into original type.
|
||||||
return t if not hasattr(a, '_fields') else type(a)(*t)
|
return t if not hasattr(a, '_fields') else type(a)(*t) # type: ignore[arg-type]
|
||||||
elif isinstance(a, list):
|
elif isinstance(a, list):
|
||||||
return immutable_list(map_aggregate(elem, fn) for elem in a)
|
return immutable_list(map_aggregate(elem, fn) for elem in a)
|
||||||
elif isinstance(a, dict):
|
elif isinstance(a, dict):
|
||||||
|
@ -251,8 +251,8 @@ if HAS_PYDOT:
|
|||||||
label += f"|target={self._typename(node.target)}" + r"\n"
|
label += f"|target={self._typename(node.target)}" + r"\n"
|
||||||
if self.normalize_args:
|
if self.normalize_args:
|
||||||
try:
|
try:
|
||||||
args, kwargs = normalize_function(
|
args, kwargs = normalize_function( # type: ignore[misc]
|
||||||
node.target, node.args, node.kwargs, normalize_to_only_use_kwargs=True
|
node.target, node.args, node.kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# Fallback to not normalizing if there's an exception.
|
# Fallback to not normalizing if there's an exception.
|
||||||
|
@ -299,7 +299,7 @@ class _MinimizerBase:
|
|||||||
|
|
||||||
# Find submodule containing colored nodes
|
# Find submodule containing colored nodes
|
||||||
submodule_name: str = ""
|
submodule_name: str = ""
|
||||||
for child_name, _ in split_module.named_children():
|
for child_name, _ in split_module.named_children(): # type: ignore[union-attr]
|
||||||
# Skip submodules we're not interested in at the moment
|
# Skip submodules we're not interested in at the moment
|
||||||
if "minimize" not in child_name:
|
if "minimize" not in child_name:
|
||||||
continue
|
continue
|
||||||
@ -316,7 +316,7 @@ class _MinimizerBase:
|
|||||||
f"Minimize submodule was not found with nodes {nodes}"
|
f"Minimize submodule was not found with nodes {nodes}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return split_module, submodule_name
|
return split_module, submodule_name # type: ignore[return-value]
|
||||||
|
|
||||||
def _run_and_compare(
|
def _run_and_compare(
|
||||||
self,
|
self,
|
||||||
@ -388,10 +388,10 @@ class _MinimizerBase:
|
|||||||
report.append(f"Result mismatch for {result_key}")
|
report.append(f"Result mismatch for {result_key}")
|
||||||
if self.module_exporter:
|
if self.module_exporter:
|
||||||
self.module_exporter(
|
self.module_exporter(
|
||||||
a_input, submodule, str(result_key[0]) + "_cpu",
|
a_input, submodule, str(result_key[0]) + "_cpu", # type: ignore[index]
|
||||||
)
|
)
|
||||||
self.module_exporter(
|
self.module_exporter(
|
||||||
b_input, submodule, str(result_key[0]) + "_acc",
|
b_input, submodule, str(result_key[0]) + "_acc", # type: ignore[index]
|
||||||
)
|
)
|
||||||
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}")
|
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}")
|
||||||
|
|
||||||
|
@ -163,14 +163,14 @@ def split_module(
|
|||||||
)
|
)
|
||||||
if keep_original_node_name:
|
if keep_original_node_name:
|
||||||
args = () if default_value is inspect.Signature.empty else (default_value,)
|
args = () if default_value is inspect.Signature.empty else (default_value,)
|
||||||
base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type)
|
base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type) # type: ignore[arg-type]
|
||||||
else:
|
else:
|
||||||
base_mod_env[node.name] = base_mod_graph.placeholder(
|
base_mod_env[node.name] = base_mod_graph.placeholder(
|
||||||
node.target, type_expr=node.type, default_value=default_value
|
node.target, type_expr=node.type, default_value=default_value # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
base_mod_env[node.name].meta = node.meta.copy()
|
base_mod_env[node.name].meta = node.meta.copy()
|
||||||
elif node.op == "get_attr":
|
elif node.op == "get_attr":
|
||||||
base_mod_env[node.name] = base_mod_graph.get_attr(node.target)
|
base_mod_env[node.name] = base_mod_graph.get_attr(node.target) # type: ignore[arg-type]
|
||||||
base_mod_env[node.name].meta = node.meta.copy()
|
base_mod_env[node.name].meta = node.meta.copy()
|
||||||
attr_val = m
|
attr_val = m
|
||||||
for atom in node.target.split("."): # type: ignore[union-attr]
|
for atom in node.target.split("."): # type: ignore[union-attr]
|
||||||
|
@ -869,7 +869,7 @@ class _SplitterBase:
|
|||||||
for node in self.module.graph.nodes:
|
for node in self.module.graph.nodes:
|
||||||
if hasattr(node, "tag"):
|
if hasattr(node, "tag"):
|
||||||
del node.tag
|
del node.tag
|
||||||
return split_module
|
return split_module # type: ignore[return-value]
|
||||||
|
|
||||||
def __call__(self) -> torch.fx.GraphModule:
|
def __call__(self) -> torch.fx.GraphModule:
|
||||||
subgraphs = self.put_nodes_into_subgraphs()
|
subgraphs = self.put_nodes_into_subgraphs()
|
||||||
|
@ -32,8 +32,8 @@ def _split_to_graph_and_name_node_map(
|
|||||||
name_node_map, Dict
|
name_node_map, Dict
|
||||||
), "Expecting the input graph to have a dict output as the last element"
|
), "Expecting the input graph to have a dict output as the last element"
|
||||||
n.args = (flattened,)
|
n.args = (flattened,)
|
||||||
orig_pytree_info = gm._graph._codegen.pytree_info
|
orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined]
|
||||||
gm._graph._codegen.pytree_info = _PyTreeInfo(
|
gm._graph._codegen.pytree_info = _PyTreeInfo( # type: ignore[attr-defined]
|
||||||
orig_pytree_info.orig_args, orig_pytree_info.in_spec, out_spec
|
orig_pytree_info.orig_args, orig_pytree_info.in_spec, out_spec
|
||||||
)
|
)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
@ -319,8 +319,8 @@ def _replace_pattern(
|
|||||||
|
|
||||||
# Hook the output Node of the replacement subgraph in to the
|
# Hook the output Node of the replacement subgraph in to the
|
||||||
# original Graph at the correct location
|
# original Graph at the correct location
|
||||||
assert len(match.returning_nodes) == len(copied_returning_nodes)
|
assert len(match.returning_nodes) == len(copied_returning_nodes) # type: ignore[arg-type]
|
||||||
for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes):
|
for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes): # type: ignore[arg-type]
|
||||||
gn.replace_all_uses_with(copied_node)
|
gn.replace_all_uses_with(copied_node)
|
||||||
match_changed_node[gn] = copied_node
|
match_changed_node[gn] = copied_node
|
||||||
# Remove the original nodes
|
# Remove the original nodes
|
||||||
|
@ -324,7 +324,7 @@ def jagged_torch_function(func, *args, **kwargs):
|
|||||||
def _flatten_sig(input, start_dim=0, end_dim=-1):
|
def _flatten_sig(input, start_dim=0, end_dim=-1):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
_flatten_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
_flatten_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -404,7 +404,7 @@ def tensor_attr_unsupported_getter(func, *args, **kwargs):
|
|||||||
def is_contiguous_general(func, *args, **kwargs):
|
def is_contiguous_general(func, *args, **kwargs):
|
||||||
from torch._prims_common import is_contiguous_for_memory_format
|
from torch._prims_common import is_contiguous_for_memory_format
|
||||||
|
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
inp = new_kwargs.pop("input")
|
inp = new_kwargs.pop("input")
|
||||||
@ -459,7 +459,7 @@ def clone_default(func, *args, **kwargs):
|
|||||||
|
|
||||||
@register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?")
|
@register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?")
|
||||||
def linear_default(func, *args, **kwargs):
|
def linear_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -473,7 +473,7 @@ def linear_default(func, *args, **kwargs):
|
|||||||
"self: jt, grad_output: jt, weight: t, output_mask: any",
|
"self: jt, grad_output: jt, weight: t, output_mask: any",
|
||||||
)
|
)
|
||||||
def linear_backward_default(func, *args, **kwargs):
|
def linear_backward_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -505,7 +505,7 @@ def to_dtype(func, *args, **kwargs):
|
|||||||
def to_copy_default(func, *args, **kwargs):
|
def to_copy_default(func, *args, **kwargs):
|
||||||
from .nested_tensor import _tensor_symint_registry
|
from .nested_tensor import _tensor_symint_registry
|
||||||
|
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -539,7 +539,7 @@ def to_copy_default(func, *args, **kwargs):
|
|||||||
torch.ops.aten.copy_.default, "self: jt_all, src: jt_all, non_blocking: any?"
|
torch.ops.aten.copy_.default, "self: jt_all, src: jt_all, non_blocking: any?"
|
||||||
)
|
)
|
||||||
def copy_default(func, *args, **kwargs):
|
def copy_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
inp = new_kwargs.pop("input")
|
inp = new_kwargs.pop("input")
|
||||||
@ -567,7 +567,7 @@ register_jagged_func(torch.ops.aten.detach.default, "self: jt_all")(
|
|||||||
"self: jt_all",
|
"self: jt_all",
|
||||||
)
|
)
|
||||||
def like_factory_default(func, *args, **kwargs):
|
def like_factory_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -583,7 +583,7 @@ def like_factory_default(func, *args, **kwargs):
|
|||||||
|
|
||||||
@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all")
|
@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all")
|
||||||
def zero__default(func, *args, **kwargs):
|
def zero__default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -596,7 +596,7 @@ def zero__default(func, *args, **kwargs):
|
|||||||
torch.ops.aten._softmax.default, "self: jt_all, dim: any, half_to_float: any"
|
torch.ops.aten._softmax.default, "self: jt_all, dim: any, half_to_float: any"
|
||||||
)
|
)
|
||||||
def _softmax_default(func, *args, **kwargs):
|
def _softmax_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -672,7 +672,7 @@ def _softmax_default(func, *args, **kwargs):
|
|||||||
"grad_output: jt, output: jt, dim: any, input_dtype: any",
|
"grad_output: jt, output: jt, dim: any, input_dtype: any",
|
||||||
)
|
)
|
||||||
def _softmax_backward(func, *args, **kwargs):
|
def _softmax_backward(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
grad_out = new_kwargs.pop("grad_output")
|
grad_out = new_kwargs.pop("grad_output")
|
||||||
@ -686,7 +686,7 @@ def _softmax_backward(func, *args, **kwargs):
|
|||||||
torch.ops.aten.native_dropout.default, "self: jt, float: any, train: any?"
|
torch.ops.aten.native_dropout.default, "self: jt, float: any, train: any?"
|
||||||
)
|
)
|
||||||
def native_dropout_default(func, *args, **kwargs):
|
def native_dropout_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -703,7 +703,7 @@ def native_dropout_default(func, *args, **kwargs):
|
|||||||
"grad_output: jt, mask: jt, scale: any",
|
"grad_output: jt, mask: jt, scale: any",
|
||||||
)
|
)
|
||||||
def native_dropout_backward_default(func, *args, **kwargs):
|
def native_dropout_backward_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
grad_output = new_kwargs.pop("grad_output")
|
grad_output = new_kwargs.pop("grad_output")
|
||||||
@ -716,7 +716,7 @@ def native_dropout_backward_default(func, *args, **kwargs):
|
|||||||
|
|
||||||
@register_jagged_func(torch.ops.aten.prod.dim_int, "self: jt, dim: any, keepdim: any?")
|
@register_jagged_func(torch.ops.aten.prod.dim_int, "self: jt, dim: any, keepdim: any?")
|
||||||
def prod_dim_int(func, *args, **kwargs):
|
def prod_dim_int(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -735,7 +735,7 @@ def prod_dim_int(func, *args, **kwargs):
|
|||||||
torch.ops.aten.split.Tensor, "self: jt, split_size: any, dim: any"
|
torch.ops.aten.split.Tensor, "self: jt, split_size: any, dim: any"
|
||||||
)
|
)
|
||||||
def split_tensor(func, *args, **kwargs):
|
def split_tensor(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -753,7 +753,7 @@ def split_tensor(func, *args, **kwargs):
|
|||||||
torch.ops.aten.split_with_sizes.default, "self: jt, split_sizes: any, dim: any"
|
torch.ops.aten.split_with_sizes.default, "self: jt, split_sizes: any, dim: any"
|
||||||
)
|
)
|
||||||
def split_with_sizes_default(func, *args, **kwargs):
|
def split_with_sizes_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -773,7 +773,7 @@ def split_with_sizes_default(func, *args, **kwargs):
|
|||||||
torch.ops.aten.narrow.default, "self: jt, dim: any, start: any, length: any"
|
torch.ops.aten.narrow.default, "self: jt, dim: any, start: any, length: any"
|
||||||
)
|
)
|
||||||
def narrow(func, *args, **kwargs):
|
def narrow(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
inp = new_kwargs.pop("input")
|
inp = new_kwargs.pop("input")
|
||||||
@ -790,7 +790,7 @@ def narrow(func, *args, **kwargs):
|
|||||||
|
|
||||||
@register_jagged_func(torch.ops.aten.chunk.default, "self: jt, chunks: any, dim: any?")
|
@register_jagged_func(torch.ops.aten.chunk.default, "self: jt, chunks: any, dim: any?")
|
||||||
def chunk_default(func, *args, **kwargs):
|
def chunk_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -833,7 +833,7 @@ def chunk_default(func, *args, **kwargs):
|
|||||||
@register_jagged_func(torch.ops.aten.unbind.int, "self: jt_all, dim: any?")
|
@register_jagged_func(torch.ops.aten.unbind.int, "self: jt_all, dim: any?")
|
||||||
def unbind_int(func, *args, **kwargs):
|
def unbind_int(func, *args, **kwargs):
|
||||||
# Note that this specializes on the length of the offsets
|
# Note that this specializes on the length of the offsets
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -867,7 +867,7 @@ def unbind_int(func, *args, **kwargs):
|
|||||||
|
|
||||||
@register_jagged_func(torch.ops.aten.squeeze.dim, "self: jt, dim: any")
|
@register_jagged_func(torch.ops.aten.squeeze.dim, "self: jt, dim: any")
|
||||||
def squeeze_dim(func, *args, **kwargs):
|
def squeeze_dim(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -880,7 +880,7 @@ def squeeze_dim(func, *args, **kwargs):
|
|||||||
|
|
||||||
@register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt, dim: any")
|
@register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt, dim: any")
|
||||||
def unsqueeze_default(func, *args, **kwargs):
|
def unsqueeze_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -895,7 +895,7 @@ def unsqueeze_default(func, *args, **kwargs):
|
|||||||
|
|
||||||
@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any")
|
@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any")
|
||||||
def cat_default(func, *args, **kwargs):
|
def cat_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -918,7 +918,7 @@ def cat_default(func, *args, **kwargs):
|
|||||||
|
|
||||||
@register_jagged_func(torch.ops.aten.matmul.default, "self: jt, other: any")
|
@register_jagged_func(torch.ops.aten.matmul.default, "self: jt, other: any")
|
||||||
def matmul_default(func, *args, **kwargs):
|
def matmul_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -943,7 +943,7 @@ def matmul_default(func, *args, **kwargs):
|
|||||||
torch.ops.aten.expand.default, "self: jt, size: any, implicit: any?"
|
torch.ops.aten.expand.default, "self: jt, size: any, implicit: any?"
|
||||||
)
|
)
|
||||||
def expand_default(func, *args, **kwargs):
|
def expand_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -960,7 +960,7 @@ def expand_default(func, *args, **kwargs):
|
|||||||
|
|
||||||
@register_jagged_func(torch.ops.aten.expand_as.default, "self: t, other: jt")
|
@register_jagged_func(torch.ops.aten.expand_as.default, "self: t, other: jt")
|
||||||
def expand_as_default(func, *args, **kwargs):
|
def expand_as_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -972,7 +972,7 @@ def expand_as_default(func, *args, **kwargs):
|
|||||||
|
|
||||||
@register_jagged_func(torch.ops.aten.where.self, "condition: jt, self: jt, other: jt")
|
@register_jagged_func(torch.ops.aten.where.self, "condition: jt, self: jt, other: jt")
|
||||||
def where_self(func, *args, **kwargs):
|
def where_self(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -990,7 +990,7 @@ def where_self(func, *args, **kwargs):
|
|||||||
|
|
||||||
@register_jagged_func(torch.ops.aten._pin_memory.default, "self: jt, device: any?")
|
@register_jagged_func(torch.ops.aten._pin_memory.default, "self: jt, device: any?")
|
||||||
def _pin_memory_default(func, *args, **kwargs):
|
def _pin_memory_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1001,7 +1001,7 @@ def _pin_memory_default(func, *args, **kwargs):
|
|||||||
|
|
||||||
@register_jagged_func(torch.ops.aten.is_pinned.default, "self: jt, device: any?")
|
@register_jagged_func(torch.ops.aten.is_pinned.default, "self: jt, device: any?")
|
||||||
def is_pinned_default(func, *args, **kwargs):
|
def is_pinned_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1026,7 +1026,7 @@ def sum_dim_IntList(func, *args, **kwargs):
|
|||||||
Performs a sum along the provided tensor dimension.
|
Performs a sum along the provided tensor dimension.
|
||||||
Returns a dense tensor if the ragged dimension is reduced away, else returns a nested tensor.
|
Returns a dense tensor if the ragged dimension is reduced away, else returns a nested tensor.
|
||||||
"""
|
"""
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
inp = new_kwargs.pop("input")
|
inp = new_kwargs.pop("input")
|
||||||
@ -1117,7 +1117,7 @@ def sum_dim_IntList(func, *args, **kwargs):
|
|||||||
torch.ops.aten.transpose.int, "self: jt_all, dim0: any, dim1: any"
|
torch.ops.aten.transpose.int, "self: jt_all, dim0: any, dim1: any"
|
||||||
)
|
)
|
||||||
def transpose_int(func, *args, **kwargs):
|
def transpose_int(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1164,7 +1164,7 @@ def transpose_int(func, *args, **kwargs):
|
|||||||
"self: jt_all, size: any",
|
"self: jt_all, size: any",
|
||||||
)
|
)
|
||||||
def view_default(func, *args, **kwargs):
|
def view_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1212,7 +1212,7 @@ def view_default(func, *args, **kwargs):
|
|||||||
"input: jt_all, normalized_shape: any, weight: any?, bias: any?, eps: any",
|
"input: jt_all, normalized_shape: any, weight: any?, bias: any?, eps: any",
|
||||||
)
|
)
|
||||||
def native_layer_norm_default(func, *args, **kwargs):
|
def native_layer_norm_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1313,7 +1313,7 @@ def native_layer_norm_default(func, *args, **kwargs):
|
|||||||
"grad_out: jt, input: jt, normalized_shape: any, mean: any, rstd: any, weight: any?, bias: any?, output_mask: any",
|
"grad_out: jt, input: jt, normalized_shape: any, mean: any, rstd: any, weight: any?, bias: any?, output_mask: any",
|
||||||
)
|
)
|
||||||
def native_layer_norm_backward_default(func, *args, **kwargs):
|
def native_layer_norm_backward_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
grad_out = new_kwargs.pop("grad_out")
|
grad_out = new_kwargs.pop("grad_out")
|
||||||
@ -1327,7 +1327,7 @@ def native_layer_norm_backward_default(func, *args, **kwargs):
|
|||||||
|
|
||||||
@register_jagged_func(torch.ops.aten.select.int, "self: jt, dim: any, index: any")
|
@register_jagged_func(torch.ops.aten.select.int, "self: jt, dim: any, index: any")
|
||||||
def select_int(func, *args, **kwargs):
|
def select_int(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1349,7 +1349,7 @@ def select_int(func, *args, **kwargs):
|
|||||||
"self: jt, dim: any?, start: any?, end: any?, step: any?",
|
"self: jt, dim: any?, start: any?, end: any?, step: any?",
|
||||||
)
|
)
|
||||||
def slice_tensor(func, *args, **kwargs):
|
def slice_tensor(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1365,7 +1365,7 @@ def slice_tensor(func, *args, **kwargs):
|
|||||||
"dilation: any, transposed: any, output_padding: any, groups: any",
|
"dilation: any, transposed: any, output_padding: any, groups: any",
|
||||||
)
|
)
|
||||||
def convolution_default(func, *args, **kwargs):
|
def convolution_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1382,7 +1382,7 @@ def mean_dim(func, *args, **kwargs):
|
|||||||
Performs a mean along the provided tensor dimension.
|
Performs a mean along the provided tensor dimension.
|
||||||
Returns a dense tensor if the ragged dimension is reduced away, else returns a nested tensor.
|
Returns a dense tensor if the ragged dimension is reduced away, else returns a nested tensor.
|
||||||
"""
|
"""
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1438,7 +1438,7 @@ def mean_dim(func, *args, **kwargs):
|
|||||||
|
|
||||||
@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any")
|
@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any")
|
||||||
def stack_default(func, *args, **kwargs):
|
def stack_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1472,7 +1472,7 @@ def stack_default(func, *args, **kwargs):
|
|||||||
"weight: t, indices: jt, padding_idx: any?, scale_grad_by_freq: any?, sparse: any?",
|
"weight: t, indices: jt, padding_idx: any?, scale_grad_by_freq: any?, sparse: any?",
|
||||||
)
|
)
|
||||||
def embedding_default(func, *args, **kwargs):
|
def embedding_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1493,7 +1493,7 @@ def embedding_default(func, *args, **kwargs):
|
|||||||
"self: jt_all",
|
"self: jt_all",
|
||||||
)
|
)
|
||||||
def values_default(func, *args, **kwargs):
|
def values_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1520,7 +1520,7 @@ def all_default(func, *args, **kwargs):
|
|||||||
"values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?",
|
"values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?",
|
||||||
)
|
)
|
||||||
def _nested_view_from_jagged_default(func, *args, **kwargs):
|
def _nested_view_from_jagged_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1549,7 +1549,7 @@ def _nested_view_from_jagged_default(func, *args, **kwargs):
|
|||||||
|
|
||||||
@register_jagged_func(torch.ops.aten._nested_get_offsets.default, "self: jt_all")
|
@register_jagged_func(torch.ops.aten._nested_get_offsets.default, "self: jt_all")
|
||||||
def _nested_get_offsets(func, *args, **kwargs):
|
def _nested_get_offsets(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1559,7 +1559,7 @@ def _nested_get_offsets(func, *args, **kwargs):
|
|||||||
|
|
||||||
@register_jagged_func(torch.ops.aten._nested_get_lengths.default, "self: jt_all")
|
@register_jagged_func(torch.ops.aten._nested_get_lengths.default, "self: jt_all")
|
||||||
def _nested_get_lengths(func, *args, **kwargs):
|
def _nested_get_lengths(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1569,7 +1569,7 @@ def _nested_get_lengths(func, *args, **kwargs):
|
|||||||
|
|
||||||
@register_jagged_func(torch.ops.aten._nested_get_ragged_idx.default, "self: jt_all")
|
@register_jagged_func(torch.ops.aten._nested_get_ragged_idx.default, "self: jt_all")
|
||||||
def _nested_get_ragged_idx(func, *args, **kwargs):
|
def _nested_get_ragged_idx(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1579,7 +1579,7 @@ def _nested_get_ragged_idx(func, *args, **kwargs):
|
|||||||
|
|
||||||
@register_jagged_func(torch.ops.aten._nested_get_min_seqlen.default, "self: jt_all")
|
@register_jagged_func(torch.ops.aten._nested_get_min_seqlen.default, "self: jt_all")
|
||||||
def _nested_get_min_seqlen(func, *args, **kwargs):
|
def _nested_get_min_seqlen(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1589,7 +1589,7 @@ def _nested_get_min_seqlen(func, *args, **kwargs):
|
|||||||
|
|
||||||
@register_jagged_func(torch.ops.aten._nested_get_max_seqlen.default, "self: jt_all")
|
@register_jagged_func(torch.ops.aten._nested_get_max_seqlen.default, "self: jt_all")
|
||||||
def _nested_get_max_seqlen(func, *args, **kwargs):
|
def _nested_get_max_seqlen(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1600,7 +1600,7 @@ def _nested_get_max_seqlen(func, *args, **kwargs):
|
|||||||
# If a section of the Nested Tensor is fully masked out we still retain the section with a length of 0
|
# If a section of the Nested Tensor is fully masked out we still retain the section with a length of 0
|
||||||
@register_jagged_func(torch.ops.aten.masked_select.default, "self: jt, mask: any")
|
@register_jagged_func(torch.ops.aten.masked_select.default, "self: jt, mask: any")
|
||||||
def masked_select_default(func, *args, **kwargs):
|
def masked_select_default(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
)
|
)
|
||||||
inp = new_kwargs.pop("input")
|
inp = new_kwargs.pop("input")
|
||||||
|
@ -242,7 +242,7 @@ def _fx_args_to_torch_args(
|
|||||||
if isinstance(arg, torch.fx.Node):
|
if isinstance(arg, torch.fx.Node):
|
||||||
fake_tensor = arg.meta.get("val")
|
fake_tensor = arg.meta.get("val")
|
||||||
if fake_tensor is None and arg.op == "get_attr":
|
if fake_tensor is None and arg.op == "get_attr":
|
||||||
fake_tensor = getattr(fx_graph_module, arg.target) # type: ignore[operator]
|
fake_tensor = getattr(fx_graph_module, arg.target) # type: ignore[operator, arg-type]
|
||||||
# NOTE: Currently, we are aware of
|
# NOTE: Currently, we are aware of
|
||||||
# FakeTensor/Tensor/SymInt/SymFloat/Symbool/int/float/bool could be in
|
# FakeTensor/Tensor/SymInt/SymFloat/Symbool/int/float/bool could be in
|
||||||
# arg.meta["val"]/get_attr.
|
# arg.meta["val"]/get_attr.
|
||||||
@ -253,8 +253,8 @@ def _fx_args_to_torch_args(
|
|||||||
wrapped_args.append(real_tensor)
|
wrapped_args.append(real_tensor)
|
||||||
elif isinstance(fake_tensor, (int, float, bool)):
|
elif isinstance(fake_tensor, (int, float, bool)):
|
||||||
wrapped_args.append(fake_tensor)
|
wrapped_args.append(fake_tensor)
|
||||||
elif symbolic_shapes.has_hint(fake_tensor):
|
elif symbolic_shapes.has_hint(fake_tensor): # type: ignore[arg-type]
|
||||||
wrapped_args.append(symbolic_shapes.hint_int(fake_tensor))
|
wrapped_args.append(symbolic_shapes.hint_int(fake_tensor)) # type: ignore[arg-type]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unexpected input argument type found inside fx.Node. arg: {arg}; "
|
f"Unexpected input argument type found inside fx.Node. arg: {arg}; "
|
||||||
|
@ -615,7 +615,7 @@ def return_and_correct_aliasing(func, args, kwargs, out):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_arg_from_alias(output_alias, schema_info, args, kwargs):
|
def get_arg_from_alias(output_alias, schema_info, args, kwargs):
|
||||||
new_args, new_kwargs = torch.fx.operator_schemas.normalize_function(
|
new_args, new_kwargs = torch.fx.operator_schemas.normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs
|
func, args=args, kwargs=kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user