diff --git a/torch/_higher_order_ops/run_const_graph.py b/torch/_higher_order_ops/run_const_graph.py index 7f24e7c589f4..ed7c5278f5fe 100644 --- a/torch/_higher_order_ops/run_const_graph.py +++ b/torch/_higher_order_ops/run_const_graph.py @@ -1,18 +1,24 @@ -# mypy: allow-untyped-defs +from typing import Any, TYPE_CHECKING + import torch from torch._C import DispatchKey from torch._higher_order_ops.utils import autograd_not_implemented from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode + + +if TYPE_CHECKING: + from torch._subclasses.functional_tensor import BaseFunctionalizeAPI + from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree from torch.utils import _pytree as pytree class RunConstGraph(HigherOrderOperator): - def __init__(self): + def __init__(self) -> None: super().__init__("run_const_graph") - def __call__(self, graph, args): + def __call__(self, graph: torch.fx.GraphModule, args: tuple[object, ...]) -> object: return super().__call__(graph, args) @@ -20,12 +26,14 @@ run_const_graph = RunConstGraph() @run_const_graph.py_impl(ProxyTorchDispatchMode) -def run_const_graph_dispatch_mode(mode, graph, args): +def run_const_graph_dispatch_mode( + mode: ProxyTorchDispatchMode, graph: torch.fx.GraphModule, args: tuple[object, ...] +) -> object: const_gm, weights = graph, args - p_args = pytree.tree_map(mode.tracer.unwrap_proxy, (graph, args)) + p_args = pytree.tree_map(mode.tracer.unwrap_proxy, (graph, args)) # type: ignore[union-attr] assert isinstance(const_gm, torch.fx.GraphModule) - assert not hasattr(mode.tracer.root, "_const_graph") - mode.tracer.root.register_module("_const_graph", const_gm) + assert not hasattr(mode.tracer.root, "_const_graph") # type: ignore[union-attr] + mode.tracer.root.register_module("_const_graph", const_gm) # type: ignore[union-attr] proxy = mode.tracer.create_proxy("call_function", run_const_graph, p_args, {}) @@ -34,12 +42,14 @@ def run_const_graph_dispatch_mode(mode, graph, args): @run_const_graph.py_functionalize_impl -def run_const_graph_functional(ctx, graph, args): +def run_const_graph_functional( + ctx: "BaseFunctionalizeAPI", graph: torch.fx.GraphModule, args: tuple[Any, ...] +) -> Any: unwrapped_args = ctx.unwrap_tensors(args) with ctx.redispatch_to_next(): - out = run_const_graph(*unwrapped_args) - return ctx.wrap_tensors(out) + out = run_const_graph(graph, unwrapped_args) + return ctx.wrap_tensors(out) # type: ignore[arg-type] run_const_graph.py_autograd_impl( @@ -48,13 +58,17 @@ run_const_graph.py_autograd_impl( @run_const_graph.py_impl(FakeTensorMode) -def run_const_graph_fake_tensor_mode(mode, graph, args): +def run_const_graph_fake_tensor_mode( + mode: FakeTensorMode, graph: torch.fx.GraphModule, args: tuple[object, ...] +) -> object: assert isinstance(graph, torch.fx.GraphModule) with mode: return graph(*args) @run_const_graph.py_impl(DispatchKey.CPU) -def run_const_graph_cpu(graph, args): +def run_const_graph_cpu( + graph: torch.fx.GraphModule, args: tuple[object, ...] +) -> object: assert isinstance(graph, torch.fx.GraphModule) return graph(*args)