mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add hooks to Scheduler nodes for generating device-specific debug strings (#135015)
Previously, instances of `SchedulerNode` and `FusedSchedulerNode` would explicitly check whether the compilation target is Triton when codegen'ing debug strings. Generating debug triton code is instead implemented as a callback set on scheduler nodes by `TritonScheduling`. This makes the codegen more device-agnostic and allows schedulers to customise the codegen output as opposed to it being closely coupled to the debug string codegen Pull Request resolved: https://github.com/pytorch/pytorch/pull/135015 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
8543000c27
commit
74e871355b
@ -25,6 +25,7 @@ from typing import (
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
import torch._inductor.metrics as metrics
|
||||
import torch._logging
|
||||
from torch._dynamo.utils import preserve_rng_state
|
||||
from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties
|
||||
@ -38,10 +39,10 @@ from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_ty
|
||||
from ...utils._sympy.value_ranges import ValueRanges
|
||||
from .. import config, ir
|
||||
from ..codecache import code_hash, get_path, PyCodeCache
|
||||
from ..metrics import is_metric_table_enabled, log_kernel_metadata
|
||||
from ..runtime.benchmarking import benchmarker
|
||||
from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK
|
||||
from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2
|
||||
from ..scheduler import BaseSchedulerNode, FusedSchedulerNode, Scheduler, SchedulerNode
|
||||
from ..utils import (
|
||||
cache_on_self,
|
||||
get_bounds_index_expr,
|
||||
@ -3098,6 +3099,14 @@ class TritonScheduling(SIMDScheduling):
|
||||
)
|
||||
)
|
||||
|
||||
def __init__(self, scheduler: Scheduler) -> None:
|
||||
super().__init__(scheduler)
|
||||
if scheduler is None or not hasattr(scheduler, "nodes"):
|
||||
return
|
||||
for node in scheduler.nodes:
|
||||
if isinstance(node, (SchedulerNode, FusedSchedulerNode)):
|
||||
node.debug_device_str = debug_triton_code
|
||||
|
||||
@classmethod
|
||||
def get_backend_features(cls, device: torch.device):
|
||||
return cls.backend_features
|
||||
@ -3174,8 +3183,8 @@ class TritonScheduling(SIMDScheduling):
|
||||
# log kernel metadata for offline analysis.
|
||||
# E.g. one can find all unaligned inner reduction and check if
|
||||
# padding helps with the perf kernel by kernel.
|
||||
if is_metric_table_enabled("kernel_metadata"):
|
||||
log_kernel_metadata(kernel_name, kernel_path, src_code)
|
||||
if metrics.is_metric_table_enabled("kernel_metadata"):
|
||||
metrics.log_kernel_metadata(kernel_name, kernel_path, src_code)
|
||||
|
||||
return kernel_name
|
||||
|
||||
@ -3346,3 +3355,33 @@ class TritonScheduling(SIMDScheduling):
|
||||
V.graph.removed_buffers = removed_buffers_orig
|
||||
V.graph.inplaced_to_remove = inplaced_to_remove_orig
|
||||
return total_ms, total_clone_ms, file_list
|
||||
|
||||
|
||||
def debug_triton_code(node: BaseSchedulerNode) -> List[str]:
|
||||
lines = []
|
||||
multi_template = node.get_template_node()
|
||||
assert multi_template is None or isinstance(multi_template, ir.MultiTemplateBuffer)
|
||||
if multi_template and multi_template.make_kernel_render is None:
|
||||
lines.append(f"{node.get_name()} Unfinalized multi template buffer")
|
||||
else:
|
||||
from torch._inductor.codegen.cuda_combined_scheduling import (
|
||||
CUDACombinedScheduling,
|
||||
)
|
||||
|
||||
device = node.get_device()
|
||||
backend = node.scheduler.get_backend(device)
|
||||
assert isinstance(
|
||||
backend, (SIMDScheduling, CUDACombinedScheduling)
|
||||
), f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}"
|
||||
V.graph.scheduler.current_device = device
|
||||
|
||||
# Don't increment kernel count when generating debug string.
|
||||
# This will confuse some unit tests that check the number of
|
||||
# generated kernels.
|
||||
old_generated_kernel_count = metrics.generated_kernel_count
|
||||
triton_code = backend.generate_kernel_code_from_nodes(node.get_nodes()).strip()
|
||||
metrics.generated_kernel_count = old_generated_kernel_count
|
||||
|
||||
lines.append(f"{node.get_name()} Triton code:")
|
||||
lines.append(textwrap.indent(triton_code, " "))
|
||||
return lines
|
||||
|
||||
@ -175,6 +175,9 @@ class BaseSchedulerNode:
|
||||
|
||||
def __init__(self, scheduler: Scheduler) -> None:
|
||||
self.scheduler: Scheduler = scheduler
|
||||
self.debug_device_str: Callable[
|
||||
[BaseSchedulerNode], List[str]
|
||||
] = lambda *args, **kwargs: []
|
||||
|
||||
def _init_from_node(self, node: ir.Operation) -> None:
|
||||
self.node: Optional[ir.Operation] = node
|
||||
@ -226,6 +229,9 @@ class BaseSchedulerNode:
|
||||
def debug_str_extra(self) -> str:
|
||||
return ""
|
||||
|
||||
def _debug_str_for_device(self) -> List[str]:
|
||||
return self.debug_device_str(self)
|
||||
|
||||
def debug_str_short(self) -> str:
|
||||
maybe_data = getattr(self.node, "data", None)
|
||||
data_str = ""
|
||||
@ -953,8 +959,7 @@ class SchedulerNode(BaseSchedulerNode):
|
||||
lines.append(textwrap.indent(self._body.debug_str(), " "))
|
||||
|
||||
assert self.node is not None
|
||||
if ir.is_triton(self.node.get_device()):
|
||||
lines.extend(debug_triton_code(self))
|
||||
lines.extend(self._debug_str_for_device())
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
@ -1178,9 +1183,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
]
|
||||
node = self.snodes[0].node
|
||||
if node is not None:
|
||||
device = node.get_device()
|
||||
if ir.is_triton(device):
|
||||
lines.extend(debug_triton_code(self))
|
||||
lines.extend(self._debug_str_for_device())
|
||||
|
||||
return textwrap.indent("\n".join(lines).rstrip(), " ")
|
||||
|
||||
@ -3757,34 +3760,3 @@ class BaseScheduling:
|
||||
and memory copy time in milliseconds on randomly generated inputs.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def debug_triton_code(node: Union[SchedulerNode, FusedSchedulerNode]) -> List[str]:
|
||||
lines = []
|
||||
multi_template = node.get_template_node()
|
||||
assert multi_template is None or isinstance(multi_template, ir.MultiTemplateBuffer)
|
||||
if multi_template and multi_template.make_kernel_render is None:
|
||||
lines.append(f"{node.get_name()} Unfinalized multi template buffer")
|
||||
else:
|
||||
from torch._inductor.codegen.cuda_combined_scheduling import (
|
||||
CUDACombinedScheduling,
|
||||
)
|
||||
|
||||
from .codegen.simd import SIMDScheduling
|
||||
|
||||
snodes = (node,) if isinstance(node, SchedulerNode) else node.snodes
|
||||
device = snodes[0].get_device()
|
||||
backend = node.scheduler.get_backend(device)
|
||||
assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling))
|
||||
V.graph.scheduler.current_device = device
|
||||
|
||||
# Don't increment kernel count when generating debug string.
|
||||
# This will confuse some unit tests that check the number of
|
||||
# generated kernels.
|
||||
old_generated_kernel_count = metrics.generated_kernel_count
|
||||
triton_code = backend.generate_kernel_code_from_nodes(snodes).strip()
|
||||
metrics.generated_kernel_count = old_generated_kernel_count
|
||||
|
||||
lines.append(f"{node.get_name()} Triton code:")
|
||||
lines.append(textwrap.indent(triton_code, " "))
|
||||
return lines
|
||||
|
||||
Reference in New Issue
Block a user