mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 18:05:13 +08:00
init
This commit is contained in:
@ -19,6 +19,9 @@ from .common import TensorArg, WorkspaceArg
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
FULL_BENCHMARK = False
|
||||
|
||||
|
||||
class MultiKernelState:
|
||||
"""
|
||||
Maintain state of multi-kernel compilation so we don't define duplicated
|
||||
@ -31,7 +34,7 @@ class MultiKernelState:
|
||||
self.subkernel_to_kernel_name = {}
|
||||
self.kernel_defs = IndentedBuffer()
|
||||
|
||||
def define_kernel(self, kernels):
|
||||
def define_kernel(self, kernels, kernel_shape_keys=None):
|
||||
"""
|
||||
Previously we name the multi kernel as "multi_kernel_{kernel_names[0]}".
|
||||
This has some minor issue.
|
||||
@ -85,7 +88,7 @@ class MultiKernelState:
|
||||
for i in range(len(kernels)):
|
||||
arg_index[i] = [slice(0, len(call_args))]
|
||||
|
||||
shape_specialize = isinstance(kernels[0], TritonTemplateKernel)
|
||||
shape_specialize = kernel_shape_keys is not None
|
||||
buf = self.kernel_defs
|
||||
buf.writeline("")
|
||||
buf.writeline("arg_index = {")
|
||||
@ -93,13 +96,25 @@ class MultiKernelState:
|
||||
slice_reprs = ", ".join(repr(s) for s in slice_list)
|
||||
buf.writeline(f" {key}: [{slice_reprs}],")
|
||||
buf.writeline("}")
|
||||
buf.writeline(
|
||||
f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, ["
|
||||
)
|
||||
with buf.indent():
|
||||
for name in kernel_names:
|
||||
buf.writeline(f"{name},")
|
||||
buf.writeline(f"], arg_index=arg_index, shape_specialize={shape_specialize})")
|
||||
|
||||
if not shape_specialize: # no size hint keys, just call with list of kernels
|
||||
buf.writeline(
|
||||
f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, ["
|
||||
)
|
||||
with buf.indent():
|
||||
for name in kernel_names:
|
||||
buf.writeline(f"{name},")
|
||||
buf.writeline(f"], arg_index=arg_index, shape_specialize=False)")
|
||||
else: # call with dict[size hint key, kernel]
|
||||
assert isinstance(kernels[0], TritonTemplateKernel)
|
||||
assert len(kernels) == len(kernel_shape_keys)
|
||||
buf.writeline(
|
||||
f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, {{"
|
||||
)
|
||||
with buf.indent():
|
||||
for shape_key, name in zip(kernel_shape_keys, kernel_names):
|
||||
buf.writeline(f"{shape_key}: {name},")
|
||||
buf.writeline(f"}}, arg_index=arg_index, shape_specialize=True)")
|
||||
|
||||
if config.triton.autotune_at_compile_time:
|
||||
V.graph.wrapper_code.src_to_kernel["\n".join(kernel_names)] = (
|
||||
@ -127,9 +142,18 @@ class MultiKernel:
|
||||
def __init__(self, kernels):
|
||||
assert len(kernels) >= 2
|
||||
|
||||
self.kernels = kernels
|
||||
self.kernels = []
|
||||
self.kernel_shape_keys = None
|
||||
if isinstance(kernels, dict):
|
||||
self.kernel_shape_keys = []
|
||||
for shape_key, kernel in kernels.items():
|
||||
self.kernels.append(kernel)
|
||||
self.kernel_shape_keys.append(shape_key)
|
||||
else:
|
||||
self.kernels = kernels
|
||||
|
||||
self.kernel_name = V.graph.wrapper_code.multi_kernel_state.define_kernel(
|
||||
kernels
|
||||
self.kernels, self.kernel_shape_keys
|
||||
)
|
||||
|
||||
# need this since some code in inductor check if the kernel object has an args
|
||||
@ -268,7 +292,21 @@ class MultiKernelCall:
|
||||
|
||||
def __init__(self, multi_kernel_name, kernels, arg_index, shape_specialize=False):
|
||||
assert len(kernels) >= 2
|
||||
self._kernels = kernels
|
||||
|
||||
if isinstance(kernels, dict):
|
||||
# This means for each unique shape we will do a separate assessment
|
||||
# for which kernel is the best. This is particularly useful for matmul
|
||||
# kernels where the best kernel can vary based on very small differences
|
||||
# in shape.
|
||||
self._shape_specialize = True
|
||||
self._shape_cache = {shape_key: i for i, shape_key in enumerate(kernels.keys()) if shape_key is not None}
|
||||
self._kernels = list(kernels.values())
|
||||
self._kernel_hints = list(kernels.keys())
|
||||
else:
|
||||
self._shape_specialize = False
|
||||
self._shape_cache = {}
|
||||
self._kernels = kernels
|
||||
self._kernel_hints = None
|
||||
self.multi_kernel_name = multi_kernel_name
|
||||
|
||||
self.disable_cache = os.environ.get(
|
||||
@ -287,13 +325,6 @@ class MultiKernelCall:
|
||||
|
||||
self._recorded = False
|
||||
|
||||
# This means for each unique shape we will do a separate assessment
|
||||
# for which kernel is the best. This is particularly useful for matmul
|
||||
# kernels where the best kernel can vary based on very small differences
|
||||
# in shape.
|
||||
self._shape_specialize = shape_specialize
|
||||
self._shape_cache = {}
|
||||
|
||||
def cache_file_path(self):
|
||||
key = code_hash(
|
||||
",".join(
|
||||
@ -418,7 +449,9 @@ class MultiKernelCall:
|
||||
if self._shape_specialize:
|
||||
cache_key = self._get_shape_cache_key(*args, **kwargs)
|
||||
cached_choice = self._get_cached_shape_choice(cache_key)
|
||||
print("running with cache key", cache_key)
|
||||
if cached_choice is not None:
|
||||
print("found cached choice", cached_choice)
|
||||
self.picked_kernel = cached_choice
|
||||
log.debug(
|
||||
"using cached shape-specialized choice %dth sub-kernel in %s. Cache key: %s",
|
||||
@ -482,14 +515,29 @@ class MultiKernelCall:
|
||||
"""
|
||||
self._shape_cache[cache_key] = kernel_idx
|
||||
|
||||
def _l1_dist(self, k1, k2):
|
||||
dist = 0
|
||||
for s1, s2 in zip(k1, k2):
|
||||
dist += sum(abs(x - y) for x, y in zip(s1, s2))
|
||||
return dist
|
||||
|
||||
def _select_kernel_by_shape(self, *args, **kwargs):
|
||||
"""
|
||||
Benchmark kernels for a particular shape and return the
|
||||
best kernel for this shape.
|
||||
"""
|
||||
assert self._shape_specialize and self._kernel_hints is not None
|
||||
global FULL_BENCHMARK
|
||||
|
||||
shape_key = self._get_shape_cache_key(*args, **kwargs)
|
||||
timings = self.benchmark_sub_kernels(*args, **kwargs)
|
||||
self.picked_kernel = timings.index(min(timings))
|
||||
if FULL_BENCHMARK:
|
||||
timings = self.benchmark_sub_kernels(*args, **kwargs)
|
||||
self.picked_kernel = timings.index(min(timings))
|
||||
else:
|
||||
dists = [self._l1_dist(shape_key, key) if key is not None else 2**64 for key in self._kernel_hints]
|
||||
self.picked_kernel = dists.index(min(dists))
|
||||
print(f"Selected kernel index {self.picked_kernel} for fresh key {shape_key} based on existing key {self._kernel_hints[self.picked_kernel]}")
|
||||
|
||||
self._cache_shape_choice(shape_key, self.picked_kernel)
|
||||
|
||||
def _metrics_table_row(self, timings):
|
||||
|
||||
@ -1692,6 +1692,16 @@ class SIMDScheduling(BaseScheduling):
|
||||
|
||||
return kernel
|
||||
|
||||
def _make_shape_cache_key(self, node: MultiTemplateBuffer, hint: int) -> tuple[tuple[int, ...], ...]:
|
||||
"""
|
||||
Returns cache key for hint-based multi-graph; key is tuple of shapes with hint filled in.
|
||||
"""
|
||||
out = []
|
||||
for arg in list(node.inputs) + [node]:
|
||||
if (size := arg.maybe_get_size()) is not None:
|
||||
out.append(tuple(hint if isinstance(s, sympy.Symbol) else s for s in size))
|
||||
return tuple(out)
|
||||
|
||||
def codegen_template(
|
||||
self,
|
||||
template_node,
|
||||
@ -1715,10 +1725,10 @@ class SIMDScheduling(BaseScheduling):
|
||||
isinstance(template_node.node, MultiTemplateBuffer)
|
||||
and template_node.node._make_kernel_renders
|
||||
):
|
||||
kernels = []
|
||||
kernels = {}
|
||||
src_codes = []
|
||||
|
||||
for make_kernel_render in template_node.node._make_kernel_renders.values():
|
||||
for size_hint, make_kernel_render in template_node.node._make_kernel_renders.items():
|
||||
kernel, render = make_kernel_render(
|
||||
template_node.node, hint_override=hint_override
|
||||
)
|
||||
@ -1743,12 +1753,13 @@ class SIMDScheduling(BaseScheduling):
|
||||
prologue_nodes,
|
||||
only_gen_src_code=False,
|
||||
)
|
||||
kernels.append(kernel)
|
||||
shape_cache_key = None if size_hint is None else self._make_shape_cache_key(template_node.node, size_hint)
|
||||
kernels[shape_cache_key] = kernel
|
||||
|
||||
if only_gen_src_code:
|
||||
return "\n\n".join(src_codes)
|
||||
|
||||
MultiKernel.merge_workspaces_inplace(kernels)
|
||||
MultiKernel.merge_workspaces_inplace(list(kernels.values()))
|
||||
multi_kernel = MultiKernel(kernels)
|
||||
node_schedule = [*prologue_nodes, template_node, *epilogue_nodes]
|
||||
self.codegen_comment(node_schedule)
|
||||
|
||||
Reference in New Issue
Block a user