mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update on " [WIP] export"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng amjames chauhang aakhundov coconutruben [ghstack-poisoned]
This commit is contained in:
@ -3858,6 +3858,8 @@ class DualWrapperCodegen(CodeGen):
|
||||
"""
|
||||
|
||||
def __init__(self, original_wrapper_code, autotuning_wrapper_code):
|
||||
from ..scheduler import BaseScheduling # noqa: TC001
|
||||
|
||||
super().__init__()
|
||||
self.original_wrapper_code = original_wrapper_code
|
||||
self.original_backends: dict[torch.device, BaseScheduling] = {}
|
||||
|
@ -1393,7 +1393,7 @@ class _InProcessFxCompile(FxCompile):
|
||||
is_backward=is_backward,
|
||||
is_const_graph=True,
|
||||
fx_wrapper=fx_wrapper,
|
||||
use_dual_wrapper=use_dual_wrapper,
|
||||
use_dual_wrapper=use_dual_wrapper, # type: ignore[arg-type]
|
||||
)
|
||||
with (
|
||||
V.set_graph_handler(const_graph),
|
||||
@ -1428,7 +1428,7 @@ class _InProcessFxCompile(FxCompile):
|
||||
const_module=const_graph,
|
||||
inputs_to_check=inputs_to_check,
|
||||
fx_wrapper=fx_wrapper,
|
||||
use_dual_wrapper=use_dual_wrapper,
|
||||
use_dual_wrapper=use_dual_wrapper, # type: ignore[arg-type]
|
||||
)
|
||||
metrics_helper = metrics.CachedMetricsHelper()
|
||||
|
||||
|
@ -2130,15 +2130,15 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
dual_wrapper = DualWrapperCodegen(
|
||||
original_wrapper_code, self.autotuning_wrapper_code
|
||||
)
|
||||
self.wrapper_code = dual_wrapper
|
||||
self.wrapper_code = dual_wrapper # type: ignore[assignment]
|
||||
|
||||
if self.const_module:
|
||||
if hasattr(self.wrapper_code, "original_wrapper_code"):
|
||||
# DualWrapperCodegen case
|
||||
self.wrapper_code.original_wrapper_code._names_iter = (
|
||||
self.wrapper_code.original_wrapper_code._names_iter = ( # type: ignore[attr-defined]
|
||||
self.const_module.wrapper_code._names_iter
|
||||
)
|
||||
self.wrapper_code.autotuning_wrapper_code._names_iter = (
|
||||
self.wrapper_code.autotuning_wrapper_code._names_iter = ( # type: ignore[attr-defined]
|
||||
self.const_module.wrapper_code._names_iter
|
||||
)
|
||||
else:
|
||||
@ -2392,8 +2392,8 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
if self.use_dual_wrapper:
|
||||
# If we're doing full graph autotuning, we need to generate the autotuning wrapper code
|
||||
# and the autotuning kernels
|
||||
original_wrapper_code = self.wrapper_code.original_wrapper_code
|
||||
autotuning_wrapper_code = self.wrapper_code.autotuning_wrapper_code
|
||||
original_wrapper_code = self.wrapper_code.original_wrapper_code # type: ignore[attr-defined]
|
||||
autotuning_wrapper_code = self.wrapper_code.autotuning_wrapper_code # type: ignore[attr-defined]
|
||||
|
||||
original_cpp_wrapper = self.cpp_wrapper
|
||||
self.wrapper_code = autotuning_wrapper_code
|
||||
|
@ -5474,7 +5474,7 @@ class Scheduler:
|
||||
available_buffer_names = OrderedSet(self.available_buffer_names)
|
||||
completed_operations = OrderedSet(self.completed_operations)
|
||||
|
||||
def wrap_codegen_node(w, *args, **kwargs):
|
||||
def wrap_codegen_node(w, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
self.enter_context(node)
|
||||
self.current_node = node
|
||||
|
||||
|
Reference in New Issue
Block a user