Update on " [WIP] export"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
This commit is contained in:
Mu-Chu Lee
2025-10-18 00:03:45 -07:00
parent 9c0d445930
commit 93382b3969
4 changed files with 10 additions and 8 deletions

View File

@ -3858,6 +3858,8 @@ class DualWrapperCodegen(CodeGen):
""" """
def __init__(self, original_wrapper_code, autotuning_wrapper_code): def __init__(self, original_wrapper_code, autotuning_wrapper_code):
from ..scheduler import BaseScheduling # noqa: TC001
super().__init__() super().__init__()
self.original_wrapper_code = original_wrapper_code self.original_wrapper_code = original_wrapper_code
self.original_backends: dict[torch.device, BaseScheduling] = {} self.original_backends: dict[torch.device, BaseScheduling] = {}

View File

@ -1393,7 +1393,7 @@ class _InProcessFxCompile(FxCompile):
is_backward=is_backward, is_backward=is_backward,
is_const_graph=True, is_const_graph=True,
fx_wrapper=fx_wrapper, fx_wrapper=fx_wrapper,
use_dual_wrapper=use_dual_wrapper, use_dual_wrapper=use_dual_wrapper, # type: ignore[arg-type]
) )
with ( with (
V.set_graph_handler(const_graph), V.set_graph_handler(const_graph),
@ -1428,7 +1428,7 @@ class _InProcessFxCompile(FxCompile):
const_module=const_graph, const_module=const_graph,
inputs_to_check=inputs_to_check, inputs_to_check=inputs_to_check,
fx_wrapper=fx_wrapper, fx_wrapper=fx_wrapper,
use_dual_wrapper=use_dual_wrapper, use_dual_wrapper=use_dual_wrapper, # type: ignore[arg-type]
) )
metrics_helper = metrics.CachedMetricsHelper() metrics_helper = metrics.CachedMetricsHelper()

View File

@ -2130,15 +2130,15 @@ class GraphLowering(torch.fx.Interpreter):
dual_wrapper = DualWrapperCodegen( dual_wrapper = DualWrapperCodegen(
original_wrapper_code, self.autotuning_wrapper_code original_wrapper_code, self.autotuning_wrapper_code
) )
self.wrapper_code = dual_wrapper self.wrapper_code = dual_wrapper # type: ignore[assignment]
if self.const_module: if self.const_module:
if hasattr(self.wrapper_code, "original_wrapper_code"): if hasattr(self.wrapper_code, "original_wrapper_code"):
# DualWrapperCodegen case # DualWrapperCodegen case
self.wrapper_code.original_wrapper_code._names_iter = ( self.wrapper_code.original_wrapper_code._names_iter = ( # type: ignore[attr-defined]
self.const_module.wrapper_code._names_iter self.const_module.wrapper_code._names_iter
) )
self.wrapper_code.autotuning_wrapper_code._names_iter = ( self.wrapper_code.autotuning_wrapper_code._names_iter = ( # type: ignore[attr-defined]
self.const_module.wrapper_code._names_iter self.const_module.wrapper_code._names_iter
) )
else: else:
@ -2392,8 +2392,8 @@ class GraphLowering(torch.fx.Interpreter):
if self.use_dual_wrapper: if self.use_dual_wrapper:
# If we're doing full graph autotuning, we need to generate the autotuning wrapper code # If we're doing full graph autotuning, we need to generate the autotuning wrapper code
# and the autotuning kernels # and the autotuning kernels
original_wrapper_code = self.wrapper_code.original_wrapper_code original_wrapper_code = self.wrapper_code.original_wrapper_code # type: ignore[attr-defined]
autotuning_wrapper_code = self.wrapper_code.autotuning_wrapper_code autotuning_wrapper_code = self.wrapper_code.autotuning_wrapper_code # type: ignore[attr-defined]
original_cpp_wrapper = self.cpp_wrapper original_cpp_wrapper = self.cpp_wrapper
self.wrapper_code = autotuning_wrapper_code self.wrapper_code = autotuning_wrapper_code

View File

@ -5474,7 +5474,7 @@ class Scheduler:
available_buffer_names = OrderedSet(self.available_buffer_names) available_buffer_names = OrderedSet(self.available_buffer_names)
completed_operations = OrderedSet(self.completed_operations) completed_operations = OrderedSet(self.completed_operations)
def wrap_codegen_node(w, *args, **kwargs): def wrap_codegen_node(w, *args, **kwargs): # type: ignore[no-untyped-def]
self.enter_context(node) self.enter_context(node)
self.current_node = node self.current_node = node