Add hooks to Scheduler nodes for generating device-specific debug strings (#135015)

Previously, instances of `SchedulerNode` and `FusedSchedulerNode` would explicitly check whether the compilation target is Triton when codegen'ing debug strings. Generating debug triton code is instead implemented as a callback set on scheduler nodes by `TritonScheduling`. This makes the codegen more device-agnostic and allows schedulers to customise the codegen output as opposed to it being closely coupled to the debug string codegen

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135015
Approved by: https://github.com/jansel
This commit is contained in:
Artemiy Bulavin
2024-10-11 20:30:48 +00:00
committed by PyTorch MergeBot
parent 8543000c27
commit 74e871355b
2 changed files with 50 additions and 39 deletions

View File

@ -25,6 +25,7 @@ from typing import (
import sympy
import torch
import torch._inductor.metrics as metrics
import torch._logging
from torch._dynamo.utils import preserve_rng_state
from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties
@ -38,10 +39,10 @@ from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_ty
from ...utils._sympy.value_ranges import ValueRanges
from .. import config, ir
from ..codecache import code_hash, get_path, PyCodeCache
from ..metrics import is_metric_table_enabled, log_kernel_metadata
from ..runtime.benchmarking import benchmarker
from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK
from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2
from ..scheduler import BaseSchedulerNode, FusedSchedulerNode, Scheduler, SchedulerNode
from ..utils import (
cache_on_self,
get_bounds_index_expr,
@ -3098,6 +3099,14 @@ class TritonScheduling(SIMDScheduling):
)
)
def __init__(self, scheduler: Scheduler) -> None:
super().__init__(scheduler)
if scheduler is None or not hasattr(scheduler, "nodes"):
return
for node in scheduler.nodes:
if isinstance(node, (SchedulerNode, FusedSchedulerNode)):
node.debug_device_str = debug_triton_code
@classmethod
def get_backend_features(cls, device: torch.device):
return cls.backend_features
@ -3174,8 +3183,8 @@ class TritonScheduling(SIMDScheduling):
# log kernel metadata for offline analysis.
# E.g. one can find all unaligned inner reduction and check if
# padding helps with the perf kernel by kernel.
if is_metric_table_enabled("kernel_metadata"):
log_kernel_metadata(kernel_name, kernel_path, src_code)
if metrics.is_metric_table_enabled("kernel_metadata"):
metrics.log_kernel_metadata(kernel_name, kernel_path, src_code)
return kernel_name
@ -3346,3 +3355,33 @@ class TritonScheduling(SIMDScheduling):
V.graph.removed_buffers = removed_buffers_orig
V.graph.inplaced_to_remove = inplaced_to_remove_orig
return total_ms, total_clone_ms, file_list
def debug_triton_code(node: BaseSchedulerNode) -> List[str]:
lines = []
multi_template = node.get_template_node()
assert multi_template is None or isinstance(multi_template, ir.MultiTemplateBuffer)
if multi_template and multi_template.make_kernel_render is None:
lines.append(f"{node.get_name()} Unfinalized multi template buffer")
else:
from torch._inductor.codegen.cuda_combined_scheduling import (
CUDACombinedScheduling,
)
device = node.get_device()
backend = node.scheduler.get_backend(device)
assert isinstance(
backend, (SIMDScheduling, CUDACombinedScheduling)
), f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}"
V.graph.scheduler.current_device = device
# Don't increment kernel count when generating debug string.
# This will confuse some unit tests that check the number of
# generated kernels.
old_generated_kernel_count = metrics.generated_kernel_count
triton_code = backend.generate_kernel_code_from_nodes(node.get_nodes()).strip()
metrics.generated_kernel_count = old_generated_kernel_count
lines.append(f"{node.get_name()} Triton code:")
lines.append(textwrap.indent(triton_code, " "))
return lines

View File

@ -175,6 +175,9 @@ class BaseSchedulerNode:
def __init__(self, scheduler: Scheduler) -> None:
self.scheduler: Scheduler = scheduler
self.debug_device_str: Callable[
[BaseSchedulerNode], List[str]
] = lambda *args, **kwargs: []
def _init_from_node(self, node: ir.Operation) -> None:
self.node: Optional[ir.Operation] = node
@ -226,6 +229,9 @@ class BaseSchedulerNode:
def debug_str_extra(self) -> str:
return ""
def _debug_str_for_device(self) -> List[str]:
return self.debug_device_str(self)
def debug_str_short(self) -> str:
maybe_data = getattr(self.node, "data", None)
data_str = ""
@ -953,8 +959,7 @@ class SchedulerNode(BaseSchedulerNode):
lines.append(textwrap.indent(self._body.debug_str(), " "))
assert self.node is not None
if ir.is_triton(self.node.get_device()):
lines.extend(debug_triton_code(self))
lines.extend(self._debug_str_for_device())
return "\n".join(lines)
@ -1178,9 +1183,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
]
node = self.snodes[0].node
if node is not None:
device = node.get_device()
if ir.is_triton(device):
lines.extend(debug_triton_code(self))
lines.extend(self._debug_str_for_device())
return textwrap.indent("\n".join(lines).rstrip(), " ")
@ -3757,34 +3760,3 @@ class BaseScheduling:
and memory copy time in milliseconds on randomly generated inputs.
"""
raise NotImplementedError
def debug_triton_code(node: Union[SchedulerNode, FusedSchedulerNode]) -> List[str]:
lines = []
multi_template = node.get_template_node()
assert multi_template is None or isinstance(multi_template, ir.MultiTemplateBuffer)
if multi_template and multi_template.make_kernel_render is None:
lines.append(f"{node.get_name()} Unfinalized multi template buffer")
else:
from torch._inductor.codegen.cuda_combined_scheduling import (
CUDACombinedScheduling,
)
from .codegen.simd import SIMDScheduling
snodes = (node,) if isinstance(node, SchedulerNode) else node.snodes
device = snodes[0].get_device()
backend = node.scheduler.get_backend(device)
assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling))
V.graph.scheduler.current_device = device
# Don't increment kernel count when generating debug string.
# This will confuse some unit tests that check the number of
# generated kernels.
old_generated_kernel_count = metrics.generated_kernel_count
triton_code = backend.generate_kernel_code_from_nodes(snodes).strip()
metrics.generated_kernel_count = old_generated_kernel_count
lines.append(f"{node.get_name()} Triton code:")
lines.append(textwrap.indent(triton_code, " "))
return lines