Revert "multi-kernel matmuls based on varying hint sizes (#156628)"

This reverts commit 6c795306378c47341d58109da03371bba2bec46e.

Reverted https://github.com/pytorch/pytorch/pull/156628 on behalf of https://github.com/huydhn due to Sorry for reverting your change but some ROCM jobs went crazy after this lands, so I try to see if reverting helps ([comment](https://github.com/pytorch/pytorch/pull/156628#issuecomment-3064617123))
This commit is contained in:
PyTorch MergeBot
2025-07-12 03:48:39 +00:00
parent 2eff14c445
commit 9c189ed29a
14 changed files with 139 additions and 635 deletions

View File

@ -294,15 +294,6 @@ RecordedEventsType = list[tuple[str, list[Any], dict[str, Any]]]
class TritonTemplateKernel(TritonKernel):
"""
A specialized kernel class for Triton templates that handles code generation
for templated Triton kernels.
This class extends TritonKernel to provide additional functionality for
template-based kernel generation, including support for subgraphs, workspace
arguments, and prologue/epilogue fusion.
"""
def __init__(
self,
kernel_name,
@ -323,7 +314,6 @@ class TritonTemplateKernel(TritonKernel):
subgraphs: Optional[list[ir.ComputedBuffer]] = None,
workspace_arg: Optional[WorkspaceArg] = None,
prologue_loads_all_inputs=False,
hint_override: Optional[int] = None,
) -> None:
numel = sympy_product(output_node.get_size())
super().__init__(
@ -332,7 +322,6 @@ class TritonTemplateKernel(TritonKernel):
"r0_": sympy.S.One,
},
features=SIMDKernelFeatures([], numel),
hint_override=hint_override,
)
self.input_nodes = input_nodes
self.output_node = output_node
@ -1107,26 +1096,16 @@ class TritonTemplateKernel(TritonKernel):
def codegen_range_tree(self):
pass # ignore default codegen
def additional_call_args_and_types(self):
if isinstance(self.grid_fn, SymbolicGridFn):
grid_args = self.grid_fn.sympy_call(*self.call_sizes, self.meta)
assert len(grid_args) in (0, 3), "grid_fn should return 3 values"
return (grid_args, map(type, grid_args))
elif all(isinstance(x, (int, sympy.Integer)) for x in self.call_sizes):
grid_args = self.grid_fn(*map(int, self.call_sizes), self.meta)
assert len(grid_args) in (0, 3), "grid_fn should return 3 values"
return (grid_args, map(type, grid_args))
return ((), ())
def call_kernel(self, name: str, node: Optional[ir.IRNode] = None):
wrapper = V.graph.wrapper_code
_, call_args, _, arg_types = self.args.python_argdefs()
additional_call_args, additional_arg_types = (
self.additional_call_args_and_types()
)
if not additional_call_args:
grid_args = ()
if isinstance(self.grid_fn, SymbolicGridFn):
grid_args = self.grid_fn.sympy_call(*self.call_sizes, self.meta)
elif all(isinstance(x, (int, sympy.Integer)) for x in self.call_sizes):
grid_args = self.grid_fn(*map(int, self.call_sizes), self.meta)
else:
assert not V.graph.cpp_wrapper, "cpp_wrapper requires SymbolicGridFn"
wrapper.add_import_once(f"import {self.grid_fn.__module__}")
meta = wrapper.add_meta_once(self.meta)
@ -1135,9 +1114,9 @@ class TritonTemplateKernel(TritonKernel):
f"*{fn_name}({', '.join(map(pexpr, self.call_sizes))}, {meta})"
)
arg_types.append(None)
call_args.extend(additional_call_args)
arg_types.extend(additional_arg_types)
assert len(grid_args) in (0, 3), "grid_fn should return 3 values"
call_args.extend(grid_args)
arg_types.extend(map(type, grid_args))
if self.workspace_arg is not None:
wrapper.generate_workspace_allocation(self.workspace_arg)
@ -1223,7 +1202,6 @@ class GeneratedCodeCache:
num_consumer_groups: int,
num_buffers_warp_spec: int,
kwargs: dict[str, Any],
hint_override: Optional[int] = None,
) -> Optional[str]:
def layout_key(layout: ir.Layout) -> str:
assert not isinstance(layout, ir.FlexibleLayout)
@ -1274,7 +1252,6 @@ class GeneratedCodeCache:
"num_buffers_warp_spec": num_buffers_warp_spec,
"epilogue_fn_hash": epilogue_fn_hash,
"kwargs": kwargs,
"hint_override": hint_override,
}
)
@ -1379,7 +1356,6 @@ class TritonTemplate(KernelTemplate):
layout: ir.Layout,
kwargs: dict[str, Any],
generate_with_caching,
hint_override: Optional[int] = None,
) -> Optional[GenerateAndLoadResult]:
"""Generate the python code and load it into the current process"""
caching_enabled = (
@ -1452,7 +1428,6 @@ class TritonTemplate(KernelTemplate):
output_node=fake_out,
workspace_arg=workspace_arg,
use_jit=False,
hint_override=hint_override,
**kernel_options,
)
@ -1572,7 +1547,6 @@ class TritonTemplate(KernelTemplate):
call_sizes: Optional[Sequence[sympy.core.symbol.Symbol]] = None,
workspace_arg: Optional[WorkspaceArg] = None,
generate_with_caching=False,
hint_override: Optional[int] = None,
**kwargs,
):
"""This function generates a TritonTemplateCaller
@ -1617,7 +1591,6 @@ class TritonTemplate(KernelTemplate):
layout,
kwargs,
generate_with_caching and self._cache_codegen_enabled_for_template,
hint_override=hint_override,
)
# May happen as result of dev by 0.
@ -1639,7 +1612,6 @@ class TritonTemplate(KernelTemplate):
extra_args = V.graph.sizevars.size_hints(
map(sympy.expand, result.kernel_args_sizevars_keys),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
)
kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}"
@ -1663,14 +1635,13 @@ class TritonTemplate(KernelTemplate):
options = result.kernel_options
def make_kernel_render(out_node, hint_override: Optional[int] = None):
def make_kernel_render(out_node):
assert result is not None
kernel = self.kernel_type(
kernel_name=str(Placeholder.KERNEL_NAME),
output_node=out_node,
workspace_arg=workspace_arg,
use_jit=False,
hint_override=hint_override,
**options,
)
render = functools.partial(
@ -1686,7 +1657,6 @@ class TritonTemplate(KernelTemplate):
*V.graph.sizevars.size_hints(
call_sizes,
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
kwargs,
)
@ -1738,7 +1708,6 @@ class TritonTemplate(KernelTemplate):
mutated_inputs=mutated_inputs,
workspace_arg=workspace_arg,
allowed_prologue_inps=result.prologue_supported_inputs,
hint_override=hint_override,
)
@ -1814,7 +1783,6 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
mutated_inputs=None,
workspace_arg: Optional[WorkspaceArg] = None,
allowed_prologue_inps: Optional[OrderedSet[str]] = None,
hint_override: Optional[int] = None,
) -> None:
super().__init__(name, input_nodes, layout, description)
self.make_kernel_render = make_kernel_render
@ -1834,7 +1802,6 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
self.allowed_prologue_inps = (
allowed_prologue_inps if allowed_prologue_inps is not None else OrderedSet()
)
self.hint_override = hint_override
def benchmark(self, *args, out):
assert self.bmreq is not None
@ -2289,7 +2256,7 @@ class AlgorithmSelectorCache(PersistentCache):
# TODO(nmacchioni): remove this hacky way to tell if we ran benchmarking
has_autotuned = False
def benchmark(choices, hint_override: Optional[int] = None):
def benchmark(choices):
nonlocal has_autotuned
# TODO(nmacchioni): remove this hacky way to tell if we ran benchmarking
has_autotuned = True
@ -2297,13 +2264,13 @@ class AlgorithmSelectorCache(PersistentCache):
# TODO(nmacchioni): remove this layer of abstraction
# construct `benchmark_fn` which should pick between in-process and sub-process autotuning
benchmark_fn = self.make_benchmark_fn(
choices, input_nodes, layout, input_gen_fns, hint_override=hint_override
choices, input_nodes, layout, input_gen_fns
)
# `benchmark_fn(choices)` will execute each choice, and return a dict[choice, timing] which
# maps each choice to its runtime, calculated by the specified benchmarker, in milliseconds
return benchmark_fn(choices)
def autotune(choices, hint_override: Optional[int] = None):
def autotune(choices):
log.debug("Starting autotuning")
with dynamo_timed(
@ -2325,13 +2292,13 @@ class AlgorithmSelectorCache(PersistentCache):
),
},
):
return benchmark(choices, hint_override=hint_override)
return benchmark(choices)
if config.autotune_in_subproc:
# Initialize the suprocess pool so it will warmup early.
torch._inductor.autotune_process.get_tuning_process_pool()
def do_autotuning(choices, precompile_fn, hint_override: Optional[int] = None):
def do_autotuning(choices, precompile_fn):
precompile_start_ts = time.time()
with dynamo_timed(
f"{name}_template_precompiling",
@ -2352,8 +2319,7 @@ class AlgorithmSelectorCache(PersistentCache):
candidates,
name,
inputs_key,
lambda choices: autotune(choices, hint_override=hint_override),
hint_override=hint_override,
autotune,
)
choices = self.prune_choices_postscreen(
choices, timings, name, inputs_key, self.prescreening_cache
@ -2366,8 +2332,7 @@ class AlgorithmSelectorCache(PersistentCache):
choices,
name,
inputs_key,
lambda choices: autotune(choices, hint_override=hint_override),
hint_override=hint_override,
autotune,
)
autotune_elapse = time.time() - autotune_start_ts
@ -2390,7 +2355,6 @@ class AlgorithmSelectorCache(PersistentCache):
autotune_elapse,
precompile_elapse,
prescreening_elapse,
hint_override=hint_override,
)
def profiler_bench_function():
@ -2425,16 +2389,8 @@ class AlgorithmSelectorCache(PersistentCache):
if return_multi_template and (config.max_autotune or config.max_autotune_gemm):
def get_timings(hint_override: Optional[int] = None):
filtered_choices = [
c
for c in choices
if not hasattr(c, "hint_override")
or c.hint_override == hint_override
]
timings = do_autotuning(
filtered_choices, precompile_fn, hint_override=hint_override
)
def get_timings():
timings = do_autotuning(choices, precompile_fn)
min_extern_choice = float("inf")
for choice, timing in timings.items():
if isinstance(choice, ExternKernelCaller):
@ -2671,7 +2627,6 @@ class AlgorithmSelectorCache(PersistentCache):
input_nodes: list[ir.IRNode],
layout: ir.Layout,
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
hint_override: Optional[int] = None,
) -> AutotuneArgs:
"""
Factory method to create AutotuneArgs from a list of ChoiceCallers.
@ -2681,9 +2636,7 @@ class AlgorithmSelectorCache(PersistentCache):
# de-duplicate args
unique_example_inputs = {
x.get_name(): input_gen_fns.get(
i, lambda x: cls.benchmark_example_value(x, hint_override=hint_override)
)(x)
x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x)
for i, x in enumerate(input_nodes)
}
example_inputs = list(unique_example_inputs.values())
@ -2696,23 +2649,20 @@ class AlgorithmSelectorCache(PersistentCache):
V.graph.sizevars.size_hints(
input_node.get_size(),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
V.graph.sizevars.size_hints(
input_node.get_stride(),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
V.graph.sizevars.size_hint(
input_node.get_layout().offset,
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
)
)
for input_node in input_nodes
]
out = cls.benchmark_example_value(layout, hint_override=hint_override)
out = cls.benchmark_example_value(layout)
out_extern = torch.as_strided(
out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
)
@ -2821,11 +2771,8 @@ class AlgorithmSelectorCache(PersistentCache):
input_nodes: list[ir.IRNode],
layout: ir.Layout,
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
hint_override: Optional[int] = None,
) -> dict[ChoiceCaller, float]:
inputs = cls.get_inputs(
choices, input_nodes, layout, input_gen_fns, hint_override=hint_override
)
inputs = cls.get_inputs(choices, input_nodes, layout, input_gen_fns)
return cls.benchmark_choices(choices, inputs)
@classmethod
@ -2835,7 +2782,6 @@ class AlgorithmSelectorCache(PersistentCache):
input_nodes: list[ir.IRNode],
layout: ir.Layout,
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
hint_override: Optional[int] = None,
):
from . import autotune_process
@ -2845,7 +2791,7 @@ class AlgorithmSelectorCache(PersistentCache):
triton = [c for c in choices if not isinstance(c, ExternKernelCaller)]
timings = cls.benchmark_in_current_process(
extern, input_nodes, layout, input_gen_fns, hint_override=hint_override
extern, input_nodes, layout, input_gen_fns
)
timings.update(autotune_process.benchmark_in_sub_process(triton)) # type: ignore[arg-type]
return timings
@ -2857,7 +2803,6 @@ class AlgorithmSelectorCache(PersistentCache):
input_nodes: list[ir.IRNode],
layout: ir.Layout,
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
hint_override: Optional[int] = None,
):
if DEBUG:
print(f"{len(choices)} tuning requests:")
@ -2868,7 +2813,6 @@ class AlgorithmSelectorCache(PersistentCache):
input_nodes=input_nodes,
layout=layout,
input_gen_fns=input_gen_fns,
hint_override=hint_override,
)
else:
return functools.partial(
@ -2876,7 +2820,6 @@ class AlgorithmSelectorCache(PersistentCache):
input_nodes=input_nodes,
layout=layout,
input_gen_fns=input_gen_fns,
hint_override=hint_override,
)
@staticmethod
@ -3035,7 +2978,6 @@ class AlgorithmSelectorCache(PersistentCache):
elapse: float,
precompile_elapse: float,
prescreening_elapse: Optional[float] = None,
hint_override: Optional[int] = None,
):
V.debug.log_autotuning_results(
name, input_nodes, timings, elapse, precompile_elapse
@ -3050,7 +2992,6 @@ class AlgorithmSelectorCache(PersistentCache):
V.graph.sizevars.size_hints(
n.get_size(),
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
hint_override=hint_override,
),
)
)
@ -3138,7 +3079,7 @@ class AlgorithmSelectorCache(PersistentCache):
)
@staticmethod
def benchmark_example_value(node, hint_override: Optional[int] = None):
def benchmark_example_value(node):
"""
Convert an ir.Buffer into a concrete torch.Tensor we can use for
benchmarking.
@ -3157,12 +3098,10 @@ class AlgorithmSelectorCache(PersistentCache):
V.graph.sizevars.size_hints(
node.get_size(),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
V.graph.sizevars.size_hints(
node.get_stride(),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
node.get_device(),
node.get_dtype(),
@ -3170,7 +3109,6 @@ class AlgorithmSelectorCache(PersistentCache):
V.graph.sizevars.size_hints(
V.graph.get_allocation_size(node),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
)