Compare commits

...

3 Commits

Author SHA1 Message Date
9b64c3a85b [Cutlass] Restore search space for swizzle
ghstack-source-id: b84769aaa79711bfb22f5aab6eb10cae33c78c29
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147224
2025-02-18 11:05:26 -08:00
3fcbbd485b [Cutlass] Add support for runtime param choices, starting with swizzle
ghstack-source-id: b1110cc453c61e35271dcc65dd4ad42c22548b83
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147223

fix

fix

set fix
2025-02-18 11:05:26 -08:00
68104e28a6 [Inductor] Add autotuning artifact logging
ghstack-source-id: 5be70c5efbafb8ea8bdb4fb9c4360d277dc73754
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147222

fix log test

import fix
2025-02-18 11:05:25 -08:00
9 changed files with 91 additions and 14 deletions

View File

@ -4,6 +4,7 @@ import functools
import logging
import os
import re
import unittest
import unittest.mock
import torch
@ -17,6 +18,7 @@ from torch._dynamo.testing import (
)
from torch._dynamo.trace_rules import _as_posix_path
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_cuda import SM90OrLater
from torch.testing._internal.common_utils import (
find_free_port,
munge_exc,
@ -756,6 +758,20 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
self.assertGreater(len(records), 0)
self.assertLess(len(records), 3)
@make_logging_test(autotuning=True)
@requires_cuda
@unittest.skipIf(not SM90OrLater, "requires H100+ GPU")
def test_autotuning(self, records):
with torch._inductor.utils.fresh_inductor_cache():
def f(a, b):
return torch.mm(a, b)
f = torch.compile(f, mode="max-autotune-no-cudagraphs")
f(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
self.assertGreater(len(records), 0)
self.assertLess(len(records), 40)
@make_logging_test(graph_region_expansion=True)
def test_graph_region_expansion(self, records):
with torch._dynamo.config.patch("track_nodes_for_deduplication", True):
@ -878,7 +894,7 @@ TorchDynamo attempted to trace the following frames: [
)
# single record tests
# non single record tests
exclusions = {
"bytecode",
"cudagraphs",
@ -915,6 +931,7 @@ exclusions = {
"cudagraph_static_inputs",
"benchmarking",
"loop_ordering",
"autotuning",
"graph_region_expansion",
}
for name in torch._logging._internal.log_registry.artifact_names:

View File

@ -29,6 +29,7 @@ from torch._inductor.codecache import (
PyCodeCache,
)
from torch._inductor.utils import get_gpu_type, is_gpu
from torch._logging import getArtifactLogger
from torch.utils._ordered_set import OrderedSet
@ -38,6 +39,7 @@ if TYPE_CHECKING:
from types import ModuleType
from torch._inductor.select_algorithm import TritonTemplateCaller
from .codegen.common import WorkspaceArg
from . import config
@ -49,7 +51,7 @@ from .virtualized import V
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
EXIT_HANDLER_REGISTERED = False
log = logging.getLogger(__name__)
log = getArtifactLogger(__name__, "autotuning")
# Used to synchronize between parent and child processes

View File

@ -153,7 +153,7 @@ class CUDATemplateKernel(CUDAKernel):
_EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream"
def __init__(self, kernel_name) -> None:
def __init__(self, kernel_name, runtime_arg_info, runtime_arg_values) -> None:
"""
Initializes a new instance of the CUDATemplateKernel class.
@ -162,6 +162,8 @@ class CUDATemplateKernel(CUDAKernel):
"""
super().__init__()
self.kernel_name = kernel_name
self.runtime_arg_info = runtime_arg_info
self.runtime_arg_values = runtime_arg_values
def check_not_null(self, node: IRNode) -> str:
"""
@ -244,7 +246,13 @@ class CUDATemplateKernel(CUDAKernel):
f"const int {s}" for s in ("M", "N", "K", "lda", "ldb", "ldc", "ldd")
]
signature = f"int {self.kernel_name}({', '.join(arg_defs + size_args)}, {self._EXTRA_CPP_ARGS})"
runtime_arg_decls = ",".join(
[f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info]
)
if runtime_arg_decls:
runtime_arg_decls += ", "
signature = f"int {self.kernel_name}({', '.join(arg_defs + size_args)}, {runtime_arg_decls}{self._EXTRA_CPP_ARGS})"
self.signature = signature
return signature
@ -278,7 +286,11 @@ class CUDATemplateKernel(CUDAKernel):
layout_args = self.get_layout_args()
call_args.extend(layout_args) # type: ignore[arg-type]
for arg in self.runtime_arg_values:
call_args.append(arg)
arg_types.extend("int" for a in layout_args)
for arg in self.runtime_arg_info:
arg_types.append(arg.ty)
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
for i in range(len(call_args)):
if V.graph.is_unspec_arg(call_args[i]):

View File

@ -1,13 +1,15 @@
# mypy: allow-untyped-defs
import functools
import itertools
import logging
from typing import Optional
from dataclasses import dataclass
from typing import Any, Optional
from typing_extensions import override
from unittest.mock import patch
import sympy
import torch
from torch._logging import getArtifactLogger
from ...autotune_process import CUDABenchmarkRequest, TensorMeta
from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout
@ -17,7 +19,13 @@ from ..common import KernelTemplate
from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel
log = logging.getLogger(__name__)
log = getArtifactLogger(__name__, "autotuning")
@dataclass(frozen=True)
class ArgInfo:
name: str
ty: str
class CUDATemplate(KernelTemplate):
@ -67,6 +75,8 @@ class CUDATemplate(KernelTemplate):
V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
), CUDATemplateKernel(
kernel_name=kernel_name,
runtime_arg_info=self.get_runtime_arg_info(),
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
) as kernel:
code = self.render(kernel=kernel, **kwargs)
_, call_args, _, _ = kernel.args.python_argdefs()
@ -92,6 +102,7 @@ class CUDATemplate(KernelTemplate):
)
V.graph.sizevars.size_hints(map(sympy.expand, call_args[len(expected_args) :]))
size_args = V.graph.sizevars.size_hints(kernel.get_layout_args())
extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs))
kernel_hash_name = f"cuda_{self.name}_{next(self.index_counter)}"
@ -100,7 +111,7 @@ class CUDATemplate(KernelTemplate):
kernel_name=kernel_name,
input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
extra_args=size_args,
extra_args=extra_args,
source_code=code,
)
@ -110,6 +121,8 @@ class CUDATemplate(KernelTemplate):
):
kernel = CUDATemplateKernel(
kernel_name="KERNEL_NAME",
runtime_arg_info=self.get_runtime_arg_info(),
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
)
render = functools.partial(
self.render,
@ -169,6 +182,12 @@ class CUDATemplate(KernelTemplate):
def render(self, **kwargs) -> str:
raise NotImplementedError
def get_runtime_arg_info(self) -> list[ArgInfo]:
return []
def get_runtime_arg_values(self, **kwargs) -> list[Any]:
return []
class CUTLASSTemplate(CUDATemplate):
"""
@ -257,3 +276,14 @@ class CUTLASSTemplate(CUDATemplate):
return (
f"({self._DTYPE_TO_CUTLASS_SPARSE_META.get(node.get_dtype())}*)({ptr})"
)
@override
def get_runtime_arg_info(self) -> list[ArgInfo]:
return [ArgInfo("swizzle", "const uint8_t")]
@override
def get_runtime_arg_values(self, **kwargs) -> list[Any]:
"""
Helper method to retrieve runtime args from generate kwargs
"""
return [kwargs[arg.name] for arg in self.get_runtime_arg_info()]

View File

@ -46,7 +46,7 @@ PT_EXPORT {{kernel_call_signature}} {
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count);
}
{{instance_type}}::Arguments arguments;
{{template.render_gemm_arguments(argument_template, epilogue_template, should_swap_xw, swizzle,
{{template.render_gemm_arguments(argument_template, epilogue_template, should_swap_xw,
X, W, Bias, Y, alpha, beta, kernel, epilogue_args)}}
{{instance_type}} gemm_op;
if (workspace_size) {
@ -120,7 +120,7 @@ GEMM_ARGS_CUTLASS_3X = r"""
{{epilogue_arguments}},
hw_info
};
arguments.scheduler.max_swizzle_size = {{swizzle}};
arguments.scheduler.max_swizzle_size = swizzle;
"""
# Jinja template for Cutlass 3.x GEMM Kernel arguments if epilogue fusion is applied,
@ -980,7 +980,6 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
Bias=Bias,
epilogue_template=epilogue_template,
argument_template=argument_template,
swizzle=kwargs["swizzle"],
should_swap_xw=should_swap_xw,
template=self,
kernel=kernel,
@ -1256,7 +1255,6 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
argument_template: str,
epilogue_template: str,
should_swap_xw: bool,
swizzle: int,
X: IRNode,
W: IRNode,
Bias: IRNode,
@ -1302,7 +1300,6 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
M="M",
N="N",
epilogue_args=epilogue_args,
swizzle=swizzle,
)
assert epilogue_template is not None

View File

@ -1262,7 +1262,7 @@ class cuda:
cutlass_max_profiling_configs: Optional[int] = None
# The L2 swizzle values to consider when profiling CUTLASS configs in max_autotune.
cutlass_max_profiling_swizzle_options: list[int] = [2]
cutlass_max_profiling_swizzle_options: list[int] = [1, 2, 4]
# Path to CUDA NVCC.
# NVCC search order:

View File

@ -1784,7 +1784,16 @@ class AlgorithmSelectorCache(PersistentCache):
start_times: dict[concurrent.futures.Future[Any], float] = {}
elapsed_times: dict[concurrent.futures.Future[Any], float] = {}
# Some choices only differ in runtime arguments, so we
# skip a choice if it has the same hash as a previously seen choice
seen_choices: OrderedSet[ChoiceCaller] = OrderedSet()
for c in choices:
# Skip choices which we have already issued a precompile
if c.hash_key() in seen_choices:
continue
else:
seen_choices.add(c.hash_key())
if hasattr(c, "precompile"):
triton_cuda_choice = isinstance(
c, TritonTemplateCaller

View File

@ -246,6 +246,7 @@ def set_logs(
compiled_autograd_verbose: bool = False,
cudagraph_static_inputs: bool = False,
benchmarking: bool = False,
autotuning: bool = False,
graph_region_expansion: bool = False,
):
"""
@ -426,6 +427,9 @@ def set_logs(
cudagraph_static_inputs (:class:`bool`):
Whether to emit debug info for cudagraph static input detection. Default: ``False``
autotuning (:class:`bool`):
Autotuning choice logs, such as kernel source, perf, and tuning parameters. Default: ``False``
graph_region_expansion (:class:`bool`):
Whether to emit the detailed steps of the duplicate graph region tracker expansion algorithm. Default: ``False``
@ -528,6 +532,7 @@ def set_logs(
compiled_autograd_verbose=compiled_autograd_verbose,
cudagraph_static_inputs=cudagraph_static_inputs,
benchmarking=benchmarking,
autotuning=autotuning,
graph_region_expansion=graph_region_expansion,
)

View File

@ -196,6 +196,11 @@ register_artifact(
"Detailed Inductor benchmarking information.",
off_by_default=True,
)
register_artifact(
"autotuning",
"Autotuning choice logs, such as kernel source, perf, and tuning parameters.",
)
register_artifact(
"graph_region_expansion",
"Logs detailed steps of the duplicate graph region tracker expansion algorithm",