mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update on " [WIP] export"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng amjames chauhang aakhundov coconutruben [ghstack-poisoned]
This commit is contained in:
@ -3858,6 +3858,8 @@ class DualWrapperCodegen(CodeGen):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, original_wrapper_code, autotuning_wrapper_code):
|
def __init__(self, original_wrapper_code, autotuning_wrapper_code):
|
||||||
|
from ..scheduler import BaseScheduling # noqa: TC001
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.original_wrapper_code = original_wrapper_code
|
self.original_wrapper_code = original_wrapper_code
|
||||||
self.original_backends: dict[torch.device, BaseScheduling] = {}
|
self.original_backends: dict[torch.device, BaseScheduling] = {}
|
||||||
|
@ -1393,7 +1393,7 @@ class _InProcessFxCompile(FxCompile):
|
|||||||
is_backward=is_backward,
|
is_backward=is_backward,
|
||||||
is_const_graph=True,
|
is_const_graph=True,
|
||||||
fx_wrapper=fx_wrapper,
|
fx_wrapper=fx_wrapper,
|
||||||
use_dual_wrapper=use_dual_wrapper,
|
use_dual_wrapper=use_dual_wrapper, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
with (
|
with (
|
||||||
V.set_graph_handler(const_graph),
|
V.set_graph_handler(const_graph),
|
||||||
@ -1428,7 +1428,7 @@ class _InProcessFxCompile(FxCompile):
|
|||||||
const_module=const_graph,
|
const_module=const_graph,
|
||||||
inputs_to_check=inputs_to_check,
|
inputs_to_check=inputs_to_check,
|
||||||
fx_wrapper=fx_wrapper,
|
fx_wrapper=fx_wrapper,
|
||||||
use_dual_wrapper=use_dual_wrapper,
|
use_dual_wrapper=use_dual_wrapper, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
metrics_helper = metrics.CachedMetricsHelper()
|
metrics_helper = metrics.CachedMetricsHelper()
|
||||||
|
|
||||||
|
@ -2130,15 +2130,15 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
dual_wrapper = DualWrapperCodegen(
|
dual_wrapper = DualWrapperCodegen(
|
||||||
original_wrapper_code, self.autotuning_wrapper_code
|
original_wrapper_code, self.autotuning_wrapper_code
|
||||||
)
|
)
|
||||||
self.wrapper_code = dual_wrapper
|
self.wrapper_code = dual_wrapper # type: ignore[assignment]
|
||||||
|
|
||||||
if self.const_module:
|
if self.const_module:
|
||||||
if hasattr(self.wrapper_code, "original_wrapper_code"):
|
if hasattr(self.wrapper_code, "original_wrapper_code"):
|
||||||
# DualWrapperCodegen case
|
# DualWrapperCodegen case
|
||||||
self.wrapper_code.original_wrapper_code._names_iter = (
|
self.wrapper_code.original_wrapper_code._names_iter = ( # type: ignore[attr-defined]
|
||||||
self.const_module.wrapper_code._names_iter
|
self.const_module.wrapper_code._names_iter
|
||||||
)
|
)
|
||||||
self.wrapper_code.autotuning_wrapper_code._names_iter = (
|
self.wrapper_code.autotuning_wrapper_code._names_iter = ( # type: ignore[attr-defined]
|
||||||
self.const_module.wrapper_code._names_iter
|
self.const_module.wrapper_code._names_iter
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -2392,8 +2392,8 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
if self.use_dual_wrapper:
|
if self.use_dual_wrapper:
|
||||||
# If we're doing full graph autotuning, we need to generate the autotuning wrapper code
|
# If we're doing full graph autotuning, we need to generate the autotuning wrapper code
|
||||||
# and the autotuning kernels
|
# and the autotuning kernels
|
||||||
original_wrapper_code = self.wrapper_code.original_wrapper_code
|
original_wrapper_code = self.wrapper_code.original_wrapper_code # type: ignore[attr-defined]
|
||||||
autotuning_wrapper_code = self.wrapper_code.autotuning_wrapper_code
|
autotuning_wrapper_code = self.wrapper_code.autotuning_wrapper_code # type: ignore[attr-defined]
|
||||||
|
|
||||||
original_cpp_wrapper = self.cpp_wrapper
|
original_cpp_wrapper = self.cpp_wrapper
|
||||||
self.wrapper_code = autotuning_wrapper_code
|
self.wrapper_code = autotuning_wrapper_code
|
||||||
|
@ -5474,7 +5474,7 @@ class Scheduler:
|
|||||||
available_buffer_names = OrderedSet(self.available_buffer_names)
|
available_buffer_names = OrderedSet(self.available_buffer_names)
|
||||||
completed_operations = OrderedSet(self.completed_operations)
|
completed_operations = OrderedSet(self.completed_operations)
|
||||||
|
|
||||||
def wrap_codegen_node(w, *args, **kwargs):
|
def wrap_codegen_node(w, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||||
self.enter_context(node)
|
self.enter_context(node)
|
||||||
self.current_node = node
|
self.current_node = node
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user