remove allow-untyped-defs from torch/_higher_order_ops/run_const_graph.py (#157847)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157847
Approved by: https://github.com/Skylion007, https://github.com/zou3519
This commit is contained in:
Bob Ren
2025-07-08 22:25:13 -07:00
committed by PyTorch MergeBot
parent 5221448574
commit 066bf29334

View File

@ -1,18 +1,24 @@
# mypy: allow-untyped-defs
from typing import Any, TYPE_CHECKING
import torch
from torch._C import DispatchKey
from torch._higher_order_ops.utils import autograd_not_implemented
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
if TYPE_CHECKING:
from torch._subclasses.functional_tensor import BaseFunctionalizeAPI
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
from torch.utils import _pytree as pytree
class RunConstGraph(HigherOrderOperator):
def __init__(self):
def __init__(self) -> None:
super().__init__("run_const_graph")
def __call__(self, graph, args):
def __call__(self, graph: torch.fx.GraphModule, args: tuple[object, ...]) -> object:
return super().__call__(graph, args)
@ -20,12 +26,14 @@ run_const_graph = RunConstGraph()
@run_const_graph.py_impl(ProxyTorchDispatchMode)
def run_const_graph_dispatch_mode(mode, graph, args):
def run_const_graph_dispatch_mode(
mode: ProxyTorchDispatchMode, graph: torch.fx.GraphModule, args: tuple[object, ...]
) -> object:
const_gm, weights = graph, args
p_args = pytree.tree_map(mode.tracer.unwrap_proxy, (graph, args))
p_args = pytree.tree_map(mode.tracer.unwrap_proxy, (graph, args)) # type: ignore[union-attr]
assert isinstance(const_gm, torch.fx.GraphModule)
assert not hasattr(mode.tracer.root, "_const_graph")
mode.tracer.root.register_module("_const_graph", const_gm)
assert not hasattr(mode.tracer.root, "_const_graph") # type: ignore[union-attr]
mode.tracer.root.register_module("_const_graph", const_gm) # type: ignore[union-attr]
proxy = mode.tracer.create_proxy("call_function", run_const_graph, p_args, {})
@ -34,12 +42,14 @@ def run_const_graph_dispatch_mode(mode, graph, args):
@run_const_graph.py_functionalize_impl
def run_const_graph_functional(ctx, graph, args):
def run_const_graph_functional(
ctx: "BaseFunctionalizeAPI", graph: torch.fx.GraphModule, args: tuple[Any, ...]
) -> Any:
unwrapped_args = ctx.unwrap_tensors(args)
with ctx.redispatch_to_next():
out = run_const_graph(*unwrapped_args)
return ctx.wrap_tensors(out)
out = run_const_graph(graph, unwrapped_args)
return ctx.wrap_tensors(out) # type: ignore[arg-type]
run_const_graph.py_autograd_impl(
@ -48,13 +58,17 @@ run_const_graph.py_autograd_impl(
@run_const_graph.py_impl(FakeTensorMode)
def run_const_graph_fake_tensor_mode(mode, graph, args):
def run_const_graph_fake_tensor_mode(
mode: FakeTensorMode, graph: torch.fx.GraphModule, args: tuple[object, ...]
) -> object:
assert isinstance(graph, torch.fx.GraphModule)
with mode:
return graph(*args)
@run_const_graph.py_impl(DispatchKey.CPU)
def run_const_graph_cpu(graph, args):
def run_const_graph_cpu(
graph: torch.fx.GraphModule, args: tuple[object, ...]
) -> object:
assert isinstance(graph, torch.fx.GraphModule)
return graph(*args)