mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Preparatory refactor for https://github.com/pytorch/pytorch/pull/146942. # Feature This PR refactors the existing wrapper codegen into `WrapperLine` subclasses, extending the existing Memory Planning IR into a fully-fledged Wrapper IR. See the diagram below.  The IR currently supports the following ops: - All existing memory planning IR ops (`AllocateLine`, `FreeIfNotReusedLine`, etc.) - Reinterpret views (`ReinterpretLine`) - Kernel definitions (`KernelDefinitionLine`) - Calls to defined kernels (`KernelCallLine`) - Calls to extern kernels (`ExternKernelLine`, `ExternKernelAllocLine`) - Ops with multiple outputs (`MultiOutputLine`) - Tensor cleanup at the end of a graph (`FreeLine`) - Leaving comments in code (`CommentLine`) There are two main motivations for this refactor: 1. Unlike free-form C++ and and Python code, Wrapper IR lines provide structured information about what the wrapper code does. This serves as a natural extension point for other types of wrapper codegen. For example, the parent PR generates FX IR from Wrapper IR. Wrapper IR aims to give new backends enough information to generate wrapper code without needing to modify core Inductor files such as `ir.py`. 2. This design will hopefully promote stronger modularity and encapsulation. a. Inductor's core compilation passes don't need to worry about whether they're targeting Python, C++, FX or anything else. They can simply focus on generating Wrapper IR, and target-specific code can be refactored into the various backends. b. Backends do not need to know about all the details and internal state of `V.graph` IR. For example, they don't need to consider whether a buffer has been removed from the graph when generating code. Wrapper IR will hopefully provide a simpler interface for generating wrapper code, which abstracts away the details of device code. # Implementation details The implementation mainly consists of separating direct C++/Python codegen into two phases: 1. Emit Wrapper IR lines describing what the wrapper code is supposed to do. 2. Inside the `codegen()` method of each `WrapperLine`, call backend methods which generate pure Python/C++ code using the information stored in the Wrapper IR line. For example, `KernelCallLine` calls `wrapper._generate_kernel_call_helper`, which is overriden by the various Python and C++ backends to generate the final wrapper code. The main difficulty in implementing this is that we need to be careful that code is generated in the correct order. Wrapper codegen happens in two passes: first we write code into `self.lines` which mainly contains wrapper IR, but can also contain raw Python or C++ lines in some situations. Then, we convert the wrapper IR into the final Python/C++ code in `self.wrapper_call`. Since the same macros may be used in both passes, it's difficult to ensure that code is written to the correct buffer. The easiest solution for this was to implement a context manager overriding the `writeline` method to write to `self.wrapper_call` after memory planning is finished. This way, `writeline` writes to `self.lines` in the first pass, and `self.wrapper_call` in the second. This obviated the need to pass `code` or `writeline` variables all the way through the call stack, which would have touched most of the existing macros. # Test plan Since this refactor touches all the existing wrapper codegen classes, the existing CI provides good coverage. The parent PR introduces new tests for the FX IR backend. Among other things, these tests assert that `self.lines` only contains Wrapper IR lines, and no free-form code. While this would not be true of all programs today, the tests suggests that the IR implemented in this PR is sufficient to cover basic PyTorch usage. # Future directions These two goals are only partially realized by this PR. These are several important steps which still undergo direct Python/C++ codegen in core files: - User-defined Triton kernels. - Reinterpret views on outputs, from `gen_output_refs()`. (In the parent PR, the FX converter has a custom way of handling this. This can eventually be ported into Wrapper IR.) - Fallback ops with custom `codegen()` methods, e.g. `ScatterFallback`. - Misc. C++ lines emitted by the various cpp backends, e.g. declaring constants. These cases will gradually be handled in subsequent PRs, as the Inductor->FX converter expands its coverage. Given that these refactors are pretty tricky to do, it seems wiser to execute them in stages, as opposed to porting everything to Wrapper IR at once.Some Python and codegen still lives in core files such as `ir.py`, as described in previous sections. Hopefully, this PR will serve as a starting point which moves the codebase towards a more modular design. Over time, we can gradually refactor the remaining codegen (mainly in `ir.py`) into backend classes. One limitation of this PR is that codegen still happens in two phases during `PythonWrapperCodegen`. First, we generate Wrapper IR into `self.lines`, and from there we generate Python or C++ code into `self.wrapper_call`, `self.header`, etc. In the long term, it would be cleaner to split wrapper IR into its own class which doesn't deal with Python/C++ codegen at all. (See the diagram at the top.) That would strictly enforce the boundary between Wrapper IR and Python/C++ wrapper code. However, this would probably be a much larger refactor. Another limitation of the current code is that the helper functions have a lot of call args. It's also possible to clean this up by passing Wrapper IR ops e.g. `KernelCallLine` into helper functions like `_generate_kernel_call_helper`, since they store all the arguments. However, that change would likely be prone to merge conflicts, so I would like to save it for follow-up PRs if possible. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150458 Approved by: https://github.com/eellison
439 lines
15 KiB
Python
439 lines
15 KiB
Python
# mypy: allow-untyped-defs
|
|
import functools
|
|
import logging
|
|
import os
|
|
import pathlib
|
|
|
|
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
from .. import config
|
|
from ..codecache import code_hash, CodeCacheFuture, get_path, write_atomic
|
|
from ..runtime.benchmarking import benchmarker
|
|
from ..utils import cache_on_self, IndentedBuffer
|
|
from ..virtualized import V
|
|
from .common import TensorArg, WorkspaceArg
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def get_kernel_argdefs(kernel):
|
|
arg_defs, _, _, _ = kernel.args.python_argdefs()
|
|
return [x.name for x in arg_defs]
|
|
|
|
|
|
def _get_all_args(args_list, arg_types_list=None):
|
|
all_args = max(args_list, key=len)[:]
|
|
arg_types = max(arg_types_list, key=len)[:] if arg_types_list is not None else None
|
|
for args in args_list:
|
|
assert OrderedSet(args).issubset(OrderedSet(all_args)), (
|
|
f"{args} v.s. {all_args}"
|
|
)
|
|
|
|
return all_args, arg_types
|
|
|
|
|
|
def get_all_kernel_argdefs(kernels):
|
|
"""
|
|
The logic here must match with `get_all_call_args`, except no need to get arg_types here
|
|
"""
|
|
argdefs_list = [get_kernel_argdefs(kernel) for kernel in kernels]
|
|
|
|
return _get_all_args(argdefs_list)[0]
|
|
|
|
|
|
def get_all_call_args(call_args_list, arg_types_list):
|
|
"""
|
|
Passed in the call_args for each subkernel and return the call_args for the
|
|
combined multi-kernel.
|
|
|
|
Note an algorithm as follows does not always work:
|
|
```
|
|
all_call_args: Dict[
|
|
Any, None
|
|
] = {} # use a dict rather than set to maintain insertion order
|
|
for call_args in call_args_list:
|
|
all_call_args.update({arg: None for arg in call_args})
|
|
|
|
all_call_args = list(all_call_args.keys())
|
|
```
|
|
It will fail if any kernel has the same argument passed in multiple times.
|
|
Check test_pass_same_arg_multi_times in test_multi_kernel.py
|
|
|
|
Instead, we pick the longest call args and assert that other call args are
|
|
a subset of it.
|
|
"""
|
|
return _get_all_args(call_args_list, arg_types_list)
|
|
|
|
|
|
def get_numel_argdefs(kernel):
|
|
numel_argdefs = [
|
|
f"{tree.prefix}numel"
|
|
for tree in kernel.range_trees
|
|
if not tree.is_reduction or kernel.inside_reduction
|
|
]
|
|
|
|
return numel_argdefs
|
|
|
|
|
|
class MultiKernelState:
|
|
"""
|
|
Maintain state of multi-kernel compilation so we don't define duplicated
|
|
multi-kernel for the same set of sub-kernels.
|
|
|
|
V.graph.wrapper_code has a reference to MultiKernelState instance.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.subkernel_to_kernel_name = {}
|
|
self.kernel_defs = IndentedBuffer()
|
|
|
|
def define_kernel(self, kernels):
|
|
"""
|
|
Previously we name the multi kernel as "multi_kernel_{kernel_names[0]}".
|
|
This has some minor issue.
|
|
|
|
E.g. for persistent reduction https://gist.github.com/shunting314/39e7c00ff8bb2055942ed5a3255d61ca ,
|
|
there are 2 flavors of non-persistent reduction:
|
|
https://gist.github.com/shunting314/056d43d35907e87efb883970b35c17d4
|
|
and
|
|
https://gist.github.com/shunting314/02ee753b65c513c54e695626afe682bd
|
|
|
|
The only different is cache eviction policy.
|
|
|
|
We should name the multi-kernel differently in these 2 cases.
|
|
"""
|
|
kernel_names = tuple(k.kernel_name for k in kernels)
|
|
if kernel_names in self.subkernel_to_kernel_name:
|
|
return self.subkernel_to_kernel_name[kernel_names]
|
|
|
|
# name the multi kernel based on the first kernel
|
|
multi_kernel_name = f"multi_kernel_{len(self.subkernel_to_kernel_name)}"
|
|
self.subkernel_to_kernel_name[kernel_names] = multi_kernel_name
|
|
|
|
if V.graph.cpp_wrapper and not config.triton.autotune_at_compile_time:
|
|
# we should not generate any python code for multi-kernel during
|
|
# the second pass of cpp-wrapper.
|
|
return multi_kernel_name
|
|
|
|
buf = self.kernel_defs
|
|
buf.writeline("")
|
|
buf.writeline(
|
|
f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, ["
|
|
)
|
|
with buf.indent():
|
|
for name in kernel_names:
|
|
buf.writeline(f"{name},")
|
|
buf.writeline("])")
|
|
|
|
if config.triton.autotune_at_compile_time:
|
|
V.graph.wrapper_code.src_to_kernel["\n".join(kernel_names)] = (
|
|
multi_kernel_name
|
|
)
|
|
|
|
return multi_kernel_name
|
|
|
|
|
|
class MultiKernel:
|
|
"""
|
|
This class maintains the compile time state for multi kernels.
|
|
|
|
Assume we do codegen for a MultiKernel encapsulating kernel1 and kernel2.
|
|
The generated definition for the multi-kernel will looks like:
|
|
```
|
|
multi_kernel_kernel1 = MultiKernelCall(
|
|
[kernel1, kernel2], multi_kernel_definition_code
|
|
)
|
|
```
|
|
|
|
Here is an concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39
|
|
"""
|
|
|
|
def __init__(self, kernels):
|
|
assert len(kernels) >= 2
|
|
|
|
self.kernels = kernels
|
|
self.kernel_name = V.graph.wrapper_code.multi_kernel_state.define_kernel(
|
|
kernels
|
|
)
|
|
|
|
# need this since some code in inductor check if the kernel object has an args
|
|
# attribute to decide if it's a non-null kernel.
|
|
self.args = object()
|
|
|
|
@staticmethod
|
|
def _merge_workspace_args(left: list[WorkspaceArg], right: list[WorkspaceArg]):
|
|
if left == right:
|
|
return left
|
|
result = {x.inner_name: x for x in left}
|
|
for arg in right:
|
|
if arg.inner_name in result:
|
|
result[arg.inner_name] = WorkspaceArg.maximum(
|
|
result[arg.inner_name], arg
|
|
)
|
|
else:
|
|
result[arg.inner_name] = arg
|
|
return [*result.values()]
|
|
|
|
@staticmethod
|
|
def merge_workspaces_inplace(kernels):
|
|
if len(kernels) < 2:
|
|
return
|
|
# All kernels must share the same workspace
|
|
workspace_args = functools.reduce(
|
|
MultiKernel._merge_workspace_args,
|
|
[kernel.args.workspace_args for kernel in kernels],
|
|
)
|
|
for kernel in kernels:
|
|
kernel.args.workspace_args = workspace_args
|
|
return workspace_args
|
|
|
|
def call_kernel(self, kernel_name):
|
|
"""
|
|
Collect the union of arguments from all subkernels as the arguments
|
|
for the multi-kernel.
|
|
"""
|
|
assert kernel_name == self.kernel_name
|
|
V.graph.wrapper_code.write_triton_header_once()
|
|
_, call_args, _, arg_types = self.kernels[0].args.python_argdefs()
|
|
for kernel in self.kernels[1:]:
|
|
_, other_call_args, _, other_arg_types = kernel.args.python_argdefs()
|
|
assert call_args == other_call_args, (call_args, other_call_args)
|
|
assert arg_types == other_arg_types
|
|
|
|
if V.graph.cpp_wrapper and not config.triton.autotune_at_compile_time:
|
|
# for the second pass of cpp-wrapper codegen, we should call
|
|
# the fast kernel directly
|
|
kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
|
|
|
|
# numels for all subkernels should be the same. Use kernels[0] here
|
|
self.kernels[0].add_numel_to_call_args(kernel_name, call_args, arg_types)
|
|
|
|
for ws in self.kernels[0].args.workspace_args:
|
|
V.graph.wrapper_code.generate_workspace_allocation(ws)
|
|
|
|
V.graph.wrapper_code.generate_kernel_call(
|
|
kernel_name,
|
|
call_args,
|
|
arg_types=arg_types,
|
|
)
|
|
|
|
for ws in reversed(self.kernels[0].args.workspace_args):
|
|
V.graph.wrapper_code.generate_workspace_deallocation(ws)
|
|
|
|
def codegen_nan_check(self):
|
|
wrapper = V.graph.wrapper_code
|
|
seen = OrderedSet[str]()
|
|
for k in self.kernels:
|
|
_, call_args, precompile_args, _ = k.args.python_argdefs()
|
|
for arg, precompile_arg in zip(call_args, precompile_args):
|
|
if arg in seen:
|
|
continue
|
|
seen.add(arg)
|
|
if isinstance(precompile_arg, TensorArg):
|
|
line = f"assert not {arg}.isnan().any().item()"
|
|
wrapper.writeline(line)
|
|
line = f"assert not {arg}.isinf().any().item()"
|
|
wrapper.writeline(line)
|
|
|
|
@property
|
|
def removed_buffers(self):
|
|
return OrderedSet.intersection(*[k.removed_buffers for k in self.kernels])
|
|
|
|
@property
|
|
def inplaced_to_remove(self):
|
|
return OrderedSet.intersection(*[k.inplaced_to_remove for k in self.kernels])
|
|
|
|
@property
|
|
@cache_on_self
|
|
def inplace_update_buffers(self):
|
|
"""
|
|
Make sure all kernels have the same inplace update mappings.
|
|
"""
|
|
for k in self.kernels[1:]:
|
|
assert k.inplace_update_buffers == self.kernels[0].inplace_update_buffers
|
|
return self.kernels[0].inplace_update_buffers
|
|
|
|
def warn_mix_layout(self, kernel_name: str):
|
|
pass
|
|
|
|
|
|
class MultiKernelCall:
|
|
"""
|
|
This class is called at run time to actually run the kernel
|
|
"""
|
|
|
|
def __init__(self, multi_kernel_name, kernels):
|
|
assert len(kernels) >= 2
|
|
self._kernels = kernels
|
|
self.multi_kernel_name = multi_kernel_name
|
|
|
|
self.disable_cache = os.environ.get(
|
|
"TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE"
|
|
) == "1" or is_metric_table_enabled("persistent_red_perf")
|
|
|
|
self.picked_kernel = None
|
|
if config.triton.multi_kernel > 1:
|
|
# manually force a subkernel to ease perf testing
|
|
picked_by_config = config.triton.multi_kernel - 2
|
|
assert picked_by_config < len(self._kernels)
|
|
self.picked_kernel = picked_by_config
|
|
elif not self.disable_cache:
|
|
self.load_cache()
|
|
|
|
self._recorded = False
|
|
|
|
def cache_file_path(self):
|
|
key = code_hash(
|
|
",".join(
|
|
[
|
|
f"{k.fn.cache_key}{k.size_hints!r}{k.triton_meta!r}"
|
|
for k in self.kernels
|
|
]
|
|
)
|
|
)
|
|
_, _, path = get_path(key, "picked_kernel")
|
|
return pathlib.Path(path)
|
|
|
|
def load_cache(self):
|
|
assert self.picked_kernel is None
|
|
path = self.cache_file_path()
|
|
if path.exists():
|
|
with path.open() as fd:
|
|
self.picked_kernel = int(fd.read())
|
|
assert self.picked_kernel >= 0 and self.picked_kernel < len(
|
|
self._kernels
|
|
)
|
|
log.debug(
|
|
"Load picked kernel %d from cache file %s", self.picked_kernel, path
|
|
)
|
|
|
|
def store_cache(self):
|
|
assert self.picked_kernel is not None
|
|
path = self.cache_file_path()
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
write_atomic(path, str(self.picked_kernel))
|
|
log.debug("Store picked kernel %d to cache file %s", self.picked_kernel, path)
|
|
|
|
@property
|
|
def kernels(self):
|
|
"""
|
|
Read results from future.
|
|
|
|
This should be called after parallel compilation is done.
|
|
In case you call this before compilation is done,
|
|
it may slow down the parallel compilation.
|
|
"""
|
|
for i, kernel in enumerate(self._kernels):
|
|
if isinstance(kernel, CodeCacheFuture):
|
|
self._kernels[i] = kernel.result()
|
|
|
|
return self._kernels
|
|
|
|
def benchmark_sub_kernels(self, *args, **kwargs):
|
|
"""
|
|
Benchmark all the sub kernels and return the execution time
|
|
(in milliseconds) for each of time.
|
|
|
|
Unit test may mock this method to force a specific kernel to
|
|
be picked.
|
|
"""
|
|
|
|
def wrap_fn(kernel):
|
|
def inner():
|
|
args_clone, kwargs_clone = kernel.clone_args(*args, **kwargs)
|
|
return kernel.run(*args_clone, **kwargs_clone)
|
|
|
|
return inner
|
|
|
|
return [
|
|
benchmarker.benchmark_gpu(wrap_fn(kernel), rep=40)
|
|
for kernel in self.kernels
|
|
]
|
|
|
|
# record_choice and lookup_choice are helper functions for cpp-wrapper
|
|
# codegen. The first pass use record_choice to keep the choice and
|
|
# the second pass do lookup by calling lookup_choice.
|
|
#
|
|
# An alternative that reused the multi-kernel cache does not work well
|
|
# since during codegen of the second pass, it's very hard to know the
|
|
# path for the cache file. Also reading the cache file need do some IO
|
|
# which can be slower.
|
|
@staticmethod
|
|
def record_choice(multi_kernel_name: str, picked_kernel_name: str):
|
|
"""
|
|
Record the multi-kernel choice for cpp-wrapper after autotuning
|
|
|
|
We should do nothing if this function is not called during codegen.
|
|
"""
|
|
from torch._inductor.graph import GraphLowering
|
|
|
|
if not isinstance(V.graph, GraphLowering):
|
|
return
|
|
|
|
if not V.graph.record_multi_kernel_choice:
|
|
return
|
|
|
|
V.graph.multi_kernel_to_choice[multi_kernel_name] = picked_kernel_name
|
|
|
|
@staticmethod
|
|
def lookup_choice(multi_kernel_name: str) -> str:
|
|
# this should always been done during cpp-wrapper codegen
|
|
assert (
|
|
V.graph.record_multi_kernel_choice
|
|
and multi_kernel_name in V.graph.multi_kernel_to_choice
|
|
)
|
|
# there should be no miss
|
|
return V.graph.multi_kernel_to_choice[multi_kernel_name]
|
|
|
|
def run(self, *args, **kwargs):
|
|
if self.picked_kernel is None:
|
|
timings = self.benchmark_sub_kernels(*args, **kwargs)
|
|
self.picked_kernel = timings.index(min(timings))
|
|
k0 = self.kernels[0]
|
|
log.debug(
|
|
"pick %dth sub-kernel in %s. Size hints %s. Reduction hint %s. Timings %s",
|
|
self.picked_kernel,
|
|
[k.inductor_meta.get("kernel_name") for k in self.kernels],
|
|
k0.size_hints,
|
|
k0.inductor_meta.get("reduction_hint"),
|
|
timings,
|
|
)
|
|
get_metric_table("persistent_red_perf").add_row(
|
|
functools.partial(self._metrics_table_row, timings)
|
|
)
|
|
if not self.disable_cache:
|
|
self.store_cache()
|
|
|
|
if not self._recorded:
|
|
self._recorded = True
|
|
picked_kernel_name = self.kernels[self.picked_kernel].inductor_meta.get(
|
|
"kernel_name"
|
|
)
|
|
assert picked_kernel_name is not None
|
|
self.record_choice(self.multi_kernel_name, picked_kernel_name)
|
|
self.run = self.kernels[self.picked_kernel].run # type: ignore[method-assign]
|
|
self.run(*args, **kwargs)
|
|
|
|
def _metrics_table_row(self, timings):
|
|
def get_kernel_path(k):
|
|
return k.fn.fn.__code__.co_filename
|
|
|
|
k0 = self.kernels[0]
|
|
row = {
|
|
"size_hints": k0.size_hints,
|
|
"reduction_hint": k0.inductor_meta.get("reduction_hint"),
|
|
}
|
|
max_kernels = 4
|
|
assert len(timings) <= max_kernels
|
|
for i in range(max_kernels):
|
|
if i < len(self.kernels):
|
|
row[f"kernel{i}_path"] = get_kernel_path(self.kernels[i])
|
|
row[f"kernel{i}_latency"] = timings[i]
|
|
else:
|
|
row[f"kernel{i}_path"] = ""
|
|
row[f"kernel{i}_latency"] = ""
|
|
return row
|