Files
pytorch/test/profiler/test_profiler.py

3570 lines
130 KiB
Python

# Owner(s): ["oncall: profiler"]
import collections
import gc
import json
import os
import re
import subprocess
import sys
import tempfile
import textwrap
import threading
import unittest
import weakref
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from unittest.mock import patch
import expecttest
import torch
import torch.nn as nn
import torch.optim
import torch.utils.data
import torch.utils.data.datapipes as dp
from torch._C._profiler import _TensorMetadata
from torch.autograd import (
_record_function_with_args_enter,
_record_function_with_args_exit,
)
from torch.autograd.profiler import KinetoStepTracker, profile as _profile
from torch.autograd.profiler_legacy import profile as _profile_legacy
from torch.profiler import (
_utils,
DeviceType,
ExecutionTraceObserver,
kineto_available,
profile,
ProfilerAction,
ProfilerActivity,
record_function,
supported_activities,
)
from torch.profiler._pattern_matcher import (
Conv2dBiasFollowedByBatchNorm2dPattern,
ExtraCUDACopyPattern,
ForLoopIndexingPattern,
FP32MatMulPattern,
GradNotSetToNonePattern,
MatMulDimInFP16Pattern,
NamePattern,
OptimizerSingleTensorPattern,
Pattern,
report_all_anti_patterns,
SynchronizedDataLoaderPattern,
)
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_device_type import skipCUDAVersionIn
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_JETSON,
IS_WINDOWS,
parametrize,
run_tests,
serialTest,
skipIfTorchDynamo,
TemporaryDirectoryName,
TemporaryFileName,
TEST_WITH_ASAN,
TEST_WITH_CROSSREF,
TEST_WITH_ROCM,
TestCase,
)
Json = Dict[str, Any]
try:
import psutil
HAS_PSUTIL = True
except ImportError:
HAS_PSUTIL = False
import pickle
from torch._C._profiler import _ExperimentalConfig, _ExtraFields_PyCall
@unittest.skipIf(not HAS_PSUTIL, "Requires psutil to run")
@unittest.skipIf(TEST_WITH_ASAN, "Cannot test with ASAN")
@unittest.skipIf(IS_WINDOWS, "Test is flaky on Windows")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
class TestProfilerCUDA(TestCase):
@skipCUDAVersionIn([(11, 5)]) # https://github.com/pytorch/pytorch/issues/69023
def test_mem_leak(self):
"""Checks that there's no memory leak when using profiler with CUDA"""
t = torch.rand(1, 1).cuda()
p = psutil.Process()
last_rss = collections.deque(maxlen=5)
for outer_idx in range(10):
with _profile(use_cuda=True):
for _ in range(1024):
t = torch.mm(t, t)
gc.collect()
torch.cuda.empty_cache()
last_rss.append(p.memory_info().rss)
# with CUDA events leaking the increase in memory was ~7 MB between
# profiler invocations above
is_increasing = all(
last_rss[idx] > last_rss[idx - 1] for idx in range(1, len(last_rss))
)
max_diff = -1
for idx in range(1, len(last_rss)):
max_diff = max(max_diff, last_rss[idx] - last_rss[idx - 1])
self.assertTrue(
not (is_increasing and max_diff > 100 * 1024),
msg=f"memory usage is increasing, {str(last_rss)}",
)
def test_custom_module_input_op_ids(self):
class MyFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, gO):
(x,) = ctx.saved_tensors
return x
def custom_layer(input_ten):
return MyFunc.apply(input_ten)
# Only testing that emit_nvtx runs when
# record_shapes option is enabled.
with torch.autograd.profiler.emit_nvtx(record_shapes=True) as prof:
x = torch.randn(10, 10, requires_grad=True)
y = torch.randn(10, 10, requires_grad=True)
z = x + y
s = custom_layer(z)
q = s.sum()
q.backward()
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
def test_cudagraph_profiling_workaround(self):
import subprocess
# repro taken from #75504
# Launch in a separate process to catch hanging/illegal memory errors
# and to make sure CUPTI isn't already initialized.
p = subprocess.check_call(
[
sys.executable,
"-c",
"""
import os
import torch
from torch.profiler import ProfilerActivity, profile
def add_one(in_: torch.Tensor):
return in_ + 1
sample_arg = torch.zeros(10, device="cuda").requires_grad_(True)
# add this before cuda graphs are created
torch.profiler._utils._init_for_cuda_graphs()
add_one_graphed = torch.cuda.graphs.make_graphed_callables(add_one, sample_args=(sample_arg,))
zeros = torch.zeros(10, device="cuda")
out = add_one_graphed(zeros)
assert out[0] == 1
with profile(activities=[ProfilerActivity.CPU]):
add_one_graphed(zeros)
with profile(activities=[ProfilerActivity.CUDA]):
add_one_graphed(zeros)
""",
],
universal_newlines=True,
timeout=60,
)
# ^ this will throw an exception if the script fails.
@unittest.skipIf(not torch.profiler.itt.is_available(), "ITT is required")
class TestProfilerITT(TestCase):
def test_custom_module_input_op_ids(self):
class MyFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, gO):
(x,) = ctx.saved_tensors
return x
def custom_layer(input_ten):
return MyFunc.apply(input_ten)
# Only testing that emit_itt runs when
# record_shapes option is enabled.
with torch.autograd.profiler.emit_itt(record_shapes=True) as prof:
x = torch.randn(10, 10, requires_grad=True)
y = torch.randn(10, 10, requires_grad=True)
z = x + y
s = custom_layer(z)
q = s.sum()
q.backward()
class TestRecordFunction(TestCase):
def _record_function_with_param(self):
u = torch.randn(3, 4, 5, requires_grad=True)
with _profile(
with_stack=True, use_kineto=kineto_available(), record_shapes=True
) as prof:
with record_function("## TEST 1 ##", "1, 2, 3"):
rf_handle = _record_function_with_args_enter(
"## TEST 2 ##", 1, False, 2.5, [u, u], "hello", u
)
_record_function_with_args_exit(rf_handle)
with record_function("## TEST 3 ##"):
rf_handle = _record_function_with_args_enter("## TEST 4 ##")
_record_function_with_args_exit(rf_handle)
return prof
def test_record_function(self):
prof_result = self._record_function_with_param()
found_test_1 = False
found_test_2 = False
found_test_3 = False
found_test_4 = False
for e in prof_result.function_events:
if "## TEST 1 ##" == e.name:
found_test_1 = True
self.assertTrue(e.input_shapes == [[]])
elif "## TEST 2 ##" == e.name:
found_test_2 = True
self.assertTrue(e.input_shapes == [[], [], [], [], [], [3, 4, 5]])
elif "## TEST 3 ##" == e.name:
found_test_3 = True
self.assertTrue(e.input_shapes == [])
elif "## TEST 4 ##" == e.name:
found_test_4 = True
self.assertTrue(e.input_shapes == [])
self.assertTrue(found_test_1)
self.assertTrue(found_test_2)
self.assertTrue(found_test_3)
self.assertTrue(found_test_4)
def test_datapipe_with_record_function(self):
with _profile(
with_stack=True, use_kineto=kineto_available(), record_shapes=True
) as prof:
input_dp1 = dp.iter.IterableWrapper(range(4))
input_dp2 = dp.iter.IterableWrapper(range(4, 8))
input_dp3 = dp.iter.IterableWrapper(range(8, 12))
output_dp = input_dp1.mux(input_dp2, input_dp3)
output = list(output_dp)
has_iter = False
has_mux = False
for e in prof.function_events:
if has_iter and has_mux:
break
if not has_iter and "IterableWrapper" in e.name:
has_iter = True
if not has_mux and "Multiplexer" in e.name:
has_mux = True
self.assertTrue(has_iter)
self.assertTrue(has_mux)
def test_datapipe_delegation_with_profiler(self):
class IDPIterator(torch.utils.data.IterDataPipe):
def __init__(self):
self.data = list(range(10))
self._idx = 0
def __iter__(self):
return self
def __next__(self):
if self._idx >= 10:
self._idx = 0
raise StopIteration
self._idx += 1
return self.data[self._idx - 1]
def get_value(self, idx):
return self.data[idx]
dp1 = IDPIterator() # The object itself is an iterator
self.assertEqual(5, dp1.get_value(5))
it_dp1 = iter(dp1) # This creates the 1st iterator
self.assertEqual(5, it_dp1.get_value(5)) # type: ignore[attr-defined]
self.assertEqual(list(range(10)), list(it_dp1))
class IDPDelegator(torch.utils.data.IterDataPipe):
def __init__(self, datapipe):
self.datapipe = datapipe
def __iter__(self):
return iter(self.datapipe)
dp2 = IDPDelegator(dp1)
it_dp2 = iter(dp2)
self.assertEqual(5, it_dp2.get_value(5))
self.assertEqual(list(range(10)), list(it_dp2))
def test_datapipe_with_record_function_fork(self):
with _profile(
with_stack=True, use_kineto=kineto_available(), record_shapes=True
) as prof:
input_dp = dp.iter.IterableWrapper(range(10))
dp1, dp2, dp3 = input_dp.fork(num_instances=3)
output1 = list(dp1)
has_iter = False
has_child = False
for e in prof.function_events:
if has_iter and has_child:
break
if not has_iter and "IterableWrapper" in e.name:
has_iter = True
if not has_child and "_ChildDataPipe" in e.name:
has_child = True
self.assertTrue(has_iter)
self.assertTrue(has_child)
class TestExecutionTrace(TestCase):
def payload(self, use_cuda=False):
u = torch.randn(3, 4, 5, requires_grad=True)
with record_function("## TEST 1 ##", "1, 2, 3"):
inf_val = float("inf")
neg_inf_val = float("-inf")
nan_val = float("nan")
rf_handle = _record_function_with_args_enter(
"## TEST 2 ##",
1,
False,
2.5,
[u, u],
(u, u),
"hello",
u,
inf_val,
neg_inf_val,
nan_val,
)
x = torch.randn(10, 10, requires_grad=True)
if use_cuda:
x = x.cuda()
y = torch.randn(10, 10, requires_grad=True)
if use_cuda:
y = y.cuda()
z = x + y + x * y + x * y
z.backward(z)
gelu = nn.GELU()
m = torch.randn(2)
_ = gelu(m)
if use_cuda:
z = z.cpu()
_record_function_with_args_exit(rf_handle)
def get_execution_trace_root(self, output_file_name) -> Json:
nodes = []
with open(output_file_name) as f:
et_graph = json.load(f)
assert "nodes" in et_graph
nodes = et_graph["nodes"]
return nodes
def get_execution_trace_rf_ids(self, nodes: List[Json]) -> List[int]:
"""Returns a sorted list of rf_id (record function ids) in execution trace"""
def get_rf_id(node):
attrs = node["attrs"]
for a in attrs:
if a["name"] == "rf_id":
return a["value"]
return None
rf_ids_ = (
get_rf_id(n)
for n in nodes
if n["name"] != "[pytorch|profiler|execution_trace|process]"
and n["name"] != "[pytorch|profiler|execution_trace|thread]"
)
return sorted(rf_id for rf_id in rf_ids_ if rf_id is not None)
def get_kineto_rf_ids(self, events: List[Json]) -> List[int]:
"""Returns a sorted list of Record function IDs for CPU operators and user annotations"""
ops_and_annotations = (
e for e in events if e.get("cat", "") in ["cpu_op", "user_annotation"]
)
return sorted(
e.get("args", {}).get("Record function id", -1) for e in ops_and_annotations
)
@unittest.skipIf(not kineto_available(), "Kineto is required")
def test_execution_trace_with_kineto(self):
trace_called_num = 0
def trace_handler(p):
nonlocal trace_called_num
trace_called_num += 1
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
# Create a temp file to save execution trace and kineto data.
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
fp.close()
kt = tempfile.NamedTemporaryFile(
mode="w+t", suffix=".kineto.json", delete=False
)
kt.close()
with profile(
activities=supported_activities(),
schedule=torch.profiler.schedule(
skip_first=3, wait=1, warmup=1, active=2, repeat=1
),
on_trace_ready=trace_handler,
execution_trace_observer=(
ExecutionTraceObserver().register_callback(fp.name)
),
) as p:
for idx in range(10):
with record_function(f"## LOOP {idx} ##"):
self.payload(use_cuda=use_cuda)
p.step()
self.assertEqual(fp.name, p.execution_trace_observer.get_output_file_path())
# Uncomment for debugging
# print("Output kineto = ", kt.name)
# print("Output ET = ", fp.name)
p.export_chrome_trace(kt.name)
self.assertEqual(trace_called_num, 1)
nodes = self.get_execution_trace_root(fp.name)
loop_count = 0
found_root_node = False
for n in nodes:
assert "name" in n
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
found_root_node = True
if n["name"].startswith("## LOOP "):
loop_count += 1
self.assertTrue(found_root_node)
# Since profiler trace is active for 2 iterations
self.assertEqual(loop_count, 2)
# Compare the collected Execution Trace and Kineto Trace
# in terms of record func ID (rf_id) and External IDs
# both of these should match for the same trace window.
with open(kt.name) as f:
kineto = json.load(f)
events = kineto["traceEvents"]
# Look up rf_ids in both Execution and Kineto trace as two lists.
rf_ids_et = self.get_execution_trace_rf_ids(nodes)
rf_ids_kineto = self.get_kineto_rf_ids(events)
self.assertCountEqual(rf_ids_et, rf_ids_kineto)
self.assertListEqual(
rf_ids_et,
rf_ids_kineto,
msg=f"ET and kineto rf_id should exactly match\n"
f" rf_ids_et = {rf_ids_et}\n"
f" rf_ids_kineto = {rf_ids_kineto}\n",
)
def test_execution_trace_alone(self):
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
# Create a temp file to save execution trace data.
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
fp.close()
expected_loop_events = 0
et = ExecutionTraceObserver().register_callback(fp.name)
et.start()
for idx in range(5):
expected_loop_events += 1
with record_function(f"## LOOP {idx} ##"):
self.payload(use_cuda=use_cuda)
et.stop()
assert fp.name == et.get_output_file_path()
et.unregister_callback()
nodes = self.get_execution_trace_root(fp.name)
loop_count = 0
# Expected tensor object tuple size, in th form of:
# [tensor_id, storage_id, offset, numel, itemsize, device_str]
tensor_tuple_size = 6
found_root_node = False
for n in nodes:
assert "name" in n
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
found_root_node = True
if n["name"].startswith("## LOOP "):
loop_count += 1
# Check if tensor tuple representation size is correct.
if n["name"] == "## TEST 2 ##":
assert len(n["inputs"]["values"][3][0]) == tensor_tuple_size
assert found_root_node
assert loop_count == expected_loop_events
@unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS")
def test_execution_trace_with_pt2(self):
class ConvAndRelu(nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(4096, 4096)
self.relu = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear(x)
x = self.relu(x)
return x
# Create a temp file to save execution trace data.
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
fp.close()
test_module = torch.compile(ConvAndRelu())
x = torch.rand(128, 4096)
et = ExecutionTraceObserver().register_callback(fp.name)
et.start()
test_module.forward(x)
et.stop()
assert fp.name == et.get_output_file_path()
et.unregister_callback()
nodes = self.get_execution_trace_root(fp.name)
found_root_node = False
for n in nodes:
assert "name" in n
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
found_root_node = True
assert found_root_node
def test_execution_trace_start_stop(self):
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
# Create a temp file to save execution trace data.
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
fp.close()
expected_loop_events = 0
et = ExecutionTraceObserver()
et.register_callback(fp.name)
for idx in range(10):
if idx == 3:
et.start()
elif idx == 5:
et.stop()
elif idx == 8:
et.start()
elif idx == 9:
et.stop()
if et._execution_trace_running:
expected_loop_events += 1
with record_function(f"## LOOP {idx} ##"):
self.payload(use_cuda=use_cuda)
assert fp.name == et.get_output_file_path()
et.unregister_callback()
nodes = self.get_execution_trace_root(fp.name)
loop_count = 0
found_root_node = False
for n in nodes:
assert "name" in n
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
found_root_node = True
if n["name"].startswith("## LOOP "):
loop_count += 1
assert found_root_node
assert loop_count == expected_loop_events
def test_execution_trace_repeat_in_loop(self):
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
iter_list = {3, 4, 6, 8}
expected_loop_events = len(iter_list)
output_files = []
for idx in range(10):
if idx in iter_list:
# Create a temp file to save execution trace data.
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
fp.close()
output_files.append(fp.name)
et = ExecutionTraceObserver()
et.register_callback(fp.name)
et.start()
with record_function(f"## LOOP {idx} ##"):
self.payload(use_cuda=use_cuda)
if idx in iter_list:
et.stop()
et.unregister_callback()
event_count = 0
for et_file in output_files:
nodes = self.get_execution_trace_root(et_file)
found_root_node = False
for n in nodes:
assert "name" in n
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
assert n["id"] == 1
found_root_node = True
if n["name"].startswith("## LOOP "):
event_count += 1
assert found_root_node
assert event_count == expected_loop_events
def test_execution_trace_no_capture(self):
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
fp.close()
et = ExecutionTraceObserver()
et.register_callback(fp.name)
assert fp.name == et.get_output_file_path()
et.unregister_callback()
nodes = self.get_execution_trace_root(fp.name)
for n in nodes:
assert "name" in n
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
found_root_node = True
assert found_root_node
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/124500")
def test_execution_trace_nested_tensor(self):
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
fp.close()
et = ExecutionTraceObserver()
observer = et.register_callback(fp.name)
def fn(nt):
return nt.sin().cos()
with torch.profiler.profile(execution_trace_observer=observer) as prof:
for i in range(3):
values = torch.rand((8 + i, 4 + i))
offsets = torch.tensor([0, 2, 4, 6, 8 + i])
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
fn(nt)
nodes = self.get_execution_trace_root(fp.name)
found_cos = False
for n in nodes:
assert "name" in n
if "cos" in n["name"]:
found_cos = True
assert found_cos
@instantiate_parametrized_tests
class TestProfiler(TestCase):
@unittest.skipIf(
TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
)
def test_source(self):
"""Checks that source code attribution works for eager, TS and autograd mode"""
# avoid automatic inlining
prev_opt = torch._C._get_graph_executor_optimize()
torch._C._set_graph_executor_optimize(False)
@torch.jit.script
def ts_method_2(x, y):
return torch.matmul(x, y)
@torch.jit.script
def ts_method_1(x, y, z):
a = x + z
w = ts_method_2(x, y) + a
return w.sum()
class DummyModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(
3, 2, kernel_size=1, stride=2, padding=3, bias=False
)
def forward(self, x):
return self.conv(x)
mod = DummyModule()
def call_module(x):
return mod(x)
with _profile(
with_stack=True,
use_kineto=kineto_available(),
experimental_config=_ExperimentalConfig(verbose=True),
) as p:
x = torch.randn(10, 10, requires_grad=True)
y = torch.randn(10, 10, requires_grad=True)
z = x + y
w = ts_method_1(x, y, z)
v = 2 * w
v.backward()
a = torch.randn(2, 3, 2, 2, requires_grad=True)
b = call_module(a)
c = b.sum()
c.backward()
for e in p.function_events:
if "aten::add" in e.name or "AddBackward" in e.name:
self.assertTrue(any("test_profiler" in entry for entry in e.stack))
self.assertTrue(
any(
(
"test_source" in entry
or "ts_method_1" in entry
or "ts_method_2" in entry
)
for entry in e.stack
)
)
# TODO: https://github.com/pytorch/kineto/issues/617
if kineto_available() and not IS_WINDOWS:
with TemporaryFileName(mode="w+") as fname:
p.export_chrome_trace(fname)
with open(fname) as f:
events = json.load(f)["traceEvents"]
def extract(pattern: str):
matches = [e for e in events if re.search(pattern, e["name"])]
self.assertEqual(
len(matches), 1, repr([e["name"] for e in matches])
)
return matches[0]
module_event = extract(r"DummyModule_0")
wrapper_event = extract(r"call_module")
self.assertEqual(
module_event["args"]["Python parent id"],
wrapper_event["args"]["Python id"],
)
torch._C._set_graph_executor_optimize(prev_opt)
@parametrize(
"name,thread_spec",
{
"basic": ((False, False),),
"multiple_preexisting": ((False, False),) * 2,
"open_in_scope": ((True, False),),
"close_in_scope": ((False, True),),
"complex": (
# Large number of background threads
(False, False),
(False, False),
(False, False),
(False, False),
# some of which finish during profiling
(False, True),
(False, True),
# And the profiled section is also multithreaded
(True, False),
(True, True),
),
}.items(),
name_fn=lambda name, thread_spec: name,
)
@serialTest()
@parametrize("work_in_main_thread", [True, False])
def test_source_multithreaded(self, name, thread_spec, work_in_main_thread):
"""Test various threading configurations.
`thread_spec` is a Tuple[Tuple[bool, bool], ...] where each pair is a
thread. The first bool indicates if the thread should be started under
the profiler context and the second is if it should be joined under the
profiler context.
"""
timeout = 15
num_threads = len(thread_spec) + 1 # Main thread
start_barrier = threading.Barrier(num_threads, timeout=timeout)
end_barrier = threading.Barrier(num_threads, timeout=timeout)
class Task(threading.Thread):
def __init__(self):
self._end_gate = threading.Event()
super().__init__(daemon=True)
self.start()
self.finished = False
def run(self):
self._run(self._end_gate)
def release(self):
self._end_gate.set()
@staticmethod
def _run(end_gate=None):
def known_preexisting_function():
start_barrier.wait()
# Fixed point that we can use to test capture of functions
# which are already running when profiling is enabled.
known_preexisting_function()
model = torch.nn.Sequential(
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
)
def invoked_during_run():
pass
invoked_during_run()
_ = model(torch.rand(4, 10))
end_barrier.wait()
if end_gate is not None:
end_gate.wait(timeout=timeout)
threads = {}
def add_threads(context: bool):
for idx, (start_under_profiler, _) in enumerate(thread_spec):
if start_under_profiler == context:
assert idx not in threads
threads[idx] = Task()
def join_threads(context: bool):
for idx, (_, end_under_profiler) in enumerate(thread_spec):
if end_under_profiler == context:
threads[idx].release()
for idx, (_, end_under_profiler) in enumerate(thread_spec):
t = threads[idx]
if end_under_profiler == context:
t.join(timeout=timeout)
try:
add_threads(False)
with torch.profiler.profile(with_stack=True) as prof:
# Threads added while the profiler are running will not be observed
# since there is no way to hook into Python's thread start call to
# register the observer. These are here purely to verify safety.
add_threads(True)
if work_in_main_thread:
Task._run()
else:
start_barrier.wait()
end_barrier.wait()
join_threads(True)
join_threads(False)
finally:
# It is very important that we clean up everything because the
# Python tracer will detect ALL active threads. (Even orphans from
# prior failed tests.) If we don't clean up properly we can
# contaminate subsequent tests.
start_barrier.abort()
end_barrier.abort()
for t in threads.values():
t.release()
for t in threads.values():
t.join(timeout=timeout)
for t in threads.values():
self.assertFalse(t.is_alive())
roots = prof.profiler.kineto_results.experimental_event_tree()
nodes = [
node
for node in _utils.traverse_dfs(roots)
if isinstance(node.extra_fields, _ExtraFields_PyCall)
]
tid_counts = collections.Counter([node.start_tid for node in nodes])
prior_threads = sum(
not start_under_profiler for start_under_profiler, _ in thread_spec
)
expected_threads = prior_threads + 1
self.assertEqual(
len(tid_counts), expected_threads, f"{expected_threads}, {tid_counts}"
)
self.assertEqual(len(nodes), sum(tid_counts.values()))
# Profiler uses uint64_t max as a placeholder until TID can be determined.
no_tid = 2**64 - 1
self.assertFalse(no_tid in tid_counts)
worker_threads = prior_threads + (1 if work_in_main_thread else 0)
observed_preexisting = [
node.start_tid
for node in nodes
if "known_preexisting_function" in node.name
]
self.assertEqual(len(observed_preexisting), worker_threads)
self.assertEqual(len(observed_preexisting), len(set(observed_preexisting)))
observed_during_run = [
node.start_tid for node in nodes if "invoked_during_run" in node.name
]
self.assertEqual(len(observed_during_run), worker_threads)
self.assertEqual(len(observed_during_run), len(set(observed_during_run)))
def payload(self, use_cuda=False):
x = torch.randn(10, 10)
if use_cuda:
x = x.cuda()
y = torch.randn(10, 10)
if use_cuda:
y = y.cuda()
z = torch.mm(x, y)
z = z + y
if use_cuda:
z = z.cpu()
def _check_stats(self, profiler_stats):
self.assertGreater(profiler_stats.profiling_window_duration_sec, 0)
self.assertGreater(profiler_stats.number_of_events, 0)
self.assertGreater(profiler_stats.profiler_prepare_call_duration_us, 0)
self.assertGreater(profiler_stats.profiler_enable_call_duration_us, 0)
self.assertGreater(profiler_stats.profiler_disable_call_duration_us, 0)
self.assertGreater(profiler_stats.parse_kineto_call_duration_us, 0)
self.assertGreater(
profiler_stats.function_events_build_tree_call_duration_us, 0
)
@unittest.skipIf(not kineto_available(), "Kineto is required")
def test_kineto(self):
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
with _profile(use_cuda=use_cuda, use_kineto=True):
self.payload(use_cuda=use_cuda)
# rerun to avoid initial start overhead
with _profile(use_cuda=use_cuda, use_kineto=True) as p:
self.payload(use_cuda=use_cuda)
output = p.key_averages().table(
sort_by="self_cuda_time_total" if use_cuda else "self_cpu_time_total",
row_limit=-1,
)
# print(output)
found_gemm = False
found_memcpy = False
found_mm = False
for e in p.function_events:
if "aten::mm" in e.name:
found_mm = True
if "gemm" in e.name or "Cijk" in e.name:
found_gemm = True
if "Memcpy" in e.name or "memcpy" in e.name:
found_memcpy = True
if use_cuda:
self.assertTrue(found_gemm)
self.assertTrue(found_memcpy)
else:
self.assertTrue(found_mm)
self._check_stats(p._stats)
# p.export_chrome_trace("/tmp/test_trace.json")
@unittest.skipIf(not kineto_available(), "Kineto is required")
@unittest.skipIf(not TEST_MULTIGPU, "Multiple GPUs needed")
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
def test_kineto_multigpu(self):
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
for gpu_id in [0, 1]:
x = torch.randn(10, 10).cuda(gpu_id)
y = torch.randn(10, 10).cuda(gpu_id)
z = x.matmul(y)
found_gemm_0 = False
found_gemm_1 = False
found_cuda = False
for evt in prof.events():
if "gemm" in evt.name.lower() and evt.device_type == DeviceType.CUDA:
if evt.device_index == 0:
found_gemm_0 = True
elif evt.device_index == 1:
found_gemm_1 = True
if "cuda" in evt.name.lower() and evt.device_type == DeviceType.CPU:
found_cuda = True
self.assertTrue(found_gemm_0)
self.assertTrue(found_gemm_1)
self.assertTrue(found_cuda)
self._check_stats(prof._stats())
def test_memory_profiler(self):
def run_profiler(tensor_creation_fn):
# collecting allocs / deallocs
with _profile(
profile_memory=True, record_shapes=True, use_kineto=kineto_available()
) as prof:
x = None
with record_function("test_user_scope_alloc"):
x = tensor_creation_fn()
with record_function("test_user_scope_dealloc"):
del x
return prof.key_averages(group_by_input_shape=True)
def check_metrics(stats, metric, allocs=None, deallocs=None):
stat_metrics = {}
for stat in stats:
stat_metrics[stat.key] = getattr(stat, metric)
if allocs is not None:
for alloc_fn in allocs:
self.assertTrue(alloc_fn in stat_metrics)
self.assertTrue(stat_metrics[alloc_fn] > 0)
if deallocs is not None:
for dealloc_fn in deallocs:
self.assertTrue(dealloc_fn in stat_metrics)
self.assertTrue(stat_metrics[dealloc_fn] < 0)
def create_cpu_tensor():
return torch.rand(10, 10)
def create_cuda_tensor():
return torch.rand(10, 10).cuda()
def create_mkldnn_tensor():
return torch.rand(10, 10, dtype=torch.float32).to_mkldnn()
stats = run_profiler(create_cpu_tensor)
check_metrics(
stats,
"cpu_memory_usage",
allocs=[
"aten::empty",
"aten::rand",
"test_user_scope_alloc",
],
deallocs=[
"test_user_scope_dealloc",
],
)
if kineto_available():
with TemporaryFileName(mode="w+") as fname:
with profile(profile_memory=True) as prof:
x = None
with record_function("test_user_scope_alloc"):
x = create_cpu_tensor()
with record_function("test_user_scope_dealloc"):
del x
prof.export_chrome_trace(fname)
with open(fname) as f:
trace = json.load(f)
assert "traceEvents" in trace
events = trace["traceEvents"]
found_memory_events = False
for evt in events:
assert "name" in evt
if evt["name"] == "[memory]":
found_memory_events = True
assert "args" in evt
assert "Addr" in evt["args"]
assert "Device Type" in evt["args"]
assert "Device Id" in evt["args"]
assert "Bytes" in evt["args"]
# Memory should be an instantaneous event.
assert "dur" not in evt["args"]
assert "cat" not in evt["args"]
assert found_memory_events
if torch.cuda.is_available():
create_cuda_tensor()
stats = run_profiler(create_cuda_tensor)
check_metrics(
stats,
"cuda_memory_usage",
allocs=[
"test_user_scope_alloc",
"aten::to",
"aten::empty_strided",
],
deallocs=[
"test_user_scope_dealloc",
],
)
check_metrics(
stats,
"cpu_memory_usage",
allocs=[
"aten::rand",
"aten::empty",
],
)
if torch.backends.mkldnn.is_available():
create_mkldnn_tensor()
stats = run_profiler(create_mkldnn_tensor)
check_metrics(
stats,
"cpu_memory_usage",
allocs=[
"test_user_scope_alloc",
"aten::rand",
"aten::empty",
"aten::to_mkldnn",
],
deallocs=[
"test_user_scope_dealloc",
],
)
# check top-level memory events
with _profile(profile_memory=True, use_kineto=kineto_available()) as prof:
x = torch.rand(10, 10)
del x
if torch.cuda.is_available():
y = torch.rand(10, 10).cuda()
del y
gc.collect()
stats = prof.key_averages(group_by_input_shape=True)
check_metrics(
stats,
"cpu_memory_usage",
allocs=["aten::rand", "aten::empty"],
deallocs=["[memory]"],
)
if torch.cuda.is_available():
check_metrics(stats, "cuda_memory_usage", deallocs=["[memory]"])
@unittest.skipIf(
IS_JETSON, "Jetson has a guard against OOM since host and gpu memory are shared"
)
def test_oom_tracing(self):
def run_profiler(tensor_creation_fn):
with _profile(profile_memory=True, record_shapes=True) as prof:
with self.assertRaisesRegex(RuntimeError, ".*[tT]ried to allocate.*"):
x = tensor_creation_fn()
return prof
def create_cuda_tensor_oom():
device = torch.device("cuda:0")
return torch.empty(1024, 1024, 1024, 20, dtype=torch.float32, device=device)
def check_trace(fname):
prof.export_chrome_trace(fname)
with open(fname) as f:
trace = json.load(f)
self.assertTrue("traceEvents" in trace)
events = trace["traceEvents"]
found_out_of_memory_events = False
for evt in events:
self.assertTrue("name" in evt)
if evt["name"] == "[OutOfMemory]":
found_out_of_memory_events = True
self.assertTrue("args" in evt)
self.assertTrue("Device Type" in evt["args"])
self.assertTrue("Device Id" in evt["args"])
self.assertTrue("Bytes" in evt["args"])
# Memory should be an instantaneous event.
self.assertTrue("dur" not in evt["args"])
self.assertTrue("cat" not in evt["args"])
self.assertTrue(found_out_of_memory_events)
if torch.cuda.is_available():
with TemporaryFileName(mode="w+") as fname:
prof = run_profiler(create_cuda_tensor_oom)
check_trace(fname)
@unittest.skipIf(not kineto_available(), "Kineto is required")
def test_module_hierarchy(self):
class A(nn.Module):
def my_new_method(self, x):
return x * 3
def forward_impl_(self, x, y):
return self.my_new_method(x) + y
def forward(self, x, y):
y = y - 2
return self.forward_impl_(x, y)
class B(nn.Module):
def forward(self, x):
return x + 2
class C(nn.Module):
def __init__(self):
super().__init__()
self.A0 = A()
self.B0 = B()
def call_b(self, x):
return self.B0.forward(x)
def forward(self, x, y):
return self.A0.forward(x, y) + self.call_b(x)
model = C()
model = torch.jit.script(model)
input_a = torch.rand(128, 128)
input_b = torch.rand(128, 128)
op_to_module_hierarchy = {}
op_to_module_hierarchy["aten::sub"] = ["TOP(C)::forward.A0(A)::forward."]
op_to_module_hierarchy["aten::mul"] = [
"TOP(C)::forward.A0(A)::forward.SELF(A)::forward_impl_.SELF(A)::my_new_method."
]
op_to_module_hierarchy["aten::add"] = [
"TOP(C)::forward.A0(A)::forward.SELF(A)::forward_impl_.",
"TOP(C)::forward.SELF(C)::call_b.B0(B)::forward.",
"TOP(C)::forward.",
]
with TemporaryFileName(mode="w+") as fname:
with profile(
activities=[torch.profiler.ProfilerActivity.CPU],
with_modules=True,
) as prof:
model(input_a, input_b)
prof.export_chrome_trace(fname)
with open(fname) as f:
trace = json.load(f)
assert "traceEvents" in trace
events = trace["traceEvents"]
found_memory_events = False
for evt in events:
assert "name" in evt
if "args" in evt:
op_name = evt["name"]
if "Module Hierarchy" in evt["args"]:
hierarchy = evt["args"]["Module Hierarchy"]
if op_name in op_to_module_hierarchy:
assert hierarchy in op_to_module_hierarchy[op_name]
def test_high_level_trace(self):
"""Checks that python side high level events are recorded."""
class RepeatedDataset(torch.utils.data.Dataset):
def __init__(self, N, D_in, D_out):
self.N = N
self.x = torch.randn(N, D_in)
self.y = torch.randn(N, D_out)
def __len__(self):
return self.N
def __getitem__(self, idx):
return self.x, self.y
class TwoLayerNet(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super().__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, D_out)
def forward(self, x):
h_relu = self.linear1(x).clamp(min=0)
y_pred = self.linear2(h_relu)
return y_pred
class CustomSGD(torch.optim.SGD):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def train():
for _, data in enumerate(dataloader):
x, y = data[0], data[1]
y_pred = model(x)
loss = criterion(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
N, D_in, H, D_out = 8, 10, 5, 2
model = TwoLayerNet(D_in, H, D_out)
criterion = torch.nn.MSELoss(reduction="sum")
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
ds = RepeatedDataset(N, D_in, D_out)
dataloader = torch.utils.data.DataLoader(ds, batch_size=1)
try:
train()
except Exception:
self.assertTrue(False, "Expected no exception without profiling.")
# Create multiple instances, expect each func is hooked only one time.
# Nested wrappers(repeated patching) will make following test fail.
optimizer_duplicate = torch.optim.SGD(model.parameters(), lr=1e-4)
dataloader_duplicate = torch.utils.data.DataLoader(ds, batch_size=1)
def judge(expected_event_count, prof):
actual_event_count = {}
for e in prof.function_events:
if "#" in e.name:
key = e.name
if key in expected_event_count.keys():
actual_event_count[key] = (
actual_event_count.setdefault(key, 0) + 1
)
for key, count in expected_event_count.items():
self.assertTrue(
(key in actual_event_count.keys())
and (count == actual_event_count[key])
)
with _profile(use_kineto=kineto_available()) as prof:
train()
expected_event_count = {
# "+1" because the final iteration will enter __next__ but skip the loop body.
"enumerate(DataLoader)#_SingleProcessDataLoaderIter.__next__": (N + 1),
"Optimizer.step#SGD.step": N,
"Optimizer.zero_grad#SGD.zero_grad": N,
}
judge(expected_event_count, prof)
# Test on pickle/unpickle. Expect to work in multi-processing.
optimizer = pickle.loads(pickle.dumps(optimizer))
with _profile(use_kineto=kineto_available()) as prof:
train()
judge(expected_event_count, prof)
# Test on customized optimizer.
optimizer = CustomSGD(model.parameters(), lr=1e-4)
with _profile(use_kineto=kineto_available()) as prof:
train()
expected_event_count = {
"enumerate(DataLoader)#_SingleProcessDataLoaderIter.__next__": (N + 1),
"Optimizer.step#CustomSGD.step": N,
"Optimizer.zero_grad#CustomSGD.zero_grad": N,
}
judge(expected_event_count, prof)
def test_flops(self):
model = torch.nn.Sequential(
nn.Conv2d(16, 33, 18),
nn.ReLU(),
nn.Linear(243, 243),
nn.ReLU(),
)
inputs = torch.randn(40, 16, 18, 260)
nested_tensor = torch.nested.nested_tensor(
[torch.randn((2, 5)), torch.randn((3, 5))], layout=torch.jagged
)
with _profile(
record_shapes=True, with_flops=True, use_kineto=kineto_available()
) as prof:
model(inputs)
# test that nested tensor won't cause exception during flop compute
nested_tensor = nested_tensor + nested_tensor
profiler_output = prof.key_averages(group_by_input_shape=True).table(
sort_by="cpu_time_total", row_limit=10
)
self.assertIn("Total MFLOPs", profiler_output)
if not (kineto_available() and torch.cuda.is_available()):
return
with profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
with_flops=True,
) as kineto_profiler:
model(inputs)
profiler_output = kineto_profiler.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1
)
self.assertIn("Total MFLOPs", profiler_output)
def test_kineto_profiler_api(self):
called_num = [0]
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
with profile(activities=supported_activities()):
self.payload(use_cuda=use_cuda)
def trace_handler(p):
output = p.key_averages().table(
sort_by="self_cuda_time_total" if use_cuda else "self_cpu_time_total",
row_limit=-1,
)
# print(output)
# p.export_chrome_trace("/tmp/test_trace_" + str(called_num[0]) + ".json")
called_num[0] += 1
initial_step = KinetoStepTracker.current_step()
with profile(
activities=supported_activities(),
schedule=torch.profiler.schedule(wait=1, warmup=1, active=2),
on_trace_ready=trace_handler,
) as p:
for idx in range(8):
self.payload(use_cuda=use_cuda)
p.step()
self.assertEqual(called_num[0], 2)
self.assertEqual(KinetoStepTracker.current_step(), initial_step + 8)
# case without schedule
with profile(activities=supported_activities()) as p:
self.payload(use_cuda=use_cuda)
self.payload(use_cuda=use_cuda)
output = p.key_averages().table(
sort_by="self_cuda_time_total" if use_cuda else "self_cpu_time_total",
row_limit=-1,
)
# print(output)
test_schedule = torch.profiler.schedule(
skip_first=2, wait=1, warmup=1, active=2, repeat=2
)
test_schedule_expected_outputs = [
ProfilerAction.NONE,
ProfilerAction.NONE,
ProfilerAction.NONE,
ProfilerAction.WARMUP,
ProfilerAction.RECORD,
ProfilerAction.RECORD_AND_SAVE,
ProfilerAction.NONE,
ProfilerAction.WARMUP,
ProfilerAction.RECORD,
ProfilerAction.RECORD_AND_SAVE,
ProfilerAction.NONE,
ProfilerAction.NONE,
ProfilerAction.NONE,
ProfilerAction.NONE,
]
for step in range(len(test_schedule_expected_outputs)):
self.assertEqual(test_schedule(step), test_schedule_expected_outputs[step])
def test_kineto_profiler_multiple_steppers(self):
niters = 8
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
net = SimpleNet()
opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
opt.zero_grad()
inputs = torch.rand(10)
with profile(activities=supported_activities()):
self.payload(use_cuda=use_cuda)
def optimizer_step():
"""This simulates a step() hook in the optimizer"""
KinetoStepTracker.increment_step("yet_another_step")
initial_step = KinetoStepTracker.current_step()
def run_batch():
out = net(inputs)
loss = torch.nn.functional.cross_entropy(out, torch.rand(2))
loss.backward()
opt.step()
# Manually call the hook. TODO: Remove this once we add the
# profiler step hooks in the Optimizer class that will get triggered above.
# See https://github.com/pytorch/pytorch/issues/88446
optimizer_step()
for idx in range(niters):
run_batch()
with profile(
activities=supported_activities(),
schedule=torch.profiler.schedule(wait=1, warmup=1, active=2),
) as p:
for idx in range(niters):
run_batch()
p.step()
self.assertEqual(KinetoStepTracker.current_step(), initial_step + 2 * niters)
def test_export_stacks(self):
with _profile(
with_stack=True,
use_kineto=kineto_available(),
experimental_config=_ExperimentalConfig(verbose=True),
) as p:
x = torch.randn(10, 10)
y = torch.randn(10, 10)
z = torch.mm(x, y)
z = z + y
with TemporaryFileName(mode="w+") as fname:
p.export_stacks(fname)
with open(fname) as f:
lines = f.readlines()
assert len(lines) > 0, "Empty stacks file"
for line in lines:
is_int = False
try:
assert int(line.split(" ")[-1]) > 0, "Invalid stacks record"
is_int = True
except ValueError:
pass
assert is_int, "Invalid stacks record"
@unittest.skipIf(not kineto_available(), "Kineto is required")
def test_tensorboard_trace_handler(self):
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
with _profile(use_cuda=use_cuda, use_kineto=True):
self.payload(use_cuda=use_cuda)
with TemporaryDirectoryName() as dname:
with profile(
activities=[torch.profiler.ProfilerActivity.CPU]
+ ([torch.profiler.ProfilerActivity.CUDA] if use_cuda else []),
schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=3),
on_trace_ready=torch.profiler.tensorboard_trace_handler(dname),
) as p:
for _ in range(18):
self.payload(use_cuda=use_cuda)
p.step()
self.assertTrue(os.path.exists(dname))
file_num = 0
for file_name in os.listdir(dname):
parts = file_name.split(".")
self.assertTrue(len(parts) > 4)
self.assertTrue(
parts[-4].isdigit() and int(parts[-4]) > 0,
"Wrong tracing file name pattern",
)
self.assertEqual(parts[-3:], ["pt", "trace", "json"])
file_num += 1
self.assertEqual(file_num, 3)
# test case for gzip file format
with TemporaryDirectoryName() as dname:
p = profile(
activities=[torch.profiler.ProfilerActivity.CPU]
+ ([torch.profiler.ProfilerActivity.CUDA] if use_cuda else []),
schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=3),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
dname, use_gzip=True
),
)
p.start()
for _ in range(18):
self.payload(use_cuda=use_cuda)
p.step()
p.stop()
self.assertTrue(os.path.exists(dname))
file_num = 0
for file_name in os.listdir(dname):
parts = file_name.split(".")
self.assertTrue(len(parts) > 4)
self.assertTrue(
parts[-5].isdigit() and int(parts[-5]) > 0,
"Wrong tracing file name pattern",
)
self.assertEqual(parts[-4:], ["pt", "trace", "json", "gz"])
file_num += 1
self.assertEqual(file_num, 3)
@unittest.skipIf(not kineto_available(), "Kineto is required")
def test_profiler_metadata(self):
t1, t2 = torch.ones(1), torch.ones(1)
with profile() as prof:
torch.add(t1, t2)
prof.add_metadata("test_key1", "test_value1")
prof.add_metadata_json("test_key2", "[1,2,3]")
with TemporaryFileName(mode="w+") as fname:
prof.export_chrome_trace(fname)
with open(fname) as f:
trace = json.load(f)
assert "test_key1" in trace
assert trace["test_key1"] == "test_value1"
assert "test_key2" in trace
assert trace["test_key2"] == [1, 2, 3]
def _test_profiler_tracing(self, use_kineto):
with _profile(use_kineto=use_kineto) as prof:
t1, t2 = torch.ones(1), torch.ones(1)
torch.add(t1, t2)
with TemporaryFileName(mode="w+") as fname:
prof.export_chrome_trace(fname)
# read the trace and expect valid json
# if the JSON generated by export_chrome_trace is not valid, this will throw and fail the test.
with open(fname) as f:
json.load(f)
# test empty trace
with _profile(use_kineto=use_kineto) as prof:
pass
# saving an empty trace
with TemporaryFileName(mode="w+") as fname:
prof.export_chrome_trace(fname)
# Same test but for cuda.
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
if not use_cuda:
return
device = torch.device("cuda:0")
with _profile(use_cuda=True, use_kineto=use_kineto) as prof:
t1, t2 = torch.ones(1, device=device), torch.ones(1, device=device)
torch.add(t1, t2)
with TemporaryFileName(mode="w+") as fname:
prof.export_chrome_trace(fname)
# Now validate the json
with open(fname) as f:
json.load(f)
def test_profiler_tracing(self):
self._test_profiler_tracing(False)
if kineto_available():
self._test_profiler_tracing(True)
def test_profiler_op_event_args(self):
torch._C._profiler._set_record_concrete_inputs_enabled_val(True)
with _profile(record_shapes=True) as prof:
a = torch.ones((64, 32), dtype=torch.float32)
c = torch.cat([a, a]).sin()
with TemporaryFileName(mode="w+") as fname:
prof.export_chrome_trace(fname)
with open(fname) as f:
j = json.load(f)
op_events = [
e for e in j["traceEvents"] if e.get("cat", "") == "cpu_op"
]
for e in op_events:
args = e["args"]
if e["name"] == "aten::ones":
self.assertEqual(
args["Input type"],
["ScalarList", "Scalar", "", "", "Scalar"],
)
self.assertEqual(
args["Concrete Inputs"], ["[64, 32]", "6", "", "", "False"]
)
if e["name"] == "aten::cat":
self.assertEqual(args["Input Dims"], [[[64, 32], [64, 32]], []])
self.assertEqual(args["Input type"], ["TensorList", "Scalar"])
# check that each op has record function id
self.assertGreaterEqual(
args.get("Record function id", -1),
0,
f"Failed finding record funciont for op = {e}",
)
def test_profiler_fwd_bwd_link(self):
with _profile(use_kineto=True) as prof:
t1, t2 = torch.ones(1, requires_grad=True), torch.ones(
1, requires_grad=True
)
z = torch.add(t1, t2)
y = torch.ones(1)
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
loss.backward()
with TemporaryFileName(mode="w+") as fname:
prof.export_chrome_trace(fname)
with open(fname) as f:
j = json.load(f)
events = j["traceEvents"]
ts_to_name = {}
flow_s_to_ts = {}
flow_f_to_ts = {}
for e in events:
if e["ph"] == "X":
ts_to_name[e["ts"]] = e["name"]
if (
"cat" in e
and "name" in e
and e["cat"] == "fwdbwd"
and e["name"] == "fwdbwd"
):
if e["ph"] == "s":
flow_s_to_ts[e["id"]] = e["ts"]
elif e["ph"] == "f":
flow_f_to_ts[e["id"]] = e["ts"]
self.assertEqual(len(flow_s_to_ts), 2)
self.assertEqual(len(flow_f_to_ts), 2)
self.assertIn(1, flow_s_to_ts)
self.assertIn(1, flow_f_to_ts)
self.assertIn(2, flow_s_to_ts)
self.assertIn(2, flow_f_to_ts)
s_ts_1 = flow_s_to_ts[1]
f_ts_1 = flow_f_to_ts[1]
s_ts_2 = flow_s_to_ts[2]
f_ts_2 = flow_f_to_ts[2]
self.assertTrue(
all(
ts in ts_to_name.keys()
for ts in [s_ts_1, f_ts_1, s_ts_2, f_ts_2]
)
)
self.assertTrue(
ts_to_name[s_ts_1] == "aten::binary_cross_entropy_with_logits"
)
self.assertTrue(ts_to_name[s_ts_2] == "aten::add")
def test_profiler_disable_fwd_bwd_link(self):
try:
torch._C._profiler._set_fwd_bwd_enabled_val(False)
with _profile(use_kineto=True) as prof:
t1, t2 = torch.ones(1, requires_grad=True), torch.ones(
1, requires_grad=True
)
z = torch.add(t1, t2)
y = torch.ones(1)
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
loss.backward()
with TemporaryFileName(mode="w+") as fname:
prof.export_chrome_trace(fname)
with open(fname) as f:
j = json.load(f)
events = j["traceEvents"]
for e in events:
self.assertNotEqual(e.get("cat", None), "fwdbwd")
finally:
torch._C._profiler._set_fwd_bwd_enabled_val(True)
# This test is broken on Windows, the likely reason is that kineto/CUPTI
# is not supported that particular environment. Once the CI stabilizes
# we can narrow the condition so Windows is checked as well (TODO)
@unittest.skipIf(not kineto_available(), "Kineto is required")
@unittest.skipIf(IS_WINDOWS, "Test does not work on Windows")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
def test_profiler_cuda_sync_events(self):
device = torch.device("cuda:0")
t1, t2 = torch.ones(1, device=device), torch.ones(1, device=device)
def workload() -> None:
torch.add(t1, t2)
torch.cuda.synchronize()
torch.add(t1, t2)
def trace_and_check(exp_config: Optional[_ExperimentalConfig]) -> None:
with _profile(
use_kineto=True,
use_cuda=True,
experimental_config=exp_config,
) as prof:
workload()
with TemporaryFileName(mode="w+") as fname:
# fname = "/tmp/kineto_out.json"
prof.export_chrome_trace(fname)
with open(fname) as f:
j = json.load(f)
cats = {e.get("cat", None) for e in j["traceEvents"]}
self.assertTrue(
"cuda_sync" in cats,
"Expected to find cuda_sync event" f" found = {cats}",
)
print("Testing enable_cuda_sync_events in _ExperimentalConfig")
trace_and_check(exp_config=_ExperimentalConfig(enable_cuda_sync_events=True))
print("Testing _profiler._set_cuda_sync_enabled_val()")
try:
torch._C._profiler._set_cuda_sync_enabled_val(True)
trace_and_check(exp_config=None)
finally:
torch._C._profiler._set_cuda_sync_enabled_val(False)
def test_profiler_type(self):
profiler_type = torch._C._autograd._profiler_type
ActiveProfilerType = torch._C._profiler.ActiveProfilerType
self.assertEqual(profiler_type(), ActiveProfilerType.NONE)
# Autograd profiler
with _profile_legacy():
self.assertEqual(profiler_type(), ActiveProfilerType.LEGACY)
# Kineto profiler
with profile():
self.assertEqual(profiler_type(), ActiveProfilerType.KINETO)
def test_profiler_correlation_id(self):
"""
We expect the correlation_id to be unique across multiple invokation of the profiler,
So we will reuse id_uniqueness_set.
"""
id_uniqueness_set = set()
model = torch.nn.Sequential(
nn.Conv2d(16, 33, 18),
nn.ReLU(),
nn.Linear(243, 243),
nn.ReLU(),
)
inputs = torch.randn(40, 16, 18, 260)
uint32_max = 2**32 - 1
for i in range(5):
with profile() as prof:
model(inputs)
for event in prof.profiler.kineto_results.events():
corr_id = event.correlation_id()
if (corr_id) and event.device_type() == DeviceType.CPU:
self.assertTrue(corr_id not in id_uniqueness_set)
id_uniqueness_set.add(corr_id)
self.assertTrue(corr_id < uint32_max)
def test_nested_tensor_with_shapes(self):
a = torch.randn(4, 4)
b = torch.randn(4, 4)
c = torch.randn(4, 4)
inp = torch.nested.nested_tensor([a, b])
with torch.profiler.profile(record_shapes=True) as prof:
torch.nn.functional.linear(inp, c, None)
for e in prof.events():
if e.name in ("aten::mm", "aten::addmm"):
# intentionally vague tests to protect against possible future changes
# of mm to addmm or other impl, or changing internal order of args
self.assertTrue(len(e.input_shapes) > 0)
self.assertTrue(len(e.input_shapes[0]) > 0)
@patch.dict(os.environ, {"KINETO_USE_DAEMON": "1"})
@patch.dict(os.environ, {"KINETO_DAEMON_INIT_DELAY_S": "1"})
def test_kineto_profiler_with_environment_variable(self):
script = """
import torch
import torch.nn as nn
from torch.profiler import supported_activities, profile
from torch.autograd.profiler import KinetoStepTracker
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
return self.fc2(self.fc1(x))
def payload(use_cuda=False):
x = torch.randn(10, 10)
if use_cuda:
x = x.cuda()
y = torch.randn(10, 10)
if use_cuda:
y = y.cuda()
z = torch.mm(x, y)
z = z + y
if use_cuda:
z = z.cpu()
niters = 8
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
net = SimpleNet()
opt = torch.optim.SGD(net.parameters(), lr=0.01)
opt.zero_grad()
inputs = torch.rand(10)
with profile(activities=supported_activities()):
payload(use_cuda=use_cuda)
initial_step = KinetoStepTracker.current_step()
def run_batch():
out = net(inputs)
loss = torch.nn.functional.cross_entropy(out, torch.rand(2))
loss.backward()
opt.step()
for _ in range(niters):
run_batch()
with profile(
activities=supported_activities(),
schedule=torch.profiler.schedule(
wait=1,
warmup=1,
active=2),
) as p:
for _ in range(niters):
run_batch()
p.step()
assert KinetoStepTracker.current_step() == initial_step + 2 * niters
"""
try:
subprocess.check_output(
[sys.executable, "-W", "all", "-c", script],
cwd=os.path.dirname(os.path.realpath(__file__)),
)
except subprocess.CalledProcessError as e:
if e.returncode != 0:
self.assertTrue(
False,
"Kineto is not working properly with the Dynolog environment variable",
)
def test_concrete_inputs_profiling(self):
x = torch.rand(2, 6)
with profile(record_shapes=True) as p:
y = x.as_strided([4, 3], [1, 4])
found = False
for e in p.events():
if e.name in ("aten::as_strided"):
found = True
self.assertTrue(len(e.input_shapes) > 0)
self.assertTrue(len(e.concrete_inputs) > 0)
self.assertEqual([2, 6], e.input_shapes[0])
self.assertEqual([4, 3], e.concrete_inputs[1])
self.assertEqual([1, 4], e.concrete_inputs[2])
self.assertTrue(found, "Expected to find aten::as_strided but did not")
def test_concrete_inputs_profiling_toggling(self):
try:
for before, after in [(True, False), (False, True)]:
x = torch.rand(2, 6)
torch._C._profiler._set_record_concrete_inputs_enabled_val(before)
with profile(record_shapes=True) as p:
y = x.as_strided([4, 3], [1, 4])
torch._C._profiler._set_record_concrete_inputs_enabled_val(after)
found = False
for e in p.events():
if e.name in ("aten::as_strided"):
found = True
self.assertTrue(len(e.input_shapes))
self.assertTrue(found, "Expected to find aten::as_strided but did not")
finally:
torch._C._profiler._set_record_concrete_inputs_enabled_val(True)
def test_record_function_fast(self):
x, y = (torch.rand((4, 4)) for _ in range(2))
with profile(record_shapes=True) as p:
for _ in range(4):
# Test first with no optional args
with torch._C._profiler._RecordFunctionFast("add_test_fast_rf1"):
x.add(y)
self.assertGreaterEqual(
len([e for e in p.events() if e.name == "add_test_fast_rf1"]), 4
)
for e in p.events():
if e.name == "add_test_fast_rf1":
self.assertTrue(e.input_shapes == [])
with profile(record_shapes=True) as p:
# add optional args
cm = torch._C._profiler._RecordFunctionFast(
"add_test_fast_rf2", [x, y], {"stream": 0, "grid": "lambda x : x + 1"}
)
for _ in range(4):
with cm:
x.add(y)
self.assertGreaterEqual(
len([e for e in p.events() if e.name == "add_test_fast_rf2"]), 4
)
for e in p.events():
if e.name == "add_test_fast_rf2":
self.assertTrue(e.input_shapes == [[4, 4], [4, 4]])
with profile(record_shapes=True) as p:
cm = torch._C._profiler._RecordFunctionFast(
"add_test_fast_rf3", input_values=["hi"], keyword_values={"hi": "hello"}
)
for _ in range(4):
try:
with cm:
x.add(y)
raise ValueError
x.relu()
except ValueError:
pass
self.assertGreaterEqual(
len([e for e in p.events() if e.name == "add_test_fast_rf3"]), 4
)
self.assertFalse(any((e.name and "relu" in e.name) for e in p.events()))
for e in p.events():
if e.name == "add_test_fast_rf3":
self.assertTrue(e.input_shapes == [[]])
with profile() as p:
for _ in range(4):
with torch._C._profiler._RecordFunctionFast(
"add_test_fast_rf4", [x, y]
):
x.add(y)
with torch._C._profiler._RecordFunctionFast("add_test_fast_rf5"):
x.relu()
self.assertGreaterEqual(
len([e for e in p.events() if e.name == "add_test_fast_rf4"]), 4
)
for e in p.events():
if e.name == "add_test_fast_rf4":
self.assertTrue(e.input_shapes == [])
self.assertGreaterEqual(
len([e for e in p.events() if e.name == "add_test_fast_rf5"]), 4
)
with profile(record_shapes=True) as p:
# test optional args with tuple
cm = torch._C._profiler._RecordFunctionFast(
"add_test_fast_rf6",
(
x,
y,
),
)
for _ in range(4):
with cm:
x.add(y)
self.assertGreaterEqual(
len([e for e in p.events() if e.name == "add_test_fast_rf6"]), 4
)
for e in p.events():
if e.name == "add_test_fast_rf6":
self.assertTrue(e.input_shapes == [[4, 4], [4, 4]])
def test_is_profiler_enabled(self):
self.assertFalse(torch.autograd.profiler._is_profiler_enabled)
with profile() as p:
self.assertTrue(torch.autograd.profiler._is_profiler_enabled)
self.assertFalse(torch.autograd.profiler._is_profiler_enabled)
with torch.autograd.profiler.profile() as p:
self.assertTrue(torch.autograd.profiler._is_profiler_enabled)
self.assertFalse(torch.autograd.profiler._is_profiler_enabled)
def test_guarded_record_function_fast(self):
x, y = (torch.rand((4, 4)) for _ in range(2))
with profile() as p:
cm = torch._C._profiler._RecordFunctionFast("guarded_rff")
for _ in range(4):
if torch.autograd.profiler._is_profiler_enabled:
with cm:
x.add(y)
else:
x.add(y)
self.assertGreaterEqual(
len([e for e in p.events() if e.name == "guarded_rff"]), 4
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
def test_event_list(self):
# AFAIK event list is part of legacy profiler and/or used when kineto is not available.
# This test has basic sanity checks to test against obvious regressions.
x, y = (torch.rand((4, 4), requires_grad=True, device="cuda") for _ in range(2))
with profile(with_stack=True) as p:
z = (x @ y).relu().sum()
z.backward()
event_list = torch.autograd.profiler_util.EventList(p.events())
# event_list._build_tree()
with TemporaryFileName(mode="w+") as fname:
event_list.export_chrome_trace(fname)
with open(fname) as f:
json.load(f)
event_list.table()
def find_node_with_name(nodes, name):
for node in _utils.traverse_dfs(nodes):
if node.name == name:
return node
def find_node_with_regex(nodes, pattern):
for node in _utils.traverse_dfs(nodes):
if re.search(pattern, node.name):
return node
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
return self.fc2(self.fc1(x))
class TestTorchTidyProfiler(TestCase):
def _get_tensor_fields(self, node, index):
self.assertIsNotNone(node)
self.assertIsInstance(
node.extra_fields, torch._C._profiler._ExtraFields_TorchOp
)
tensor_info = node.extra_fields.inputs[index]
self.assertIsInstance(tensor_info, _TensorMetadata)
self.assertIsNotNone(tensor_info.impl_ptr)
self.assertIsNotNone(tensor_info.storage_data_ptr)
self.assertIsNotNone(tensor_info.id)
return tensor_info.impl_ptr, tensor_info.storage_data_ptr, tensor_info.id
def test_pointers_and_ids(self):
a = torch.randn(4, 3)
a_initial_storage_data = a.storage().data_ptr()
# Views of tensors can share the same storage, but have different TensorImpls
b = a.view((1, 12))
c = torch.randn(4, 1)
c_initial_storage_data = c.storage().data_ptr()
d = torch.randn(4, 3)
with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
_ = a + c
_ = b * c
# Resize should create a new data_ptr but keep the TensorImpl the same.
f = a.resize_(128, 129)
_ = torch.relu(f)
# `.set_` points a Tensor at an existing storage.
_ = d.sin()
c.set_(d.storage())
_ = c.cos()
nodes = p.profiler.kineto_results.experimental_event_tree()
def get_fields(op_name, index):
return self._get_tensor_fields(find_node_with_name(nodes, op_name), index)
a_impl, a_storage_data, a_id = get_fields("aten::add", 0)
b_impl, b_storage_data, b_id = get_fields("aten::mul", 0)
# Profiler matches ground truth from Python API.
self.assertEqual(a_storage_data, a_initial_storage_data)
# Views are handled correctly.
self.assertEqual(a_storage_data, b_storage_data)
self.assertNotEqual(a_impl, b_impl)
# The same Tensor used in multiple calls gives identical results.
c_impl, c_storage_data, c_id = get_fields("aten::add", 1)
self.assertEqual((c_impl, c_storage_data, c_id), get_fields("aten::mul", 1))
self.assertEqual(c_storage_data, c_initial_storage_data)
# Mutations to the underlying storage are reflected. (But ID is shared.)
f_impl, f_storage_data, f_id = get_fields("aten::relu", 0)
self.assertEqual(a_impl, f_impl)
self.assertNotEqual(a_storage_data, f_storage_data)
self.assertEqual(a_id, f_id)
# Calling `set_` with an existing Tensor makes them share an ID.
d_impl, d_storage_data, d_id = get_fields("aten::sin", 0)
c_impl_new, c_storage_data_new, c_id_new = get_fields("aten::cos", 0)
self.assertNotEqual(d_impl, c_impl_new)
self.assertEqual(d_storage_data, c_storage_data_new)
self.assertEqual(c_id, c_id_new)
self.assertEqual(d_id, c_id_new)
@staticmethod
def _format_allocations(profiled_code):
gc.collect()
with profile(profile_memory=True, record_shapes=True) as prof:
profiled_code()
gc.collect()
root_events = prof.profiler.kineto_results.experimental_event_tree()
events = sorted(_utils.traverse_dfs(root_events), key=lambda x: x.start_time_ns)
allocations = tuple(
event.extra_fields
for event in events
if isinstance(
event.extra_fields, torch._C._profiler._ExtraFields_Allocation
)
)
return textwrap.indent(
"\n".join(
f"{repr(i.id):>5}{' ' * 6}"
f"{repr(i.allocation_id):>5}{' ' * 6}"
f"{'Allocation' if i.alloc_size > 0 else 'Free'}"
for i in allocations
),
" " * 12,
)
def test_tensorimpl_invalidation_set(self) -> None:
def profiled_code(add_empty_set: bool):
x = torch.ones((1,))
# Determines if new storage is created before or after the old one
# is destroyed.
if add_empty_set:
x.set_()
x.set_(torch.ones((1,)).storage())
x.view_as(x)
self.assertExpectedInline(
self._format_allocations(lambda: profiled_code(add_empty_set=False)),
"""\
0 1 Allocation
0 2 Allocation
0 1 Free
0 2 Free""",
)
self.assertExpectedInline(
self._format_allocations(lambda: profiled_code(add_empty_set=True)),
"""\
0 1 Allocation
0 1 Free
0 2 Allocation
0 2 Free""",
)
def test_tensorimpl_invalidation_keep_alive(self) -> None:
def profiled_code(add_empty_set: bool):
x = torch.ones((1,))
x_storages = [x.storage()]
for _ in range(3):
x.set_()
x.set_(torch.ones((1,)).storage())
# This keeps the StorageImpls alive and preserves the chain.
# (Despite the `set_()` call.)
x_storages.append(x.storage())
x.view_as(x)
# Free storage in a deterministic fashion.
while x_storages:
x_storages.pop()
gc.collect()
# Determines if new storage is created before or after the old one
# is destroyed.
if add_empty_set:
x.set_()
for _ in range(3):
x.set_(torch.ones((1,)).storage())
x.view_as(x)
del x
gc.collect()
self.assertExpectedInline(
self._format_allocations(lambda: profiled_code(add_empty_set=False)),
"""\
0 1 Allocation
0 2 Allocation
0 4 Allocation
0 5 Allocation
0 4 Free
0 2 Free
0 1 Free
0 6 Allocation
0 5 Free
0 7 Allocation
0 6 Free
0 8 Allocation
0 7 Free
0 8 Free""",
)
self.assertExpectedInline(
self._format_allocations(lambda: profiled_code(add_empty_set=True)),
"""\
0 1 Allocation
0 2 Allocation
0 4 Allocation
0 5 Allocation
0 4 Free
0 2 Free
0 1 Free
0 5 Free
0 6 Allocation
0 7 Allocation
0 6 Free
0 8 Allocation
0 7 Free
0 8 Free""",
)
def test_tensorimpl_invalidation_full(self) -> None:
def profiled_code():
x = torch.ones((1,))
x_storages = [x.storage()]
for _ in range(3):
x.set_()
x.set_(torch.ones((1,)).storage())
x_storages.append(x.storage())
x.view_as(x)
# Free storage in a deterministic fashion.
while x_storages:
x_storages.pop()
gc.collect()
for _ in range(3):
x.set_(torch.ones((1,)).storage())
for _ in range(3):
x.set_()
x.set_(torch.ones((1,)).storage())
for i in range(4):
x.resize_((1 + i,))
x.view_as(x)
self.assertExpectedInline(
self._format_allocations(profiled_code),
"""\
0 1 Allocation
0 2 Allocation
0 4 Allocation
0 5 Allocation
0 4 Free
0 2 Free
0 1 Free
0 6 Allocation
0 5 Free
0 7 Allocation
0 6 Free
0 8 Allocation
0 7 Free
0 8 Free
0 9 Allocation
0 9 Free
0 10 Allocation
0 10 Free
0 11 Allocation
0 12 Allocation
0 11 Free
0 13 Allocation
0 12 Free
0 14 Allocation
0 13 Free
0 14 Free""",
)
def test_tensorimpl_invalidation_scalar_args(self) -> None:
def profiled_code():
with torch.no_grad():
x = torch.ones((1,))
for _ in range(10):
x.add_(2)
self.assertExpectedInline(
self._format_allocations(profiled_code),
"""\
0 1 Allocation
1 2 Allocation
2 3 Allocation
2 3 Free
1 2 Free
3 4 Allocation
4 5 Allocation
4 5 Free
3 4 Free
5 6 Allocation
6 7 Allocation
6 7 Free
5 6 Free
7 8 Allocation
8 9 Allocation
8 9 Free
7 8 Free
9 10 Allocation
10 11 Allocation
10 11 Free
9 10 Free
11 12 Allocation
12 13 Allocation
12 13 Free
11 12 Free
13 14 Allocation
14 15 Allocation
14 15 Free
13 14 Free
15 16 Allocation
16 17 Allocation
16 17 Free
15 16 Free
17 18 Allocation
18 19 Allocation
18 19 Free
17 18 Free
19 20 Allocation
20 21 Allocation
20 21 Free
19 20 Free
0 1 Free""",
)
def test_module_and_optimizer_ids(self) -> None:
model = torch.nn.Linear(2, 1, bias=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
def check(cold_start: bool) -> None:
with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
x = torch.ones((1, 2))
_ = x.sin() # Mark `x`
model(x).backward()
optimizer.step()
_ = optimizer.state[model.weight][
"momentum_buffer"
].cos() # Mark weight momentum
_ = model.weight.grad.tan() # Mark weight gradient
nodes = p.profiler.kineto_results.experimental_event_tree()
def get_fields(op_name, index):
return self._get_tensor_fields(
find_node_with_name(nodes, op_name), index
)
# Marked Tensors act as ground truth for python tracer IDs.
_, _, x_id = get_fields("aten::sin", 0)
_, _, weight_momenumtum_id = get_fields("aten::cos", 0)
_, _, weight_grad_id = get_fields("aten::tan", 0)
self.assertNotEqual(x_id, weight_momenumtum_id)
self.assertNotEqual(x_id, weight_grad_id)
self.assertNotEqual(weight_momenumtum_id, weight_grad_id)
# Use linear op to identify weight ground truth.
linear_op_node = find_node_with_name(nodes, "aten::linear")
self.assertIsNotNone(linear_op_node)
x_metadata, weight_metadata, _ = linear_op_node.extra_fields.inputs
self.assertEqual(x_id, x_metadata.id)
# Module
linear_module_node = find_node_with_name(nodes, "nn.Module: Linear_0")
self.assertIsNotNone(linear_module_node)
self.assertIsNotNone(linear_module_node.extra_fields.module)
self.assertIsNone(linear_module_node.extra_fields.optimizer)
linear_parameters = linear_module_node.extra_fields.module.parameters
name, weight, weight_grad = linear_parameters[0]
self.assertEqual(name, "weight")
self.assertEqual(weight.id, weight_metadata.id)
self.assertEqual(weight_grad is None, cold_start)
if not cold_start:
self.assertEqual(weight_grad.id, weight_grad_id)
# Optimizer
step_node = find_node_with_regex(nodes, "_optimizer_step_code")
self.assertIsNotNone(step_node)
self.assertIsNone(step_node.extra_fields.module)
self.assertIsNotNone(step_node.extra_fields.optimizer)
optimizer_parameters = step_node.extra_fields.optimizer.parameters
self.assertEqual(len(optimizer_parameters), 2) # Weight and bias
weight, weight_grad, state = optimizer_parameters[0]
self.assertEqual(weight.id, weight_metadata.id)
self.assertEqual(weight_grad.id, weight_grad_id)
self.assertEqual(len(state), 1)
self.assertEqual(state[0][0], "momentum_buffer")
self.assertEqual(state[0][1].id, weight_momenumtum_id)
# Check that we handle first step (lazy initalization) and steady state.
check(cold_start=True)
check(cold_start=False)
def _test_allocation_ids(self, before_fn, after_fn) -> None:
with profile(profile_memory=True, record_shapes=True) as p:
# Introduce other operations and allocations to check robustness
_ = before_fn()
x = torch.rand(4, 3)
x.resize_(4, 4)
# We need to use `x` post resize for profiler to determine its ID.
x.sin()
# Introduce other operations and allocations to check robustness
_ = after_fn()
# Ensure `x` is the last variable collected to make it easier to
# find the deallocation event.
gc.collect()
del x
gc.collect()
nodes = p.profiler.kineto_results.experimental_event_tree()
def find_chain(names: List[str]):
out = []
for name in names:
root = [out[-1]] if out else nodes
out.append(find_node_with_name(root, name))
self.assertIsNotNone(out[-1], name)
return out
allocation = find_chain(["aten::rand", "aten::empty", "[memory]"])[
-1
].extra_fields
_, uniform_node = find_chain(["aten::rand", "aten::uniform_"])
x_impl, x_storage_data, x_id = self._get_tensor_fields(uniform_node, 0)
# Make sure IDs are consistent between allocations and op inputs
self.assertEqual(allocation.ptr, x_storage_data)
self.assertEqual(allocation.id, x_id)
resize_node = find_node_with_name(nodes, "aten::resize_")
self.assertIsNotNone(resize_node)
self.assertEqual(len(resize_node.children), 2)
allocate_new = resize_node.children[0].extra_fields
free_old = resize_node.children[1].extra_fields
# Destruction of the old storage for x.
self.assertEqual(free_old.id, allocation.id)
self.assertEqual(free_old.ptr, allocation.ptr)
# Make sure ID is retained through change in storage.
self.assertEqual(allocate_new.id, allocation.id)
self.assertNotEqual(allocate_new.ptr, allocation.ptr)
# Deletion when `x` goes out of scope.
free_new = [
i for i in nodes if i.tag == torch._C._profiler._EventType.Allocation
][-1].extra_fields
self.assertIsInstance(free_new, torch._C._profiler._ExtraFields_Allocation)
self.assertEqual(free_new.id, allocate_new.id)
self.assertEqual(free_new.ptr, allocate_new.ptr)
def test_allocation_ids(self) -> None:
self._test_allocation_ids(lambda: None, lambda: None)
def test_allocation_ids_with_other_ops(self) -> None:
x = torch.ones((1,))
self._test_allocation_ids(
lambda: (x + 1).relu_(), lambda: torch.zeros((1,)).cos()
)
def test_impl_reuse(self) -> None:
repeats = 1_000
with profile(profile_memory=True, record_shapes=True) as p:
for _ in range(repeats):
torch.ones((1,))
gc.collect()
roots = p.profiler.kineto_results.experimental_event_tree()
tensor_impls = tuple(
e.extra_fields.inputs[0].impl_ptr
for e in _utils.traverse_dfs(roots)
if e.name == "aten::fill_"
)
self.assertEqual(len(tensor_impls), repeats)
self.assertEqual(len(set(tensor_impls)), repeats)
def test_allocation_id_uniqueness(self) -> None:
repeats = 1_000
with profile(profile_memory=True, record_shapes=True) as p:
for _ in range(repeats):
torch.ones((1,))
gc.collect()
roots = p.profiler.kineto_results.experimental_event_tree()
id_set = set()
for e in _utils.traverse_dfs(roots):
fields = e.extra_fields
if isinstance(fields, torch._C._profiler._ExtraFields_TorchOp):
id_set |= {
t.allocation_id
for t in fields.inputs
if isinstance(t, _TensorMetadata)
}
elif isinstance(fields, torch._C._profiler._ExtraFields_Allocation):
id_set.add(fields.allocation_id)
id_set.difference_update([None])
self.assertEqual(repeats, len(id_set))
def test_extra_fields(self):
with profile(with_stack=True, profile_memory=True) as p:
_ = torch.ones((1,))
nodes = p.profiler.kineto_results.experimental_event_tree()
node = find_node_with_name(nodes, "aten::ones")
self.assertIsNotNone(node)
self.assertIsInstance(
node.extra_fields, torch._C._profiler._ExtraFields_TorchOp
)
self.assertIsInstance(
node.parent.extra_fields, torch._C._profiler._ExtraFields_PyCCall
)
self.assertEqual(node.children[0].name, "aten::empty")
self.assertEqual(node.children[0].children[0].name, "[memory]")
self.assertIsInstance(
node.children[0].children[0].extra_fields,
torch._C._profiler._ExtraFields_Allocation,
)
def test_tensor_properties(self):
x = torch.ones(10, 10).as_strided([4, 4], [12, 3])
y = torch.ones(4, 1, requires_grad=True)
with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
_ = x + y
_ = x * y
nodes = p.profiler.kineto_results.experimental_event_tree()
node = find_node_with_name(nodes, "aten::add")
self.assertIsNotNone(node)
self.assertIsInstance(
node.extra_fields, torch._C._profiler._ExtraFields_TorchOp
)
def getattr_inputs(name, default):
return [getattr(i, name, default) for i in node.extra_fields.inputs]
self.assertEqual(getattr_inputs("sizes", []), [[4, 4], [4, 1], []])
self.assertEqual(getattr_inputs("strides", []), [[12, 3], [1, 1], []])
self.assertEqual(
getattr_inputs("layout", None), [torch.strided, torch.strided, None]
)
self.assertEqual(
getattr_inputs("device", None),
[torch.device("cpu"), torch.device("cpu"), None],
)
self.assertEqual(
getattr_inputs("dtype", None), [torch.float32, torch.float32, None]
)
self.assertEqual(node.extra_fields.scope, torch.profiler.RecordScope.FUNCTION)
mul_node = find_node_with_name(nodes, "aten::mul")
self.assertIsNotNone(mul_node)
self.assertEqual(
node.extra_fields.sequence_number + 1, mul_node.extra_fields.sequence_number
)
def test_sparse_tensors(self):
i = [[0, 1, 1], [2, 0, 2]]
v = [3, 4, 5]
s = torch.sparse_coo_tensor(i, v, (2, 3))
with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
_ = s + s
nodes = p.profiler.kineto_results.experimental_event_tree()
node = find_node_with_name(nodes, "aten::add")
self.assertIsNotNone(node)
self.assertIsInstance(
node.extra_fields, torch._C._profiler._ExtraFields_TorchOp
)
def getattr_inputs(name, default):
return [getattr(i, name, default) for i in node.extra_fields.inputs]
self.assertEqual(getattr_inputs("sizes", []), [[2, 3], [2, 3], []])
self.assertEqual(getattr_inputs("strides", []), [[], [], []])
self.assertEqual(
getattr_inputs("layout", None), [torch.sparse_coo, torch.sparse_coo, None]
)
self.assertEqual(
getattr_inputs("device", None),
[torch.device("cpu"), torch.device("cpu"), None],
)
@unittest.skipIf(
not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
)
def test_mkldnn_tensors(self):
x = torch.ones(4, 3).to_mkldnn()
with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
_ = x + x
nodes = p.profiler.kineto_results.experimental_event_tree()
node = find_node_with_name(nodes, "aten::add")
self.assertIsNotNone(node)
self.assertIsInstance(
node.extra_fields, torch._C._profiler._ExtraFields_TorchOp
)
def getattr_inputs(name, default):
return [getattr(i, name, default) for i in node.extra_fields.inputs]
self.assertEqual(getattr_inputs("sizes", []), [[4, 3], [4, 3], []])
self.assertEqual(getattr_inputs("strides", []), [[], [], []])
self.assertEqual(
getattr_inputs("layout", None), [torch._mkldnn, torch._mkldnn, None]
)
self.assertEqual(
getattr_inputs("device", None),
[torch.device("cpu"), torch.device("cpu"), None],
)
def test_scalar_ins(self):
x = torch.ones(5, 5)
alpha = 0.9
with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
_ = torch.add(x, 9.1, alpha=alpha)
nodes = p.profiler.kineto_results.experimental_event_tree()
node = find_node_with_name(nodes, "aten::add")
self.assertIsNotNone(node)
def getattr_inputs(name, default):
return [getattr(i, name, default) for i in node.extra_fields.inputs]
# The second argument to the add gets promotoed to a zerodim Tensor
self.assertEqual(
getattr_inputs("dtype", None), [torch.float32, torch.float64, None]
)
self.assertEqual(getattr_inputs("sizes", []), [[5, 5], [], []])
self.assertEqual(node.extra_fields.inputs[2], alpha)
def test_tensor_lists(self):
x = torch.ones((1,))
y = torch.ones((1,))
with profile(with_stack=True, profile_memory=True, record_shapes=True) as p:
_ = torch.stack((x, y))
nodes = p.profiler.kineto_results.experimental_event_tree()
node = find_node_with_name(nodes, "aten::stack")
inputs = node.extra_fields.inputs
self.assertEqual(len(inputs), 2)
self.assertIsInstance(inputs[0], list)
self.assertEqual(len(inputs[0]), 2)
self.assertEqual(x.storage().data_ptr(), inputs[0][0].storage_data_ptr)
self.assertEqual(y.storage().data_ptr(), inputs[0][1].storage_data_ptr)
def test_nnmodule_params(self):
def flat_out_extrafields(nodes, out=None):
if out is None:
out = []
for node in nodes:
if (
isinstance(node.extra_fields, _ExtraFields_PyCall)
and node.extra_fields.module
):
if node.extra_fields.module.parameters:
out.append(node.extra_fields.module)
flat_out_extrafields(node.children, out)
return out
inputs = torch.rand(10)
net = SimpleNet()
out = net(inputs)
torch.nn.functional.cross_entropy(out, torch.rand(2)).backward()
with torch.profiler.profile(with_stack=True, profile_memory=True) as p:
_ = net(inputs)
modules = flat_out_extrafields(
p.profiler.kineto_results.experimental_event_tree()
)
self.assertEqual(
len(modules), 2, f"Expected two parameter list, but got {len(modules)}"
)
params = [
(n, p.storage_data_ptr, g.storage_data_ptr)
for module in modules
for (n, p, g) in module.parameters
]
expected = [
(name, val.storage().data_ptr(), val.grad.storage().data_ptr())
for name, val in net.fc1._parameters.items()
]
expected += [
(name, val.storage().data_ptr(), val.grad.storage().data_ptr())
for name, val in net.fc2._parameters.items()
]
self.assertEqual(expected, params, f"{expected} vs. {params}")
def _flat_out_extrafields(self, nodes, out=None):
if out is None:
out = []
for node in nodes:
if (
isinstance(node.extra_fields, _ExtraFields_PyCall)
and node.extra_fields.optimizer
and node.extra_fields.optimizer.parameters
):
# avoiding OptInfo duplicates from iterations
addr = node.extra_fields.optimizer.parameters[0][0].storage_data_ptr
if not [o for o in out if addr == o.parameters[0][0].storage_data_ptr]:
out.append(node.extra_fields.optimizer)
self._flat_out_extrafields(node.children, out)
return out
def _check_results(self, opt, opts, check_items=False):
self.assertEqual(len(opts), 1, f"Expected 1 optimizer: len(opts): {len(opts)}")
self.assertEqual(
id(opt),
opts[0].self_ptr,
f"Optimizer addr ({id(opt)}) vs. profiled addr ({opts[0].self_ptr})",
)
if check_items:
self.assertEqual(len(opt.param_groups), len(opts))
for group, opt_ in zip(opt.param_groups, opts):
self.assertEqual(
[(v.storage().data_ptr()) for v in group.get("params", [])],
[(o.storage_data_ptr) for (o, _, _) in opt_.parameters],
)
for opt_ in opts:
observed_state = {
p.storage_data_ptr: {name: s.storage_data_ptr for name, s in state}
for (p, _, state) in opt_.parameters
}
# Make sure the profiler collected all optimizer state and check
# that the address recorded by the profiler is correct.
for parameter, parameter_state in opt.state.items():
self.assertEqual(
{
name: value.storage().data_ptr()
for name, value in parameter_state.items()
},
observed_state.get(parameter.storage().data_ptr(), []),
)
def test_optimizer(self):
inputs = torch.rand(10)
with torch.profiler.profile(with_stack=True, profile_memory=True) as p:
net = SimpleNet()
opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
opt.zero_grad()
out = net(inputs)
loss = torch.nn.functional.cross_entropy(out, torch.rand(2))
loss.backward()
opt.step()
self._check_results(
opt,
self._flat_out_extrafields(
p.profiler.kineto_results.experimental_event_tree()
),
False,
)
def _test_optimizer_parameters(self, optimizer_factory):
inputs = torch.rand(10)
with torch.profiler.profile(with_stack=True, profile_memory=True) as p:
net = SimpleNet()
opt = optimizer_factory(net.parameters())
for _ in range(2):
opt.zero_grad()
out = net(inputs)
loss = torch.nn.functional.cross_entropy(out, torch.rand(2))
loss.backward()
opt.step()
self._check_results(
opt,
self._flat_out_extrafields(
p.profiler.kineto_results.experimental_event_tree()
),
True,
)
def test_optimizer_parameters_sgd(self):
self._test_optimizer_parameters(
lambda params: torch.optim.SGD(params, lr=0.01, momentum=0.9)
)
def test_optimizer_parameters_adam(self):
self._test_optimizer_parameters(
lambda params: torch.optim.Adam(params, foreach=True)
)
def test_allocations(self):
gc.collect()
with profile(profile_memory=True) as p:
x = torch.empty((3, 4))
nodes = p.profiler.kineto_results.experimental_event_tree()
node = find_node_with_name(nodes, "[memory]")
self.assertIsNotNone(node)
alloc_size = 3 * 4 * 4 # fp32 -> 4 bytes
ptr = node.extra_fields.ptr
self.assertGreater(ptr, 0)
self.assertEqual(node.extra_fields.alloc_size, alloc_size)
self.assertEqual(node.extra_fields.device, torch.device("cpu"))
total_allocated = node.extra_fields.total_allocated
# total_reserved is only for CUDACachingAllocator
self.assertEqual(node.extra_fields.total_reserved, 0)
with profile(profile_memory=True) as p:
del x
gc.collect()
nodes = p.profiler.kineto_results.experimental_event_tree()
node = find_node_with_name(nodes, "[memory]")
self.assertIsNotNone(node)
self.assertEqual(node.extra_fields.ptr, ptr)
self.assertEqual(node.extra_fields.alloc_size, -alloc_size)
self.assertEqual(node.extra_fields.device, torch.device("cpu"))
self.assertEqual(
node.extra_fields.total_allocated, total_allocated - alloc_size
)
def test_refcounts(self):
class Sentinel:
pass
def make():
outer_sentinel = Sentinel()
def outer():
# Python will only close over variables used in the function.
_ = outer_sentinel
inner_sentinel = Sentinel()
def inner():
_ = inner_sentinel
with profile(with_stack=True):
inner()
return weakref.ref(inner_sentinel)
return outer, weakref.ref(outer_sentinel)
# Use a factory function to ensure the test scope never sees strong
# references. `del` has strange semantics that interact with closures
# at an AST level, so this is simpler.
outer, outer_sentinel_ref = make()
inner_sentinel_ref = outer()
self.assertIsNone(inner_sentinel_ref())
# `outer` holds the last reference via closure.
self.assertIsNotNone(outer_sentinel_ref())
del outer
self.assertIsNone(outer_sentinel_ref())
@dataclass(frozen=True)
class MockKinetoEvent:
_name: str
_start_us: int
_duration_us: int
_linked_correlation_id: int
_device_type: int
@property
def name(self) -> str:
return self._name
def start_ns(self) -> int:
return self._start_us * 1000
def duration_ns(self) -> int:
return self._duration_us * 1000
def linked_correlation_id(self) -> int:
return self._linked_correlation_id
def device_type(self) -> DeviceType:
return DeviceType.CUDA if self._device_type == 1 else DeviceType.CPU
@dataclass(frozen=True)
class MockProfilerEvent:
_name: str
id: int
start_time_ns: int
duration_time_ns: int
correlation_id: int = 0
children: List["MockProfilerEvent"] = field(default_factory=list)
parent: Optional["MockProfilerEvent"] = None
@property
def end_time_ns(self):
return self.start_time_ns + self.duration_time_ns
@property
def name(self) -> str:
return self._name
def __post__init__(self, parent, children):
object.__setattr__(self, "parent", parent)
object.__setattr__(self, "children", children)
class MockNode:
def __init__(self, name, children) -> None:
self.name = name
self.children = [MockNode(name, i) for name, i in children.items()]
class TestExperimentalUtils(TestCase):
def make_tree(self) -> List[MockNode]:
tree = {
"root_0": {
"1": {"2": {}},
"3": {
"4": {},
"5": {},
},
},
"root_1": {
"6": {},
"7": {},
"8": {
"9": {"10": {}},
},
},
}
return [MockNode(name, i) for name, i in tree.items()]
def test_dfs(self) -> None:
self.assertEqual(
" ".join(i.name for i in _utils.traverse_dfs(self.make_tree())),
"root_0 1 2 3 4 5 root_1 6 7 8 9 10",
)
def test_bfs(self) -> None:
self.assertEqual(
" ".join(i.name for i in _utils.traverse_bfs(self.make_tree())),
"root_0 root_1 1 3 6 7 8 2 4 5 9 10",
)
@staticmethod
def generate_mock_profile():
cuda_events = [
MockKinetoEvent("cudaLaunchKernel", 400, 100, 1, 0),
MockKinetoEvent("cudaLaunchKernel", 500, 100, 2, 0),
MockKinetoEvent("cudaLaunchKernel", 600, 100, 3, 0),
MockKinetoEvent("cudaLaunchKernel", 700, 100, 4, 0),
MockKinetoEvent("cudaLaunchKernel", 800, 100, 5, 0),
MockKinetoEvent("cudaLaunchKernel", 1500, 100, 6, 0),
MockKinetoEvent("GPU", 900, 100, 1, 1),
MockKinetoEvent("GPU", 1000, 100, 2, 1),
MockKinetoEvent("GPU", 1100, 100, 3, 1),
MockKinetoEvent("GPU", 1200, 100, 4, 1),
MockKinetoEvent("GPU", 1300, 100, 5, 1),
MockKinetoEvent("GPU", 1700, 100, 6, 1),
]
cpu_events = [
MockProfilerEvent("CPU (Before cudaLaunchKernel)", 1, 0, 100000),
MockProfilerEvent("CPU (Before cudaLaunchKernel)", 2, 100000, 100000),
MockProfilerEvent("CPU (Before cudaLaunchKernel)", 3, 200000, 100000),
MockProfilerEvent("CPU (Before cudaLaunchKernel)", 4, 300000, 100000),
MockProfilerEvent("CPU (After cudaLaunchKernel)", 5, 400000, 100000),
MockProfilerEvent("CPU (After cudaLaunchKernel)", 6, 500000, 100000),
MockProfilerEvent("CPU (After cudaLaunchKernel)", 7, 600000, 100000),
MockProfilerEvent("CPU (After cudaLaunchKernel)", 8, 700000, 100000),
MockProfilerEvent("CPU (After GPU)", 9, 800000, 100000),
MockProfilerEvent("CPU (After GPU)", 10, 900000, 100000),
MockProfilerEvent("CPU (After GPU)", 11, 1100000, 100000),
MockProfilerEvent("CPU (After GPU)", 12, 1200000, 500000),
]
profiler = unittest.mock.Mock()
profiler.kineto_results = unittest.mock.Mock()
profiler.kineto_results.events = unittest.mock.Mock(return_value=cuda_events)
profiler.kineto_results.experimental_event_tree = unittest.mock.Mock(
return_value=cpu_events
)
return profiler
@staticmethod
def load_mock_profile():
accept = expecttest.ACCEPT
json_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"profiler_utils_mock_events.json",
)
if accept and torch.cuda.is_available():
def garbage_code(x):
for i in range(5):
x[0, i] = i
x = torch.ones((4096, 4096), device="cuda")
x = x @ x
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
with_stack=True,
) as prof:
for _ in range(5):
x = x @ x
garbage_code(x)
for _ in range(5):
x = x @ x
kineto_events = [
{
"_name": e.name,
"_start_ns": e.start_ns(),
"_duration_ns": e.duration_ns(),
"_linked_correlation_id": e.linked_correlation_id(),
"_device_type": 1 if e.device_type() == DeviceType.CUDA else 0,
}
for e in prof.profiler.kineto_results.events()
]
def EventTreeDFS(event_tree):
from collections import deque
stack = deque(event_tree)
while stack:
curr_event = stack.pop()
yield curr_event
for child_event in curr_event.children:
stack.append(child_event)
profiler_events = [
{
"_name": e.name,
"id": e.id,
"start_time_ns": e.start_time_ns,
"duration_time_ns": e.duration_time_ns,
"correlation_id": e.correlation_id,
"children": [child.id for child in e.children],
"parent": e.parent.id if e.parent else None,
}
for e in EventTreeDFS(
prof.profiler.kineto_results.experimental_event_tree()
)
]
with open(json_file_path, "w") as f:
json.dump([kineto_events, profiler_events], f)
assert os.path.exists(json_file_path)
with open(json_file_path) as f:
kineto_events, profiler_events = json.load(f)
cuda_events = [MockKinetoEvent(*event.values()) for event in kineto_events]
cpu_events = []
id_map = {}
for e in profiler_events:
event = MockProfilerEvent(**e)
id_map[event.id] = event
cpu_events.append(event)
for event in cpu_events:
parent = None if event.parent is None else id_map[event.parent]
children = [id_map[child] for child in event.children]
event.__post__init__(parent, children)
cpu_events = [event for event in cpu_events if event.parent is None]
profiler = unittest.mock.Mock()
profiler.kineto_results = unittest.mock.Mock()
profiler.kineto_results.events = unittest.mock.Mock(return_value=cuda_events)
profiler.kineto_results.experimental_event_tree = unittest.mock.Mock(
return_value=cpu_events
)
return profiler
def test_utils_compute_self_time(self):
with profile() as prof:
t1, t2 = torch.ones(1, requires_grad=True), torch.ones(
1, requires_grad=True
)
z = torch.add(t1, t2)
y = torch.ones(1)
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
loss.backward()
basic_eval = _utils.BasicEvaluation(prof.profiler)
metrics = basic_eval.metrics
self.assertTrue(len(metrics) > 0)
for event_key, event_metrics in metrics.items():
self.assertEqual(
event_metrics.self_time_ns,
event_key.event.duration_time_ns
- sum(child.duration_time_ns for child in event_key.event.children),
)
def test_utils_intervals_overlap(self):
event = _utils.EventKey(MockProfilerEvent("Event 1", 1, 5, 5))
intervals = [
_utils.Interval(0, 9),
_utils.Interval(1, 2),
_utils.Interval(2, 3),
_utils.Interval(3, 4),
_utils.Interval(4, 5),
_utils.Interval(8, 12),
]
print(event.intervals_overlap(intervals))
self.assertEqual(event.intervals_overlap(intervals), 5)
def test_utils_compute_queue_depth(self):
def format_queue_depth(queue_depth_list, events):
res = ""
for data, event in zip(queue_depth_list, events):
res += f"{data.queue_depth} [{event.name}]\n"
return res
# We have to use Mock because time series data is too flaky to test
profiler = self.generate_mock_profile()
basic_evaluation = _utils.BasicEvaluation(profiler)
self.assertExpectedInline(
format_queue_depth(
basic_evaluation.queue_depth_list, basic_evaluation.cuda_events
),
"""\
1 [cudaLaunchKernel]
2 [cudaLaunchKernel]
3 [cudaLaunchKernel]
4 [cudaLaunchKernel]
5 [cudaLaunchKernel]
4 [GPU]
3 [GPU]
2 [GPU]
1 [GPU]
0 [GPU]
1 [cudaLaunchKernel]
0 [GPU]
""",
)
self.assertExpectedInline(
format_queue_depth(
[basic_evaluation.metrics[k] for k in basic_evaluation.event_keys],
basic_evaluation.events,
),
"""\
0 [CPU (Before cudaLaunchKernel)]
0 [CPU (Before cudaLaunchKernel)]
0 [CPU (Before cudaLaunchKernel)]
0 [CPU (Before cudaLaunchKernel)]
1 [CPU (After cudaLaunchKernel)]
2 [CPU (After cudaLaunchKernel)]
3 [CPU (After cudaLaunchKernel)]
4 [CPU (After cudaLaunchKernel)]
5 [CPU (After GPU)]
4 [CPU (After GPU)]
2 [CPU (After GPU)]
1 [CPU (After GPU)]
""",
)
def test_utils_compute_queue_depth_when_no_cuda_events(self):
# For traces with only cpu events, we expect empty queue depth list
x = torch.ones((1024, 1024))
with profile() as prof:
for _ in range(5):
x = x @ x
basic_evaluation = _utils.BasicEvaluation(prof.profiler)
self.assertFalse(basic_evaluation.compute_queue_depth())
def test_utils_compute_idle_time(self):
profiler = self.generate_mock_profile()
basic_evaluation = _utils.BasicEvaluation(profiler)
expected_output = "\n".join(
[
f"{basic_evaluation.metrics[event_key].idle_time_ns} [{event_key.event.name}]"
for event_key in basic_evaluation.event_keys
]
)
self.assertExpectedInline(
expected_output,
"""\
100000 [CPU (Before cudaLaunchKernel)]
100000 [CPU (Before cudaLaunchKernel)]
100000 [CPU (Before cudaLaunchKernel)]
100000 [CPU (Before cudaLaunchKernel)]
0 [CPU (After cudaLaunchKernel)]
0 [CPU (After cudaLaunchKernel)]
0 [CPU (After cudaLaunchKernel)]
0 [CPU (After cudaLaunchKernel)]
0 [CPU (After GPU)]
0 [CPU (After GPU)]
0 [CPU (After GPU)]
100000 [CPU (After GPU)]""",
)
@unittest.skipIf(IS_JETSON, "JSON not behaving as expected on Jetson")
def test_utils_get_optimizable_events(self):
basic_evaluation = _utils.BasicEvaluation(self.load_mock_profile())
optimizable_events = basic_evaluation.get_optimizable_events(
2, print_enable=False
)
expected_output = "\n".join(
[f"{event_key.event.name}" for event_key in optimizable_events]
)
self.assertExpectedInline(
expected_output,
"""\
<built-in function _cuda_synchronize>
aten::copy_""",
)
def test_profiler_name_pattern(self):
x = torch.ones((4096, 4096))
with profile() as prof:
for _ in range(5):
x = x @ x
x = x + x
matched_events = NamePattern(prof, "aten::mm").matched_events()
output = "\n".join([f"{event.name}" for event in matched_events])
self.assertExpectedInline(
output,
"""\
aten::mm
aten::mm
aten::mm
aten::mm
aten::mm""",
)
# TODO: Add logic for CUDA version of test
@unittest.skipIf(torch.cuda.is_available(), "Test not working for CUDA")
def test_profiler_pattern_match_helper(self):
x = torch.ones((100, 100))
with profile() as prof:
for _ in range(5):
x = x @ x
x = x + x
event_tree = prof.profiler.kineto_results.experimental_event_tree()
pattern = Pattern(prof)
self.assertEqual([], pattern.siblings_of(event_tree[0])[0])
self.assertEqual(event_tree[1:], pattern.siblings_of(event_tree[0])[1])
child_nodes = event_tree[0].children
self.assertEqual([], pattern.siblings_of(child_nodes[0])[0])
self.assertEqual(child_nodes[1:], pattern.siblings_of(child_nodes[0])[1])
self.assertEqual(
event_tree[0], pattern.root_of(event_tree[0].children[0].children[0])
)
self.assertEqual(None, pattern.next_of(event_tree[-1]))
self.assertEqual(event_tree[1], pattern.next_of(event_tree[0]))
self.assertEqual(event_tree[0], pattern.prev_of(event_tree[1]))
@unittest.skipIf(
TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
def test_profiler_extra_cuda_copy_pattern(self):
cases = (
(0, lambda: torch.ones((100, 100), device="cuda")),
(1, lambda: torch.ones((100, 100)).to("cuda")),
(1, lambda: torch.zeros((100, 100)).to("cuda")),
(1, lambda: torch.empty((100, 100)).fill_(5).to("cuda")),
(1, lambda: torch.ones((100, 100)).cuda()),
(1, lambda: torch.zeros((100, 100)).cuda()),
(1, lambda: torch.empty((100, 100)).fill_(5).cuda()),
(1, lambda: torch.rand((100, 100)).cuda()),
(1, lambda: torch.randn((100, 100)).cuda()),
(1, lambda: torch.full((100, 100), 10).cuda()),
(0, lambda: torch.rand((100, 100)).to(dtype=torch.float16)),
(0, lambda: torch.rand((100, 100)).half()),
(0, lambda: torch.rand((100, 100), device="cuda").half()),
)
num_matched = []
for _, fn in cases:
with profile(with_stack=True, record_shapes=True) as prof:
fn()
pattern = ExtraCUDACopyPattern(prof)
num_matched.append(len(pattern.matched_events()))
self.assertEqual(num_matched, [i for i, _ in cases])
@unittest.skipIf(
TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
)
def test_profiler_for_loop_indexing_pattern(self):
x = torch.ones((100, 100))
def case1():
for i in range(100):
x[i] = i
def case2():
y = 0
for i in range(100):
y += x[i]
def case3():
y = 1
for i in range(100):
y *= x[i]
def case4():
y = x
for _ in range(100):
y = y @ x
def case5():
for i in range(100):
x[i, :] = torch.arange(100) + i
cases = ((1, case1), (1, case2), (1, case3), (0, case4), (1, case5))
num_matched = []
for _, fn in cases:
with profile(with_stack=True) as prof:
fn()
pattern = ForLoopIndexingPattern(prof)
num_matched.append(len(pattern.matched_events()))
self.assertEqual(num_matched, [i for i, _ in cases])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
def test_profiler_fp32_matmul_pattern(self):
x = torch.ones((100, 100), device="cuda")
with profile(with_stack=True) as prof:
x = x @ x
pattern = FP32MatMulPattern(prof)
has_tf32 = 0 if pattern.skip else 1
num_matched = len(pattern.matched_events())
self.assertEqual(num_matched, has_tf32)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
def test_profiler_extra_cuda_copy_pattern_benchmark(self):
with profile(with_stack=True, record_shapes=True) as prof:
x = torch.ones((100, 100)).to("cuda")
x = torch.ones((50, 50)).to("cuda")
pattern = ExtraCUDACopyPattern(prof)
shapes_factor_map = pattern.benchmark(pattern.matched_events())
self.assertEqual(len(shapes_factor_map), 2)
def test_profiler_optimizer_single_tensor_pattern(self):
x = torch.ones((100, 100))
cases = (
(1, lambda: torch.optim.Adam(model.parameters())),
(1, lambda: torch.optim.SGD(model.parameters(), lr=0.01)),
(1, lambda: torch.optim.AdamW(model.parameters())),
(0, lambda: torch.optim.Adam(model.parameters(), foreach=True)),
(0, lambda: torch.optim.SGD(model.parameters(), lr=0.01, foreach=True)),
(0, lambda: torch.optim.AdamW(model.parameters(), foreach=True)),
)
num_matched = []
for _, fn in cases:
with profile(with_stack=True) as prof:
model = nn.Sequential(
nn.Linear(100, 100),
nn.ReLU(),
nn.Linear(100, 10),
)
optimizer = fn()
optimizer.zero_grad()
y_hat = model(x)
loss = torch.nn.functional.cross_entropy(
y_hat, torch.randint(0, 10, (100,))
)
loss.backward()
optimizer.step()
pattern = OptimizerSingleTensorPattern(prof)
num_matched.append(len(pattern.matched_events()))
self.assertEqual(num_matched, [i for i, _ in cases])
def test_profiler_synchronized_dataloader_pattern(self):
dataset = torch.rand((100, 100))
sync_dataloader = torch.utils.data.DataLoader(dataset, batch_size=10)
async_dataloader = torch.utils.data.DataLoader(
dataset, batch_size=10, num_workers=4
)
with profile(with_stack=True) as prof:
next(iter(sync_dataloader))
next(iter(async_dataloader))
pattern = SynchronizedDataLoaderPattern(prof)
num_matched = len(pattern.matched_events())
self.assertEqual(num_matched, 1)
@skipIfTorchDynamo(
"pattern checks for aten::_zero op which might not be there with torch.compile'd graph"
)
def test_profiler_grad_not_set_to_none_pattern(self):
x = torch.ones((100, 100))
model = nn.Sequential(
nn.Linear(100, 100),
nn.ReLU(),
nn.Linear(100, 10),
)
optimizer = torch.optim.Adam(model.parameters())
cases = (
(0, lambda: optimizer.zero_grad()),
(0, lambda: model.zero_grad()),
(1, lambda: optimizer.zero_grad(set_to_none=False)),
(1, lambda: model.zero_grad(set_to_none=False)),
)
num_matched = []
for _, fn in cases:
with profile(with_stack=True) as prof:
y_hat = model(x)
loss = torch.nn.functional.cross_entropy(
y_hat, torch.randint(0, 10, (100,))
)
loss.backward()
optimizer.step()
fn()
pattern = GradNotSetToNonePattern(prof)
num_matched.append(len(pattern.matched_events()))
self.assertEqual(num_matched, [i for i, _ in cases])
def test_profiler_conv2d_bias_followed_by_batchnorm2d_pattern(self):
x = torch.randn((1, 3, 32, 32))
cases = (
(1, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1), nn.BatchNorm2d(3))),
(0, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1, bias=False), nn.BatchNorm2d(3))),
(0, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1))),
)
num_matched = []
for _, model in cases:
with profile(with_stack=True, record_shapes=True) as prof:
model(x)
pattern = Conv2dBiasFollowedByBatchNorm2dPattern(prof)
num_matched.append(len(pattern.matched_events()))
self.assertEqual(num_matched, [i for i, _ in cases])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
def test_profiler_matmul_dim_fp16_pattern(self):
cases = (
(1, torch.randn((201, 201), device="cuda", dtype=torch.float16)),
(1, torch.randn((3, 97, 97), device="cuda", dtype=torch.float16)),
(0, torch.randn((200, 200), device="cuda", dtype=torch.float16)),
(0, torch.randn((3, 200, 200), device="cuda", dtype=torch.float16)),
)
num_matched = []
for _, x in cases:
with profile(with_stack=True, record_shapes=True) as prof:
x @ x
pattern = MatMulDimInFP16Pattern(prof)
num_matched.append(len(pattern.matched_events()))
self.assertEqual(num_matched, [i for i, _ in cases])
def test_profiler_pattern_matcher_json_report(self):
x = torch.ones((100, 100))
model = nn.Sequential(
nn.Linear(100, 100),
nn.ReLU(),
nn.Linear(100, 10),
)
optimizer = torch.optim.Adam(model.parameters())
with profile(with_stack=True, record_shapes=True) as prof:
y_hat = model(x)
loss = torch.nn.functional.cross_entropy(
y_hat, torch.randint(0, 10, (100,))
)
loss.backward()
optimizer.step()
optimizer.zero_grad()
report_all_anti_patterns(prof, json_report_dir=".", print_enable=False)
try:
with open("./torchtidy_report.json") as f:
report = json.load(f)
# It is platform dependent whether the path will include "profiler/"
keys = [k for k in report.keys() if k.endswith("test_profiler.py")]
self.assertEqual(len(keys), 1, f"{keys}")
entry = report[keys[0]]
self.assertTrue(len(entry) > 0)
expected_fields = sorted(["line_number", "name", "url", "message"])
for event in entry:
actual_fields = sorted(event.keys())
self.assertEqual(expected_fields, actual_fields)
finally:
os.remove("torchtidy_report.json")
if __name__ == "__main__":
run_tests()