mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[WIP] export
ghstack-source-id: be382af7fe97c76c15abf8ed1ce0b43c8fe03568 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165823
This commit is contained in:
@ -518,9 +518,11 @@ def init_backend_registration() -> None:
|
||||
"cpu",
|
||||
lambda scheduling: cpu_backends[config.cpu_backend](scheduling),
|
||||
PythonWrapperCodegen,
|
||||
CppWrapperCpuArrayRef
|
||||
if config.aot_inductor.allow_stack_allocation
|
||||
else CppWrapperCpu,
|
||||
(
|
||||
CppWrapperCpuArrayRef
|
||||
if config.aot_inductor.allow_stack_allocation
|
||||
else CppWrapperCpu
|
||||
),
|
||||
WrapperFxCodegen,
|
||||
)
|
||||
|
||||
|
@ -12,6 +12,7 @@ import operator
|
||||
import random
|
||||
import re
|
||||
import tempfile
|
||||
from enum import auto, Enum
|
||||
from itertools import chain, count
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
@ -181,6 +182,7 @@ def user_defined_kernel_grid_fn_code(
|
||||
)
|
||||
)
|
||||
if config.triton.autotune_at_compile_time
|
||||
and not config.triton.autotune_full_graph
|
||||
else None
|
||||
),
|
||||
)
|
||||
@ -190,6 +192,7 @@ def user_defined_kernel_grid_fn_code(
|
||||
if (
|
||||
wrapper
|
||||
and config.triton.autotune_at_compile_time
|
||||
and not config.triton.autotune_full_graph
|
||||
and name not in wrapper.kernel_autotune_names
|
||||
):
|
||||
wrapper.kernel_autotune_calls.writeline(example_grid or line)
|
||||
@ -198,12 +201,15 @@ def user_defined_kernel_grid_fn_code(
|
||||
writeline(f"def {fn_name}(meta):")
|
||||
kernel_autotune_calls_indent = (
|
||||
wrapper.kernel_autotune_calls.indent()
|
||||
if wrapper and config.triton.autotune_at_compile_time
|
||||
if wrapper
|
||||
and config.triton.autotune_at_compile_time
|
||||
and not config.triton.autotune_full_graph
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
with output.indent(), kernel_autotune_calls_indent:
|
||||
if (
|
||||
config.triton.autotune_at_compile_time
|
||||
and not config.triton.autotune_full_graph
|
||||
and original_fxnode_name
|
||||
and V.graph.autotuning_grids
|
||||
and original_fxnode_name in V.graph.autotuning_grids
|
||||
@ -1080,7 +1086,10 @@ class PythonWrapperCodegen(CodeGen):
|
||||
@functools.cache
|
||||
def add_import_once(line: str) -> None:
|
||||
self.imports.writeline(line)
|
||||
if config.triton.autotune_at_compile_time:
|
||||
if (
|
||||
config.triton.autotune_at_compile_time
|
||||
and not config.triton.autotune_full_graph
|
||||
):
|
||||
self.kernel_autotune_calls.writeline(line)
|
||||
|
||||
self.add_import_once = add_import_once
|
||||
@ -1225,7 +1234,10 @@ class PythonWrapperCodegen(CodeGen):
|
||||
import triton.language as tl
|
||||
from {triton_heuristics.__name__} import start_graph, end_graph
|
||||
"""
|
||||
if config.triton.autotune_at_compile_time:
|
||||
if (
|
||||
config.triton.autotune_at_compile_time
|
||||
and not config.triton.autotune_full_graph
|
||||
):
|
||||
self.kernel_autotune_calls.splice(import_str)
|
||||
self.kernel_autotune_calls.writeline(
|
||||
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
|
||||
@ -1240,7 +1252,10 @@ class PythonWrapperCodegen(CodeGen):
|
||||
import_get_raw_stream_str = V.graph.device_ops.import_get_raw_stream_as(
|
||||
"get_raw_stream"
|
||||
)
|
||||
if config.triton.autotune_at_compile_time:
|
||||
if (
|
||||
config.triton.autotune_at_compile_time
|
||||
and not config.triton.autotune_full_graph
|
||||
):
|
||||
if not self.kernel_autotune_calls.contains(import_get_raw_stream_str):
|
||||
self.kernel_autotune_calls.writeline(import_get_raw_stream_str)
|
||||
if not V.graph.cpp_wrapper:
|
||||
@ -1398,9 +1413,10 @@ class PythonWrapperCodegen(CodeGen):
|
||||
self.write_get_raw_stream_header()
|
||||
name = f"stream{device_idx}"
|
||||
if config.triton.autotune_at_compile_time:
|
||||
self.kernel_autotune_calls.writeline(
|
||||
f"{name} = get_raw_stream({device_idx})"
|
||||
)
|
||||
if not config.triton.autotune_full_graph:
|
||||
self.kernel_autotune_calls.writeline(
|
||||
f"{name} = get_raw_stream({device_idx})"
|
||||
)
|
||||
if V.graph.cpp_wrapper:
|
||||
# For cpp wrapper, no need to continue codegen for the main body
|
||||
return name
|
||||
@ -1431,7 +1447,10 @@ class PythonWrapperCodegen(CodeGen):
|
||||
self.writeline(
|
||||
EnterDeviceContextManagerLine(device_idx, self.last_seen_device_guard_index)
|
||||
)
|
||||
if config.triton.autotune_at_compile_time:
|
||||
if (
|
||||
config.triton.autotune_at_compile_time
|
||||
and not config.triton.autotune_full_graph
|
||||
):
|
||||
# mimic logic of EnterDeviceContextManagerLine.codegen for the autotune code block
|
||||
self.write_triton_header_once()
|
||||
self.kernel_autotune_calls.writeline(
|
||||
@ -1448,7 +1467,10 @@ class PythonWrapperCodegen(CodeGen):
|
||||
|
||||
def codegen_device_guard_exit(self) -> None:
|
||||
self.writeline(ExitDeviceContextManagerLine())
|
||||
if config.triton.autotune_at_compile_time:
|
||||
if (
|
||||
config.triton.autotune_at_compile_time
|
||||
and not config.triton.autotune_full_graph
|
||||
):
|
||||
self.kernel_autotune_calls.do_unindent()
|
||||
|
||||
def generate_return(self, output_refs: list[str]) -> None:
|
||||
@ -1672,7 +1694,10 @@ class PythonWrapperCodegen(CodeGen):
|
||||
|
||||
def _write_multi_kernel_defs(self) -> None:
|
||||
kernel_defs = self.multi_kernel_state.kernel_defs
|
||||
if config.triton.autotune_at_compile_time:
|
||||
if (
|
||||
config.triton.autotune_at_compile_time
|
||||
and not config.triton.autotune_full_graph
|
||||
):
|
||||
self.kernel_autotune_defs.splice(kernel_defs)
|
||||
else:
|
||||
self.header.splice(kernel_defs)
|
||||
@ -1690,8 +1715,12 @@ class PythonWrapperCodegen(CodeGen):
|
||||
|
||||
self.run_wrapper_ir_passes(is_inference)
|
||||
|
||||
if config.triton.store_cubin and not config.triton.autotune_at_compile_time:
|
||||
self.generate_reset_kernel_saved_flags()
|
||||
if config.triton.store_cubin:
|
||||
if (
|
||||
not config.triton.autotune_at_compile_time
|
||||
or config.triton.autotune_full_graph
|
||||
):
|
||||
self.generate_reset_kernel_saved_flags()
|
||||
|
||||
# At this point, we shouldn't generate any new memory planning lines.
|
||||
# Override writeline to point at the wrapper call, in case it gets called.
|
||||
@ -1713,10 +1742,17 @@ class PythonWrapperCodegen(CodeGen):
|
||||
if config.profile_bandwidth:
|
||||
self.generate_end_graph()
|
||||
|
||||
if config.triton.store_cubin and not config.triton.autotune_at_compile_time:
|
||||
self.generate_save_uncompiled_kernels()
|
||||
if config.triton.store_cubin:
|
||||
if (
|
||||
not config.triton.autotune_at_compile_time
|
||||
or config.triton.autotune_full_graph
|
||||
):
|
||||
self.generate_save_uncompiled_kernels()
|
||||
|
||||
if config.triton.autotune_at_compile_time:
|
||||
if (
|
||||
config.triton.autotune_at_compile_time
|
||||
and not config.triton.autotune_full_graph
|
||||
):
|
||||
self.generate_and_run_autotune_block()
|
||||
|
||||
# cpp_wrapper currently doesn't support nvtx
|
||||
@ -1976,17 +2012,20 @@ class PythonWrapperCodegen(CodeGen):
|
||||
def codegen_alloc_from_pool(
|
||||
self, name, offset, dtype, shape, stride
|
||||
) -> tuple[str, list[str]]:
|
||||
return "alloc_from_pool({})".format(
|
||||
", ".join(
|
||||
[
|
||||
name,
|
||||
pexpr(offset), # bytes not numel
|
||||
str(dtype),
|
||||
self.codegen_python_shape_tuple(shape),
|
||||
self.codegen_python_shape_tuple(stride),
|
||||
]
|
||||
)
|
||||
), []
|
||||
return (
|
||||
"alloc_from_pool({})".format(
|
||||
", ".join(
|
||||
[
|
||||
name,
|
||||
pexpr(offset), # bytes not numel
|
||||
str(dtype),
|
||||
self.codegen_python_shape_tuple(shape),
|
||||
self.codegen_python_shape_tuple(stride),
|
||||
]
|
||||
)
|
||||
),
|
||||
[],
|
||||
)
|
||||
|
||||
def codegen_reinterpret_view(
|
||||
self,
|
||||
@ -2214,7 +2253,11 @@ class PythonWrapperCodegen(CodeGen):
|
||||
def _format_kernel_definition(
|
||||
kernel_name: str, kernel_body: str, metadata: Optional[str] = None
|
||||
):
|
||||
if config.triton.autotune_at_compile_time and metadata:
|
||||
if (
|
||||
config.triton.autotune_at_compile_time
|
||||
and not config.triton.autotune_full_graph
|
||||
and metadata
|
||||
):
|
||||
# Generating autotune block
|
||||
# Need to replace C++ comment starter with Python comment starter
|
||||
metadata = re.sub(r"^// ", "# ", metadata, flags=re.MULTILINE)
|
||||
@ -2231,10 +2274,11 @@ class PythonWrapperCodegen(CodeGen):
|
||||
cpp_definition: Optional[str] = None,
|
||||
):
|
||||
if config.triton.autotune_at_compile_time:
|
||||
body = self._format_kernel_definition(
|
||||
kernel_name, kernel_body, metadata=metadata
|
||||
)
|
||||
self.kernel_autotune_defs.splice(body)
|
||||
if not config.triton.autotune_full_graph:
|
||||
body = self._format_kernel_definition(
|
||||
kernel_name, kernel_body, metadata=metadata
|
||||
)
|
||||
self.kernel_autotune_defs.splice(body)
|
||||
if V.graph.cpp_wrapper:
|
||||
# For cpp wrapper, no need to continue codegen for the main body
|
||||
return
|
||||
@ -2564,7 +2608,10 @@ class PythonWrapperCodegen(CodeGen):
|
||||
else:
|
||||
raise AssertionError(ws.zero_mode)
|
||||
|
||||
if config.triton.autotune_at_compile_time:
|
||||
if (
|
||||
config.triton.autotune_at_compile_time
|
||||
and not config.triton.autotune_full_graph
|
||||
):
|
||||
self.kernel_autotune_calls.writeline(
|
||||
PythonWrapperCodegen.make_allocation(
|
||||
self,
|
||||
@ -2771,7 +2818,6 @@ class PythonWrapperCodegen(CodeGen):
|
||||
if isinstance(arg, str)
|
||||
}
|
||||
)
|
||||
|
||||
device = device or V.graph.get_current_device_or_throw()
|
||||
self.writeline(
|
||||
KernelCallLine(
|
||||
@ -2839,134 +2885,141 @@ class PythonWrapperCodegen(CodeGen):
|
||||
config.triton.autotune_at_compile_time
|
||||
and kernel_name not in self.kernel_autotune_names
|
||||
):
|
||||
# Create example args for autotune in a separate epilogue
|
||||
assert arg_types is not None and len(call_args) == len(arg_types), (
|
||||
"call_args and arg_types do not match"
|
||||
)
|
||||
|
||||
autotune_args = None
|
||||
if original_fxnode_name and V.graph.autotuning_mapping:
|
||||
autotune_args = V.graph.autotuning_mapping.get(
|
||||
original_fxnode_name, None
|
||||
if not config.triton.autotune_full_graph:
|
||||
# Create example args for autotune in a separate epilogue
|
||||
assert arg_types is not None and len(call_args) == len(arg_types), (
|
||||
"call_args and arg_types do not match"
|
||||
)
|
||||
|
||||
def get_autotune_deletion_call() -> str:
|
||||
"""After all the autotune kernel calls have been written (i.e.
|
||||
self.kernel_autotune_example_args is complete), returns a deletion call
|
||||
for all autotune example tensors that are unnecessary after kernel_name
|
||||
is called."""
|
||||
tensors_to_delete = [
|
||||
tensor
|
||||
for tensor, kn in self.kernel_autotune_example_args.values()
|
||||
if kn == kernel_name
|
||||
]
|
||||
if tensors_to_delete:
|
||||
return f"del {', '.join(tensors_to_delete)}\n"
|
||||
return ""
|
||||
autotune_args = None
|
||||
if original_fxnode_name and V.graph.autotuning_mapping:
|
||||
autotune_args = V.graph.autotuning_mapping.get(
|
||||
original_fxnode_name, None
|
||||
)
|
||||
|
||||
def infer_arg_by_inputs(raw_keys, raw_args, idx, reused_args):
|
||||
"""We try to infer raw_arg (i.e. raw_args[idx]) from remaining raw_args.
|
||||
This is particularly useful for jagged cases, where the dimension is often
|
||||
being passed in as an input."""
|
||||
def get_autotune_deletion_call() -> str:
|
||||
"""After all the autotune kernel calls have been written (i.e.
|
||||
self.kernel_autotune_example_args is complete), returns a deletion call
|
||||
for all autotune example tensors that are unnecessary after kernel_name
|
||||
is called."""
|
||||
tensors_to_delete = [
|
||||
tensor
|
||||
for tensor, kn in self.kernel_autotune_example_args.values()
|
||||
if kn == kernel_name
|
||||
]
|
||||
if tensors_to_delete:
|
||||
return f"del {', '.join(tensors_to_delete)}\n"
|
||||
return ""
|
||||
|
||||
target_arg = raw_args[idx]
|
||||
if target_arg in reused_args:
|
||||
return True
|
||||
def infer_arg_by_inputs(raw_keys, raw_args, idx, reused_args):
|
||||
"""We try to infer raw_arg (i.e. raw_args[idx]) from remaining raw_args.
|
||||
This is particularly useful for jagged cases, where the dimension is often
|
||||
being passed in as an input."""
|
||||
|
||||
for i, (raw_key, raw_arg) in enumerate(zip(raw_keys, raw_args)):
|
||||
if i == idx or not isinstance(raw_arg, IRNode):
|
||||
continue
|
||||
target_arg = raw_args[idx]
|
||||
if target_arg in reused_args:
|
||||
return True
|
||||
|
||||
triton_input = ""
|
||||
for i, (raw_key, raw_arg) in enumerate(zip(raw_keys, raw_args)):
|
||||
if i == idx or not isinstance(raw_arg, IRNode):
|
||||
continue
|
||||
|
||||
triton_input = ""
|
||||
if autotune_args and raw_key in autotune_args:
|
||||
triton_input = self.get_autotuning_input_name( # type: ignore[attr-defined]
|
||||
autotune_args[raw_key]
|
||||
)
|
||||
if triton_input == "":
|
||||
continue
|
||||
|
||||
try:
|
||||
layout = raw_arg.get_layout()
|
||||
for dim, s in enumerate(layout.size):
|
||||
if s == target_arg:
|
||||
reused_args[target_arg] = (
|
||||
f"{triton_input}.shape[{dim}]"
|
||||
)
|
||||
return True
|
||||
except NotImplementedError:
|
||||
# If layout for this IRNode is not implemented, we could just skip.
|
||||
# Only raise for other Error cases.
|
||||
continue
|
||||
return False
|
||||
|
||||
all_args = []
|
||||
if raw_args is None:
|
||||
# create a dummy raw_args for uniform behavior in the following loop
|
||||
assert raw_keys is None, "keys are not None but args are"
|
||||
raw_keys = [None] * len(call_args)
|
||||
raw_args = [None] * len(call_args)
|
||||
else:
|
||||
assert len(raw_args) == len(call_args), (
|
||||
"call_args and raw_args do not match"
|
||||
)
|
||||
|
||||
reused_args = {}
|
||||
for i, (arg, arg_type, raw_key, raw_arg) in enumerate(
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
zip(call_args, arg_types, raw_keys, raw_args)
|
||||
):
|
||||
key = None
|
||||
if isinstance(arg, str) and "=" in str(arg):
|
||||
# arg may be passed in a kwarg style, and then we need to extract its value
|
||||
key, arg = arg.split("=")
|
||||
|
||||
triton_input: Optional[str] = None
|
||||
if autotune_args and raw_key in autotune_args:
|
||||
triton_input = self.get_autotuning_input_name( # type: ignore[attr-defined]
|
||||
autotune_args[raw_key]
|
||||
)
|
||||
if triton_input == "":
|
||||
continue
|
||||
|
||||
try:
|
||||
layout = raw_arg.get_layout()
|
||||
for dim, s in enumerate(layout.size):
|
||||
if s == target_arg:
|
||||
reused_args[target_arg] = f"{triton_input}.shape[{dim}]"
|
||||
return True
|
||||
except NotImplementedError:
|
||||
# If layout for this IRNode is not implemented, we could just skip.
|
||||
# Only raise for other Error cases.
|
||||
continue
|
||||
return False
|
||||
|
||||
all_args = []
|
||||
if raw_args is None:
|
||||
# create a dummy raw_args for uniform behavior in the following loop
|
||||
assert raw_keys is None, "keys are not None but args are"
|
||||
raw_keys = [None] * len(call_args)
|
||||
raw_args = [None] * len(call_args)
|
||||
else:
|
||||
assert len(raw_args) == len(call_args), (
|
||||
"call_args and raw_args do not match"
|
||||
)
|
||||
|
||||
reused_args = {}
|
||||
for i, (arg, arg_type, raw_key, raw_arg) in enumerate(
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
zip(call_args, arg_types, raw_keys, raw_args)
|
||||
):
|
||||
key = None
|
||||
if isinstance(arg, str) and "=" in str(arg):
|
||||
# arg may be passed in a kwarg style, and then we need to extract its value
|
||||
key, arg = arg.split("=")
|
||||
|
||||
triton_input: Optional[str] = None
|
||||
if autotune_args and raw_key in autotune_args:
|
||||
triton_input = self.get_autotuning_input_name( # type: ignore[attr-defined]
|
||||
autotune_args[raw_key]
|
||||
)
|
||||
|
||||
if triton_input:
|
||||
arg_str = triton_input
|
||||
if not isinstance(arg_type, torch_dtype) and (
|
||||
issubclass(arg_type, sympy.Basic)
|
||||
or isinstance(arg, SymbolicCallArg)
|
||||
if triton_input:
|
||||
arg_str = triton_input
|
||||
if not isinstance(arg_type, torch_dtype) and (
|
||||
issubclass(arg_type, sympy.Basic)
|
||||
or isinstance(arg, SymbolicCallArg)
|
||||
):
|
||||
reused_args[raw_arg] = arg_str
|
||||
elif raw_key == "" and infer_arg_by_inputs(
|
||||
raw_keys, raw_args, i, reused_args
|
||||
):
|
||||
reused_args[raw_arg] = arg_str
|
||||
elif raw_key == "" and infer_arg_by_inputs(
|
||||
raw_keys, raw_args, i, reused_args
|
||||
):
|
||||
# Empty raw_key means this is a arg that's not native to the triton kernel,
|
||||
# and is being added by inductor.
|
||||
arg_str = reused_args[raw_arg]
|
||||
elif isinstance(arg_type, torch_dtype):
|
||||
# workspace allocation is already generated by `generate_workspace_allocation()`
|
||||
# in `TritonKernel.call_kernel()`.
|
||||
if re.match(r"^(workspace|semaphore)", arg):
|
||||
arg_str = arg
|
||||
elif arg not in self.kernel_autotune_example_args:
|
||||
# Empty raw_key means this is a arg that's not native to the triton kernel,
|
||||
# and is being added by inductor.
|
||||
arg_str = reused_args[raw_arg]
|
||||
elif isinstance(arg_type, torch_dtype):
|
||||
# workspace allocation is already generated by `generate_workspace_allocation()`
|
||||
# in `TritonKernel.call_kernel()`.
|
||||
if re.match(r"^(workspace|semaphore)", arg):
|
||||
arg_str = arg
|
||||
elif arg not in self.kernel_autotune_example_args:
|
||||
arg_str = self.generate_example_arg_value(
|
||||
arg, arg_type, raw_arg
|
||||
)
|
||||
else:
|
||||
arg_str = self.kernel_autotune_example_args[arg][0]
|
||||
self.kernel_autotune_example_args[arg] = (arg_str, kernel_name)
|
||||
else:
|
||||
arg_str = self.generate_example_arg_value(
|
||||
arg, arg_type, raw_arg
|
||||
)
|
||||
else:
|
||||
arg_str = self.kernel_autotune_example_args[arg][0]
|
||||
self.kernel_autotune_example_args[arg] = (arg_str, kernel_name)
|
||||
else:
|
||||
arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg)
|
||||
all_args.append(arg_str if key is None else f"{key}={arg_str}")
|
||||
all_args.append(arg_str if key is None else f"{key}={arg_str}")
|
||||
|
||||
# Make sure kernel launch under a device guard because models don't always run on device 0
|
||||
self.kernel_autotune_calls.writeline(
|
||||
f"with {V.graph.device_ops.device_guard(device.index)}:"
|
||||
)
|
||||
self.kernel_autotune_calls.do_indent()
|
||||
self.kernel_autotune_calls.writeline(
|
||||
f"{kernel_name}.run({', '.join(all_args)}, stream={stream_name})"
|
||||
)
|
||||
self.kernel_autotune_calls.do_unindent()
|
||||
# Make sure kernel launch under a device guard because models don't always run on device 0
|
||||
self.kernel_autotune_calls.writeline(
|
||||
f"with {V.graph.device_ops.device_guard(device.index)}:"
|
||||
)
|
||||
self.kernel_autotune_calls.do_indent()
|
||||
self.kernel_autotune_calls.writeline(
|
||||
f"{kernel_name}.run({', '.join(all_args)}, stream={stream_name})"
|
||||
)
|
||||
self.kernel_autotune_calls.do_unindent()
|
||||
|
||||
self.kernel_autotune_calls.writeline(
|
||||
DelayReplaceLine("<del_call>", get_autotune_deletion_call, "<del_call>")
|
||||
)
|
||||
self.kernel_autotune_names.add(kernel_name)
|
||||
self.kernel_autotune_calls.writeline(
|
||||
DelayReplaceLine(
|
||||
"<del_call>", get_autotune_deletion_call, "<del_call>"
|
||||
)
|
||||
)
|
||||
self.kernel_autotune_names.add(kernel_name)
|
||||
if V.graph.cpp_wrapper:
|
||||
# For cpp wrapper, no need to continue codegen for the main body
|
||||
return
|
||||
@ -3317,9 +3370,12 @@ class PythonWrapperCodegen(CodeGen):
|
||||
# In this case, we strip the first key path away.
|
||||
return go(
|
||||
outputs[0].get_name(),
|
||||
keypath[1:]
|
||||
if isinstance(out, ir.MultiOutput) and len(out.indices) != 0
|
||||
else keypath,
|
||||
(
|
||||
keypath[1:]
|
||||
if isinstance(out, ir.MultiOutput)
|
||||
and len(out.indices) != 0
|
||||
else keypath
|
||||
),
|
||||
)
|
||||
else:
|
||||
assert isinstance(keypath[0], pytree.SequenceKey)
|
||||
@ -3787,3 +3843,137 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
|
||||
# V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
|
||||
# )
|
||||
self.parent_wrapper.write_get_raw_stream_header_once()
|
||||
|
||||
|
||||
class DualWrapperState(Enum):
|
||||
DUAL = auto()
|
||||
ORIGINAL = auto()
|
||||
AUTOTUNING = auto()
|
||||
|
||||
|
||||
class DualWrapperCodegen(CodeGen):
|
||||
"""
|
||||
A wrapper class that contains two wrapper_code instances and delegates method calls to both.
|
||||
This allows generating code for both wrappers simultaneously.
|
||||
"""
|
||||
|
||||
def __init__(self, original_wrapper_code, autotuning_wrapper_code):
|
||||
from ..scheduler import BaseScheduling # noqa: TC001
|
||||
|
||||
super().__init__()
|
||||
self.original_wrapper_code = original_wrapper_code
|
||||
self.original_backends: dict[torch.device, BaseScheduling] = {}
|
||||
self.autotuning_wrapper_code = autotuning_wrapper_code
|
||||
self.autotuning_backends: dict[torch.device, BaseScheduling] = {}
|
||||
self.state = DualWrapperState.DUAL
|
||||
|
||||
# Store original states.
|
||||
self.removed_operations: OrderedSet[str] = OrderedSet()
|
||||
self.removed_buffers: OrderedSet[str] = OrderedSet()
|
||||
self.removed_inplace_buffers: OrderedSet[str] = OrderedSet()
|
||||
self.mutated_buffers: OrderedSet[str] = OrderedSet()
|
||||
self.never_reuse_buffers: OrderedSet[str] = OrderedSet()
|
||||
self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""
|
||||
Default handler for any method call not explicitly implemented.
|
||||
Initially raises NotImplementedError, but can be overridden to delegate to both wrappers.
|
||||
"""
|
||||
if hasattr(self.original_wrapper_code, name) and hasattr(
|
||||
self.autotuning_wrapper_code, name
|
||||
):
|
||||
attr1 = getattr(self.original_wrapper_code, name)
|
||||
attr2 = getattr(self.autotuning_wrapper_code, name)
|
||||
|
||||
# Check if both attributes are callable (methods/functions)
|
||||
if callable(attr1) and callable(attr2):
|
||||
|
||||
def dual_method(*args, **kwargs):
|
||||
tmp_wrapper_code = V.graph.wrapper_code
|
||||
tmp_cpp_wrapper = V.graph.cpp_wrapper
|
||||
|
||||
# Call the method on both wrappers
|
||||
V.graph.wrapper_code = self.original_wrapper_code
|
||||
V.graph.scheduler.backends = self.original_backends
|
||||
self.state = DualWrapperState.ORIGINAL
|
||||
result1 = attr1(*args, **kwargs)
|
||||
|
||||
V.graph.wrapper_code = self.autotuning_wrapper_code
|
||||
V.graph.scheduler.backends = self.autotuning_backends
|
||||
V.graph.cpp_wrapper = False
|
||||
self.state = DualWrapperState.AUTOTUNING
|
||||
result2 = attr2(*args, **kwargs)
|
||||
|
||||
# Restore to original wrapper_code.
|
||||
V.graph.wrapper_code = tmp_wrapper_code
|
||||
V.graph.cpp_wrapper = tmp_cpp_wrapper
|
||||
self.state = DualWrapperState.DUAL
|
||||
|
||||
# Check if results are the same, otherwise raise an error
|
||||
if result1 == result2:
|
||||
return result1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"DualWrapperCodegen method '{name}' returned different results."
|
||||
f"original_wrapper_code v.s. autotuning_wrapper_code: {result1} != {result2}"
|
||||
)
|
||||
|
||||
return dual_method
|
||||
else:
|
||||
# Handle non-callable attributes (e.g., lists, integers, etc.)
|
||||
if attr1 == attr2:
|
||||
return attr1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"DualWrapperCodegen attribute '{name}' has different values."
|
||||
f"original_wrapper_code v.s. autotuning_wrapper_code: {attr1} != {attr2}"
|
||||
)
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
def store_graph_states(self):
|
||||
self.removed_operations = V.graph.removed_operations.copy()
|
||||
self.removed_buffers = V.graph.removed_buffers.copy()
|
||||
self.removed_inplace_buffers = V.graph.removed_inplace_buffers.copy()
|
||||
self.mutated_buffers = V.graph.mutated_buffers.copy()
|
||||
self.never_reuse_buffers = V.graph.never_reuse_buffers.copy()
|
||||
self.inplaced_to_remove = V.graph.inplaced_to_remove.copy()
|
||||
|
||||
def restore_graph_states(self):
|
||||
V.graph.removed_operations = self.removed_operations
|
||||
V.graph.removed_buffers = self.removed_buffers
|
||||
V.graph.removed_inplace_buffers = self.removed_inplace_buffers
|
||||
V.graph.mutated_buffers = self.mutated_buffers
|
||||
V.graph.never_reuse_buffers = self.never_reuse_buffers
|
||||
V.graph.inplaced_to_remove = self.inplaced_to_remove
|
||||
|
||||
def for_each_wrapper(self, func, *args, **kwargs):
|
||||
"""
|
||||
Apply a function to each wrapper (original_wrapper_code and autotuning_wrapper_code).
|
||||
The function should take a wrapper as its input parameter.
|
||||
"""
|
||||
# This should just be self, using tmp for ease of understanding.
|
||||
tmp_wrapper_code = V.graph.wrapper_code
|
||||
tmp_cpp_wrapper = V.graph.cpp_wrapper
|
||||
|
||||
V.graph.wrapper_code = self.original_wrapper_code
|
||||
V.graph.scheduler.backends = self.original_backends
|
||||
self.state = DualWrapperState.ORIGINAL
|
||||
self.store_graph_states()
|
||||
func(self.original_wrapper_code, *args, **kwargs)
|
||||
self.restore_graph_states()
|
||||
|
||||
V.graph.wrapper_code = self.autotuning_wrapper_code
|
||||
V.graph.scheduler.backends = self.autotuning_backends
|
||||
V.graph.cpp_wrapper = False
|
||||
self.state = DualWrapperState.AUTOTUNING
|
||||
func(self.autotuning_wrapper_code, *args, **kwargs)
|
||||
|
||||
# Restore to original wrapper_code.
|
||||
V.graph.wrapper_code = tmp_wrapper_code
|
||||
V.graph.cpp_wrapper = tmp_cpp_wrapper
|
||||
self.state = DualWrapperState.DUAL
|
||||
V.graph.scheduler.backends = self.original_backends
|
||||
|
@ -1351,6 +1351,13 @@ class _InProcessFxCompile(FxCompile):
|
||||
# See details in vllm/compilation/pass_manager.py.
|
||||
log.warning("failed to log pt2_configs")
|
||||
|
||||
# We use dual wrapper to generate autotuning code alongside with the original codegen.
|
||||
use_dual_wrapper = (
|
||||
aot_mode
|
||||
and config.triton.autotune_at_compile_time
|
||||
and config.triton.autotune_full_graph
|
||||
)
|
||||
|
||||
with (
|
||||
V.set_fake_mode(fake_mode),
|
||||
maybe_disable_comprehensive_padding(example_inputs),
|
||||
@ -1386,6 +1393,7 @@ class _InProcessFxCompile(FxCompile):
|
||||
is_backward=is_backward,
|
||||
is_const_graph=True,
|
||||
fx_wrapper=fx_wrapper,
|
||||
use_dual_wrapper=use_dual_wrapper, # type: ignore[arg-type]
|
||||
)
|
||||
with (
|
||||
V.set_graph_handler(const_graph),
|
||||
@ -1420,6 +1428,7 @@ class _InProcessFxCompile(FxCompile):
|
||||
const_module=const_graph,
|
||||
inputs_to_check=inputs_to_check,
|
||||
fx_wrapper=fx_wrapper,
|
||||
use_dual_wrapper=use_dual_wrapper, # type: ignore[arg-type]
|
||||
)
|
||||
metrics_helper = metrics.CachedMetricsHelper()
|
||||
|
||||
|
@ -1351,6 +1351,11 @@ class triton:
|
||||
# Side effect for this option is increased memory footprint during first pass compilation.
|
||||
autotune_with_sample_inputs: bool = False
|
||||
|
||||
# Autotune the full graph instead of just individual kernels.
|
||||
# This provides a comprehensive autotuning approach across the entire computation graph.
|
||||
# This option is mutually exclusive with autotune_with_sample_inputs.
|
||||
autotune_full_graph: bool = True
|
||||
|
||||
# Allows tiling reductions into multiple dimensions.
|
||||
# For best results, this should be used with prefer_nd_tiling.
|
||||
tile_reductions: bool = False
|
||||
|
@ -342,6 +342,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
name: Optional[str] = None,
|
||||
inputs_to_check: Optional[Sequence[int]] = None,
|
||||
fx_wrapper: bool = False,
|
||||
use_dual_wrapper: bool = False,
|
||||
) -> None:
|
||||
super().__init__(gm)
|
||||
self.example_inputs = example_inputs
|
||||
@ -425,6 +426,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
|
||||
self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
|
||||
self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment]
|
||||
self.autotuning_wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment]
|
||||
|
||||
from torch._inductor.extern_node_serializer import extern_node_json_serializer
|
||||
|
||||
@ -445,6 +447,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.name = name # type: ignore[assignment]
|
||||
self.cpp_wrapper = cpp_wrapper
|
||||
self.fx_wrapper = fx_wrapper
|
||||
self.use_dual_wrapper = use_dual_wrapper
|
||||
|
||||
# record multi_kernel choice for cpp_wrapper so the second pass knows
|
||||
# which sub-kernel is picked. Copy cpp_wrapper to another variable
|
||||
@ -2091,8 +2094,58 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
partition_signatures,
|
||||
)
|
||||
|
||||
# We do not need tuning code for Triton if only_cpu is True.
|
||||
if only_cpu:
|
||||
self.use_dual_wrapper = False
|
||||
|
||||
# Create autotuning_wrapper_code for full graph autotuning if needed
|
||||
if self.use_dual_wrapper:
|
||||
if self.cpp_wrapper:
|
||||
# If we're using cpp wrapper, create a separate Python wrapper for autotuning
|
||||
python_wrapper_code_gen_cls = get_wrapper_codegen_for_device(
|
||||
self.device_type, cpp_wrapper=False, fx_wrapper=self.fx_wrapper
|
||||
)
|
||||
assert python_wrapper_code_gen_cls is not None, (
|
||||
f"Python wrapper for device {self.device_type} not supported"
|
||||
)
|
||||
self.autotuning_wrapper_code = python_wrapper_code_gen_cls.create(
|
||||
is_subgraph,
|
||||
subgraph_name,
|
||||
parent_wrapper_code,
|
||||
partition_signatures,
|
||||
)
|
||||
else:
|
||||
# If we're already using Python wrapper, create a copy for autotuning
|
||||
self.autotuning_wrapper_code = wrapper_code_gen_cls.create(
|
||||
is_subgraph,
|
||||
subgraph_name,
|
||||
parent_wrapper_code,
|
||||
partition_signatures,
|
||||
)
|
||||
|
||||
# Create DualWrapperCodegen to handle both wrappers
|
||||
from .codegen.wrapper import DualWrapperCodegen
|
||||
|
||||
original_wrapper_code = self.wrapper_code
|
||||
dual_wrapper = DualWrapperCodegen(
|
||||
original_wrapper_code, self.autotuning_wrapper_code
|
||||
)
|
||||
self.wrapper_code = dual_wrapper # type: ignore[assignment]
|
||||
|
||||
if self.const_module:
|
||||
self.wrapper_code._names_iter = self.const_module.wrapper_code._names_iter
|
||||
if hasattr(self.wrapper_code, "original_wrapper_code"):
|
||||
# DualWrapperCodegen case
|
||||
self.wrapper_code.original_wrapper_code._names_iter = ( # type: ignore[attr-defined]
|
||||
self.const_module.wrapper_code._names_iter
|
||||
)
|
||||
self.wrapper_code.autotuning_wrapper_code._names_iter = ( # type: ignore[attr-defined]
|
||||
self.const_module.wrapper_code._names_iter
|
||||
)
|
||||
else:
|
||||
# Regular wrapper case
|
||||
self.wrapper_code._names_iter = (
|
||||
self.const_module.wrapper_code._names_iter
|
||||
)
|
||||
|
||||
def extract_autotune_inputs(
|
||||
self, example_inputs: list[Union[int, float, torch.Tensor]]
|
||||
@ -2189,6 +2242,73 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.autotuning_inputs = returned_outputs[: len(kwargs_inputs)]
|
||||
self.autotuning_mapping = triton_inputs
|
||||
|
||||
def extract_real_inputs(self) -> list[Union[int, float, torch.Tensor]]:
|
||||
def materialize(
|
||||
x: Union[torch.SymInt, torch.SymFloat, torch.Tensor],
|
||||
) -> Union[int, float, torch.Tensor]:
|
||||
if x is None:
|
||||
# pyrefly: ignore # bad-return
|
||||
return None
|
||||
elif isinstance(x, (torch.SymInt, torch.SymFloat)):
|
||||
# Need concrete value to run dynamic shapes and tune the result
|
||||
return x.node.hint
|
||||
elif isinstance(x, FakeTensor):
|
||||
return defake(x)
|
||||
else:
|
||||
assert isinstance(x, torch.Tensor), (
|
||||
"Unknown type when creating real inputs" + str(type(x))
|
||||
)
|
||||
return x
|
||||
|
||||
tracing_context = torch._guards.TracingContext.try_get()
|
||||
if tracing_context is not None and not isinstance(V.real_inputs, NullHandler):
|
||||
if tracing_context.output_strides:
|
||||
tracing_context.output_strides.clear()
|
||||
|
||||
params_flat = [
|
||||
param
|
||||
for param in tracing_context.params_flat # type: ignore[union-attr]
|
||||
if param is not None
|
||||
]
|
||||
real_inputs = [
|
||||
materialize(x) for x in itertools.chain(params_flat, V.real_inputs)
|
||||
]
|
||||
else:
|
||||
# In the backward pass, V.real_inputs is not OrderedSet.
|
||||
# Generating random inputs based on self.example_inputs sometimes can be problematic,
|
||||
# e.g. illegal memory access. A comprehensive fix is to autotune in a separate process.
|
||||
real_inputs = [
|
||||
materialize(x) # type:ignore[arg-type]
|
||||
for x in (
|
||||
self.example_inputs # type:ignore[union-attr]
|
||||
if isinstance(V.real_inputs, NullHandler)
|
||||
else V.real_inputs
|
||||
)
|
||||
]
|
||||
|
||||
if self.mutated_inputs:
|
||||
from .compile_fx import clone_preserve_strides
|
||||
|
||||
mutated_input_idxs = [
|
||||
idx
|
||||
for idx, name in enumerate(self.graph_inputs)
|
||||
if name in self.mutated_inputs
|
||||
and isinstance(real_inputs[idx], torch.Tensor)
|
||||
]
|
||||
for idx in mutated_input_idxs:
|
||||
# clone mutated Tensor inputs to avoid mutating them in
|
||||
# the first pass of the CPP wrapper-based compilation, as
|
||||
# this will lead to a side effect on the example inputs:
|
||||
# e.g. if torch.compile(f)(x) if called on input-mutating
|
||||
# f, the inputs x will be mutated twice in the process:
|
||||
# once here, and again when running the compiled model;
|
||||
# this will also lead to a numerically incorrect output
|
||||
mutated_inp = real_inputs[idx]
|
||||
assert isinstance(mutated_inp, torch.Tensor)
|
||||
real_inputs[idx] = clone_preserve_strides(mutated_inp)
|
||||
del mutated_inp
|
||||
return real_inputs
|
||||
|
||||
def codegen_with_cpp_wrapper(
|
||||
self,
|
||||
) -> tuple[ValueWithLineMap, ValueWithLineMap]:
|
||||
@ -2196,76 +2316,15 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
For GPU, Triton kernels are autotuned and stored as cubin files
|
||||
"""
|
||||
if any(device in self.device_types for device in ["cuda", "xpu"]):
|
||||
|
||||
def extract_real_inputs() -> list[Union[int, float, torch.Tensor]]:
|
||||
def materialize(
|
||||
x: Union[torch.SymInt, torch.SymFloat, torch.Tensor],
|
||||
) -> Union[int, float, torch.Tensor]:
|
||||
if x is None:
|
||||
# pyrefly: ignore # bad-return
|
||||
return None
|
||||
elif isinstance(x, (torch.SymInt, torch.SymFloat)):
|
||||
# Need concrete value to run dynamic shapes and tune the result
|
||||
return x.node.hint
|
||||
elif isinstance(x, FakeTensor):
|
||||
return defake(x)
|
||||
else:
|
||||
assert isinstance(x, torch.Tensor), (
|
||||
"Unknown type when creating real inputs" + str(type(x))
|
||||
)
|
||||
return x
|
||||
|
||||
tracing_context = torch._guards.TracingContext.try_get()
|
||||
if tracing_context is not None and not isinstance(
|
||||
V.real_inputs, NullHandler
|
||||
):
|
||||
if tracing_context.output_strides:
|
||||
tracing_context.output_strides.clear()
|
||||
|
||||
params_flat = [
|
||||
param
|
||||
for param in tracing_context.params_flat # type: ignore[union-attr]
|
||||
if param is not None
|
||||
]
|
||||
real_inputs = [
|
||||
materialize(x)
|
||||
for x in itertools.chain(params_flat, V.real_inputs)
|
||||
]
|
||||
else:
|
||||
# In the backward pass, V.real_inputs is not OrderedSet.
|
||||
# Generating random inputs based on self.example_inputs sometimes can be problematic,
|
||||
# e.g. illegal memory access. A comprehensive fix is to autotune in a separate process.
|
||||
real_inputs = [
|
||||
materialize(x) # type:ignore[arg-type]
|
||||
for x in (
|
||||
self.example_inputs # type:ignore[union-attr]
|
||||
if isinstance(V.real_inputs, NullHandler)
|
||||
else V.real_inputs
|
||||
)
|
||||
]
|
||||
|
||||
if self.mutated_inputs:
|
||||
from .compile_fx import clone_preserve_strides
|
||||
|
||||
mutated_input_idxs = [
|
||||
idx
|
||||
for idx, name in enumerate(self.graph_inputs)
|
||||
if name in self.mutated_inputs
|
||||
and isinstance(real_inputs[idx], torch.Tensor)
|
||||
]
|
||||
for idx in mutated_input_idxs:
|
||||
# clone mutated Tensor inputs to avoid mutating them in
|
||||
# the first pass of the CPP wrapper-based compilation, as
|
||||
# this will lead to a side effect on the example inputs:
|
||||
# e.g. if torch.compile(f)(x) if called on input-mutating
|
||||
# f, the inputs x will be mutated twice in the process:
|
||||
# once here, and again when running the compiled model;
|
||||
# this will also lead to a numerically incorrect output
|
||||
mutated_inp = real_inputs[idx]
|
||||
assert isinstance(mutated_inp, torch.Tensor)
|
||||
real_inputs[idx] = clone_preserve_strides(mutated_inp)
|
||||
del mutated_inp
|
||||
return real_inputs
|
||||
# Validate mutual exclusivity between autotune_with_sample_inputs and autotune_full_graph
|
||||
if (
|
||||
config.triton.autotune_with_sample_inputs
|
||||
and config.triton.autotune_full_graph
|
||||
):
|
||||
raise ValueError(
|
||||
"autotune_with_sample_inputs and autotune_full_graph are mutually exclusive. "
|
||||
"Only one of these options can be enabled at a time."
|
||||
)
|
||||
|
||||
if config.triton.autotune_at_compile_time:
|
||||
# If autotune_at_compile_time is True, we can do the codegen in one-pass
|
||||
@ -2277,7 +2336,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
user_defined_kernels = True
|
||||
break
|
||||
if user_defined_kernels:
|
||||
real_inputs = extract_real_inputs()
|
||||
real_inputs = self.extract_real_inputs()
|
||||
self.extract_autotune_inputs(real_inputs)
|
||||
return self.codegen()
|
||||
else:
|
||||
@ -2285,7 +2344,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.cpp_wrapper = False
|
||||
compiled = self.compile_to_module().call
|
||||
|
||||
real_inputs = extract_real_inputs()
|
||||
real_inputs = self.extract_real_inputs()
|
||||
with torch.utils._python_dispatch._disable_current_modes():
|
||||
compiled(real_inputs)
|
||||
del real_inputs
|
||||
@ -2330,9 +2389,26 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
V.graph.all_codegen_kernel_names,
|
||||
)
|
||||
|
||||
result = self.wrapper_code.generate(self.is_inference)
|
||||
self.wrapper_code.pop_codegened_graph()
|
||||
return result
|
||||
if self.use_dual_wrapper:
|
||||
# If we're doing full graph autotuning, we need to generate the autotuning wrapper code
|
||||
# and the autotuning kernels
|
||||
original_wrapper_code = self.wrapper_code.original_wrapper_code # type: ignore[attr-defined]
|
||||
autotuning_wrapper_code = self.wrapper_code.autotuning_wrapper_code # type: ignore[attr-defined]
|
||||
|
||||
original_cpp_wrapper = self.cpp_wrapper
|
||||
self.wrapper_code = autotuning_wrapper_code
|
||||
self.cpp_wrapper = False
|
||||
autotuning_code, _ = self.wrapper_code.generate(self.is_inference)
|
||||
autotuning_module = self._compile_to_module_lines(autotuning_code)
|
||||
real_inputs = self.extract_real_inputs()
|
||||
autotuning_module.call(real_inputs)
|
||||
del real_inputs
|
||||
self.cpp_wrapper = original_cpp_wrapper
|
||||
self.wrapper_code = original_wrapper_code
|
||||
|
||||
result = self.wrapper_code.generate(self.is_inference)
|
||||
self.wrapper_code.pop_codegened_graph()
|
||||
return result
|
||||
|
||||
def codegen_subgraph(self, parent_graph: GraphLowering) -> None:
|
||||
"""
|
||||
@ -2416,7 +2492,10 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
) -> CompiledModule:
|
||||
from .codecache import PyCodeCache
|
||||
|
||||
if config.triton.autotune_at_compile_time:
|
||||
if (
|
||||
config.triton.autotune_at_compile_time
|
||||
and not config.triton.autotune_full_graph
|
||||
):
|
||||
# sanitize docstrings in kernel defs (#155006)
|
||||
kernel_autotune_defs = self.wrapper_code.kernel_autotune_defs.getvalue()
|
||||
kernel_autotune_defs = kernel_autotune_defs.replace('"""', '\\"\\"\\"')
|
||||
|
@ -5469,74 +5469,102 @@ class Scheduler:
|
||||
node.get_name(),
|
||||
)
|
||||
|
||||
self.enter_context(node)
|
||||
current_device = self.current_device
|
||||
buffer_names_to_free = OrderedSet(self.buffer_names_to_free)
|
||||
available_buffer_names = OrderedSet(self.available_buffer_names)
|
||||
completed_operations = OrderedSet(self.completed_operations)
|
||||
|
||||
if device := node.get_device():
|
||||
if (
|
||||
device != self.current_device
|
||||
or node.is_extern()
|
||||
or node.is_template()
|
||||
):
|
||||
self.flush()
|
||||
if device != self.current_device:
|
||||
if self.current_device and device_need_guard(
|
||||
self.current_device.type
|
||||
def wrap_codegen_node(w, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
self.enter_context(node)
|
||||
self.current_node = node
|
||||
|
||||
curr_node = node
|
||||
self.current_device = kwargs["current_device"]
|
||||
self.buffer_names_to_free = OrderedSet(kwargs["buffer_names_to_free"])
|
||||
self.available_buffer_names = OrderedSet(
|
||||
kwargs["available_buffer_names"]
|
||||
)
|
||||
self.completed_operations = OrderedSet(kwargs["completed_operations"])
|
||||
|
||||
if device := node.get_device():
|
||||
if (
|
||||
device != self.current_device
|
||||
or node.is_extern()
|
||||
or node.is_template()
|
||||
):
|
||||
V.graph.wrapper_code.codegen_device_guard_exit()
|
||||
self.current_device = device
|
||||
if device_need_guard(device.type):
|
||||
assert device.index is not None, "device should have an index"
|
||||
V.graph.wrapper_code.codegen_device_guard_enter(device.index)
|
||||
self.flush()
|
||||
if device != self.current_device:
|
||||
if self.current_device and device_need_guard(
|
||||
self.current_device.type
|
||||
):
|
||||
w.codegen_device_guard_exit()
|
||||
self.current_device = device
|
||||
if device_need_guard(device.type):
|
||||
assert device.index is not None, (
|
||||
"device should have an index"
|
||||
)
|
||||
w.codegen_device_guard_enter(device.index)
|
||||
|
||||
self.current_node = node
|
||||
self.buffer_names_to_free.update(node.last_usage)
|
||||
self.buffer_names_to_free.update(node.last_usage)
|
||||
if node.is_template():
|
||||
prologue, template_node, epilogue = (
|
||||
node.get_prologue_template_epilogue(list(node.get_nodes()))
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
self.get_backend(device).codegen_template(
|
||||
template_node, epilogue, prologue
|
||||
)
|
||||
elif node.is_extern():
|
||||
curr_node = typing.cast(ExternKernelSchedulerNode, curr_node)
|
||||
self.codegen_extern_call(curr_node)
|
||||
elif node.is_foreach():
|
||||
curr_node = typing.cast(ForeachKernelSchedulerNode, curr_node)
|
||||
# pyrefly: ignore # unbound-name
|
||||
backend_ = self.get_backend(device)
|
||||
from .codegen.cuda_combined_scheduling import CUDACombinedScheduling
|
||||
from .codegen.simd import SIMDScheduling
|
||||
|
||||
if node.is_template():
|
||||
prologue, template_node, epilogue = node.get_prologue_template_epilogue(
|
||||
list(node.get_nodes())
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
self.get_backend(device).codegen_template(
|
||||
template_node, epilogue, prologue
|
||||
)
|
||||
elif node.is_extern():
|
||||
node = typing.cast(ExternKernelSchedulerNode, node)
|
||||
self.codegen_extern_call(node)
|
||||
elif node.is_foreach():
|
||||
node = typing.cast(ForeachKernelSchedulerNode, node)
|
||||
# pyrefly: ignore # unbound-name
|
||||
backend_ = self.get_backend(device)
|
||||
from .codegen.cuda_combined_scheduling import CUDACombinedScheduling
|
||||
from .codegen.simd import SIMDScheduling
|
||||
|
||||
if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling)):
|
||||
backend = backend_
|
||||
if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling)):
|
||||
backend = backend_
|
||||
else:
|
||||
raise AssertionError(f"{type(self)=}")
|
||||
backend.codegen_combo_kernel(node)
|
||||
elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
|
||||
# pyrefly: ignore # unbound-name
|
||||
self.get_backend(device).codegen_node(node)
|
||||
else:
|
||||
raise AssertionError(f"{type(self)=}")
|
||||
backend.codegen_combo_kernel(node)
|
||||
elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
|
||||
assert isinstance(node, NopKernelSchedulerNode)
|
||||
curr_node.mark_run()
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
self.get_backend(device).codegen_node(node)
|
||||
if config.triton.debug_sync_kernel:
|
||||
# pyrefly: ignore # unbound-name
|
||||
self.get_backend(device).codegen_sync()
|
||||
|
||||
self.available_buffer_names.update(curr_node.get_buffer_names())
|
||||
self.completed_operations.update(curr_node.get_operation_names())
|
||||
|
||||
if not isinstance(node, NopKernelSchedulerNode):
|
||||
device = node.get_device()
|
||||
if (
|
||||
device is not None
|
||||
and device.type != "meta"
|
||||
and self.get_backend(device).ready_to_flush()
|
||||
):
|
||||
self.flush()
|
||||
|
||||
from .codegen.wrapper import DualWrapperCodegen
|
||||
|
||||
states = {
|
||||
"current_device": current_device,
|
||||
"buffer_names_to_free": buffer_names_to_free,
|
||||
"available_buffer_names": available_buffer_names,
|
||||
"completed_operations": completed_operations,
|
||||
}
|
||||
if isinstance(V.graph.wrapper_code, DualWrapperCodegen):
|
||||
V.graph.wrapper_code.for_each_wrapper(wrap_codegen_node, **states)
|
||||
else:
|
||||
assert isinstance(node, NopKernelSchedulerNode)
|
||||
node.mark_run()
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
if config.triton.debug_sync_kernel:
|
||||
# pyrefly: ignore # unbound-name
|
||||
self.get_backend(device).codegen_sync()
|
||||
|
||||
self.available_buffer_names.update(node.get_buffer_names())
|
||||
self.completed_operations.update(node.get_operation_names())
|
||||
|
||||
if not isinstance(node, NopKernelSchedulerNode):
|
||||
device = node.get_device()
|
||||
if (
|
||||
device is not None
|
||||
and device.type != "meta"
|
||||
and self.get_backend(device).ready_to_flush()
|
||||
):
|
||||
self.flush()
|
||||
wrap_codegen_node(V.graph.wrapper_code, **states)
|
||||
|
||||
if self.current_device != self.default_device_context:
|
||||
# when default_device_context is not None, we are codegen
|
||||
|
Reference in New Issue
Block a user