mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[Profiler] Add queue depth computation (#79993)
Test Plan: Add test in test_profiler.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/79993 Approved by: https://github.com/robieta
This commit is contained in:
committed by
PyTorch MergeBot
parent
8cfb8f94a0
commit
3a1e3e67c5
@ -1194,17 +1194,17 @@ class TestExperimentalUtils(TestCase):
|
||||
]
|
||||
|
||||
cpu_events = [
|
||||
MockProfilerEvent("CPU", 1, 0, 100000),
|
||||
MockProfilerEvent("CPU", 2, 100001, 100000),
|
||||
MockProfilerEvent("CPU", 3, 200001, 100000),
|
||||
MockProfilerEvent("CPU", 4, 300001, 100000),
|
||||
MockProfilerEvent("CPU", 5, 400001, 100000),
|
||||
MockProfilerEvent("CPU", 6, 500001, 100000),
|
||||
MockProfilerEvent("CPU", 7, 600001, 100000),
|
||||
MockProfilerEvent("CPU", 8, 700001, 100000),
|
||||
MockProfilerEvent("CPU", 9, 800001, 100000),
|
||||
MockProfilerEvent("CPU", 10, 900001, 100000),
|
||||
MockProfilerEvent("CPU", 11, 1000001, 100000),
|
||||
MockProfilerEvent("CPU (Before cudaLaunchKernel)", 1, 0, 100000),
|
||||
MockProfilerEvent("CPU (Before cudaLaunchKernel)", 2, 100001, 100000),
|
||||
MockProfilerEvent("CPU (Before cudaLaunchKernel)", 3, 200001, 100000),
|
||||
MockProfilerEvent("CPU (Before cudaLaunchKernel)", 4, 300001, 100000),
|
||||
MockProfilerEvent("CPU (After cudaLaunchKernel)", 5, 400001, 100000),
|
||||
MockProfilerEvent("CPU (After cudaLaunchKernel)", 6, 500001, 100000),
|
||||
MockProfilerEvent("CPU (After cudaLaunchKernel)", 7, 600001, 100000),
|
||||
MockProfilerEvent("CPU (After GPU)", 8, 700001, 100000),
|
||||
MockProfilerEvent("CPU (After GPU)", 9, 800001, 100000),
|
||||
MockProfilerEvent("CPU (After GPU)", 10, 900001, 100000),
|
||||
MockProfilerEvent("CPU (No Event)", 11, 1000001, 100000),
|
||||
]
|
||||
|
||||
profiler = unittest.mock.Mock()
|
||||
@ -1234,13 +1234,44 @@ class TestExperimentalUtils(TestCase):
|
||||
for child in event_key.event.children
|
||||
]))
|
||||
|
||||
def test_utils_compute_queue_depth_list(self):
|
||||
def test_utils_compute_queue_depth(self):
|
||||
|
||||
def format_queue_depth(queue_depth_list, events):
|
||||
res = ""
|
||||
for data, event in zip(queue_depth_list, events):
|
||||
res += f"{data.queue_depth} [{event.name()}]\n"
|
||||
return res
|
||||
|
||||
# We have to use Mock because time series data is too flaky to test
|
||||
profiler = self.generate_mock_profile()
|
||||
basic_eval = _utils.BasicEvaluation(profiler)
|
||||
golden_queue_depth_list = [1, 2, 3, 2, 1, 0]
|
||||
for observed, golden in zip(basic_eval.compute_queue_depth(),
|
||||
golden_queue_depth_list):
|
||||
self.assertEqual(observed.queue_depth, golden)
|
||||
basic_evaluation = _utils.BasicEvaluation(profiler)
|
||||
self.assertExpectedInline(
|
||||
format_queue_depth(basic_evaluation.queue_depth_list,
|
||||
basic_evaluation.cuda_events), """\
|
||||
1 [cudaLaunchKernel]
|
||||
2 [cudaLaunchKernel]
|
||||
3 [cudaLaunchKernel]
|
||||
2 [GPU]
|
||||
1 [GPU]
|
||||
0 [GPU]
|
||||
""")
|
||||
self.assertExpectedInline(
|
||||
format_queue_depth([
|
||||
basic_evaluation.metrics[k]
|
||||
for k in basic_evaluation.event_keys
|
||||
], basic_evaluation.events), """\
|
||||
0 [CPU (Before cudaLaunchKernel)]
|
||||
0 [CPU (Before cudaLaunchKernel)]
|
||||
0 [CPU (Before cudaLaunchKernel)]
|
||||
0 [CPU (Before cudaLaunchKernel)]
|
||||
1 [CPU (After cudaLaunchKernel)]
|
||||
2 [CPU (After cudaLaunchKernel)]
|
||||
3 [CPU (After cudaLaunchKernel)]
|
||||
2 [CPU (After GPU)]
|
||||
1 [CPU (After GPU)]
|
||||
0 [CPU (After GPU)]
|
||||
0 [CPU (No Event)]
|
||||
""")
|
||||
|
||||
def test_utils_compute_queue_depth_when_no_cuda_events(self):
|
||||
# For traces with only cpu events, we expect empty queue depth list
|
||||
|
@ -12,6 +12,7 @@ class EventMetrics:
|
||||
duration_time_ns: int = 0
|
||||
self_time_ns: int = 0
|
||||
idle_time_ns: int = 0
|
||||
queue_depth: int = 0
|
||||
|
||||
@property
|
||||
def fraction_idle_time(self):
|
||||
@ -62,6 +63,12 @@ class BasicEvaluation:
|
||||
self.profile = prof
|
||||
self.metrics: Dict[EventKey, EventMetrics] = {}
|
||||
self.compute_self_time()
|
||||
self.event_keys = sorted((e for e in self.metrics.keys()),
|
||||
key=lambda x: x.event.start_time_ns)
|
||||
self.events = [e.event for e in self.event_keys]
|
||||
self.cuda_events: List[_KinetoEvent] = []
|
||||
self.queue_depth_list = self.compute_queue_depth()
|
||||
self.compute_idle_time()
|
||||
|
||||
def compute_self_time(self):
|
||||
'''
|
||||
@ -88,9 +95,9 @@ class BasicEvaluation:
|
||||
|
||||
def compute_queue_depth(self):
|
||||
'''
|
||||
Computes event's idle time. Idle time is defined as the time when the CUDA kernel queue depth is 0.
|
||||
It also return a Time series of the queue depth data.
|
||||
qd = cuda kernel queue depth
|
||||
Computes queue_depth at each event. This will calculate the queue depth data for
|
||||
All the events in the tree.
|
||||
This will return a list of Interval of queue depth data of cuda launch and kernels.
|
||||
'''
|
||||
assert (self.profile.kineto_results is not None)
|
||||
cuda_event_list = self.profile.kineto_results.events()
|
||||
@ -104,10 +111,6 @@ class BasicEvaluation:
|
||||
return e.device_type() == DeviceType.CUDA and "mem" not in e.name(
|
||||
).lower()
|
||||
|
||||
# Record All the idle intervals
|
||||
idle_interval: List[Interval] = []
|
||||
queue_depth_list: List[Interval] = []
|
||||
|
||||
cuda_launch_events = sorted(
|
||||
(e for e in cuda_event_list if is_cuda_launch_kernel(e)),
|
||||
key=lambda x: x.start_us())
|
||||
@ -115,6 +118,9 @@ class BasicEvaluation:
|
||||
(e for e in cuda_event_list if is_cuda_kernel(e)),
|
||||
key=lambda x: x.start_us())
|
||||
|
||||
self.cuda_events = sorted(cuda_launch_events + cuda_kernel_events,
|
||||
key=lambda x: x.start_us())
|
||||
|
||||
kernel_mapping: Dict[_KinetoEvent, int] = {}
|
||||
last_mapped_kernel = 0
|
||||
for cuda_launch_event in cuda_launch_events:
|
||||
@ -126,36 +132,67 @@ class BasicEvaluation:
|
||||
kernel_mapping[cuda_launch_event] = index
|
||||
last_mapped_kernel = index if index is not None else last_mapped_kernel
|
||||
|
||||
current_kernel_index = -1
|
||||
spawned_kernel_index = None
|
||||
for cuda_launch_event in cuda_launch_events:
|
||||
current_kernel_index = 0
|
||||
spawned_kernel_index = -1
|
||||
|
||||
all_events = cuda_launch_events + cuda_kernel_events + self.events
|
||||
|
||||
def new_old_event_comparator(event):
|
||||
if hasattr(event, "start_us"):
|
||||
return event.start_us() * 1000
|
||||
if hasattr(event, "start_time_ns"):
|
||||
return event.start_time_ns
|
||||
raise Exception("Unknown Event Type")
|
||||
|
||||
queue_depth_list: List[Interval] = []
|
||||
all_events.sort(key=new_old_event_comparator)
|
||||
for event in all_events:
|
||||
# Find latest cuda kernel event
|
||||
while (current_kernel_index + 1 < len(cuda_kernel_events) and
|
||||
cuda_kernel_events[current_kernel_index + 1].start_us() +
|
||||
cuda_kernel_events[current_kernel_index + 1].duration_us() <
|
||||
cuda_launch_event.start_us() +
|
||||
cuda_launch_event.duration_us()):
|
||||
if hasattr(event, "start_us"):
|
||||
start_time = event.start_us() * 1000
|
||||
end_time = (event.start_us() + event.duration_us()) * 1000
|
||||
# Find current spawned cuda kernel event
|
||||
if event in kernel_mapping and kernel_mapping[
|
||||
event] is not None:
|
||||
spawned_kernel_index = kernel_mapping[event]
|
||||
elif hasattr(event, "start_time_ns"):
|
||||
start_time = event.start_time_ns # type: ignore[attr-defined]
|
||||
end_time = event.end_time_ns # type: ignore[attr-defined]
|
||||
|
||||
while (current_kernel_index < len(cuda_kernel_events) and
|
||||
(cuda_kernel_events[current_kernel_index].start_us()) * 1000
|
||||
<= start_time):
|
||||
current_kernel_index += 1
|
||||
current_queue_depth = spawned_kernel_index - current_kernel_index + 1
|
||||
|
||||
# Find current spawned cuda kernel event
|
||||
spawned_kernel_index = kernel_mapping[cuda_launch_event]
|
||||
if spawned_kernel_index is None:
|
||||
current_queue_depth = 0
|
||||
else:
|
||||
current_queue_depth = spawned_kernel_index - current_kernel_index
|
||||
if hasattr(event, "start_us"):
|
||||
queue_depth_list.append(
|
||||
Interval(start_time, end_time, current_queue_depth))
|
||||
elif hasattr(event, "start_time_ns"):
|
||||
self.metrics[EventKey(event)].queue_depth = current_queue_depth
|
||||
|
||||
queue_depth_list.append(
|
||||
Interval(
|
||||
cuda_launch_event.start_us(),
|
||||
cuda_launch_event.start_us() +
|
||||
cuda_launch_event.duration_us(), current_queue_depth))
|
||||
return queue_depth_list
|
||||
|
||||
def compute_idle_time(self):
|
||||
'''
|
||||
Computes idle time of the profile.
|
||||
'''
|
||||
# Based on queue_depth_list, we can calculate idle time for all the events
|
||||
idle = False
|
||||
idle_start = 0
|
||||
idle_intervals: List[Interval] = []
|
||||
for data_point in self.queue_depth_list:
|
||||
if data_point.queue_depth == 0 and not idle:
|
||||
idle_start = data_point.end
|
||||
idle = True
|
||||
if data_point.queue_depth > 0 and idle:
|
||||
idle_intervals.append(Interval(idle_start, data_point.start))
|
||||
idle = False
|
||||
|
||||
event_list = [e.event for e in self.metrics.keys()]
|
||||
for event in event_list:
|
||||
self.metrics[EventKey(event)].idle_time_ns = EventKey(
|
||||
event).intervals_overlap(idle_interval)
|
||||
|
||||
return queue_depth_list
|
||||
event).intervals_overlap(idle_intervals)
|
||||
|
||||
|
||||
def index_of_first_match(seq, predicate, start=0, end=None):
|
||||
|
Reference in New Issue
Block a user