mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "multi-kernel matmuls based on varying hint sizes (#156628)"
This reverts commit 6c795306378c47341d58109da03371bba2bec46e. Reverted https://github.com/pytorch/pytorch/pull/156628 on behalf of https://github.com/huydhn due to Sorry for reverting your change but some ROCM jobs went crazy after this lands, so I try to see if reverting helps ([comment](https://github.com/pytorch/pytorch/pull/156628#issuecomment-3064617123))
This commit is contained in:
@ -294,15 +294,6 @@ RecordedEventsType = list[tuple[str, list[Any], dict[str, Any]]]
|
||||
|
||||
|
||||
class TritonTemplateKernel(TritonKernel):
|
||||
"""
|
||||
A specialized kernel class for Triton templates that handles code generation
|
||||
for templated Triton kernels.
|
||||
|
||||
This class extends TritonKernel to provide additional functionality for
|
||||
template-based kernel generation, including support for subgraphs, workspace
|
||||
arguments, and prologue/epilogue fusion.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name,
|
||||
@ -323,7 +314,6 @@ class TritonTemplateKernel(TritonKernel):
|
||||
subgraphs: Optional[list[ir.ComputedBuffer]] = None,
|
||||
workspace_arg: Optional[WorkspaceArg] = None,
|
||||
prologue_loads_all_inputs=False,
|
||||
hint_override: Optional[int] = None,
|
||||
) -> None:
|
||||
numel = sympy_product(output_node.get_size())
|
||||
super().__init__(
|
||||
@ -332,7 +322,6 @@ class TritonTemplateKernel(TritonKernel):
|
||||
"r0_": sympy.S.One,
|
||||
},
|
||||
features=SIMDKernelFeatures([], numel),
|
||||
hint_override=hint_override,
|
||||
)
|
||||
self.input_nodes = input_nodes
|
||||
self.output_node = output_node
|
||||
@ -1107,26 +1096,16 @@ class TritonTemplateKernel(TritonKernel):
|
||||
def codegen_range_tree(self):
|
||||
pass # ignore default codegen
|
||||
|
||||
def additional_call_args_and_types(self):
|
||||
if isinstance(self.grid_fn, SymbolicGridFn):
|
||||
grid_args = self.grid_fn.sympy_call(*self.call_sizes, self.meta)
|
||||
assert len(grid_args) in (0, 3), "grid_fn should return 3 values"
|
||||
return (grid_args, map(type, grid_args))
|
||||
elif all(isinstance(x, (int, sympy.Integer)) for x in self.call_sizes):
|
||||
grid_args = self.grid_fn(*map(int, self.call_sizes), self.meta)
|
||||
assert len(grid_args) in (0, 3), "grid_fn should return 3 values"
|
||||
return (grid_args, map(type, grid_args))
|
||||
return ((), ())
|
||||
|
||||
def call_kernel(self, name: str, node: Optional[ir.IRNode] = None):
|
||||
wrapper = V.graph.wrapper_code
|
||||
_, call_args, _, arg_types = self.args.python_argdefs()
|
||||
|
||||
additional_call_args, additional_arg_types = (
|
||||
self.additional_call_args_and_types()
|
||||
)
|
||||
|
||||
if not additional_call_args:
|
||||
grid_args = ()
|
||||
if isinstance(self.grid_fn, SymbolicGridFn):
|
||||
grid_args = self.grid_fn.sympy_call(*self.call_sizes, self.meta)
|
||||
elif all(isinstance(x, (int, sympy.Integer)) for x in self.call_sizes):
|
||||
grid_args = self.grid_fn(*map(int, self.call_sizes), self.meta)
|
||||
else:
|
||||
assert not V.graph.cpp_wrapper, "cpp_wrapper requires SymbolicGridFn"
|
||||
wrapper.add_import_once(f"import {self.grid_fn.__module__}")
|
||||
meta = wrapper.add_meta_once(self.meta)
|
||||
@ -1135,9 +1114,9 @@ class TritonTemplateKernel(TritonKernel):
|
||||
f"*{fn_name}({', '.join(map(pexpr, self.call_sizes))}, {meta})"
|
||||
)
|
||||
arg_types.append(None)
|
||||
|
||||
call_args.extend(additional_call_args)
|
||||
arg_types.extend(additional_arg_types)
|
||||
assert len(grid_args) in (0, 3), "grid_fn should return 3 values"
|
||||
call_args.extend(grid_args)
|
||||
arg_types.extend(map(type, grid_args))
|
||||
|
||||
if self.workspace_arg is not None:
|
||||
wrapper.generate_workspace_allocation(self.workspace_arg)
|
||||
@ -1223,7 +1202,6 @@ class GeneratedCodeCache:
|
||||
num_consumer_groups: int,
|
||||
num_buffers_warp_spec: int,
|
||||
kwargs: dict[str, Any],
|
||||
hint_override: Optional[int] = None,
|
||||
) -> Optional[str]:
|
||||
def layout_key(layout: ir.Layout) -> str:
|
||||
assert not isinstance(layout, ir.FlexibleLayout)
|
||||
@ -1274,7 +1252,6 @@ class GeneratedCodeCache:
|
||||
"num_buffers_warp_spec": num_buffers_warp_spec,
|
||||
"epilogue_fn_hash": epilogue_fn_hash,
|
||||
"kwargs": kwargs,
|
||||
"hint_override": hint_override,
|
||||
}
|
||||
)
|
||||
|
||||
@ -1379,7 +1356,6 @@ class TritonTemplate(KernelTemplate):
|
||||
layout: ir.Layout,
|
||||
kwargs: dict[str, Any],
|
||||
generate_with_caching,
|
||||
hint_override: Optional[int] = None,
|
||||
) -> Optional[GenerateAndLoadResult]:
|
||||
"""Generate the python code and load it into the current process"""
|
||||
caching_enabled = (
|
||||
@ -1452,7 +1428,6 @@ class TritonTemplate(KernelTemplate):
|
||||
output_node=fake_out,
|
||||
workspace_arg=workspace_arg,
|
||||
use_jit=False,
|
||||
hint_override=hint_override,
|
||||
**kernel_options,
|
||||
)
|
||||
|
||||
@ -1572,7 +1547,6 @@ class TritonTemplate(KernelTemplate):
|
||||
call_sizes: Optional[Sequence[sympy.core.symbol.Symbol]] = None,
|
||||
workspace_arg: Optional[WorkspaceArg] = None,
|
||||
generate_with_caching=False,
|
||||
hint_override: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""This function generates a TritonTemplateCaller
|
||||
@ -1617,7 +1591,6 @@ class TritonTemplate(KernelTemplate):
|
||||
layout,
|
||||
kwargs,
|
||||
generate_with_caching and self._cache_codegen_enabled_for_template,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
# May happen as result of dev by 0.
|
||||
@ -1639,7 +1612,6 @@ class TritonTemplate(KernelTemplate):
|
||||
extra_args = V.graph.sizevars.size_hints(
|
||||
map(sympy.expand, result.kernel_args_sizevars_keys),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}"
|
||||
@ -1663,14 +1635,13 @@ class TritonTemplate(KernelTemplate):
|
||||
|
||||
options = result.kernel_options
|
||||
|
||||
def make_kernel_render(out_node, hint_override: Optional[int] = None):
|
||||
def make_kernel_render(out_node):
|
||||
assert result is not None
|
||||
kernel = self.kernel_type(
|
||||
kernel_name=str(Placeholder.KERNEL_NAME),
|
||||
output_node=out_node,
|
||||
workspace_arg=workspace_arg,
|
||||
use_jit=False,
|
||||
hint_override=hint_override,
|
||||
**options,
|
||||
)
|
||||
render = functools.partial(
|
||||
@ -1686,7 +1657,6 @@ class TritonTemplate(KernelTemplate):
|
||||
*V.graph.sizevars.size_hints(
|
||||
call_sizes,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
kwargs,
|
||||
)
|
||||
@ -1738,7 +1708,6 @@ class TritonTemplate(KernelTemplate):
|
||||
mutated_inputs=mutated_inputs,
|
||||
workspace_arg=workspace_arg,
|
||||
allowed_prologue_inps=result.prologue_supported_inputs,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
|
||||
@ -1814,7 +1783,6 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
|
||||
mutated_inputs=None,
|
||||
workspace_arg: Optional[WorkspaceArg] = None,
|
||||
allowed_prologue_inps: Optional[OrderedSet[str]] = None,
|
||||
hint_override: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(name, input_nodes, layout, description)
|
||||
self.make_kernel_render = make_kernel_render
|
||||
@ -1834,7 +1802,6 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
|
||||
self.allowed_prologue_inps = (
|
||||
allowed_prologue_inps if allowed_prologue_inps is not None else OrderedSet()
|
||||
)
|
||||
self.hint_override = hint_override
|
||||
|
||||
def benchmark(self, *args, out):
|
||||
assert self.bmreq is not None
|
||||
@ -2289,7 +2256,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
# TODO(nmacchioni): remove this hacky way to tell if we ran benchmarking
|
||||
has_autotuned = False
|
||||
|
||||
def benchmark(choices, hint_override: Optional[int] = None):
|
||||
def benchmark(choices):
|
||||
nonlocal has_autotuned
|
||||
# TODO(nmacchioni): remove this hacky way to tell if we ran benchmarking
|
||||
has_autotuned = True
|
||||
@ -2297,13 +2264,13 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
# TODO(nmacchioni): remove this layer of abstraction
|
||||
# construct `benchmark_fn` which should pick between in-process and sub-process autotuning
|
||||
benchmark_fn = self.make_benchmark_fn(
|
||||
choices, input_nodes, layout, input_gen_fns, hint_override=hint_override
|
||||
choices, input_nodes, layout, input_gen_fns
|
||||
)
|
||||
# `benchmark_fn(choices)` will execute each choice, and return a dict[choice, timing] which
|
||||
# maps each choice to its runtime, calculated by the specified benchmarker, in milliseconds
|
||||
return benchmark_fn(choices)
|
||||
|
||||
def autotune(choices, hint_override: Optional[int] = None):
|
||||
def autotune(choices):
|
||||
log.debug("Starting autotuning")
|
||||
|
||||
with dynamo_timed(
|
||||
@ -2325,13 +2292,13 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
),
|
||||
},
|
||||
):
|
||||
return benchmark(choices, hint_override=hint_override)
|
||||
return benchmark(choices)
|
||||
|
||||
if config.autotune_in_subproc:
|
||||
# Initialize the suprocess pool so it will warmup early.
|
||||
torch._inductor.autotune_process.get_tuning_process_pool()
|
||||
|
||||
def do_autotuning(choices, precompile_fn, hint_override: Optional[int] = None):
|
||||
def do_autotuning(choices, precompile_fn):
|
||||
precompile_start_ts = time.time()
|
||||
with dynamo_timed(
|
||||
f"{name}_template_precompiling",
|
||||
@ -2352,8 +2319,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
candidates,
|
||||
name,
|
||||
inputs_key,
|
||||
lambda choices: autotune(choices, hint_override=hint_override),
|
||||
hint_override=hint_override,
|
||||
autotune,
|
||||
)
|
||||
choices = self.prune_choices_postscreen(
|
||||
choices, timings, name, inputs_key, self.prescreening_cache
|
||||
@ -2366,8 +2332,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
choices,
|
||||
name,
|
||||
inputs_key,
|
||||
lambda choices: autotune(choices, hint_override=hint_override),
|
||||
hint_override=hint_override,
|
||||
autotune,
|
||||
)
|
||||
|
||||
autotune_elapse = time.time() - autotune_start_ts
|
||||
@ -2390,7 +2355,6 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
autotune_elapse,
|
||||
precompile_elapse,
|
||||
prescreening_elapse,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
def profiler_bench_function():
|
||||
@ -2425,16 +2389,8 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
|
||||
if return_multi_template and (config.max_autotune or config.max_autotune_gemm):
|
||||
|
||||
def get_timings(hint_override: Optional[int] = None):
|
||||
filtered_choices = [
|
||||
c
|
||||
for c in choices
|
||||
if not hasattr(c, "hint_override")
|
||||
or c.hint_override == hint_override
|
||||
]
|
||||
timings = do_autotuning(
|
||||
filtered_choices, precompile_fn, hint_override=hint_override
|
||||
)
|
||||
def get_timings():
|
||||
timings = do_autotuning(choices, precompile_fn)
|
||||
min_extern_choice = float("inf")
|
||||
for choice, timing in timings.items():
|
||||
if isinstance(choice, ExternKernelCaller):
|
||||
@ -2671,7 +2627,6 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
input_nodes: list[ir.IRNode],
|
||||
layout: ir.Layout,
|
||||
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
|
||||
hint_override: Optional[int] = None,
|
||||
) -> AutotuneArgs:
|
||||
"""
|
||||
Factory method to create AutotuneArgs from a list of ChoiceCallers.
|
||||
@ -2681,9 +2636,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
|
||||
# de-duplicate args
|
||||
unique_example_inputs = {
|
||||
x.get_name(): input_gen_fns.get(
|
||||
i, lambda x: cls.benchmark_example_value(x, hint_override=hint_override)
|
||||
)(x)
|
||||
x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x)
|
||||
for i, x in enumerate(input_nodes)
|
||||
}
|
||||
example_inputs = list(unique_example_inputs.values())
|
||||
@ -2696,23 +2649,20 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
V.graph.sizevars.size_hints(
|
||||
input_node.get_size(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
V.graph.sizevars.size_hints(
|
||||
input_node.get_stride(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
V.graph.sizevars.size_hint(
|
||||
input_node.get_layout().offset,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
)
|
||||
)
|
||||
for input_node in input_nodes
|
||||
]
|
||||
out = cls.benchmark_example_value(layout, hint_override=hint_override)
|
||||
out = cls.benchmark_example_value(layout)
|
||||
out_extern = torch.as_strided(
|
||||
out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
|
||||
)
|
||||
@ -2821,11 +2771,8 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
input_nodes: list[ir.IRNode],
|
||||
layout: ir.Layout,
|
||||
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
|
||||
hint_override: Optional[int] = None,
|
||||
) -> dict[ChoiceCaller, float]:
|
||||
inputs = cls.get_inputs(
|
||||
choices, input_nodes, layout, input_gen_fns, hint_override=hint_override
|
||||
)
|
||||
inputs = cls.get_inputs(choices, input_nodes, layout, input_gen_fns)
|
||||
return cls.benchmark_choices(choices, inputs)
|
||||
|
||||
@classmethod
|
||||
@ -2835,7 +2782,6 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
input_nodes: list[ir.IRNode],
|
||||
layout: ir.Layout,
|
||||
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
|
||||
hint_override: Optional[int] = None,
|
||||
):
|
||||
from . import autotune_process
|
||||
|
||||
@ -2845,7 +2791,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
triton = [c for c in choices if not isinstance(c, ExternKernelCaller)]
|
||||
|
||||
timings = cls.benchmark_in_current_process(
|
||||
extern, input_nodes, layout, input_gen_fns, hint_override=hint_override
|
||||
extern, input_nodes, layout, input_gen_fns
|
||||
)
|
||||
timings.update(autotune_process.benchmark_in_sub_process(triton)) # type: ignore[arg-type]
|
||||
return timings
|
||||
@ -2857,7 +2803,6 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
input_nodes: list[ir.IRNode],
|
||||
layout: ir.Layout,
|
||||
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
|
||||
hint_override: Optional[int] = None,
|
||||
):
|
||||
if DEBUG:
|
||||
print(f"{len(choices)} tuning requests:")
|
||||
@ -2868,7 +2813,6 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
input_gen_fns=input_gen_fns,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
else:
|
||||
return functools.partial(
|
||||
@ -2876,7 +2820,6 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
input_gen_fns=input_gen_fns,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -3035,7 +2978,6 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
elapse: float,
|
||||
precompile_elapse: float,
|
||||
prescreening_elapse: Optional[float] = None,
|
||||
hint_override: Optional[int] = None,
|
||||
):
|
||||
V.debug.log_autotuning_results(
|
||||
name, input_nodes, timings, elapse, precompile_elapse
|
||||
@ -3050,7 +2992,6 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
V.graph.sizevars.size_hints(
|
||||
n.get_size(),
|
||||
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||
hint_override=hint_override,
|
||||
),
|
||||
)
|
||||
)
|
||||
@ -3138,7 +3079,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def benchmark_example_value(node, hint_override: Optional[int] = None):
|
||||
def benchmark_example_value(node):
|
||||
"""
|
||||
Convert an ir.Buffer into a concrete torch.Tensor we can use for
|
||||
benchmarking.
|
||||
@ -3157,12 +3098,10 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
V.graph.sizevars.size_hints(
|
||||
node.get_size(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
V.graph.sizevars.size_hints(
|
||||
node.get_stride(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
node.get_device(),
|
||||
node.get_dtype(),
|
||||
@ -3170,7 +3109,6 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
V.graph.sizevars.size_hints(
|
||||
V.graph.get_allocation_size(node),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user