[WIP] export

ghstack-source-id: be382af7fe97c76c15abf8ed1ce0b43c8fe03568
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165823
This commit is contained in:
Mu-Chu Lee
2025-10-18 00:03:45 -07:00
parent aaac8cb0f5
commit ba9c2f3e13
6 changed files with 601 additions and 288 deletions

View File

@ -518,9 +518,11 @@ def init_backend_registration() -> None:
"cpu",
lambda scheduling: cpu_backends[config.cpu_backend](scheduling),
PythonWrapperCodegen,
CppWrapperCpuArrayRef
if config.aot_inductor.allow_stack_allocation
else CppWrapperCpu,
(
CppWrapperCpuArrayRef
if config.aot_inductor.allow_stack_allocation
else CppWrapperCpu
),
WrapperFxCodegen,
)

View File

@ -12,6 +12,7 @@ import operator
import random
import re
import tempfile
from enum import auto, Enum
from itertools import chain, count
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
@ -181,6 +182,7 @@ def user_defined_kernel_grid_fn_code(
)
)
if config.triton.autotune_at_compile_time
and not config.triton.autotune_full_graph
else None
),
)
@ -190,6 +192,7 @@ def user_defined_kernel_grid_fn_code(
if (
wrapper
and config.triton.autotune_at_compile_time
and not config.triton.autotune_full_graph
and name not in wrapper.kernel_autotune_names
):
wrapper.kernel_autotune_calls.writeline(example_grid or line)
@ -198,12 +201,15 @@ def user_defined_kernel_grid_fn_code(
writeline(f"def {fn_name}(meta):")
kernel_autotune_calls_indent = (
wrapper.kernel_autotune_calls.indent()
if wrapper and config.triton.autotune_at_compile_time
if wrapper
and config.triton.autotune_at_compile_time
and not config.triton.autotune_full_graph
else contextlib.nullcontext()
)
with output.indent(), kernel_autotune_calls_indent:
if (
config.triton.autotune_at_compile_time
and not config.triton.autotune_full_graph
and original_fxnode_name
and V.graph.autotuning_grids
and original_fxnode_name in V.graph.autotuning_grids
@ -1080,7 +1086,10 @@ class PythonWrapperCodegen(CodeGen):
@functools.cache
def add_import_once(line: str) -> None:
self.imports.writeline(line)
if config.triton.autotune_at_compile_time:
if (
config.triton.autotune_at_compile_time
and not config.triton.autotune_full_graph
):
self.kernel_autotune_calls.writeline(line)
self.add_import_once = add_import_once
@ -1225,7 +1234,10 @@ class PythonWrapperCodegen(CodeGen):
import triton.language as tl
from {triton_heuristics.__name__} import start_graph, end_graph
"""
if config.triton.autotune_at_compile_time:
if (
config.triton.autotune_at_compile_time
and not config.triton.autotune_full_graph
):
self.kernel_autotune_calls.splice(import_str)
self.kernel_autotune_calls.writeline(
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
@ -1240,7 +1252,10 @@ class PythonWrapperCodegen(CodeGen):
import_get_raw_stream_str = V.graph.device_ops.import_get_raw_stream_as(
"get_raw_stream"
)
if config.triton.autotune_at_compile_time:
if (
config.triton.autotune_at_compile_time
and not config.triton.autotune_full_graph
):
if not self.kernel_autotune_calls.contains(import_get_raw_stream_str):
self.kernel_autotune_calls.writeline(import_get_raw_stream_str)
if not V.graph.cpp_wrapper:
@ -1398,9 +1413,10 @@ class PythonWrapperCodegen(CodeGen):
self.write_get_raw_stream_header()
name = f"stream{device_idx}"
if config.triton.autotune_at_compile_time:
self.kernel_autotune_calls.writeline(
f"{name} = get_raw_stream({device_idx})"
)
if not config.triton.autotune_full_graph:
self.kernel_autotune_calls.writeline(
f"{name} = get_raw_stream({device_idx})"
)
if V.graph.cpp_wrapper:
# For cpp wrapper, no need to continue codegen for the main body
return name
@ -1431,7 +1447,10 @@ class PythonWrapperCodegen(CodeGen):
self.writeline(
EnterDeviceContextManagerLine(device_idx, self.last_seen_device_guard_index)
)
if config.triton.autotune_at_compile_time:
if (
config.triton.autotune_at_compile_time
and not config.triton.autotune_full_graph
):
# mimic logic of EnterDeviceContextManagerLine.codegen for the autotune code block
self.write_triton_header_once()
self.kernel_autotune_calls.writeline(
@ -1448,7 +1467,10 @@ class PythonWrapperCodegen(CodeGen):
def codegen_device_guard_exit(self) -> None:
self.writeline(ExitDeviceContextManagerLine())
if config.triton.autotune_at_compile_time:
if (
config.triton.autotune_at_compile_time
and not config.triton.autotune_full_graph
):
self.kernel_autotune_calls.do_unindent()
def generate_return(self, output_refs: list[str]) -> None:
@ -1672,7 +1694,10 @@ class PythonWrapperCodegen(CodeGen):
def _write_multi_kernel_defs(self) -> None:
kernel_defs = self.multi_kernel_state.kernel_defs
if config.triton.autotune_at_compile_time:
if (
config.triton.autotune_at_compile_time
and not config.triton.autotune_full_graph
):
self.kernel_autotune_defs.splice(kernel_defs)
else:
self.header.splice(kernel_defs)
@ -1690,8 +1715,12 @@ class PythonWrapperCodegen(CodeGen):
self.run_wrapper_ir_passes(is_inference)
if config.triton.store_cubin and not config.triton.autotune_at_compile_time:
self.generate_reset_kernel_saved_flags()
if config.triton.store_cubin:
if (
not config.triton.autotune_at_compile_time
or config.triton.autotune_full_graph
):
self.generate_reset_kernel_saved_flags()
# At this point, we shouldn't generate any new memory planning lines.
# Override writeline to point at the wrapper call, in case it gets called.
@ -1713,10 +1742,17 @@ class PythonWrapperCodegen(CodeGen):
if config.profile_bandwidth:
self.generate_end_graph()
if config.triton.store_cubin and not config.triton.autotune_at_compile_time:
self.generate_save_uncompiled_kernels()
if config.triton.store_cubin:
if (
not config.triton.autotune_at_compile_time
or config.triton.autotune_full_graph
):
self.generate_save_uncompiled_kernels()
if config.triton.autotune_at_compile_time:
if (
config.triton.autotune_at_compile_time
and not config.triton.autotune_full_graph
):
self.generate_and_run_autotune_block()
# cpp_wrapper currently doesn't support nvtx
@ -1976,17 +2012,20 @@ class PythonWrapperCodegen(CodeGen):
def codegen_alloc_from_pool(
self, name, offset, dtype, shape, stride
) -> tuple[str, list[str]]:
return "alloc_from_pool({})".format(
", ".join(
[
name,
pexpr(offset), # bytes not numel
str(dtype),
self.codegen_python_shape_tuple(shape),
self.codegen_python_shape_tuple(stride),
]
)
), []
return (
"alloc_from_pool({})".format(
", ".join(
[
name,
pexpr(offset), # bytes not numel
str(dtype),
self.codegen_python_shape_tuple(shape),
self.codegen_python_shape_tuple(stride),
]
)
),
[],
)
def codegen_reinterpret_view(
self,
@ -2214,7 +2253,11 @@ class PythonWrapperCodegen(CodeGen):
def _format_kernel_definition(
kernel_name: str, kernel_body: str, metadata: Optional[str] = None
):
if config.triton.autotune_at_compile_time and metadata:
if (
config.triton.autotune_at_compile_time
and not config.triton.autotune_full_graph
and metadata
):
# Generating autotune block
# Need to replace C++ comment starter with Python comment starter
metadata = re.sub(r"^// ", "# ", metadata, flags=re.MULTILINE)
@ -2231,10 +2274,11 @@ class PythonWrapperCodegen(CodeGen):
cpp_definition: Optional[str] = None,
):
if config.triton.autotune_at_compile_time:
body = self._format_kernel_definition(
kernel_name, kernel_body, metadata=metadata
)
self.kernel_autotune_defs.splice(body)
if not config.triton.autotune_full_graph:
body = self._format_kernel_definition(
kernel_name, kernel_body, metadata=metadata
)
self.kernel_autotune_defs.splice(body)
if V.graph.cpp_wrapper:
# For cpp wrapper, no need to continue codegen for the main body
return
@ -2564,7 +2608,10 @@ class PythonWrapperCodegen(CodeGen):
else:
raise AssertionError(ws.zero_mode)
if config.triton.autotune_at_compile_time:
if (
config.triton.autotune_at_compile_time
and not config.triton.autotune_full_graph
):
self.kernel_autotune_calls.writeline(
PythonWrapperCodegen.make_allocation(
self,
@ -2771,7 +2818,6 @@ class PythonWrapperCodegen(CodeGen):
if isinstance(arg, str)
}
)
device = device or V.graph.get_current_device_or_throw()
self.writeline(
KernelCallLine(
@ -2839,134 +2885,141 @@ class PythonWrapperCodegen(CodeGen):
config.triton.autotune_at_compile_time
and kernel_name not in self.kernel_autotune_names
):
# Create example args for autotune in a separate epilogue
assert arg_types is not None and len(call_args) == len(arg_types), (
"call_args and arg_types do not match"
)
autotune_args = None
if original_fxnode_name and V.graph.autotuning_mapping:
autotune_args = V.graph.autotuning_mapping.get(
original_fxnode_name, None
if not config.triton.autotune_full_graph:
# Create example args for autotune in a separate epilogue
assert arg_types is not None and len(call_args) == len(arg_types), (
"call_args and arg_types do not match"
)
def get_autotune_deletion_call() -> str:
"""After all the autotune kernel calls have been written (i.e.
self.kernel_autotune_example_args is complete), returns a deletion call
for all autotune example tensors that are unnecessary after kernel_name
is called."""
tensors_to_delete = [
tensor
for tensor, kn in self.kernel_autotune_example_args.values()
if kn == kernel_name
]
if tensors_to_delete:
return f"del {', '.join(tensors_to_delete)}\n"
return ""
autotune_args = None
if original_fxnode_name and V.graph.autotuning_mapping:
autotune_args = V.graph.autotuning_mapping.get(
original_fxnode_name, None
)
def infer_arg_by_inputs(raw_keys, raw_args, idx, reused_args):
"""We try to infer raw_arg (i.e. raw_args[idx]) from remaining raw_args.
This is particularly useful for jagged cases, where the dimension is often
being passed in as an input."""
def get_autotune_deletion_call() -> str:
"""After all the autotune kernel calls have been written (i.e.
self.kernel_autotune_example_args is complete), returns a deletion call
for all autotune example tensors that are unnecessary after kernel_name
is called."""
tensors_to_delete = [
tensor
for tensor, kn in self.kernel_autotune_example_args.values()
if kn == kernel_name
]
if tensors_to_delete:
return f"del {', '.join(tensors_to_delete)}\n"
return ""
target_arg = raw_args[idx]
if target_arg in reused_args:
return True
def infer_arg_by_inputs(raw_keys, raw_args, idx, reused_args):
"""We try to infer raw_arg (i.e. raw_args[idx]) from remaining raw_args.
This is particularly useful for jagged cases, where the dimension is often
being passed in as an input."""
for i, (raw_key, raw_arg) in enumerate(zip(raw_keys, raw_args)):
if i == idx or not isinstance(raw_arg, IRNode):
continue
target_arg = raw_args[idx]
if target_arg in reused_args:
return True
triton_input = ""
for i, (raw_key, raw_arg) in enumerate(zip(raw_keys, raw_args)):
if i == idx or not isinstance(raw_arg, IRNode):
continue
triton_input = ""
if autotune_args and raw_key in autotune_args:
triton_input = self.get_autotuning_input_name( # type: ignore[attr-defined]
autotune_args[raw_key]
)
if triton_input == "":
continue
try:
layout = raw_arg.get_layout()
for dim, s in enumerate(layout.size):
if s == target_arg:
reused_args[target_arg] = (
f"{triton_input}.shape[{dim}]"
)
return True
except NotImplementedError:
# If layout for this IRNode is not implemented, we could just skip.
# Only raise for other Error cases.
continue
return False
all_args = []
if raw_args is None:
# create a dummy raw_args for uniform behavior in the following loop
assert raw_keys is None, "keys are not None but args are"
raw_keys = [None] * len(call_args)
raw_args = [None] * len(call_args)
else:
assert len(raw_args) == len(call_args), (
"call_args and raw_args do not match"
)
reused_args = {}
for i, (arg, arg_type, raw_key, raw_arg) in enumerate(
# pyrefly: ignore # no-matching-overload
zip(call_args, arg_types, raw_keys, raw_args)
):
key = None
if isinstance(arg, str) and "=" in str(arg):
# arg may be passed in a kwarg style, and then we need to extract its value
key, arg = arg.split("=")
triton_input: Optional[str] = None
if autotune_args and raw_key in autotune_args:
triton_input = self.get_autotuning_input_name( # type: ignore[attr-defined]
autotune_args[raw_key]
)
if triton_input == "":
continue
try:
layout = raw_arg.get_layout()
for dim, s in enumerate(layout.size):
if s == target_arg:
reused_args[target_arg] = f"{triton_input}.shape[{dim}]"
return True
except NotImplementedError:
# If layout for this IRNode is not implemented, we could just skip.
# Only raise for other Error cases.
continue
return False
all_args = []
if raw_args is None:
# create a dummy raw_args for uniform behavior in the following loop
assert raw_keys is None, "keys are not None but args are"
raw_keys = [None] * len(call_args)
raw_args = [None] * len(call_args)
else:
assert len(raw_args) == len(call_args), (
"call_args and raw_args do not match"
)
reused_args = {}
for i, (arg, arg_type, raw_key, raw_arg) in enumerate(
# pyrefly: ignore # no-matching-overload
zip(call_args, arg_types, raw_keys, raw_args)
):
key = None
if isinstance(arg, str) and "=" in str(arg):
# arg may be passed in a kwarg style, and then we need to extract its value
key, arg = arg.split("=")
triton_input: Optional[str] = None
if autotune_args and raw_key in autotune_args:
triton_input = self.get_autotuning_input_name( # type: ignore[attr-defined]
autotune_args[raw_key]
)
if triton_input:
arg_str = triton_input
if not isinstance(arg_type, torch_dtype) and (
issubclass(arg_type, sympy.Basic)
or isinstance(arg, SymbolicCallArg)
if triton_input:
arg_str = triton_input
if not isinstance(arg_type, torch_dtype) and (
issubclass(arg_type, sympy.Basic)
or isinstance(arg, SymbolicCallArg)
):
reused_args[raw_arg] = arg_str
elif raw_key == "" and infer_arg_by_inputs(
raw_keys, raw_args, i, reused_args
):
reused_args[raw_arg] = arg_str
elif raw_key == "" and infer_arg_by_inputs(
raw_keys, raw_args, i, reused_args
):
# Empty raw_key means this is a arg that's not native to the triton kernel,
# and is being added by inductor.
arg_str = reused_args[raw_arg]
elif isinstance(arg_type, torch_dtype):
# workspace allocation is already generated by `generate_workspace_allocation()`
# in `TritonKernel.call_kernel()`.
if re.match(r"^(workspace|semaphore)", arg):
arg_str = arg
elif arg not in self.kernel_autotune_example_args:
# Empty raw_key means this is a arg that's not native to the triton kernel,
# and is being added by inductor.
arg_str = reused_args[raw_arg]
elif isinstance(arg_type, torch_dtype):
# workspace allocation is already generated by `generate_workspace_allocation()`
# in `TritonKernel.call_kernel()`.
if re.match(r"^(workspace|semaphore)", arg):
arg_str = arg
elif arg not in self.kernel_autotune_example_args:
arg_str = self.generate_example_arg_value(
arg, arg_type, raw_arg
)
else:
arg_str = self.kernel_autotune_example_args[arg][0]
self.kernel_autotune_example_args[arg] = (arg_str, kernel_name)
else:
arg_str = self.generate_example_arg_value(
arg, arg_type, raw_arg
)
else:
arg_str = self.kernel_autotune_example_args[arg][0]
self.kernel_autotune_example_args[arg] = (arg_str, kernel_name)
else:
arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg)
all_args.append(arg_str if key is None else f"{key}={arg_str}")
all_args.append(arg_str if key is None else f"{key}={arg_str}")
# Make sure kernel launch under a device guard because models don't always run on device 0
self.kernel_autotune_calls.writeline(
f"with {V.graph.device_ops.device_guard(device.index)}:"
)
self.kernel_autotune_calls.do_indent()
self.kernel_autotune_calls.writeline(
f"{kernel_name}.run({', '.join(all_args)}, stream={stream_name})"
)
self.kernel_autotune_calls.do_unindent()
# Make sure kernel launch under a device guard because models don't always run on device 0
self.kernel_autotune_calls.writeline(
f"with {V.graph.device_ops.device_guard(device.index)}:"
)
self.kernel_autotune_calls.do_indent()
self.kernel_autotune_calls.writeline(
f"{kernel_name}.run({', '.join(all_args)}, stream={stream_name})"
)
self.kernel_autotune_calls.do_unindent()
self.kernel_autotune_calls.writeline(
DelayReplaceLine("<del_call>", get_autotune_deletion_call, "<del_call>")
)
self.kernel_autotune_names.add(kernel_name)
self.kernel_autotune_calls.writeline(
DelayReplaceLine(
"<del_call>", get_autotune_deletion_call, "<del_call>"
)
)
self.kernel_autotune_names.add(kernel_name)
if V.graph.cpp_wrapper:
# For cpp wrapper, no need to continue codegen for the main body
return
@ -3317,9 +3370,12 @@ class PythonWrapperCodegen(CodeGen):
# In this case, we strip the first key path away.
return go(
outputs[0].get_name(),
keypath[1:]
if isinstance(out, ir.MultiOutput) and len(out.indices) != 0
else keypath,
(
keypath[1:]
if isinstance(out, ir.MultiOutput)
and len(out.indices) != 0
else keypath
),
)
else:
assert isinstance(keypath[0], pytree.SequenceKey)
@ -3787,3 +3843,137 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
# V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
# )
self.parent_wrapper.write_get_raw_stream_header_once()
class DualWrapperState(Enum):
DUAL = auto()
ORIGINAL = auto()
AUTOTUNING = auto()
class DualWrapperCodegen(CodeGen):
"""
A wrapper class that contains two wrapper_code instances and delegates method calls to both.
This allows generating code for both wrappers simultaneously.
"""
def __init__(self, original_wrapper_code, autotuning_wrapper_code):
from ..scheduler import BaseScheduling # noqa: TC001
super().__init__()
self.original_wrapper_code = original_wrapper_code
self.original_backends: dict[torch.device, BaseScheduling] = {}
self.autotuning_wrapper_code = autotuning_wrapper_code
self.autotuning_backends: dict[torch.device, BaseScheduling] = {}
self.state = DualWrapperState.DUAL
# Store original states.
self.removed_operations: OrderedSet[str] = OrderedSet()
self.removed_buffers: OrderedSet[str] = OrderedSet()
self.removed_inplace_buffers: OrderedSet[str] = OrderedSet()
self.mutated_buffers: OrderedSet[str] = OrderedSet()
self.never_reuse_buffers: OrderedSet[str] = OrderedSet()
self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
def __getattr__(self, name):
"""
Default handler for any method call not explicitly implemented.
Initially raises NotImplementedError, but can be overridden to delegate to both wrappers.
"""
if hasattr(self.original_wrapper_code, name) and hasattr(
self.autotuning_wrapper_code, name
):
attr1 = getattr(self.original_wrapper_code, name)
attr2 = getattr(self.autotuning_wrapper_code, name)
# Check if both attributes are callable (methods/functions)
if callable(attr1) and callable(attr2):
def dual_method(*args, **kwargs):
tmp_wrapper_code = V.graph.wrapper_code
tmp_cpp_wrapper = V.graph.cpp_wrapper
# Call the method on both wrappers
V.graph.wrapper_code = self.original_wrapper_code
V.graph.scheduler.backends = self.original_backends
self.state = DualWrapperState.ORIGINAL
result1 = attr1(*args, **kwargs)
V.graph.wrapper_code = self.autotuning_wrapper_code
V.graph.scheduler.backends = self.autotuning_backends
V.graph.cpp_wrapper = False
self.state = DualWrapperState.AUTOTUNING
result2 = attr2(*args, **kwargs)
# Restore to original wrapper_code.
V.graph.wrapper_code = tmp_wrapper_code
V.graph.cpp_wrapper = tmp_cpp_wrapper
self.state = DualWrapperState.DUAL
# Check if results are the same, otherwise raise an error
if result1 == result2:
return result1
else:
raise RuntimeError(
f"DualWrapperCodegen method '{name}' returned different results."
f"original_wrapper_code v.s. autotuning_wrapper_code: {result1} != {result2}"
)
return dual_method
else:
# Handle non-callable attributes (e.g., lists, integers, etc.)
if attr1 == attr2:
return attr1
else:
raise RuntimeError(
f"DualWrapperCodegen attribute '{name}' has different values."
f"original_wrapper_code v.s. autotuning_wrapper_code: {attr1} != {attr2}"
)
else:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
def store_graph_states(self):
self.removed_operations = V.graph.removed_operations.copy()
self.removed_buffers = V.graph.removed_buffers.copy()
self.removed_inplace_buffers = V.graph.removed_inplace_buffers.copy()
self.mutated_buffers = V.graph.mutated_buffers.copy()
self.never_reuse_buffers = V.graph.never_reuse_buffers.copy()
self.inplaced_to_remove = V.graph.inplaced_to_remove.copy()
def restore_graph_states(self):
V.graph.removed_operations = self.removed_operations
V.graph.removed_buffers = self.removed_buffers
V.graph.removed_inplace_buffers = self.removed_inplace_buffers
V.graph.mutated_buffers = self.mutated_buffers
V.graph.never_reuse_buffers = self.never_reuse_buffers
V.graph.inplaced_to_remove = self.inplaced_to_remove
def for_each_wrapper(self, func, *args, **kwargs):
"""
Apply a function to each wrapper (original_wrapper_code and autotuning_wrapper_code).
The function should take a wrapper as its input parameter.
"""
# This should just be self, using tmp for ease of understanding.
tmp_wrapper_code = V.graph.wrapper_code
tmp_cpp_wrapper = V.graph.cpp_wrapper
V.graph.wrapper_code = self.original_wrapper_code
V.graph.scheduler.backends = self.original_backends
self.state = DualWrapperState.ORIGINAL
self.store_graph_states()
func(self.original_wrapper_code, *args, **kwargs)
self.restore_graph_states()
V.graph.wrapper_code = self.autotuning_wrapper_code
V.graph.scheduler.backends = self.autotuning_backends
V.graph.cpp_wrapper = False
self.state = DualWrapperState.AUTOTUNING
func(self.autotuning_wrapper_code, *args, **kwargs)
# Restore to original wrapper_code.
V.graph.wrapper_code = tmp_wrapper_code
V.graph.cpp_wrapper = tmp_cpp_wrapper
self.state = DualWrapperState.DUAL
V.graph.scheduler.backends = self.original_backends

View File

@ -1351,6 +1351,13 @@ class _InProcessFxCompile(FxCompile):
# See details in vllm/compilation/pass_manager.py.
log.warning("failed to log pt2_configs")
# We use dual wrapper to generate autotuning code alongside with the original codegen.
use_dual_wrapper = (
aot_mode
and config.triton.autotune_at_compile_time
and config.triton.autotune_full_graph
)
with (
V.set_fake_mode(fake_mode),
maybe_disable_comprehensive_padding(example_inputs),
@ -1386,6 +1393,7 @@ class _InProcessFxCompile(FxCompile):
is_backward=is_backward,
is_const_graph=True,
fx_wrapper=fx_wrapper,
use_dual_wrapper=use_dual_wrapper, # type: ignore[arg-type]
)
with (
V.set_graph_handler(const_graph),
@ -1420,6 +1428,7 @@ class _InProcessFxCompile(FxCompile):
const_module=const_graph,
inputs_to_check=inputs_to_check,
fx_wrapper=fx_wrapper,
use_dual_wrapper=use_dual_wrapper, # type: ignore[arg-type]
)
metrics_helper = metrics.CachedMetricsHelper()

View File

@ -1351,6 +1351,11 @@ class triton:
# Side effect for this option is increased memory footprint during first pass compilation.
autotune_with_sample_inputs: bool = False
# Autotune the full graph instead of just individual kernels.
# This provides a comprehensive autotuning approach across the entire computation graph.
# This option is mutually exclusive with autotune_with_sample_inputs.
autotune_full_graph: bool = True
# Allows tiling reductions into multiple dimensions.
# For best results, this should be used with prefer_nd_tiling.
tile_reductions: bool = False

View File

@ -342,6 +342,7 @@ class GraphLowering(torch.fx.Interpreter):
name: Optional[str] = None,
inputs_to_check: Optional[Sequence[int]] = None,
fx_wrapper: bool = False,
use_dual_wrapper: bool = False,
) -> None:
super().__init__(gm)
self.example_inputs = example_inputs
@ -425,6 +426,7 @@ class GraphLowering(torch.fx.Interpreter):
self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment]
self.autotuning_wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment]
from torch._inductor.extern_node_serializer import extern_node_json_serializer
@ -445,6 +447,7 @@ class GraphLowering(torch.fx.Interpreter):
self.name = name # type: ignore[assignment]
self.cpp_wrapper = cpp_wrapper
self.fx_wrapper = fx_wrapper
self.use_dual_wrapper = use_dual_wrapper
# record multi_kernel choice for cpp_wrapper so the second pass knows
# which sub-kernel is picked. Copy cpp_wrapper to another variable
@ -2091,8 +2094,58 @@ class GraphLowering(torch.fx.Interpreter):
partition_signatures,
)
# We do not need tuning code for Triton if only_cpu is True.
if only_cpu:
self.use_dual_wrapper = False
# Create autotuning_wrapper_code for full graph autotuning if needed
if self.use_dual_wrapper:
if self.cpp_wrapper:
# If we're using cpp wrapper, create a separate Python wrapper for autotuning
python_wrapper_code_gen_cls = get_wrapper_codegen_for_device(
self.device_type, cpp_wrapper=False, fx_wrapper=self.fx_wrapper
)
assert python_wrapper_code_gen_cls is not None, (
f"Python wrapper for device {self.device_type} not supported"
)
self.autotuning_wrapper_code = python_wrapper_code_gen_cls.create(
is_subgraph,
subgraph_name,
parent_wrapper_code,
partition_signatures,
)
else:
# If we're already using Python wrapper, create a copy for autotuning
self.autotuning_wrapper_code = wrapper_code_gen_cls.create(
is_subgraph,
subgraph_name,
parent_wrapper_code,
partition_signatures,
)
# Create DualWrapperCodegen to handle both wrappers
from .codegen.wrapper import DualWrapperCodegen
original_wrapper_code = self.wrapper_code
dual_wrapper = DualWrapperCodegen(
original_wrapper_code, self.autotuning_wrapper_code
)
self.wrapper_code = dual_wrapper # type: ignore[assignment]
if self.const_module:
self.wrapper_code._names_iter = self.const_module.wrapper_code._names_iter
if hasattr(self.wrapper_code, "original_wrapper_code"):
# DualWrapperCodegen case
self.wrapper_code.original_wrapper_code._names_iter = ( # type: ignore[attr-defined]
self.const_module.wrapper_code._names_iter
)
self.wrapper_code.autotuning_wrapper_code._names_iter = ( # type: ignore[attr-defined]
self.const_module.wrapper_code._names_iter
)
else:
# Regular wrapper case
self.wrapper_code._names_iter = (
self.const_module.wrapper_code._names_iter
)
def extract_autotune_inputs(
self, example_inputs: list[Union[int, float, torch.Tensor]]
@ -2189,6 +2242,73 @@ class GraphLowering(torch.fx.Interpreter):
self.autotuning_inputs = returned_outputs[: len(kwargs_inputs)]
self.autotuning_mapping = triton_inputs
def extract_real_inputs(self) -> list[Union[int, float, torch.Tensor]]:
def materialize(
x: Union[torch.SymInt, torch.SymFloat, torch.Tensor],
) -> Union[int, float, torch.Tensor]:
if x is None:
# pyrefly: ignore # bad-return
return None
elif isinstance(x, (torch.SymInt, torch.SymFloat)):
# Need concrete value to run dynamic shapes and tune the result
return x.node.hint
elif isinstance(x, FakeTensor):
return defake(x)
else:
assert isinstance(x, torch.Tensor), (
"Unknown type when creating real inputs" + str(type(x))
)
return x
tracing_context = torch._guards.TracingContext.try_get()
if tracing_context is not None and not isinstance(V.real_inputs, NullHandler):
if tracing_context.output_strides:
tracing_context.output_strides.clear()
params_flat = [
param
for param in tracing_context.params_flat # type: ignore[union-attr]
if param is not None
]
real_inputs = [
materialize(x) for x in itertools.chain(params_flat, V.real_inputs)
]
else:
# In the backward pass, V.real_inputs is not OrderedSet.
# Generating random inputs based on self.example_inputs sometimes can be problematic,
# e.g. illegal memory access. A comprehensive fix is to autotune in a separate process.
real_inputs = [
materialize(x) # type:ignore[arg-type]
for x in (
self.example_inputs # type:ignore[union-attr]
if isinstance(V.real_inputs, NullHandler)
else V.real_inputs
)
]
if self.mutated_inputs:
from .compile_fx import clone_preserve_strides
mutated_input_idxs = [
idx
for idx, name in enumerate(self.graph_inputs)
if name in self.mutated_inputs
and isinstance(real_inputs[idx], torch.Tensor)
]
for idx in mutated_input_idxs:
# clone mutated Tensor inputs to avoid mutating them in
# the first pass of the CPP wrapper-based compilation, as
# this will lead to a side effect on the example inputs:
# e.g. if torch.compile(f)(x) if called on input-mutating
# f, the inputs x will be mutated twice in the process:
# once here, and again when running the compiled model;
# this will also lead to a numerically incorrect output
mutated_inp = real_inputs[idx]
assert isinstance(mutated_inp, torch.Tensor)
real_inputs[idx] = clone_preserve_strides(mutated_inp)
del mutated_inp
return real_inputs
def codegen_with_cpp_wrapper(
self,
) -> tuple[ValueWithLineMap, ValueWithLineMap]:
@ -2196,76 +2316,15 @@ class GraphLowering(torch.fx.Interpreter):
For GPU, Triton kernels are autotuned and stored as cubin files
"""
if any(device in self.device_types for device in ["cuda", "xpu"]):
def extract_real_inputs() -> list[Union[int, float, torch.Tensor]]:
def materialize(
x: Union[torch.SymInt, torch.SymFloat, torch.Tensor],
) -> Union[int, float, torch.Tensor]:
if x is None:
# pyrefly: ignore # bad-return
return None
elif isinstance(x, (torch.SymInt, torch.SymFloat)):
# Need concrete value to run dynamic shapes and tune the result
return x.node.hint
elif isinstance(x, FakeTensor):
return defake(x)
else:
assert isinstance(x, torch.Tensor), (
"Unknown type when creating real inputs" + str(type(x))
)
return x
tracing_context = torch._guards.TracingContext.try_get()
if tracing_context is not None and not isinstance(
V.real_inputs, NullHandler
):
if tracing_context.output_strides:
tracing_context.output_strides.clear()
params_flat = [
param
for param in tracing_context.params_flat # type: ignore[union-attr]
if param is not None
]
real_inputs = [
materialize(x)
for x in itertools.chain(params_flat, V.real_inputs)
]
else:
# In the backward pass, V.real_inputs is not OrderedSet.
# Generating random inputs based on self.example_inputs sometimes can be problematic,
# e.g. illegal memory access. A comprehensive fix is to autotune in a separate process.
real_inputs = [
materialize(x) # type:ignore[arg-type]
for x in (
self.example_inputs # type:ignore[union-attr]
if isinstance(V.real_inputs, NullHandler)
else V.real_inputs
)
]
if self.mutated_inputs:
from .compile_fx import clone_preserve_strides
mutated_input_idxs = [
idx
for idx, name in enumerate(self.graph_inputs)
if name in self.mutated_inputs
and isinstance(real_inputs[idx], torch.Tensor)
]
for idx in mutated_input_idxs:
# clone mutated Tensor inputs to avoid mutating them in
# the first pass of the CPP wrapper-based compilation, as
# this will lead to a side effect on the example inputs:
# e.g. if torch.compile(f)(x) if called on input-mutating
# f, the inputs x will be mutated twice in the process:
# once here, and again when running the compiled model;
# this will also lead to a numerically incorrect output
mutated_inp = real_inputs[idx]
assert isinstance(mutated_inp, torch.Tensor)
real_inputs[idx] = clone_preserve_strides(mutated_inp)
del mutated_inp
return real_inputs
# Validate mutual exclusivity between autotune_with_sample_inputs and autotune_full_graph
if (
config.triton.autotune_with_sample_inputs
and config.triton.autotune_full_graph
):
raise ValueError(
"autotune_with_sample_inputs and autotune_full_graph are mutually exclusive. "
"Only one of these options can be enabled at a time."
)
if config.triton.autotune_at_compile_time:
# If autotune_at_compile_time is True, we can do the codegen in one-pass
@ -2277,7 +2336,7 @@ class GraphLowering(torch.fx.Interpreter):
user_defined_kernels = True
break
if user_defined_kernels:
real_inputs = extract_real_inputs()
real_inputs = self.extract_real_inputs()
self.extract_autotune_inputs(real_inputs)
return self.codegen()
else:
@ -2285,7 +2344,7 @@ class GraphLowering(torch.fx.Interpreter):
self.cpp_wrapper = False
compiled = self.compile_to_module().call
real_inputs = extract_real_inputs()
real_inputs = self.extract_real_inputs()
with torch.utils._python_dispatch._disable_current_modes():
compiled(real_inputs)
del real_inputs
@ -2330,9 +2389,26 @@ class GraphLowering(torch.fx.Interpreter):
V.graph.all_codegen_kernel_names,
)
result = self.wrapper_code.generate(self.is_inference)
self.wrapper_code.pop_codegened_graph()
return result
if self.use_dual_wrapper:
# If we're doing full graph autotuning, we need to generate the autotuning wrapper code
# and the autotuning kernels
original_wrapper_code = self.wrapper_code.original_wrapper_code # type: ignore[attr-defined]
autotuning_wrapper_code = self.wrapper_code.autotuning_wrapper_code # type: ignore[attr-defined]
original_cpp_wrapper = self.cpp_wrapper
self.wrapper_code = autotuning_wrapper_code
self.cpp_wrapper = False
autotuning_code, _ = self.wrapper_code.generate(self.is_inference)
autotuning_module = self._compile_to_module_lines(autotuning_code)
real_inputs = self.extract_real_inputs()
autotuning_module.call(real_inputs)
del real_inputs
self.cpp_wrapper = original_cpp_wrapper
self.wrapper_code = original_wrapper_code
result = self.wrapper_code.generate(self.is_inference)
self.wrapper_code.pop_codegened_graph()
return result
def codegen_subgraph(self, parent_graph: GraphLowering) -> None:
"""
@ -2416,7 +2492,10 @@ class GraphLowering(torch.fx.Interpreter):
) -> CompiledModule:
from .codecache import PyCodeCache
if config.triton.autotune_at_compile_time:
if (
config.triton.autotune_at_compile_time
and not config.triton.autotune_full_graph
):
# sanitize docstrings in kernel defs (#155006)
kernel_autotune_defs = self.wrapper_code.kernel_autotune_defs.getvalue()
kernel_autotune_defs = kernel_autotune_defs.replace('"""', '\\"\\"\\"')

View File

@ -5469,74 +5469,102 @@ class Scheduler:
node.get_name(),
)
self.enter_context(node)
current_device = self.current_device
buffer_names_to_free = OrderedSet(self.buffer_names_to_free)
available_buffer_names = OrderedSet(self.available_buffer_names)
completed_operations = OrderedSet(self.completed_operations)
if device := node.get_device():
if (
device != self.current_device
or node.is_extern()
or node.is_template()
):
self.flush()
if device != self.current_device:
if self.current_device and device_need_guard(
self.current_device.type
def wrap_codegen_node(w, *args, **kwargs): # type: ignore[no-untyped-def]
self.enter_context(node)
self.current_node = node
curr_node = node
self.current_device = kwargs["current_device"]
self.buffer_names_to_free = OrderedSet(kwargs["buffer_names_to_free"])
self.available_buffer_names = OrderedSet(
kwargs["available_buffer_names"]
)
self.completed_operations = OrderedSet(kwargs["completed_operations"])
if device := node.get_device():
if (
device != self.current_device
or node.is_extern()
or node.is_template()
):
V.graph.wrapper_code.codegen_device_guard_exit()
self.current_device = device
if device_need_guard(device.type):
assert device.index is not None, "device should have an index"
V.graph.wrapper_code.codegen_device_guard_enter(device.index)
self.flush()
if device != self.current_device:
if self.current_device and device_need_guard(
self.current_device.type
):
w.codegen_device_guard_exit()
self.current_device = device
if device_need_guard(device.type):
assert device.index is not None, (
"device should have an index"
)
w.codegen_device_guard_enter(device.index)
self.current_node = node
self.buffer_names_to_free.update(node.last_usage)
self.buffer_names_to_free.update(node.last_usage)
if node.is_template():
prologue, template_node, epilogue = (
node.get_prologue_template_epilogue(list(node.get_nodes()))
)
# pyrefly: ignore # unbound-name
self.get_backend(device).codegen_template(
template_node, epilogue, prologue
)
elif node.is_extern():
curr_node = typing.cast(ExternKernelSchedulerNode, curr_node)
self.codegen_extern_call(curr_node)
elif node.is_foreach():
curr_node = typing.cast(ForeachKernelSchedulerNode, curr_node)
# pyrefly: ignore # unbound-name
backend_ = self.get_backend(device)
from .codegen.cuda_combined_scheduling import CUDACombinedScheduling
from .codegen.simd import SIMDScheduling
if node.is_template():
prologue, template_node, epilogue = node.get_prologue_template_epilogue(
list(node.get_nodes())
)
# pyrefly: ignore # unbound-name
self.get_backend(device).codegen_template(
template_node, epilogue, prologue
)
elif node.is_extern():
node = typing.cast(ExternKernelSchedulerNode, node)
self.codegen_extern_call(node)
elif node.is_foreach():
node = typing.cast(ForeachKernelSchedulerNode, node)
# pyrefly: ignore # unbound-name
backend_ = self.get_backend(device)
from .codegen.cuda_combined_scheduling import CUDACombinedScheduling
from .codegen.simd import SIMDScheduling
if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling)):
backend = backend_
if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling)):
backend = backend_
else:
raise AssertionError(f"{type(self)=}")
backend.codegen_combo_kernel(node)
elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
# pyrefly: ignore # unbound-name
self.get_backend(device).codegen_node(node)
else:
raise AssertionError(f"{type(self)=}")
backend.codegen_combo_kernel(node)
elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
assert isinstance(node, NopKernelSchedulerNode)
curr_node.mark_run()
# pyrefly: ignore # unbound-name
self.get_backend(device).codegen_node(node)
if config.triton.debug_sync_kernel:
# pyrefly: ignore # unbound-name
self.get_backend(device).codegen_sync()
self.available_buffer_names.update(curr_node.get_buffer_names())
self.completed_operations.update(curr_node.get_operation_names())
if not isinstance(node, NopKernelSchedulerNode):
device = node.get_device()
if (
device is not None
and device.type != "meta"
and self.get_backend(device).ready_to_flush()
):
self.flush()
from .codegen.wrapper import DualWrapperCodegen
states = {
"current_device": current_device,
"buffer_names_to_free": buffer_names_to_free,
"available_buffer_names": available_buffer_names,
"completed_operations": completed_operations,
}
if isinstance(V.graph.wrapper_code, DualWrapperCodegen):
V.graph.wrapper_code.for_each_wrapper(wrap_codegen_node, **states)
else:
assert isinstance(node, NopKernelSchedulerNode)
node.mark_run()
# pyrefly: ignore # unbound-name
if config.triton.debug_sync_kernel:
# pyrefly: ignore # unbound-name
self.get_backend(device).codegen_sync()
self.available_buffer_names.update(node.get_buffer_names())
self.completed_operations.update(node.get_operation_names())
if not isinstance(node, NopKernelSchedulerNode):
device = node.get_device()
if (
device is not None
and device.type != "meta"
and self.get_backend(device).ready_to_flush()
):
self.flush()
wrap_codegen_node(V.graph.wrapper_code, **states)
if self.current_device != self.default_device_context:
# when default_device_context is not None, we are codegen