Compare commits

...

4 Commits

Author SHA1 Message Date
6a862cdb23 do_bench_using_profiling update 2025-06-06 16:27:37 -07:00
4eb2cf1548 enable stuff 2025-06-04 13:52:04 -07:00
e0ec6df976 remove breakpoint 2025-06-03 21:08:11 -07:00
27da636a63 TMP update autotune configs 2025-06-03 21:08:06 -07:00
3 changed files with 119 additions and 32 deletions

View File

@ -891,6 +891,8 @@ profile_bandwidth_with_do_bench_using_profiling = (
disable_cpp_codegen = False
new_configs: bool = True
# Freezing will attempt to inline weights as constants in optimization
# and run constant folding and other optimizations on them. After freezing, weights
# can no longer be updated.

View File

@ -84,37 +84,86 @@ class BaseHeuristicSingleton(type):
cls._instances[cls] = instance
return cls._instances[cls]
import os
class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
"""
Base class for mm_configs, device specific triton kernels config inherit from here
"""
@property
def mm_configs(self) -> list[BaseConfig]:
if os.environ.get("TORCHINDUCTOR_NEW_CONFIGS", "0") == "1" or config.new_configs:
return [
# GemmConfig(16, 16, 128, 5, 1),
# GemmConfig(16, 16, 256, 4, 1),
GemmConfig(64, 16, 128, 4, 4),
GemmConfig(64, 16, 256, 4, 4),
GemmConfig(64, 32, 128, 4, 4),
GemmConfig(64, 32, 128, 5, 8),
# GemmConfig(63, 32, 256, 1, 8),
GemmConfig(64, 64, 128, 4, 4),
GemmConfig(64, 128, 64, 4, 4),
GemmConfig(64, 128, 128, 3, 4),
GemmConfig(128, 16, 128, 5, 8),
GemmConfig(128, 128, 32, 5, 8),
GemmConfig(128, 128, 64, 3, 4),
GemmConfig(128, 128, 64, 3, 8),
GemmConfig(128, 128, 64, 4, 4),
GemmConfig(128, 128, 64, 4, 8),
GemmConfig(128, 256, 32, 5, 8),
GemmConfig(128, 256, 64, 3, 8),
GemmConfig(128, 256, 64, 4, 8),
#GemmConfig(128, 256, 64, 5, 8),
GemmConfig(256, 128, 32, 5, 8)
]
else:
return [
GemmConfig(32, 32, 16, 1, 2),
GemmConfig(32, 32, 128, 2, 4),
GemmConfig(32, 64, 32, 5, 8),
GemmConfig(64, 32, 32, 5, 8),
GemmConfig(64, 32, 128, 5, 4),
GemmConfig(64, 64, 16, 2, 4),
GemmConfig(64, 64, 32, 2, 4),
GemmConfig(64, 64, 64, 3, 8),
GemmConfig(64, 64, 128, 5, 4),
GemmConfig(64, 128, 32, 3, 4),
GemmConfig(64, 128, 32, 4, 8),
GemmConfig(64, 128, 64, 3, 4),
GemmConfig(64, 128, 128, 4, 4),
GemmConfig(128, 64, 32, 3, 4),
GemmConfig(128, 64, 32, 4, 8),
GemmConfig(128, 128, 32, 2, 8),
GemmConfig(128, 128, 32, 3, 4),
GemmConfig(128, 128, 64, 3, 4),
GemmConfig(128, 128, 64, 5, 8),
]
def __init__(self) -> None:
# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform. The configs are as follows:
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
self.mm_configs: list[BaseConfig] = [
GemmConfig(32, 32, 16, 1, 2),
GemmConfig(32, 32, 128, 2, 4),
GemmConfig(32, 64, 32, 5, 8),
GemmConfig(64, 32, 32, 5, 8),
GemmConfig(64, 32, 128, 5, 4),
GemmConfig(64, 64, 16, 2, 4),
GemmConfig(64, 64, 32, 2, 4),
GemmConfig(64, 64, 64, 3, 8),
GemmConfig(64, 64, 128, 5, 4),
GemmConfig(64, 128, 32, 3, 4),
GemmConfig(64, 128, 32, 4, 8),
GemmConfig(64, 128, 64, 3, 4),
GemmConfig(64, 128, 128, 4, 4),
GemmConfig(128, 64, 32, 3, 4),
GemmConfig(128, 64, 32, 4, 8),
GemmConfig(128, 128, 32, 2, 8),
GemmConfig(128, 128, 32, 3, 4),
GemmConfig(128, 128, 64, 3, 4),
GemmConfig(128, 128, 64, 5, 8),
]
# self.mm_configs: list[BaseConfig] = [
# GemmConfig(32, 32, 16, 1, 2),
# GemmConfig(32, 32, 128, 2, 4),
# GemmConfig(32, 64, 32, 5, 8),
# GemmConfig(64, 32, 32, 5, 8),
# GemmConfig(64, 32, 128, 5, 4),
# GemmConfig(64, 64, 16, 2, 4),
# GemmConfig(64, 64, 32, 2, 4),
# GemmConfig(64, 64, 64, 3, 8),
# GemmConfig(64, 64, 128, 5, 4),
# GemmConfig(64, 128, 32, 3, 4),
# GemmConfig(64, 128, 32, 4, 8),
# GemmConfig(64, 128, 64, 3, 4),
# GemmConfig(64, 128, 128, 4, 4),
# GemmConfig(128, 64, 32, 3, 4),
# GemmConfig(128, 64, 32, 4, 8),
# GemmConfig(128, 128, 32, 2, 8),
# GemmConfig(128, 128, 32, 3, 4),
# GemmConfig(128, 128, 64, 3, 4),
# GemmConfig(128, 128, 64, 5, 8),
# ]
# Exhaustive search for mm configs
self.exhaustive_configs: list[BaseConfig] = [

View File

@ -305,21 +305,56 @@ def do_bench_using_profiling(
log.debug("raw events")
log.debug(p.key_averages().table(sort_by="self_device_time_total", row_limit=-1))
filtered_events = EventList(
[
event
for event in p.events()
if event.device_type == DeviceType.CUDA and event.name != "Context Sync"
]
)
tmp = [
event
for event in p.events()
if event.device_type == DeviceType.CUDA
and event.name not in ["Context Sync", "Memset (Device)"]
]
filtered_events = EventList(tmp)
num_event_per_group = len(filtered_events) / n_repeat
if len(filtered_events) % n_repeat != 0:
raise RuntimeError(
log.debug(
"Failed to divide all profiling events into #repeat groups. "
"#CUDA events: %d, #repeats: %s",
"#CUDA events: %d, #repeats: %s. There could have been extra "
"events added, check filtered_events.",
len(filtered_events),
n_repeat,
)
num_event_per_group = len(filtered_events) / n_repeat
# There's a bug in recent cuda cupti where kernel events aren't being recorded.
# This has been reported to NV, but no workaround yet, so we'll adjust the final
# estimate using some heuristics to see if we can still salvage a number.
# histogram of the kernels
from collections import defaultdict
name_counts = defaultdict(int)
for event in tmp:
name_counts[event.name] += 1
largest, smallest = max(name_counts), min(name_counts)
largest_num, smallest_num = name_counts[largest], name_counts[smallest]
if largest_num > n_repeat:
# we're out of luck here, since kernels are being ran multiple times in fn()
raise RuntimeError(
"Failed to divide all profiling events into #repeat groups and unable to adjust because"
" kernels are being ran multiple times in the benchmarking function."
)
else:
# It could be the case here that kernels are being run multiple times, and the profiler is missing events.
# Not much we can do in that situation. It shouldn't appear in most of the usage of this function, which is for microbenchmarking
# but beware if using on larger models.
# we're going to assume that the benchmarking function is one of each kernel, which we can't do if the distribution is lopsided
if largest_num / smallest_num < 0.90:
raise RuntimeError(
"Failed to divide all profiling events into #repeat groups and unable to adjust because"
" either the benchmarking function is too complex, or too many kernels are being skipped."
)
# TODO just sum one of each kernel and adjust the other ratios by there prevalence in the histogram.
breakpoint()
actual_events = EventList(
[
event
@ -327,10 +362,11 @@ def do_bench_using_profiling(
if i % num_event_per_group != 0
]
)
actual_events._build_tree()
actual_events = actual_events.key_averages()
log.debug("profiling time breakdown")
log.info("profiling time breakdown")
log.debug(actual_events.table(row_limit=-1))
res = sum(event.device_time_total for event in actual_events) / 1000.0 / n_repeat