Use libkineto in profiler (#46470)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46470

Adding ability to use Kineto (CUPTI) to profile CUDA kernels

Test Plan:
USE_KINETO=1 USE_CUDA=1 USE_MKLDNN=1 BLAS=MKL BUILD_BINARY=1 python setup.py develop install
python test/test_profiler.py

python test/test_autograd.py -k test_profile
python test/test_autograd.py -k test_record

```
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                       Memcpy HtoD (Pageable -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us       2.000us        33.33%       2.000us       1.000us             2
                                      sgemm_32x32x32_NN         0.00%       0.000us         0.00%       0.000us       0.000us       2.000us        33.33%       2.000us       2.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us        16.67%       1.000us       1.000us             1
                       Memcpy DtoH (Device -> Pageable)         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us        16.67%       1.000us       1.000us             1
                                            aten::randn         5.17%      74.000us         6.71%      96.000us      48.000us       0.000us         0.00%       0.000us       0.000us             2
                                            aten::empty         1.33%      19.000us         1.33%      19.000us       4.750us       0.000us         0.00%       0.000us       0.000us             4
                                          aten::normal_         1.05%      15.000us         1.05%      15.000us       7.500us       0.000us         0.00%       0.000us       0.000us             2
                                               aten::to        77.90%       1.114ms        91.61%       1.310ms     436.667us       0.000us         0.00%       3.000us       1.000us             3
                                    aten::empty_strided         2.52%      36.000us         2.52%      36.000us      12.000us       0.000us         0.00%       0.000us       0.000us             3
                                            aten::copy_         2.73%      39.000us        11.19%     160.000us      53.333us       0.000us         0.00%       3.000us       1.000us             3
                                        cudaMemcpyAsync         4.34%      62.000us         4.34%      62.000us      20.667us       0.000us         0.00%       0.000us       0.000us             3
                                  cudaStreamSynchronize         1.61%      23.000us         1.61%      23.000us       7.667us       0.000us         0.00%       0.000us       0.000us             3
                                               aten::mm         0.21%       3.000us         7.20%     103.000us     103.000us       0.000us         0.00%       2.000us       2.000us             1
                                           aten::stride         0.21%       3.000us         0.21%       3.000us       1.000us       0.000us         0.00%       0.000us       0.000us             3
                                       cudaLaunchKernel         2.45%      35.000us         2.45%      35.000us      17.500us       0.000us         0.00%       0.000us       0.000us             2
                                              aten::add         0.49%       7.000us         4.27%      61.000us      61.000us       0.000us         0.00%       1.000us       1.000us             1
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
```

benchmark: https://gist.github.com/ilia-cher/a5a9eb6b68504542a3cad5150fc39b1a

Reviewed By: Chillee

Differential Revision: D25142223

Pulled By: ilia-cher

fbshipit-source-id: b0dff46c28da5fb0a8e01cf548aa4f2b723fde80
This commit is contained in:
Ilia Cherniavskii
2020-11-25 04:30:15 -08:00
committed by Facebook GitHub Bot
parent e9efd8df1b
commit f7a8bf2855
26 changed files with 2044 additions and 1021 deletions

View File

@ -10,13 +10,13 @@ namespace {
// Used to generate unique callback handles
CallbackHandle next_unique_callback_handle() {
static std::atomic<uint64_t> unique_cb_id {0};
return CallbackHandle(++unique_cb_id);
static std::atomic<uint64_t> unique_cb_id {1};
return CallbackHandle(unique_cb_id++);
}
RecordFunctionHandle next_unique_record_function_handle() {
static std::atomic<uint64_t> unique_rf_id {0};
return RecordFunctionHandle(++unique_rf_id);
static std::atomic<uint64_t> unique_rf_id {1};
return RecordFunctionHandle(unique_rf_id++);
}
thread_local RecordFunctionTLS rf_tls_;

View File

@ -1,10 +1,9 @@
import argparse
import statistics
import sys
import timeit
import torch
from torch.utils._benchmark import Timer
from torch.utils.benchmark import Timer
PARALLEL_TASKS_NUM = 4
INTERNAL_ITER = None
@ -34,12 +33,12 @@ if __name__ == '__main__':
parser.add_argument('--with_cuda', action='store_true')
parser.add_argument('--with_stack', action='store_true')
parser.add_argument('--use_script', action='store_true')
parser.add_argument('--use_kineto', action='store_true')
parser.add_argument('--profiling_tensor_size', default=1, type=int)
parser.add_argument('--workload', default='loop', type=str)
parser.add_argument('--internal_iter', default=256, type=int)
parser.add_argument('--n', default=100, type=int)
parser.add_argument('--use_timer', action='store_true')
parser.add_argument('--timer_min_run_time', default=100, type=int)
parser.add_argument('--timer_min_run_time', default=10, type=int)
parser.add_argument('--cuda_only', action='store_true')
args = parser.parse_args()
@ -47,16 +46,17 @@ if __name__ == '__main__':
print("No CUDA available")
sys.exit()
print("Payload: {}; {} iterations, N = {}\n".format(
args.workload, args.internal_iter, args.n))
print("Payload: {}, {} iterations; timer min. runtime = {}\n".format(
args.workload, args.internal_iter, args.timer_min_run_time))
INTERNAL_ITER = args.internal_iter
for profiling_enabled in [False, True]:
print("Profiling {}, tensor size {}x{}, use cuda: {}, with stacks: {}, use script: {}".format(
print("Profiling {}, tensor size {}x{}, use cuda: {}, use kineto: {}, with stacks: {}, use script: {}".format(
"enabled" if profiling_enabled else "disabled",
args.profiling_tensor_size,
args.profiling_tensor_size,
args.with_cuda,
args.use_kineto,
args.with_stack,
args.use_script))
@ -83,27 +83,18 @@ if __name__ == '__main__':
x = None
with torch.autograd.profiler.profile(
use_cuda=args.with_cuda,
with_stack=args.with_stack) as prof:
with_stack=args.with_stack,
use_kineto=args.use_kineto,
use_cpu=not args.cuda_only) as prof:
x = workload(input_x)
return x
else:
def payload():
return workload(input_x)
if args.use_timer:
t = Timer(
"payload()",
globals={"payload": payload},
timer=timeit.default_timer,
).blocked_autorange(min_run_time=args.timer_min_run_time)
print(t)
else:
runtimes = timeit.repeat(payload, repeat=args.n, number=1)
avg_time = statistics.mean(runtimes) * 1000.0
stddev_time = statistics.stdev(runtimes) * 1000.0
print("\tavg. time: {:.3f} ms, stddev: {:.3f} ms".format(
avg_time, stddev_time))
if args.workload == "loop":
print("\ttime per iteration: {:.3f} ms".format(
avg_time / args.internal_iter))
print()
t = Timer(
"payload()",
globals={"payload": payload},
timer=timeit.default_timer,
).blocked_autorange(min_run_time=args.timer_min_run_time)
print(t)

View File

@ -1751,7 +1751,8 @@ endif()
#
# End ATen checks
#
set(TEMP_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libs" FORCE)
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/fmt)
# Disable compiler feature checks for `fmt`.
@ -1764,6 +1765,7 @@ add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/fmt)
set_target_properties(fmt-header-only PROPERTIES INTERFACE_COMPILE_FEATURES "")
list(APPEND Caffe2_DEPENDENCY_LIBS fmt::fmt-header-only)
set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS} CACHE BOOL "Build shared libs" FORCE)
# ---[ Kineto
if(USE_KINETO)
@ -1774,8 +1776,34 @@ if(USE_KINETO)
set(KINETO_LIBRARY_TYPE "static" CACHE STRING "")
set(CUDA_SOURCE_DIR "${CUDA_TOOLKIT_ROOT_DIR}" CACHE STRING "")
message(STATUS "Configuring Kineto dependency:")
message(STATUS " KINETO_SOURCE_DIR = ${KINETO_SOURCE_DIR}")
message(STATUS " KINETO_BUILD_TESTS = ${KINETO_BUILD_TESTS}")
message(STATUS " KINETO_LIBRARY_TYPE = ${KINETO_LIBRARY_TYPE}")
message(STATUS " CUDA_SOURCE_DIR = ${CUDA_SOURCE_DIR}")
if(EXISTS ${CUDA_SOURCE_DIR}/extras/CUPTI/include)
set(CUPTI_INCLUDE_DIR "${CUDA_SOURCE_DIR}/extras/CUPTI/include")
elseif(EXISTS ${CUDA_SOURCE_DIR}/include/cupti.h)
set(CUPTI_INCLUDE_DIR "${CUDA_SOURCE_DIR}/include")
endif()
if((NOT DEFINED CUDA_cupti_LIBRARY) OR (${CUDA_cupti_LIBRARY} STREQUAL "CUDA_cupti_LIBRARY-NOTFOUND"))
if(EXISTS ${CUDA_SOURCE_DIR}/extras/CUPTI/lib64/libcupti_static.a)
set(CUDA_cupti_LIBRARY "${CUDA_SOURCE_DIR}/extras/CUPTI/lib64/libcupti_static.a")
elseif(EXISTS ${CUDA_SOURCE_DIR}/lib64/libcupti_static.a)
set(CUDA_cupti_LIBRARY "${CUDA_SOURCE_DIR}/lib64/libcupti_static.a")
elseif(EXISTS ${CUDA_SOURCE_DIR}/extras/CUPTI/lib64/libcupti.so)
set(CUDA_cupti_LIBRARY "${CUDA_SOURCE_DIR}/extras/CUPTI/lib64/libcupti.so")
elseif(EXISTS ${CUDA_SOURCE_DIR}/lib64/libcupti.so)
set(CUDA_cupti_LIBRARY "${CUDA_SOURCE_DIR}/lib64/libcupti.so")
endif()
endif()
message(STATUS " CUDA_cupti_LIBRARY = ${CUDA_cupti_LIBRARY}")
message(STATUS " CUPTI_INCLUDE_DIR = ${CUPTI_INCLUDE_DIR}")
add_subdirectory("${KINETO_SOURCE_DIR}")
message(STATUS "Configured libkineto as a dependency.")
message(STATUS "Configured Kineto as a dependency.")
endif()
list(APPEND Caffe2_DEPENDENCY_LIBS kineto)

View File

@ -2163,7 +2163,7 @@ TEST(TLSFutureCallbacksTest, Basic) {
// test running callbacks with propagation of TLS state.
{
// Enable the profiler in this thread
torch::autograd::profiler::enableProfiler(
torch::autograd::profiler::enableProfilerLegacy(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::CPU, false, false));
auto s1 = c10::make_intrusive<Future>(IntType::get());
@ -2172,12 +2172,12 @@ TEST(TLSFutureCallbacksTest, Basic) {
// Since we join here, we can ensure that all callbacks corresponding to
// markCompleted() have finished.
t.join();
torch::autograd::profiler::disableProfiler();
torch::autograd::profiler::disableProfilerLegacy();
}
// then() with TLS State
{
// Enable the profiler in this thread
torch::autograd::profiler::enableProfiler(
torch::autograd::profiler::enableProfilerLegacy(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::CPU, false, false));
auto s1 = c10::make_intrusive<Future>(IntType::get());
@ -2190,7 +2190,7 @@ TEST(TLSFutureCallbacksTest, Basic) {
std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); });
t.join();
s2->wait();
torch::autograd::profiler::disableProfiler();
torch::autograd::profiler::disableProfilerLegacy();
}
}
@ -2199,7 +2199,7 @@ TEST(ProfilerDisableInCallbackTest, Basic) {
auto profilerEnabledCb = []() {
ASSERT_TRUE(torch::autograd::profiler::profilerEnabled());
};
torch::autograd::profiler::enableProfiler(
torch::autograd::profiler::enableProfilerLegacy(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::CPU, false, false));
auto s1 = c10::make_intrusive<Future>(IntType::get());
@ -2212,10 +2212,10 @@ TEST(ProfilerDisableInCallbackTest, Basic) {
// Don't cleanup TLSState, and just consolidate.
auto opts = torch::autograd::profiler::ProfilerDisableOptions(false, true);
auto thread_event_lists =
torch::autograd::profiler::disableProfiler(std::move(opts));
torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
// Ensure that the events from this thread are still profiled and we obtain
// the expected in events in our consolidated list when calling
// disableProfiler().
// disableProfilerLegacy().
bool found_ones = false;
bool found_add = false;
for (const auto& li : thread_event_lists) {
@ -2237,13 +2237,13 @@ TEST(ProfilerDisableInCallbackTest, Basic) {
s1->addCallback(verifyProfilerCb);
// Disable the profiler, but do not consolidate results in the main thread.
auto opts = torch::autograd::profiler::ProfilerDisableOptions(true, false);
torch::autograd::profiler::disableProfiler(std::move(opts));
torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
std::thread t([s1 = std::move(s1)]() { s1->markCompleted(at::IValue(1)); });
t.join();
// Similar to above test, but verifies correctness in the case where
// continuation runs on the main thread.
torch::autograd::profiler::enableProfiler(
torch::autograd::profiler::enableProfilerLegacy(
torch::autograd::profiler::ProfilerConfig(
torch::autograd::profiler::ProfilerState::CPU, false, false));
s1 = c10::make_intrusive<Future>(IntType::get());
@ -2251,7 +2251,7 @@ TEST(ProfilerDisableInCallbackTest, Basic) {
// Runs callback inline
s1->markCompleted(at::IValue(1));
opts = torch::autograd::profiler::ProfilerDisableOptions(true, false);
torch::autograd::profiler::disableProfiler(std::move(opts));
torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
}
TEST(IValueKWargsTest, Basic) {

View File

@ -33,7 +33,7 @@ from torch.testing._internal.common_utils import (TEST_MKL, TEST_WITH_ROCM, Test
suppress_warnings, slowTest,
load_tests, random_symmetric_matrix,
IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck)
from torch.autograd import Variable, Function, detect_anomaly
from torch.autograd import Variable, Function, detect_anomaly, kineto_available
from torch.autograd.function import InplaceFunction
from torch.testing import randn_like
from torch.testing._internal.common_methods_invocations import (method_tests,
@ -2954,7 +2954,7 @@ class TestAutograd(TestCase):
https://github.com/pytorch/pytorch/issues/34086""")
def test_profiler_tracing(self):
t1, t2 = torch.ones(1), torch.ones(1)
with torch.autograd.profiler.profile() as prof:
with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof:
torch.add(t1, t2)
with tempfile.NamedTemporaryFile(mode="w+") as f:
@ -2969,7 +2969,7 @@ class TestAutograd(TestCase):
device = torch.device("cuda:0")
t1, t2 = torch.ones(1, device=device), torch.ones(1, device=device)
with torch.autograd.profiler.profile(use_cuda=True) as prof:
with torch.autograd.profiler.profile(use_cuda=True, use_kineto=kineto_available()) as prof:
torch.add(t1, t2)
with tempfile.NamedTemporaryFile(mode="w+") as f:
@ -2980,7 +2980,7 @@ class TestAutograd(TestCase):
def test_profiler(self):
x = torch.randn(10, 10)
with profile() as p:
with profile(use_kineto=kineto_available()) as p:
self.assertTrue(torch.autograd._profiler_enabled())
y = x * 2 + 4
@ -2991,22 +2991,21 @@ class TestAutograd(TestCase):
'aten::empty', 'aten::add', 'aten::to', 'aten::empty_strided',
'aten::copy_', 'aten::empty']
top_level_names = ['aten::mul', 'aten::add']
top_level_iter = iter(top_level_names)
self.assertEqual(len(p.function_events), len(names))
for info, expected_name in zip(p.function_events, names):
if info.cpu_interval.start > last_end:
top_level_name_expected = next(top_level_iter)
self.assertEqual(info.name, top_level_name_expected)
last_end = info.cpu_interval.end
self.assertEqual(info.name, expected_name)
for evt in p.function_events:
if evt.time_range.start > last_end:
self.assertTrue(evt.name in top_level_names)
last_end = evt.time_range.end
self.assertTrue(evt.name in names)
def test_profiler_seq_nr(self):
with profile() as p:
with profile(use_kineto=kineto_available()) as p:
x = torch.randn(10, 10, requires_grad=True)
y = torch.randn(10, 10, requires_grad=True)
z = x + y
s = z.sum()
s.backward()
print(p.key_averages().table(
sort_by="self_cpu_time_total", row_limit=-1))
# expecting aten::add, aten::sum to have the sequence numbers,
# expecting the corresponding backward nodes to have the same numbers
# as the forward ops
@ -3049,7 +3048,7 @@ class TestAutograd(TestCase):
def test_profiler_unboxed_only(self):
x = torch.rand(3, 4)
with torch.autograd.profiler.profile() as prof:
with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof:
x.resize_([3, 2])
def test_profiler_propagation(self):
@ -3074,7 +3073,7 @@ class TestAutograd(TestCase):
traced_bar = torch.jit.trace(bar, x)
with profile() as p:
with profile(use_kineto=kineto_available()) as p:
traced_bar(x)
found_foo = False
@ -3096,7 +3095,7 @@ class TestAutograd(TestCase):
def test_record_function_callbacks(self):
x = torch.randn(10, 10)
with profile() as p:
with profile(use_kineto=kineto_available()) as p:
with record_function("foo"):
y = x * 2 + 4
@ -3128,12 +3127,12 @@ class TestAutograd(TestCase):
node_id=0,
name="",
thread=thread,
cpu_start=range[0],
cpu_end=range[1],
start_us=range[0],
end_us=range[1],
)
)
events.populate_cpu_children()
events._populate_cpu_children()
# Note that [1, 3] pushes out [0, 2] first. Then we record [1, 2]
# as a child of [1, 3]
@ -3152,7 +3151,7 @@ class TestAutograd(TestCase):
"""
x = torch.randn(1024)
with torch.autograd.profiler.profile() as prof:
with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof:
torch.einsum("i->", x)
prof_str = str(prof)
@ -3162,8 +3161,8 @@ class TestAutograd(TestCase):
def test_profiler_function_event_avg(self):
avg = FunctionEventAvg()
avg.add(FunctionEvent(id=0, node_id=0, name="foo", thread=0, cpu_start=10, cpu_end=15))
avg.add(FunctionEvent(id=1, node_id=0, name="foo", thread=0, cpu_start=20, cpu_end=30))
avg.add(FunctionEvent(id=0, node_id=0, name="foo", thread=0, start_us=10, end_us=15))
avg.add(FunctionEvent(id=1, node_id=0, name="foo", thread=0, start_us=20, end_us=30))
avg.add(avg)
self.assertEqual(avg.key, "foo")
@ -3182,7 +3181,7 @@ class TestAutograd(TestCase):
layer1 = torch.nn.Linear(20, 30)
layer2 = torch.nn.Linear(30, 40)
input = torch.randn(128, 20)
with profile(record_shapes=True) as prof:
with profile(record_shapes=True, use_kineto=kineto_available()) as prof:
layer2(layer1(input))
print(prof.function_events)
@ -3198,18 +3197,18 @@ class TestAutograd(TestCase):
last_end = 0
for event in prof.function_events:
if event.cpu_interval.start > last_end:
if event.time_range.start > last_end:
name_expected, input_shape_expected = next(expected_iter)
if name_expected is not None:
self.assertEqual(event.name, name_expected)
self.assertEqual(event.input_shapes, input_shape_expected)
last_end = event.cpu_interval.end
last_end = event.time_range.end
def test_profiler_no_cuda(self):
print("")
layer = torch.nn.Linear(20, 30)
x = torch.randn(128, 20)
with profile(use_cuda=False) as prof:
with profile(use_cuda=False, use_kineto=kineto_available()) as prof:
layer(x)
prof_str = str(prof)
@ -3221,7 +3220,7 @@ class TestAutograd(TestCase):
print("")
rnn = torch.nn.LSTM(10, 20, 2)
total_time_s = 0
with profile(record_shapes=True) as prof:
with profile(record_shapes=True, use_kineto=kineto_available()) as prof:
for i in range(20):
input = torch.randn(5, 3, 10)
h = torch.randn(2, 3, 20)
@ -3258,7 +3257,7 @@ class TestAutograd(TestCase):
def test_memory_profiler(self):
def run_profiler(tensor_creation_fn, metric):
# collecting allocs / deallocs
with profile(profile_memory=True, record_shapes=True) as prof:
with profile(profile_memory=True, record_shapes=True, use_kineto=kineto_available()) as prof:
x = None
with record_function("test_user_scope_alloc"):
x = tensor_creation_fn()
@ -3350,7 +3349,7 @@ class TestAutograd(TestCase):
# check partial overlap of tensor allocation with memory profiler
x = torch.rand(10, 10)
with profile(profile_memory=True, record_shapes=True) as prof:
with profile(profile_memory=True, record_shapes=True, use_kineto=kineto_available()) as prof:
del x
x = torch.rand(10, 10)
del x
@ -3376,7 +3375,7 @@ class TestAutograd(TestCase):
forward(x)
with profile() as p:
with profile(use_kineto=kineto_available()) as p:
forward(x)
events = p.function_events
@ -3401,7 +3400,7 @@ class TestAutograd(TestCase):
def f(x, y):
return x + y
with profile() as p:
with profile(use_kineto=kineto_available()) as p:
f(1, 2)
self.assertTrue('my_func' in str(p))

View File

@ -44,6 +44,7 @@ from torch._six import PY37, StringIO
from torch.autograd import Variable
from torch.jit.annotations import BroadcastingList2, BroadcastingList3, Any # noqa: F401
from torch.testing import FileCheck
import torch.autograd.profiler
import torch.cuda
import torch.jit
import torch.jit._logging
@ -2552,10 +2553,10 @@ graph(%Ra, %Rb):
for e in prof.function_events:
if e.name == "aten::mul":
self.assertTrue(e.thread not in mul_events)
mul_events[e.thread] = e.cpu_interval.elapsed_us()
mul_events[e.thread] = e.time_range.elapsed_us()
elif e.name == "other_fn":
self.assertTrue(e.thread not in other_fn_events)
other_fn_events[e.thread] = e.cpu_interval.elapsed_us()
other_fn_events[e.thread] = e.time_range.elapsed_us()
self.assertTrue(len(mul_events) == 2)
self.assertTrue(len(other_fn_events) == 2)
@ -8268,7 +8269,7 @@ dedent """
def _test_dtype_op_shape(self, ops, args, input_dims=1):
if input_dims < 1:
raise 'input dims must be at least 1'
raise RuntimeError("input dims must be at least 1")
dtypes = [torch.float32, torch.float64, torch.int64, torch.int32]
str_args = ', '.join([str(arg) for arg in args]) + (', ' if len(args) else '')
tensor_data = ('[' * input_dims) + '1, 2, 3' + (input_dims * ']')

View File

@ -7,6 +7,7 @@ import torch.nn as nn
from torch.testing._internal.common_utils import (
TestCase, run_tests, TEST_WITH_ASAN, IS_WINDOWS)
from torch.autograd.profiler import profile
from torch.autograd import kineto_available
try:
import psutil
@ -73,7 +74,7 @@ class TestProfiler(TestCase):
mod = DummyModule()
with profile(with_stack=True) as p:
with profile(with_stack=True, use_kineto=kineto_available()) as p:
x = torch.randn(10, 10, requires_grad=True)
y = torch.randn(10, 10, requires_grad=True)
z = x + y
@ -99,6 +100,34 @@ class TestProfiler(TestCase):
torch._C._set_graph_executor_optimize(prev_opt)
def payload(self):
x = torch.randn(10, 10).cuda()
y = torch.randn(10, 10).cuda()
z = torch.mm(x, y)
z = z + y
z = z.cpu()
@unittest.skipIf(not kineto_available(), "Kineto is required")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
def test_kineto(self):
with profile(use_cuda=True, use_kineto=True):
self.payload()
# rerun to avoid initial start overhead
with profile(use_cuda=True, use_kineto=True) as p:
self.payload()
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
found_gemm = False
found_memcpy = False
for e in p.function_events:
if "gemm" in e.name:
found_gemm = True
if "Memcpy" in e.name or "memcpy" in e.name:
found_memcpy = True
self.assertTrue(found_gemm)
self.assertTrue(found_memcpy)
# p.export_chrome_trace("/tmp/test_trace.json")
if __name__ == '__main__':
run_tests()

View File

@ -74,7 +74,8 @@ jit_core_sources = [
# list for the shared files.
core_sources_common = [
"torch/csrc/autograd/profiler.cpp",
"torch/csrc/autograd/profiler_legacy.cpp",
"torch/csrc/autograd/profiler_kineto.cpp",
"torch/csrc/jit/frontend/edit_distance.cpp",
"torch/csrc/jit/frontend/string_to_type.cpp",
"torch/csrc/jit/mobile/type_parser.cpp",

View File

@ -1,4 +1,4 @@
from typing import List
from typing import List, Set
from enum import Enum
# Defined in tools/autograd/init.cpp
@ -8,7 +8,16 @@ class ProfilerState(Enum):
CPU = ...
CUDA = ...
NVTX = ...
KINETO = ...
class ProfilerActivity(Enum):
CPU = ...
CUDA = ...
class DeviceType(Enum):
CPU = ...
CUDA = ...
...
class ProfilerConfig:
def __init__(
@ -37,9 +46,25 @@ class ProfilerEvent:
def thread_id(self) -> int: ...
...
class KinetoEvent:
def name(self) -> str: ...
def device_index(self) -> int: ...
def start_us(self) -> int: ...
def duration_us(self) -> int: ...
...
def _enable_profiler(config: ProfilerConfig) -> None: ...
def _disable_profiler() -> List[List[ProfilerEvent]]: ...
class ProfilerResult:
def events(self) -> List[KinetoEvent]: ...
def legacy_events(self) -> List[List[ProfilerEvent]]: ...
def save(self, str) -> None: ...
def _enable_profiler(config: ProfilerConfig, activities: Set[ProfilerActivity]) -> None: ...
def _prepare_profiler(config: ProfilerConfig, activities: Set[ProfilerActivity]) -> None: ...
def _disable_profiler() -> ProfilerResult: ...
def _profiler_enabled() -> bool: ...
def kineto_available() -> bool: ...
def _enable_record_function(enable: bool) -> None: ...
def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ...

View File

@ -18,7 +18,6 @@ from .gradcheck import gradcheck, gradgradcheck
from .grad_mode import no_grad, enable_grad, set_grad_enabled
from .anomaly_mode import detect_anomaly, set_detect_anomaly
from ..overrides import has_torch_function, handle_torch_function
from . import profiler
from . import functional
__all__ = ['Variable', 'Function', 'backward', 'grad_mode']
@ -251,6 +250,10 @@ if not torch._C._autograd_init():
raise RuntimeError("autograd initialization failed")
# Import all native method/classes
from torch._C._autograd import (ProfilerState, ProfilerConfig, ProfilerEvent,
_enable_profiler, _disable_profiler, _profiler_enabled,
_enable_record_function, _set_empty_test_observer)
from torch._C._autograd import (DeviceType, ProfilerActivity, ProfilerState, ProfilerConfig, ProfilerEvent,
_enable_profiler_legacy, _disable_profiler_legacy, _profiler_enabled,
_enable_record_function, _set_empty_test_observer, kineto_available)
if kineto_available():
from torch._C._autograd import (ProfilerResult, KinetoEvent,
_prepare_profiler, _enable_profiler, _disable_profiler)

View File

@ -1,12 +1,13 @@
import itertools
from typing import Any
import torch
from torch.autograd import DeviceType
from torch.futures import Future
from collections import defaultdict, namedtuple
from operator import attrgetter
from typing import List, Dict, Tuple, Optional
from typing import Dict, List, Tuple, Optional
try:
# Available in Python >= 3.2
@ -37,14 +38,38 @@ class EventList(list):
use_cuda = kwargs.pop('use_cuda', True)
profile_memory = kwargs.pop('profile_memory', False)
super(EventList, self).__init__(*args, **kwargs)
self._cpu_children_populated = False
self._use_cuda = use_cuda
self._profile_memory = profile_memory
self._tree_built = False
def _build_tree(self):
self._populate_cpu_children()
self._remove_dup_nodes()
self._set_backward_stacktraces()
self._tree_built = True
def __str__(self):
return self.table()
def populate_cpu_children(self):
def _remove_dup_nodes(self):
while True:
to_delete = []
for idx in range(len(self)):
if (self[idx].cpu_parent is not None and
self[idx].cpu_parent.name == self[idx].name and
len(self[idx].cpu_parent.cpu_children) == 1):
self[idx].cpu_parent.cpu_children = self[idx].cpu_children
self[idx].cpu_parent.kernels = self[idx].kernels # lift kernels up
for ch in self[idx].cpu_children:
ch.cpu_parent = self[idx].cpu_parent
to_delete.append(idx)
if len(to_delete) == 0:
break
new_evts = [ev for ind, ev in enumerate(self) if ind not in to_delete]
self.clear()
self.extend(new_evts)
def _populate_cpu_children(self):
"""Populates child events into each underlying FunctionEvent object.
One event is a child of another if [s1, e1) is inside [s2, e2). Where
s1 and e1 would be start and end of the child event's interval. And
@ -56,13 +81,11 @@ class EventList(list):
If for any reason two intervals intersect only partially, this function
will not record a parent child relationship between then.
"""
if self.cpu_children_populated:
return
# Some events can be async (i.e. start and end on different threads),
# since it's generally undefined how to attribute children ranges to
# async ranges, we do not use them when calculating nested ranges and stats
sync_events = [evt for evt in self if not evt.is_async]
sync_events = [evt for evt in self if not evt.is_async and evt.device_type == DeviceType.CPU]
events = sorted(
sync_events,
key=attrgetter("thread"),
@ -89,15 +112,15 @@ class EventList(list):
for thread_id, thread_events in threads:
thread_events_ = sorted(
thread_events,
key=lambda event: [event.cpu_interval.start, -event.cpu_interval.end],
key=lambda event: [event.time_range.start, -event.time_range.end],
)
current_events: List[FunctionEvent] = []
cur_end = 0
for event in thread_events_:
while len(current_events) > 0:
parent = current_events[-1]
if event.cpu_interval.start >= parent.cpu_interval.end or \
event.cpu_interval.end > parent.cpu_interval.end:
if event.time_range.start >= parent.time_range.end or \
event.time_range.end > parent.time_range.end:
# this can't be a parent
current_events.pop()
else:
@ -112,22 +135,18 @@ class EventList(list):
current_events.append(event)
self._cpu_children_populated = True
def set_backward_stacktraces(self):
self.populate_cpu_children()
def _set_backward_stacktraces(self):
def bw_parent(evt):
if evt is None:
return None
elif evt.scope == 1:
elif evt.scope == 1: # BACKWARD_FUNCTION
return evt
else:
return bw_parent(evt.cpu_parent)
fwd_stacks = {}
for evt in self:
if bw_parent(evt) is None:
if bw_parent(evt) is None and evt.stack is not None:
t = (evt.sequence_nr, evt.thread)
if t not in fwd_stacks:
fwd_stacks[t] = evt.stack
@ -142,15 +161,10 @@ class EventList(list):
else:
evt.stack = []
@property
def self_cpu_time_total(self):
return sum([event.self_cpu_time_total for event in self])
@property
def cpu_children_populated(self):
return self._cpu_children_populated
def table(self, sort_by=None, row_limit=100, max_src_column_width=75, header=None, top_level_events_only=False):
"""Prints an EventList as a nicely formatted table.
@ -205,8 +219,8 @@ class EventList(list):
'"args": {}}, '
% (
evt.name,
evt.cpu_interval.start,
evt.cpu_interval.elapsed_us(),
evt.time_range.start,
evt.time_range.elapsed_us(),
evt.thread
if not evt.is_remote
else f'" node_id:{evt.node_id}, thread_id:{evt.thread} "',
@ -222,7 +236,7 @@ class EventList(list):
'"pid": "CPU functions", '
'"id": %s, '
'"cat": "cpu_to_cuda", '
'"args": {}}, ' % (evt.name, evt.cpu_interval.start,
'"args": {}}, ' % (evt.name, evt.time_range.start,
evt.thread, next_id))
f.write('{"name": "%s", '
'"ph": "f", '
@ -262,11 +276,11 @@ class EventList(list):
Returns:
An EventList containing FunctionEventAvg objects.
"""
self.populate_cpu_children()
stats: Dict[Tuple[int, Tuple[int, int]], FunctionEventAvg] = defaultdict(FunctionEventAvg)
assert self._tree_built
stats: Dict[Tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg)
def get_key(event, group_by_input_shapes, group_by_stack_n):
key = [str(event.key), str(event.node_id)]
def get_key(event, group_by_input_shapes, group_by_stack_n) -> Tuple[str, ...]:
key = [str(event.key), str(event.node_id), str(event.device_type), str(event.is_legacy)]
if group_by_input_shapes:
key.append(str(event.input_shapes))
if group_by_stack_n > 0:
@ -326,6 +340,11 @@ class profile(object):
with_stack (bool, optional): record source information (file and line number) for the ops
use_kineto (bool, default False): experimental support for Kineto profiler
use_cpu (default True) - whether to profile CPU events; setting to False requires
use_kineto=True and can be used to lower the overhead for GPU-only profiling
.. warning:
Enabling memory profiling or source attribution incurs additional profiler
overhead
@ -365,44 +384,83 @@ class profile(object):
use_cuda=False,
record_shapes=False,
profile_memory=False,
with_stack=False):
self.enabled = enabled
self.use_cuda = use_cuda
self.function_events = None
with_stack=False,
use_kineto=False,
use_cpu=True):
self.enabled: bool = enabled
if not self.enabled:
return
self.use_cuda = use_cuda
self.function_events = None
self.entered = False
self.record_shapes = record_shapes
self.profile_memory = profile_memory
self.with_stack = with_stack
self.use_cpu = use_cpu
self.kineto_results = None
if not self.use_cpu:
assert use_kineto, \
"Device-only events supported only with Kineto (use_kineto=True)"
self.profiler_kind = None
self.kineto_activities = set()
if use_kineto:
self.profiler_kind = torch.autograd.ProfilerState.KINETO
if self.use_cpu:
self.kineto_activities.add(torch.autograd.ProfilerActivity.CPU)
if self.use_cuda:
self.kineto_activities.add(
# uses CUPTI
torch.autograd.ProfilerActivity.CUDA)
assert len(self.kineto_activities) > 0, \
"No activities specified for Kineto profiler"
elif self.use_cuda:
# legacy CUDA mode
self.profiler_kind = torch.autograd.ProfilerState.CUDA
else:
self.profiler_kind = torch.autograd.ProfilerState.CPU
if self.profiler_kind == torch.autograd.ProfilerState.KINETO:
assert (
torch.autograd.kineto_available()
), """Requested Kineto profiling but Kineto is not available,
make sure PyTorch is built with USE_KINETO=1"""
def config(self):
assert self.profiler_kind is not None
return torch.autograd.ProfilerConfig(
self.profiler_kind,
self.record_shapes,
self.profile_memory,
self.with_stack)
def __enter__(self):
if not self.enabled:
return
if self.entered:
raise RuntimeError("autograd profiler traces are not reentrant")
raise RuntimeError("profiler context manager is not reentrant")
self.entered = True
profiler_kind = torch.autograd.ProfilerState.CUDA if self.use_cuda \
else torch.autograd.ProfilerState.CPU
config = torch.autograd.ProfilerConfig(
profiler_kind,
self.record_shapes,
self.profile_memory,
self.with_stack)
torch.autograd._enable_profiler(config)
if self.kineto_activities:
torch.autograd._prepare_profiler(self.config(), self.kineto_activities)
torch.autograd._enable_profiler(self.config(), self.kineto_activities)
else:
torch.autograd._enable_profiler_legacy(self.config())
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if not self.enabled:
return
records = torch.autograd._disable_profiler()
if self.kineto_activities:
self.kineto_results = torch.autograd._disable_profiler()
parsed_results = parse_kineto_results(self.kineto_results)
else:
records = torch.autograd._disable_profiler_legacy()
parsed_results = parse_legacy_records(records)
self.function_events = EventList(
parse_event_records(records),
parsed_results,
use_cuda=self.use_cuda,
profile_memory=self.profile_memory)
if self.with_stack:
self.function_events.set_backward_stacktraces()
self.function_events._build_tree()
return False
def __repr__(self):
@ -413,13 +471,11 @@ class profile(object):
def __str__(self):
if self.function_events is None:
return '<unfinished torch.autograd.profile>'
self.function_events.populate_cpu_children()
return str(self.function_events)
def _check_finish(self):
if self.function_events is None:
raise RuntimeError("can't export a trace that didn't finish running")
self.function_events.populate_cpu_children()
def table(self, sort_by=None, row_limit=100, max_src_column_width=75, header=None, top_level_events_only=False):
self._check_finish()
@ -432,8 +488,11 @@ class profile(object):
def export_chrome_trace(self, path):
self._check_finish()
assert self.function_events is not None
return self.function_events.export_chrome_trace(path)
if self.kineto_results is not None:
self.kineto_results.save(path)
else:
assert self.function_events is not None
return self.function_events.export_chrome_trace(path)
export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
@ -630,7 +689,7 @@ class emit_nvtx(object):
raise RuntimeError("NVTX annotation context manager is not reentrant")
self.entered = True
torch.cuda.synchronize()
torch.autograd._enable_profiler(
torch.autograd._enable_profiler_legacy(
torch.autograd.ProfilerConfig(
torch.autograd.ProfilerState.NVTX,
self.record_shapes,
@ -643,7 +702,7 @@ class emit_nvtx(object):
if not self.enabled:
return
torch.cuda.synchronize()
torch.autograd._disable_profiler()
torch.autograd._disable_profiler_legacy()
return False
@ -731,13 +790,14 @@ Kernel = namedtuple('Kernel', ['name', 'device', 'interval'])
class FunctionEvent(FormattedTimesMixin):
"""Profiling information about a single function."""
def __init__(
self, id, node_id, name, thread, cpu_start, cpu_end, fwd_thread=None, input_shapes=None,
self, id, name, thread, start_us, end_us, fwd_thread=None, input_shapes=None,
stack=None, scope=0, cpu_memory_usage=0, cuda_memory_usage=0, is_async=False,
is_remote=True, sequence_nr=-1):
is_remote=False, sequence_nr=-1, node_id=-1, device_type=DeviceType.CPU, device_index=0,
is_legacy=False):
self.id: int = id
self.node_id: int = node_id
self.name: str = name
self.cpu_interval: Interval = Interval(cpu_start, cpu_end)
self.time_range: Interval = Interval(start_us, end_us)
self.thread: int = thread
self.fwd_thread: Optional[int] = fwd_thread
self.kernels: List[Kernel] = []
@ -752,8 +812,12 @@ class FunctionEvent(FormattedTimesMixin):
self.is_async: bool = is_async
self.is_remote: bool = is_remote
self.sequence_nr: int = sequence_nr
self.device_type: DeviceType = device_type
self.device_index: int = device_index
self.is_legacy: bool = is_legacy
def append_kernel(self, name, device, start, end):
assert self.device_type == DeviceType.CPU
self.kernels.append(Kernel(name, device, Interval(start, end)))
def append_cpu_child(self, child):
@ -762,7 +826,9 @@ class FunctionEvent(FormattedTimesMixin):
One is supposed to append only direct children to the event to have
correct self cpu time being reported.
"""
assert(self.device_type == DeviceType.CPU)
assert(isinstance(child, FunctionEvent))
assert(child.device_type == DeviceType.CPU)
self.cpu_children.append(child)
def set_cpu_parent(self, parent):
@ -772,14 +838,16 @@ class FunctionEvent(FormattedTimesMixin):
the child's range interval is completely inside the parent's. We use
this connection to determine the event is from top-level op or not.
"""
assert(self.device_type == DeviceType.CPU)
assert(isinstance(parent, FunctionEvent))
assert(parent.device_type == DeviceType.CPU)
self.cpu_parent = parent
# Note: async events don't have children, are not used when computing 'self'
# metrics of other events, have only total cpu time
@property
def self_cpu_memory_usage(self):
if self.is_async:
if self.is_async or self.device_type != DeviceType.CPU:
return 0
return self.cpu_memory_usage - sum(
[child.cpu_memory_usage for child in self.cpu_children]
@ -787,7 +855,7 @@ class FunctionEvent(FormattedTimesMixin):
@property
def self_cuda_memory_usage(self):
if self.is_async:
if self.is_async or self.device_type != DeviceType.CPU:
return 0
return self.cuda_memory_usage - sum(
[child.cuda_memory_usage for child in self.cpu_children]
@ -795,7 +863,7 @@ class FunctionEvent(FormattedTimesMixin):
@property
def self_cpu_time_total(self):
if self.is_async:
if self.is_async or self.device_type != DeviceType.CPU:
return 0
return self.cpu_time_total - sum(
[child.cpu_time_total for child in self.cpu_children]
@ -803,16 +871,37 @@ class FunctionEvent(FormattedTimesMixin):
@property
def cuda_time_total(self):
return sum(kinfo.interval.elapsed_us() for kinfo in self.kernels)
if self.is_async:
return 0
if self.device_type == DeviceType.CPU:
if not self.is_legacy:
# account for the kernels in the children ops
return (sum(kinfo.interval.elapsed_us() for kinfo in self.kernels) +
sum(ch.cuda_time_total for ch in self.cpu_children))
else:
# each legacy cpu events has a single (fake) kernel
return sum(kinfo.interval.elapsed_us() for kinfo in self.kernels)
else:
assert self.device_type == DeviceType.CUDA
return self.time_range.elapsed_us()
@property
def self_cuda_time_total(self):
return sum(kinfo.interval.elapsed_us() for kinfo in self.kernels) - \
sum([child.cuda_time_total for child in self.cpu_children])
if self.is_async:
return 0
if self.device_type == DeviceType.CPU:
return self.cuda_time_total - \
sum([child.cuda_time_total for child in self.cpu_children])
else:
assert(self.device_type == DeviceType.CUDA)
return self.cuda_time_total
@property
def cpu_time_total(self):
return self.cpu_interval.elapsed_us()
if self.device_type == DeviceType.CPU:
return self.time_range.elapsed_us()
else:
return 0
@property
def key(self):
@ -820,14 +909,16 @@ class FunctionEvent(FormattedTimesMixin):
def __repr__(self):
return (
'<FunctionEvent id={} node_id={} cpu_time={} cpu_start={} cpu_end={} '
'<FunctionEvent id={} name={} device_type={} node_id={} cpu_time={} start_us={} end_us={} '
'cpu_children={} cuda_time={} name={} thread={} input_shapes={} '
'cpu_memory_usage={} cuda_memory_usage={} is_async={} is_remote={} seq_nr={}>'.format(
'cpu_memory_usage={} cuda_memory_usage={} is_async={} is_remote={} seq_nr={} is_legacy={}>'.format(
self.id,
self.name,
self.device_type,
self.node_id,
self.cpu_time_str,
self.cpu_interval.start,
self.cpu_interval.end,
self.time_range.start,
self.time_range.end,
str([child.id for child in self.cpu_children]),
self.cuda_time_str,
self.name,
@ -838,6 +929,7 @@ class FunctionEvent(FormattedTimesMixin):
self.is_async,
self.is_remote,
self.sequence_nr,
self.is_legacy,
)
)
@ -863,6 +955,8 @@ class FunctionEventAvg(FormattedTimesMixin):
self.self_cuda_memory_usage: int = 0
self.cpu_children: Optional[List[FunctionEvent]] = None
self.cpu_parent: Optional[FunctionEvent] = None
self.device_type: DeviceType = DeviceType.CPU
self.is_legacy: bool = False
def add(self, other):
if self.key is None:
@ -878,6 +972,8 @@ class FunctionEventAvg(FormattedTimesMixin):
self.input_shapes = other.input_shapes
self.stack = other.stack
self.scope = other.scope
self.device_type = other.device_type
self.is_legacy = other.is_legacy
assert isinstance(other, (FunctionEvent, FunctionEventAvg))
assert other.key == self.key
@ -923,10 +1019,111 @@ class StringTable(defaultdict):
self[key] = torch._C._demangle(key) if len(key) > 1 else key
return self[key]
def parse_event_records(thread_records):
def filter_stack_entry(entry):
filtered_entries = [
("autograd/__init__", "_make_grads"),
("autograd/__init__", "backward"),
("torch/tensor", "backward"),
("_internal/common_utils", "prof_callable"),
("_internal/common_utils", "prof_func_call"),
("_internal/common_utils", "prof_meth_call"),
]
return all([not (f[0] in entry and f[1] in entry) for f in filtered_entries])
def filter_name(name):
# ignoring the following utility ops
filtered_out_names = [
"profiler::_record_function_enter",
"profiler::_record_function_exit",
"aten::is_leaf",
"aten::output_nr",
"aten::_version",
]
return name in filtered_out_names
# Parsing of kineto profiler events
def parse_kineto_results(result):
# result.events() has most of the events - PyTorch op-level and device-level events
# result.legacy_events() has events not yet ported to kineto
# (e.g. start/stop marks, tensor memory allocator events)
# First, find __start_profile mark to get the absolute time of the start of the trace;
# save memory allocation records
start_record = None
mem_records = []
for record in itertools.chain(*result.legacy_events()):
if record.kind() == 'mark' and record.name() == '__start_profile':
assert start_record is None
start_record = record
if record.kind() == 'memory_alloc':
mem_records.append(record)
assert start_record is not None, "Invalid profiler output, __start_profile is missing"
# Create and return FunctionEvent list
string_table = StringTable()
function_events = []
cuda_corr_map: Dict[int, List[torch.autograd.KinetoEvent]] = {}
for kineto_event in result.events():
if filter_name(kineto_event.name()):
continue
rel_start_us = kineto_event.start_us() - start_record.start_us()
rel_end_us = rel_start_us + kineto_event.duration_us()
abs_end_us = kineto_event.start_us() + kineto_event.duration_us()
cpu_memory_usage = 0
cuda_memory_usage = 0
if kineto_event.device_type() == DeviceType.CPU:
# find the corresponding memory allocation events
for mem_record in mem_records:
if (mem_record.start_us() >= kineto_event.start_us() and
mem_record.start_us() <= abs_end_us):
cpu_memory_usage += mem_record.cpu_memory_usage()
cuda_memory_usage += mem_record.cuda_memory_usage()
is_async = kineto_event.start_thread_id() != kineto_event.end_thread_id()
fe = FunctionEvent(
id=kineto_event.correlation_id(),
name=string_table[kineto_event.name()],
thread=kineto_event.start_thread_id(),
start_us=rel_start_us,
end_us=rel_end_us,
fwd_thread=kineto_event.fwd_thread_id(),
input_shapes=kineto_event.shapes(),
stack=[entry for entry in kineto_event.stack() if filter_stack_entry(entry)],
scope=kineto_event.scope(),
cpu_memory_usage=cpu_memory_usage,
cuda_memory_usage=cuda_memory_usage,
is_async=is_async,
sequence_nr=kineto_event.sequence_nr(),
device_type=kineto_event.device_type(),
device_index=kineto_event.device_index(),
)
function_events.append(fe)
if kineto_event.device_type() == DeviceType.CUDA:
corr_id = kineto_event.linked_correlation_id()
if corr_id > 0:
if corr_id not in cuda_corr_map:
cuda_corr_map[corr_id] = []
cuda_corr_map[corr_id].append(kineto_event)
# associate CUDA kernels with CPU events
for fe in function_events:
if (fe.device_type == DeviceType.CPU and not fe.is_async and
fe.id in cuda_corr_map):
for k_evt in cuda_corr_map[fe.id]:
fe.append_kernel(
k_evt.name(),
k_evt.device_index(),
k_evt.start_us(),
k_evt.start_us() + k_evt.duration_us())
function_events.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end])
return function_events
# Parsing of legacy profiler events
def parse_legacy_records(thread_records):
def get_record_key(record):
"""
Returns a tuple to be used by parse_event_records for correlating start and
Returns a tuple to be used by parse_legacy_records for correlating start and
end records.
"""
return (record.handle(), record.node_id())
@ -938,26 +1135,6 @@ def parse_event_records(thread_records):
record_stack = []
string_table = StringTable()
# ignoring the following utility ops
filtered_out_names = [
"profiler::_record_function_enter",
"profiler::_record_function_exit",
"aten::is_leaf",
"aten::output_nr",
"aten::_version",
]
def filter_stack_entry(entry):
filtered_entries = [
("autograd/__init__", "_make_grads"),
("autograd/__init__", "backward"),
("torch/tensor", "backward"),
("_internal/common_utils", "prof_callable"),
("_internal/common_utils", "prof_func_call"),
("_internal/common_utils", "prof_meth_call"),
]
return all([not (f[0] in entry and f[1] in entry) for f in filtered_entries])
# cuda start events and the overall profiler start event don't happen
# at exactly the same time because we need to record an event on each device
# and each record takes ~4us. So we adjust here by the difference
@ -994,7 +1171,7 @@ def parse_event_records(thread_records):
prev_record = None
for record in thread_record_list:
record_key = get_record_key(record)
if (record.name() in filtered_out_names or
if (filter_name(record.name()) or
record_key in filtered_handles):
filtered_handles.add(record_key)
continue
@ -1035,8 +1212,8 @@ def parse_event_records(thread_records):
node_id=record.node_id(),
name=string_table[start.name()],
thread=start.thread_id(),
cpu_start=start_record.cpu_elapsed_us(start),
cpu_end=start_record.cpu_elapsed_us(record),
start_us=start_record.cpu_elapsed_us(start),
end_us=start_record.cpu_elapsed_us(record),
fwd_thread=start.fwd_thread_id(),
input_shapes=start.shapes(),
stack=[entry for entry in start.stack() if filter_stack_entry(entry)],
@ -1046,6 +1223,8 @@ def parse_event_records(thread_records):
is_async=is_async,
is_remote=is_remote_event,
sequence_nr=start.sequence_nr(),
device_type=DeviceType.CPU,
is_legacy=True,
)
# note: async events have only cpu total time
if not is_async and start.has_cuda():
@ -1074,7 +1253,7 @@ def parse_event_records(thread_records):
# granularity of the given clock tick)--we always show
# the outermost nested call first. This adds stability
# in how FunctionEvents appear
functions.sort(key=lambda evt: [evt.cpu_interval.start, -evt.cpu_interval.end])
functions.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end])
return functions
@ -1121,8 +1300,8 @@ def parse_nvprof_trace(path):
node_id=0, # missing a node_id when calling FunctionEvent. This is just to ensure
# that pytorch doesn't crash when creating a FunctionEvent() object
name=strings[row['name']],
cpu_start=row['start_time'],
cpu_end=row['end_time'],
start_us=row['start_time'],
end_us=row['end_time'],
thread=0) # TODO: find in sqlite database
functions.append(evt)
functions_map[evt.id] = evt
@ -1153,7 +1332,7 @@ def parse_nvprof_trace(path):
row['kernel_start'],
row['kernel_end'])
functions.sort(key=lambda evt: evt.cpu_interval.start)
functions.sort(key=lambda evt: evt.time_range.start)
return functions
@ -1182,7 +1361,9 @@ def build_table(
has_input_shapes = any(
[(event.input_shapes is not None and len(event.input_shapes) > 0) for event in events])
MAX_NAME_COLUMN_WIDTH = 55
name_column_width = max([len(evt.key) for evt in events]) + 4
name_column_width = min(name_column_width, MAX_NAME_COLUMN_WIDTH)
DEFAULT_COLUMN_WIDTH = 12
@ -1269,7 +1450,16 @@ def build_table(
result.append('\n') # Yes, newline after the end as well
self_cpu_time_total = sum([event.self_cpu_time_total for event in events])
cuda_time_total = sum([evt.self_cuda_time_total for evt in events])
cuda_time_total = 0
for evt in events:
if evt.device_type == DeviceType.CPU:
# in legacy profiler, kernel info is stored in cpu events
if evt.is_legacy:
cuda_time_total += evt.self_cuda_time_total
elif evt.device_type == DeviceType.CUDA:
# in kineto mode, there're events with the correct device type (e.g. CUDA)
cuda_time_total += evt.self_cuda_time_total
# Actual printing
if header is not None:
append('=' * line_length)
@ -1290,8 +1480,11 @@ def build_table(
continue
else:
event_limit += 1
name = evt.key
if len(name) >= MAX_NAME_COLUMN_WIDTH - 3:
name = name[:(MAX_NAME_COLUMN_WIDTH - 3)] + "..."
row_values = [
evt.key, # Name
name,
# Self CPU total, 0 for async events. %
format_time_share(evt.self_cpu_time_total,
self_cpu_time_total),

View File

@ -1,5 +1,6 @@
#include <torch/csrc/python_headers.h>
#include <c10/core/DeviceType.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/autograd/grad_mode.h>
@ -39,37 +40,132 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
.value("Disabled", ProfilerState::Disabled)
.value("CPU", ProfilerState::CPU)
.value("CUDA", ProfilerState::CUDA)
.value("NVTX", ProfilerState::NVTX);
.value("NVTX", ProfilerState::NVTX)
.value("KINETO", ProfilerState::KINETO);
py::enum_<ActivityType>(m, "ProfilerActivity")
.value("CPU", ActivityType::CPU)
.value("CUDA", ActivityType::CUDA);
py::class_<ProfilerConfig>(m, "ProfilerConfig")
.def(py::init<ProfilerState, bool, bool, bool>());
py::class_<Event>(m, "ProfilerEvent")
.def("kind", &Event::kind)
.def("name", [](const Event& e) { return e.name(); })
.def("thread_id", &Event::threadId)
.def("fwd_thread_id", &Event::fwdThreadId)
.def("device", &Event::device)
.def("cpu_elapsed_us", &Event::cpuElapsedUs)
.def("cuda_elapsed_us", &Event::cudaElapsedUs)
.def("has_cuda", &Event::hasCuda)
.def("shapes", &Event::shapes)
.def("cpu_memory_usage", &Event::cpuMemoryUsage)
.def("cuda_memory_usage", &Event::cudaMemoryUsage)
.def("handle", &Event::handle)
.def("node_id", &Event::nodeId)
.def("is_remote", &Event::isRemote)
.def("sequence_nr", &Event::sequenceNr)
.def("stack", &Event::stack)
.def("scope", &Event::scope);
py::class_<LegacyEvent>(m, "ProfilerEvent")
.def("kind", &LegacyEvent::kindStr)
.def("name", [](const LegacyEvent& e) { return e.name(); })
.def("thread_id", &LegacyEvent::threadId)
.def("fwd_thread_id", &LegacyEvent::fwdThreadId)
.def("device", &LegacyEvent::device)
.def("cpu_elapsed_us", &LegacyEvent::cpuElapsedUs)
.def("cuda_elapsed_us", &LegacyEvent::cudaElapsedUs)
.def("has_cuda", &LegacyEvent::hasCuda)
.def("shapes", &LegacyEvent::shapes)
.def("cpu_memory_usage", &LegacyEvent::cpuMemoryUsage)
.def("cuda_memory_usage", &LegacyEvent::cudaMemoryUsage)
.def("handle", &LegacyEvent::handle)
.def("node_id", &LegacyEvent::nodeId)
.def("is_remote", &LegacyEvent::isRemote)
.def("sequence_nr", &LegacyEvent::sequenceNr)
.def("stack", &LegacyEvent::stack)
.def("scope", &LegacyEvent::scope)
.def("correlation_id", &LegacyEvent::correlationId)
.def("start_us", &LegacyEvent::cpuUs);
py::class_<ProfilerDisableOptions>(m, "_ProfilerDisableOptions")
.def(py::init<bool, bool>());
py::enum_<c10::DeviceType>(m, "DeviceType")
.value("CPU", c10::DeviceType::CPU)
.value("CUDA", c10::DeviceType::CUDA)
.value("MKLDNN", c10::DeviceType::MKLDNN)
.value("OPENGL", c10::DeviceType::OPENGL)
.value("OPENCL", c10::DeviceType::OPENCL)
.value("IDEEP", c10::DeviceType::IDEEP)
.value("HIP", c10::DeviceType::HIP)
.value("FPGA", c10::DeviceType::FPGA)
.value("MSNPU", c10::DeviceType::MSNPU)
.value("XLA", c10::DeviceType::XLA)
.value("Vulkan", c10::DeviceType::Vulkan)
.value("Metal", c10::DeviceType::Metal);
#ifdef USE_KINETO
py::class_<KinetoEvent>(m, "KinetoEvent")
// name of the event
.def("name", &KinetoEvent::name)
// PyTorch thread id of the start callback
.def("start_thread_id", [](const KinetoEvent& e) {
return e.startThreadId();
})
// PyTorch thread id of the end callback
.def("end_thread_id", [](const KinetoEvent& e) {
return e.endThreadId();
})
// for events of scope BACKWARD_FUNCTION - PyTorch thread id
// of the corresponding forward op
.def("fwd_thread_id", [](const KinetoEvent& e) {
return e.fwdThreadId();
})
// together with fwd_thread_id, used to uniquely identify
// the forward op
.def("sequence_nr", [](const KinetoEvent& e) {
return e.sequenceNr();
})
// absolute start time (since unix epoch) in us
.def("start_us", &KinetoEvent::startUs)
// duration in us
.def("duration_us", &KinetoEvent::durationUs)
// used for correlation between high-level PyTorch events
// and low-level device events
.def("correlation_id", [](const KinetoEvent& e) {
return e.correlationId();
})
// shapes of input tensors
.def("shapes", [](const KinetoEvent& e) {
if (e.hasShapes()) {
return e.shapes();
} else {
return std::vector<std::vector<int64_t>>();
}
})
// stack traces of the PyTorch CPU events
.def("stack", [](const KinetoEvent& e) {
if (e.hasStack()) {
return e.stack();
} else {
return std::vector<std::string>();
}
})
// type of the RecordFunction that generated a PyTorch CPU event
// (op, torchscript function, user label, etc)
.def("scope", [](const KinetoEvent& e) {
return e.scope();
})
// device number, for CPU - process id
.def("device_index", &KinetoEvent::deviceIndex)
// for CUDA - stream id, for CPU - start thread id
.def("device_resource_id", &KinetoEvent::deviceResourceId)
// device type
.def("device_type", [](const KinetoEvent& e) {
return e.deviceType();
})
// correlation id of a linked event
.def("linked_correlation_id", &KinetoEvent::linkedCorrelationId);
py::class_<ProfilerResult>(m, "ProfilerResult")
.def("events", &ProfilerResult::events)
.def("legacy_events", &ProfilerResult::legacy_events)
.def("save", &ProfilerResult::save);
m.def("_enable_profiler", enableProfiler);
m.def("_disable_profiler", disableProfiler);
m.def("_prepare_profiler", prepareProfiler);
#endif
m.def("kineto_available", kinetoAvailable);
m.def("_enable_profiler_legacy", enableProfilerLegacy);
py::class_<ProfilerDisableOptions>(m, "_ProfilerDisableOptions")
.def(py::init<bool, bool>());
m.def(
"_disable_profiler",
disableProfiler,
"_disable_profiler_legacy",
disableProfilerLegacy,
py::arg("profiler_disable_options") = ProfilerDisableOptions());
m.def("_profiler_enabled", profilerEnabled);
m.def("_enable_record_function", [](bool enable) {

View File

@ -1,461 +1,4 @@
#pragma once
#include <iostream>
#include <mutex>
#include <memory>
#include <vector>
#include <cstdint>
#include <string>
#include <sstream>
#include <forward_list>
#include <tuple>
#include <ATen/ATen.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#ifndef _WIN32
#include <ctime>
#endif
#if defined(C10_IOS) && defined(C10_MOBILE)
#include <sys/time.h> // for gettimeofday()
#endif
#include <ATen/record_function.h>
struct CUevent_st;
typedef std::shared_ptr<CUevent_st> CUDAEventStub;
namespace torch { namespace autograd {
struct Node;
namespace profiler {
struct TORCH_API CUDAStubs {
virtual void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) {
fail();
}
virtual float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) {
fail();
return 0.f;
}
virtual void nvtxMarkA(const char* name) {
fail();
}
virtual void nvtxRangePushA(const char* name) {
fail();
}
virtual void nvtxRangePop() {
fail();
}
virtual bool enabled() {
return false;
}
virtual void onEachDevice(std::function<void(int)> op) {
fail();
}
virtual void synchronize() {
fail();
}
virtual ~CUDAStubs();
private:
void fail() {
AT_ERROR("CUDA used in profiler but not enabled.");
}
};
TORCH_API void registerCUDAMethods(CUDAStubs* stubs);
constexpr inline size_t ceilToMultiple(size_t a, size_t b) {
return ((a + b - 1) / b) * b;
}
inline int64_t getTime() {
#if defined(C10_IOS) && defined(C10_MOBILE)
// clock_gettime is only available on iOS 10.0 or newer. Unlike OS X, iOS can't rely on
// CLOCK_REALTIME, as it is defined no matter if clock_gettime is implemented or not
struct timeval now;
gettimeofday(&now, NULL);
return static_cast<int64_t>(now.tv_sec) * 1000000000 + static_cast<int64_t>(now.tv_usec) * 1000;
#elif defined(_WIN32) || defined(__MACH__)
using namespace std::chrono;
using clock = std::conditional<high_resolution_clock::is_steady, high_resolution_clock, steady_clock>::type;
return duration_cast<nanoseconds>(clock::now().time_since_epoch()).count();
#else
// clock_gettime is *much* faster than std::chrono implementation on Linux
struct timespec t{};
clock_gettime(CLOCK_MONOTONIC, &t);
return static_cast<int64_t>(t.tv_sec) * 1000000000 + static_cast<int64_t>(t.tv_nsec);
#endif
}
// A struct to control settings of disableProfiler options.
struct TORCH_API ProfilerDisableOptions {
ProfilerDisableOptions() = default;
ProfilerDisableOptions(bool shouldCleanupTLSState, bool shouldConsolidate)
: cleanupTLSState(shouldCleanupTLSState),
consolidate(shouldConsolidate) {}
// Whether we should clean up profiler states that are thread local, such as
// ThreadLocalDebugInfo and thread local RecordFunction callbacks.
bool cleanupTLSState = true;
// Whether we should consolidate all currently recorded profiled events. If
// false, will not consolidate and other threads can continue to write to the
// event lists.
bool consolidate = true;
};
enum class C10_API_ENUM ProfilerState {
Disabled,
CPU, // CPU-only profiling
CUDA, // CPU + CUDA events
NVTX, // only emit NVTX markers
};
struct TORCH_API ProfilerConfig {
ProfilerConfig(
ProfilerState state,
bool report_input_shapes = false,
bool profile_memory = false,
bool with_stack = false)
: state(state),
report_input_shapes(report_input_shapes),
profile_memory(profile_memory),
with_stack(with_stack) {}
~ProfilerConfig();
ProfilerState state;
bool report_input_shapes;
bool profile_memory;
bool with_stack;
// Returns IValues corresponding to ProfilerConfig struct, to be used for
// serialization.
at::IValue toIValue() const;
// Reconstructs a ProfilerConfig from IValues given by toIValue.
static ProfilerConfig fromIValue(const at::IValue& profilerConfigIValue);
};
enum class C10_API_ENUM EventKind : uint16_t {
Mark,
PushRange,
PopRange,
MemoryAlloc,
};
struct TORCH_API Event final {
Event(
EventKind kind,
at::StringView name,
uint16_t thread_id,
bool record_cuda,
at::RecordFunctionHandle handle = 0,
std::vector<std::vector<int64_t>>&& shapes = {},
int node_id = -1)
: name_(std::move(name)),
kind_(kind),
thread_id_(thread_id),
handle_(handle),
shapes_(shapes),
node_id_(node_id) {
record(record_cuda);
}
// Constructor to be used in conjunction with Event::fromIValue.
Event(
EventKind kind,
at::StringView name,
uint16_t thread_id,
at::RecordFunctionHandle handle,
std::vector<std::vector<int64_t>>&& shapes,
int node_id,
bool is_remote,
int64_t cpu_memory_usage,
int64_t cpu_ns,
bool cuda_recorded,
int64_t cuda_memory_usage = 0,
int device = -1,
double cuda_us = -1)
: cpu_ns_(cpu_ns),
name_(std::move(name)),
kind_(kind),
thread_id_(thread_id),
handle_(handle),
shapes_(shapes),
cpu_memory_usage_(cpu_memory_usage),
cuda_memory_usage_(cuda_memory_usage),
device_(device),
node_id_(node_id),
is_remote_(is_remote),
cuda_us_(cuda_us) {
// Sanity check values that were deserialized
TORCH_INTERNAL_ASSERT(cpu_ns_ > 0);
if (cuda_recorded) {
TORCH_INTERNAL_ASSERT(device_ >= 0);
TORCH_INTERNAL_ASSERT(cuda_us_ >= 0);
}
}
// Returns IValues corresponding to event structure, to be used for
// serialization.
at::IValue toIValue() const;
// Reconstructs an event from IValues given by toIValue.
static Event fromIValue(const at::IValue& eventIValue);
void record(bool record_cuda);
std::string kind() const {
switch(kind_) {
case EventKind::Mark: return "mark";
case EventKind::PushRange: return "push";
case EventKind::PopRange: return "pop";
case EventKind::MemoryAlloc: return "memory_alloc";
}
throw std::runtime_error("unknown EventKind");
}
// Get enum kind of this event.
EventKind eventKind() const {
return kind_;
}
const char* name() const {
return name_.str();
}
uint64_t threadId() const {
return thread_id_;
}
std::vector<std::vector<int64_t>> shapes() const {
return shapes_;
}
double cpuElapsedUs(const Event& e) const {
return (e.cpu_ns_ - cpu_ns_)/(1000.0);
}
double cpuUs() const {
return cpu_ns_ / (1000.0);
}
double cudaElapsedUs(const Event& e) const;
bool hasCuda() const {
return cuda_event != nullptr || (isRemote() && device_ != -1);
}
int device() const {
return device_;
}
void updateMemoryStats(int64_t alloc_size, c10::Device device) {
if (device.type() == c10::DeviceType::CUDA ||
device.type() == c10::DeviceType::HIP) {
cuda_memory_usage_ = alloc_size;
} else if (device.type() == c10::DeviceType::CPU ||
device.type() == c10::DeviceType::MKLDNN ||
device.type() == c10::DeviceType::IDEEP) {
cpu_memory_usage_ = alloc_size;
} else {
LOG(WARNING) << "Unsupported memory profiling device: " << device;
}
}
int64_t cpuMemoryUsage() const {
return cpu_memory_usage_;
}
int64_t cudaMemoryUsage() const {
return cuda_memory_usage_;
}
at::RecordFunctionHandle handle() const {
return handle_;
}
// Node ID corresponding to this event.
int nodeId( ) const {
return node_id_;
}
// Set Node ID on this event.
void setNodeId(int node_id) {
node_id_ = node_id;
}
void setName(at::StringView newName_) {
name_ = std::move(newName_);
}
bool isRemote() const {
return is_remote_;
}
void setCudaUs(int64_t cuda_us) {
cuda_us_ = cuda_us;
}
void setSequenceNr(int64_t sequence_nr) {
sequence_nr_ = sequence_nr;
}
int64_t sequenceNr() const {
return sequence_nr_;
}
const std::vector<std::string>& stack() const {
return stack_;
}
void setStack(const std::vector<std::string>& stack) {
stack_ = stack;
}
uint64_t fwdThreadId() const {
return fwd_thread_id_;
}
void setFwdThreadId(uint64_t fwd_thread_id) {
fwd_thread_id_ = fwd_thread_id;
}
uint8_t scope() const {
return scope_;
}
void setScope(uint8_t scope) {
scope_ = scope;
}
private:
// signed to allow for negative intervals, initialized for safety.
int64_t cpu_ns_ = 0;
at::StringView name_;
EventKind kind_;
uint64_t thread_id_;
uint64_t fwd_thread_id_;
at::RecordFunctionHandle handle_ {0};
std::vector<std::vector<int64_t>> shapes_;
int64_t cpu_memory_usage_ = 0;
int64_t cuda_memory_usage_ = 0;
int device_ = -1;
CUDAEventStub cuda_event = nullptr;
int node_id_ = 0;
bool is_remote_ = false;
int64_t cuda_us_ = -1;
int64_t sequence_nr_ = -1;
std::vector<std::string> stack_;
uint8_t scope_;
};
// a linked-list of fixed sized vectors, to avoid
// a std::vector resize from taking a large amount of time inside
// a profiling event
struct RangeEventList {
RangeEventList() {
events_.reserve(kReservedCapacity);
}
template<typename... Args>
void record(Args&&... args) {
std::lock_guard<std::mutex> guard(mutex_);
events_.emplace_back(std::forward<Args>(args)...);
}
std::vector<Event> consolidate() {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<Event> result;
result.insert(
result.begin(),
std::make_move_iterator(events_.begin()),
std::make_move_iterator(events_.end()));
events_.erase(events_.begin(), events_.end());
return result;
}
size_t size() {
std::lock_guard<std::mutex> lock(mutex_);
return events_.size();
}
private:
// This mutex is used to serialize access when different threads are writing
// to the same instance of RangeEventList.
std::mutex mutex_;
std::vector<Event> events_;
static const size_t kReservedCapacity = 1024;
};
using thread_event_lists = std::vector<std::vector<Event>>;
// NOTE: profiler mode is thread local, with automatic propagation
// across thread boundary (e.g. at::launch tasks)
TORCH_API void enableProfiler(const ProfilerConfig&);
TORCH_API thread_event_lists disableProfiler(c10::optional<ProfilerDisableOptions> profilerDisableOptions = c10::nullopt);
// adds profiledEvents to the current thread local recorded events. Each event
// will be marked with node ID given by fromNodeId.
TORCH_API void addEventList(std::vector<Event>&& profiledEvents);
// Returns if the profiler is currently enabled in the current thread.
TORCH_API bool profilerEnabled();
// Retrieve the thread_local ProfilerConfig.
TORCH_API ProfilerConfig getProfilerConfig();
// Writes profiled events to a stream.
TORCH_API void writeProfilerEventsToStream(std::ostream& out, const std::vector<Event*>& events);
// Usage:
// {
// RecordProfile guard("filename.trace");
// // code you want to profile
// }
// Then open filename.trace in chrome://tracing
struct TORCH_API RecordProfile {
RecordProfile(std::ostream& out);
RecordProfile(const std::string& filename);
~RecordProfile();
private:
void init();
std::unique_ptr<std::ofstream> file_;
std::ostream& out_;
void processEvents(const std::vector<Event*>& events);
};
// A guard that enables the profiler, taking in an optional callback to process
// the results
// Usage:
// {
// TLSProfilerGuard g([](thread_event_lists profilerResults) {
// // process profilerResults
// });
// Code to profile
// }
struct TORCH_API TLSProfilerGuard {
explicit TLSProfilerGuard(
const ProfilerConfig& cfg,
c10::optional<std::function<void(const thread_event_lists&)>>
resultCallback = c10::nullopt,
c10::optional<ProfilerDisableOptions> profilerDisableOptions =
c10::nullopt)
: cb_(std::move(resultCallback)),
profilerDisableOptions_(std::move(profilerDisableOptions)) {
enableProfiler(cfg);
}
~TLSProfilerGuard() {
thread_event_lists event_lists = disableProfiler(profilerDisableOptions_);
if (cb_) {
try {
(*cb_)(event_lists);
} catch (const std::exception& e) {
LOG(ERROR) << "Got error processing profiler events: " << e.what();
}
}
}
private:
c10::optional<std::function<void(const thread_event_lists&)>> cb_;
const c10::optional<ProfilerDisableOptions> profilerDisableOptions_;
};
} // namespace profiler
}} // namespace torch::autograd
#include <torch/csrc/autograd/profiler_legacy.h>
#include <torch/csrc/autograd/profiler_kineto.h>

View File

@ -32,7 +32,7 @@ static inline void cudaCheck(cudaError_t result, const char * file, int line) {
#define TORCH_CUDA_CHECK(result) cudaCheck(result,__FILE__,__LINE__);
struct CUDAMethods : public CUDAStubs {
void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) override {
void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) const override {
TORCH_CUDA_CHECK(cudaGetDevice(device));
CUevent_st* cuda_event_ptr;
TORCH_CUDA_CHECK(cudaEventCreate(&cuda_event_ptr));
@ -43,23 +43,28 @@ struct CUDAMethods : public CUDAStubs {
*cpu_ns = getTime();
TORCH_CUDA_CHECK(cudaEventRecord(cuda_event_ptr, stream));
}
float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) override {
float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) const override{
TORCH_CUDA_CHECK(cudaEventSynchronize(event->get()));
TORCH_CUDA_CHECK(cudaEventSynchronize(event2->get()));
float ms;
TORCH_CUDA_CHECK(cudaEventElapsedTime(&ms, event->get(), event2->get()));
return ms*1000.0;
}
void nvtxMarkA(const char* name) override {
void nvtxMarkA(const char* name) const override {
::nvtxMark(name);
}
void nvtxRangePushA(const char* name) override {
void nvtxRangePushA(const char* name) const override {
::nvtxRangePushA(name);
}
void nvtxRangePop() override {
void nvtxRangePop() const override {
::nvtxRangePop();
}
void onEachDevice(std::function<void(int)> op) override {
void onEachDevice(std::function<void(int)> op) const override {
at::cuda::OptionalCUDAGuard device_guard;
int count = at::cuda::device_count();
for(int i = 0; i < count; i++) {
@ -67,13 +72,14 @@ struct CUDAMethods : public CUDAStubs {
op(i);
}
}
void synchronize() override {
void synchronize() const override {
cudaDeviceSynchronize();
}
bool enabled() override {
bool enabled() const override {
return true;
}
};
struct RegisterCUDAMethods {

View File

@ -0,0 +1,368 @@
#include <torch/csrc/autograd/profiler_kineto.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <sstream>
#ifdef USE_KINETO
#include <pthread.h>
#include <libkineto.h>
#endif
namespace torch { namespace autograd { namespace profiler {
#ifdef USE_KINETO
namespace {
// TODO: consider TLS (tid + tls counter)
uint64_t next_correlation_id() {
static std::atomic<uint64_t> corr_id_ {1};
return corr_id_++;
}
inline int64_t getTimeUs() {
using namespace std::chrono;
return duration_cast<microseconds>(high_resolution_clock::now().time_since_epoch()).count();
}
std::string shapesToStr(const std::vector<std::vector<int64_t>>& shapes);
struct TORCH_API KinetoThreadLocalState : public ProfilerThreadLocalState {
using ProfilerThreadLocalState::ProfilerThreadLocalState;
virtual ~KinetoThreadLocalState() override = default;
void reportClientActivity(
const at::RecordFunction& fn,
const KinetoObserverContext* ctx) {
if (!ctx) {
return;
}
libkineto::ClientTraceActivity op;
op.startTime = ctx->startUs;
op.endTime = getTimeUs();
op.opType = std::string(fn.name().str());
op.device = 0;
op.threadId = ctx->startThreadId;
op.correlation = ctx->correlationId;
// optimization - postpone shapesToStr till finalizeCPUTrace
// is called from disableProfiler
// if (ctx->shapes && !ctx->shapes->empty()) {
// op.inputDims = shapesToStr(*ctx->shapes);
// }
// Not setting atm
op.inputTypes = "[]";
op.arguments = "[]";
op.outputDims = "[]";
op.outputTypes = "[]";
op.inputNames = "[]";
op.outputNames = "[]";
//
op.threadId = pthread_self();
{
std::lock_guard<std::mutex> guard(state_mutex_);
kineto_events_.emplace_back();
kineto_events_.back()
.activity(op)
.startThreadId(ctx->startThreadId)
.endThreadId(ctx->endThreadId)
.sequenceNr(ctx->sequenceNr)
.fwdThreadId(ctx->fwdThreadId)
.scope(ctx->recFunScope);
if (ctx->shapes && !ctx->shapes->empty()) {
kineto_events_.back().shapes(*ctx->shapes);
}
if (ctx->stack && !ctx->stack->empty()) {
kineto_events_.back().stack(*ctx->stack);
}
cpu_trace->activities.emplace_back(std::move(op));
}
}
// TODO: use kineto
void reportMemoryUsage(
void* /* unused */,
int64_t alloc_size,
c10::Device device) override {
if (config_.profile_memory && config_.state != ProfilerState::Disabled) {
uint64_t thread_id = at::RecordFunction::currentThreadId();
LegacyEvent evt(
EventKind::MemoryAlloc,
at::StringView(""),
thread_id,
config_.state == ProfilerState::CUDA);
evt.setCpuUs(getTimeUs()); // upd. time using Kineto's clock
evt.updateMemoryStats(alloc_size, device);
getEventList(thread_id).record(std::move(evt));
}
}
void addTraceEvents(libkineto::ActivityTraceInterface& trace) {
const auto& events = *(trace.activities());
for (const auto& ev_ptr : events) {
// ClientTraceActivity events are already processed
if (ev_ptr->type() != libkineto::ActivityType::CPU_OP) {
kineto_events_.emplace_back();
kineto_events_.back()
.activity(*ev_ptr);
}
}
}
void finalizeCPUTrace() {
TORCH_INTERNAL_ASSERT(cpu_trace->activities.size() == kineto_events_.size());
for (auto idx = 0; idx < cpu_trace->activities.size(); ++idx) {
if (kineto_events_[idx].hasShapes()) {
cpu_trace->activities[idx].inputDims = shapesToStr(kineto_events_[idx].shapes());
} else {
cpu_trace->activities[idx].inputDims = "[]";
}
}
}
std::vector<KinetoEvent> kineto_events_;
std::unique_ptr<libkineto::CpuTraceBuffer> cpu_trace =
std::make_unique<libkineto::CpuTraceBuffer>();
};
KinetoThreadLocalState* getProfilerTLSState() {
const auto& state = c10::ThreadLocalDebugInfo::get(
c10::DebugInfoKind::PROFILER_STATE);
return static_cast<KinetoThreadLocalState*>(state);
}
void pushProfilingCallbacks() {
auto state_ptr = getProfilerTLSState();
TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback(
[](const at::RecordFunction& fn) {
auto state_ptr = getProfilerTLSState();
if (!state_ptr || state_ptr->config().state != ProfilerState::KINETO) {
return std::make_unique<KinetoObserverContext>();
}
auto corr_id = next_correlation_id();
libkineto::api().activityProfiler().pushCorrelationId(corr_id);
auto ctx_ptr = std::make_unique<KinetoObserverContext>();
ctx_ptr->startUs = getTimeUs();
ctx_ptr->correlationId = corr_id;
ctx_ptr->startThreadId = at::RecordFunction::currentThreadId();
if (state_ptr->config().report_input_shapes) {
ctx_ptr->shapes = inputSizes(fn);
}
ctx_ptr->sequenceNr = fn.seqNr();
ctx_ptr->fwdThreadId = fn.forwardThreadId();
ctx_ptr->recFunScope = (uint8_t)fn.scope();
#ifndef C10_MOBILE
// backward nodes source range corresponds to the forward node
// TODO: consider using C++ stack trace
if (state_ptr->config().with_stack &&
fn.scope() != at::RecordScope::BACKWARD_FUNCTION) {
auto cs = prepareCallstack(jit::currentCallstack());
if (cs.empty()) {
cs = prepareCallstack(jit::tracer::pythonCallstack());
}
ctx_ptr->stack = callstackStr(cs);
}
#endif
return ctx_ptr;
},
[](const at::RecordFunction& fn, at::ObserverContext* ctx_ptr) {
auto state_ptr = getProfilerTLSState();
if (!state_ptr || state_ptr->config().state != ProfilerState::KINETO) {
return;
}
auto* kineto_ctx_ptr = static_cast<KinetoObserverContext*>(ctx_ptr);
TORCH_INTERNAL_ASSERT(kineto_ctx_ptr != nullptr);
kineto_ctx_ptr->endThreadId = at::RecordFunction::currentThreadId();
state_ptr->reportClientActivity(fn, kineto_ctx_ptr);
libkineto::api().activityProfiler().popCorrelationId();
})
.needsInputs(state_ptr->config().report_input_shapes)
.needsIds(true));
state_ptr->setCallbackHandle(handle);
}
std::string shapesToStr(const std::vector<std::vector<int64_t>>& shapes) {
std::ostringstream oss;
oss << "[";
for (auto t_idx = 0; t_idx < shapes.size(); ++t_idx) {
if (t_idx > 0) {
oss << ", ";
}
oss << "[";
for (auto s_idx = 0; s_idx < shapes[t_idx].size(); ++s_idx) {
if (s_idx > 0) {
oss << ", ";
}
oss << shapes[t_idx][s_idx];
}
oss << "]";
}
oss << "]";
return oss.str();
}
} // namespace
void prepareProfiler(
const ProfilerConfig& config,
const std::set<ActivityType>& activities) {
TORCH_CHECK(config.state == ProfilerState::KINETO,
"Supported only in Kineto profiler");
std::set<libkineto::ActivityType> cpuTypes = {
libkineto::ActivityType::CPU_OP,
libkineto::ActivityType::EXTERNAL_CORRELATION,
libkineto::ActivityType::CUDA_RUNTIME,
};
std::set<libkineto::ActivityType> cudaTypes = {
libkineto::ActivityType::GPU_MEMCPY,
libkineto::ActivityType::GPU_MEMSET,
libkineto::ActivityType::CONCURRENT_KERNEL,
// also including CUDA_RUNTIME
libkineto::ActivityType::CUDA_RUNTIME,
};
std::set<libkineto::ActivityType> k_activities;
if (activities.count(ActivityType::CPU)) {
k_activities.insert(cpuTypes.begin(), cpuTypes.end());
}
if (activities.count(ActivityType::CUDA)) {
k_activities.insert(cudaTypes.begin(), cudaTypes.end());
}
if (!libkineto::api().isProfilerRegistered()) {
libkineto_init();
}
if (!libkineto::api().isProfilerInitialized()) {
libkineto::api().initProfilerIfRegistered();
}
libkineto::api().activityProfiler().prepareTrace(k_activities);
}
void enableProfiler(
const ProfilerConfig& config,
const std::set<ActivityType>& activities) {
TORCH_CHECK(config.state == ProfilerState::KINETO);
TORCH_CHECK(!activities.empty(), "No activities specified for Kineto profiler");
auto state_ptr = getProfilerTLSState();
TORCH_CHECK(!state_ptr, "Profiler is already enabled on this thread");
auto state = std::make_shared<KinetoThreadLocalState>(config);
c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state);
state->cpu_trace = std::make_unique<libkineto::CpuTraceBuffer>();
state->cpu_trace->span.startTime = getTimeUs();
// TODO: number of GPU ops
state->cpu_trace->gpuOpCount = -1;
state->cpu_trace->span.name = "PyTorch Profiler";
if (activities.count(ActivityType::CPU)) {
pushProfilingCallbacks();
}
libkineto::api().activityProfiler().startTrace();
state->mark("__start_profile", false);
}
std::unique_ptr<ProfilerResult> disableProfiler() {
// all the DebugInfoBase objects are scope based and supposed to use DebugInfoGuard
auto state = c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE);
auto state_ptr = static_cast<KinetoThreadLocalState*>(state.get());
TORCH_CHECK(state_ptr && state_ptr->config().state == ProfilerState::KINETO,
"Can't disable Kineto profiler when it's not running");
if (state_ptr->hasCallbackHandle()) {
at::removeCallback(state_ptr->callbackHandle());
}
state_ptr->mark("__stop_profile");
state_ptr->cpu_trace->span.endTime = getTimeUs();
state_ptr->finalizeCPUTrace();
libkineto::api().activityProfiler().transferCpuTrace(std::move(state_ptr->cpu_trace));
auto trace = std::move(libkineto::api().activityProfiler().stopTrace());
TORCH_CHECK(trace);
state_ptr->addTraceEvents(*trace);
return std::make_unique<ProfilerResult>(
std::move(state_ptr->kineto_events_),
std::move(state_ptr->consolidate()),
std::move(trace));
}
KinetoEvent& KinetoEvent::activity(const libkineto::TraceActivity& activity) {
name_ = activity.name();
device_index_ = activity.deviceId();
device_resource_id_ = activity.resourceId();
start_us_ = activity.timestamp();
duration_us_ = activity.duration();
correlation_id_ = activity.correlationId();
activity_type_ = (uint8_t)activity.type();
if (activity.linkedActivity()) {
linked_correlation_id_ = activity.linkedActivity()->correlationId();
}
return *this;
}
c10::DeviceType KinetoEvent::deviceType() const {
switch (activity_type_) {
case (uint8_t)libkineto::ActivityType::CPU_OP:
return c10::DeviceType::CPU;
case (uint8_t)libkineto::ActivityType::GPU_MEMCPY:
return c10::DeviceType::CUDA;
case (uint8_t)libkineto::ActivityType::GPU_MEMSET:
return c10::DeviceType::CUDA;
case (uint8_t)libkineto::ActivityType::CONCURRENT_KERNEL:
return c10::DeviceType::CUDA;
case (uint8_t)libkineto::ActivityType::EXTERNAL_CORRELATION:
return c10::DeviceType::CPU;
case (uint8_t)libkineto::ActivityType::CUDA_RUNTIME:
return c10::DeviceType::CPU;
}
TORCH_CHECK(false, "Unknown activity type");
}
KinetoEvent::KinetoEvent() : activity_type_((uint8_t)libkineto::ActivityType::CPU_OP) {}
ProfilerResult::ProfilerResult(
std::vector<KinetoEvent> events,
thread_event_lists legacy_events,
std::unique_ptr<libkineto::ActivityTraceInterface> trace)
: events_(std::move(events)),
legacy_events_(std::move(legacy_events)),
trace_(std::move(trace)) {}
ProfilerResult::~ProfilerResult() {}
void ProfilerResult::save(const std::string& path) {
// Kineto's save is destructive
TORCH_CHECK(!saved_, "Trace is already saved");
trace_->save(path);
saved_ = true;
}
#endif
bool kinetoAvailable() {
#ifdef USE_KINETO
return true;
#else
return false;
#endif
}
}}}

View File

@ -0,0 +1,213 @@
#pragma once
#include <torch/csrc/autograd/profiler_legacy.h>
#ifdef USE_KINETO
namespace libkineto {
class TraceActivity;
class ActivityTraceInterface;
}
#endif
namespace torch {
namespace autograd {
namespace profiler {
enum class C10_API_ENUM ActivityType {
CPU = 0,
CUDA, // CUDA kernels, runtime
NUM_KINETO_ACTIVITIES, // must be the last one
};
#ifdef USE_KINETO
struct KinetoObserverContext : public at::ObserverContext {
int64_t startUs;
uint64_t correlationId;
uint64_t startThreadId;
uint64_t endThreadId;
c10::optional<std::vector<std::vector<int64_t>>> shapes;
int64_t sequenceNr;
uint64_t fwdThreadId;
uint8_t recFunScope;
c10::optional<std::vector<std::string>> stack;
};
struct TORCH_API KinetoEvent {
KinetoEvent();
uint64_t startThreadId() const {
return start_thread_id_;
}
uint64_t endThreadId() const {
return end_thread_id_;
}
uint8_t activityType() const {
return activity_type_;
}
uint64_t fwdThreadId() const {
return fwd_thread_id_;
}
bool hasShapes() const {
return shapes_ != c10::nullopt;
}
const std::vector<std::vector<int64_t>>& shapes() const {
return *shapes_;
}
int64_t sequenceNr() const {
return sequence_nr_;
}
bool hasStack() const {
return stack_ != c10::nullopt;
}
const std::vector<std::string>& stack() const {
return *stack_;
}
uint8_t scope() const {
return scope_;
}
KinetoEvent& startThreadId(uint64_t start_thread_id) {
start_thread_id_ = start_thread_id;
return *this;
}
KinetoEvent& endThreadId(uint64_t end_thread_id) {
end_thread_id_ = end_thread_id;
return *this;
}
KinetoEvent& fwdThreadId(uint64_t fwd_thread_id) {
fwd_thread_id_ = fwd_thread_id;
return *this;
}
KinetoEvent& shapes(const std::vector<std::vector<int64_t>>& shapes) {
shapes_ = shapes;
return *this;
}
KinetoEvent& sequenceNr(int64_t sequence_nr) {
sequence_nr_ = sequence_nr;
return *this;
}
KinetoEvent& stack(const std::vector<std::string>& st) {
stack_ = st;
return *this;
}
KinetoEvent& scope(uint8_t scope) {
scope_ = scope;
return *this;
}
// Kineto fields
KinetoEvent& activity(const libkineto::TraceActivity& activity);
std::string name() const {
return name_;
}
uint64_t deviceIndex() const {
return device_index_;
}
uint64_t startUs() const {
return start_us_;
}
uint64_t durationUs() const {
return duration_us_;
}
uint64_t correlationId() const {
return correlation_id_;
}
KinetoEvent& correlationId(uint64_t correlation_id) {
correlation_id_ = correlation_id;
return *this;
}
uint64_t linkedCorrelationId() const {
return linked_correlation_id_;
}
int64_t deviceResourceId() const {
return device_resource_id_;
}
c10::DeviceType deviceType() const;
uint64_t start_thread_id_ = 0;
uint64_t end_thread_id_ = 0;
uint64_t fwd_thread_id_ = 0;
int64_t sequence_nr_ = -1;
uint8_t scope_ = 0;
uint8_t activity_type_;
c10::optional<std::vector<std::vector<int64_t>>> shapes_;
c10::optional<std::vector<std::string>> stack_;
std::string name_;
uint64_t device_index_ = 0;
uint64_t start_us_ = 0;
uint64_t duration_us_ = 0;
uint64_t correlation_id_ = 0;
uint64_t linked_correlation_id_ = 0;
int64_t device_resource_id_ = 0;
};
// Consolidating events returned directly from Kineto
// with events manually created by us (e.g. start/stop marks,
// memory allocation events)
struct TORCH_API ProfilerResult {
ProfilerResult(
std::vector<KinetoEvent> events,
thread_event_lists legacy_events,
std::unique_ptr<libkineto::ActivityTraceInterface> trace);
~ProfilerResult();
const std::vector<KinetoEvent>& events() const {
return events_;
}
const thread_event_lists& legacy_events() const {
return legacy_events_;
}
void save(const std::string& path);
private:
bool saved_ = false;
std::vector<KinetoEvent> events_;
thread_event_lists legacy_events_;
std::unique_ptr<libkineto::ActivityTraceInterface> trace_;
};
TORCH_API void enableProfiler(
const ProfilerConfig& config,
const std::set<ActivityType>& activities);
TORCH_API std::unique_ptr<ProfilerResult> disableProfiler();
TORCH_API void prepareProfiler(
const ProfilerConfig& config,
const std::set<ActivityType>& activities);
#endif // USE_KINETO
TORCH_API bool kinetoAvailable();
} // namespace profiler
}} // namespace torch::autograd

View File

@ -23,54 +23,33 @@
namespace torch { namespace autograd { namespace profiler {
namespace {
std::vector<FileLineFunc> prepareCallstack(const std::vector<jit::StackEntry>& cs) {
std::vector<FileLineFunc> entries;
entries.reserve(cs.size());
for (const auto& entry : cs) {
auto& range = entry.range;
if (range.source()) {
auto& src = range.source();
if (src && src->filename()) {
auto line = src->starting_line_no() +
src->lineno_for_offset(range.start());
entries.emplace_back(FileLineFunc{*(src->filename()), line, entry.filename});
}
}
}
return entries;
}
enum EventIValueIdx {
KIND = 0,
NAME,
THREAD_ID,
HANDLE,
NODE_ID,
CPU_MEM_USAGE,
CPU_NS,
CUDA_RECORDED,
CUDA_MEM_USAGE,
CUDA_DEVICE,
CUDA_US,
SHAPES,
NUM_EVENT_IVALUE_IDX // must be last in list
};
enum ProfilerIValueIdx {
STATE = 0,
REPORT_INPUT_SHAPES,
PROFILE_MEMORY,
NUM_PROFILER_CFG_IVALUE_IDX // must be last in list
};
const std::unordered_set<std::string> disable_cuda_profiling = {
"aten::view",
"aten::t",
"aten::transpose",
"aten::stride",
"aten::empty",
"aten::empty_like",
"aten::empty_strided",
"aten::as_strided",
"aten::expand",
"aten::resize_",
"aten::squeeze",
"aten::unsqueeze",
"aten::slice",
"aten::_unsafe_view",
"aten::size"
};
CUDAStubs default_stubs;
constexpr CUDAStubs* default_stubs_addr = &default_stubs;
// Constant initialization, so it is guaranteed to be initialized before
// static initialization calls which may invoke registerCUDAMethods
static CUDAStubs* cuda_stubs = default_stubs_addr;
std::vector<std::string> callstackStr(const std::vector<FileLineFunc>& cs) {
std::vector<std::string> cs_str;
cs_str.reserve(cs.size());
for (const auto& entry : cs) {
std::stringstream loc;
loc << entry.filename << "(" << entry.line << "): " << entry.funcname;
cs_str.push_back(loc.str());
}
return cs_str;
}
// We decompose the profiler logic into the following components:
//
@ -163,252 +142,267 @@ static CUDAStubs* cuda_stubs = default_stubs_addr;
// - save profiling events into the profiling state
//
struct FileLineFunc {
std::string filename;
size_t line;
std::string funcname;
};
namespace {
const CUDAStubs default_stubs;
constexpr const CUDAStubs* default_stubs_addr = &default_stubs;
// Constant initialization, so it is guaranteed to be initialized before
// static initialization calls which may invoke registerCUDAMethods
inline const CUDAStubs*& cuda_stubs() {
static const CUDAStubs* stubs_ = default_stubs_addr;
return stubs_;
}
}
// Profiler state
struct ProfilerThreadLocalState : public c10::MemoryReportingInfoBase {
explicit ProfilerThreadLocalState(const ProfilerConfig& config)
: config_(config), remoteProfiledEvents_{c10::nullopt} {}
~ProfilerThreadLocalState() override = default;
const ProfilerConfig& ProfilerThreadLocalState::config() const {
return config_;
}
inline const ProfilerConfig& config() const {
return config_;
thread_event_lists ProfilerThreadLocalState::consolidate() {
std::lock_guard<std::mutex> g(state_mutex_);
thread_event_lists result;
for (auto& kv : event_lists_map_) {
auto& list = kv.second;
result.emplace_back(list->consolidate());
}
thread_event_lists consolidate() {
std::lock_guard<std::mutex> g(state_mutex_);
thread_event_lists result;
for (auto& kv : event_lists_map_) {
auto& list = kv.second;
result.emplace_back(list->consolidate());
}
// Consolidate remote events if applicable as well.
if (remoteProfiledEvents_) {
result.insert(
result.end(),
std::make_move_iterator(remoteProfiledEvents_->begin()),
std::make_move_iterator(remoteProfiledEvents_->end()));
}
return result;
// Consolidate remote events if applicable as well.
if (remoteProfiledEvents_) {
result.insert(
result.end(),
std::make_move_iterator(remoteProfiledEvents_->begin()),
std::make_move_iterator(remoteProfiledEvents_->end()));
}
return result;
}
void mark(std::string name, bool include_cuda = true) {
if (config_.state == ProfilerState::Disabled) {
return;
}
if (config_.state == ProfilerState::NVTX) {
cuda_stubs->nvtxMarkA(name.c_str());
} else {
Event evt(
EventKind::Mark,
at::StringView(std::move(name)),
at::RecordFunction::currentThreadId(),
include_cuda && config_.state == ProfilerState::CUDA);
evt.setNodeId(at::RecordFunction::getDefaultNodeId());
getEventList().record(std::move(evt));
}
void ProfilerThreadLocalState::mark(std::string name, bool include_cuda) {
if (config_.state == ProfilerState::Disabled) {
return;
}
void setOrAddRemoteProfiledEvents(
std::vector<Event>&& remoteProfiledEvents) {
// Lock to serialize access from multiple callback threads.
std::lock_guard<std::mutex> guard(state_mutex_);
if (remoteProfiledEvents_) {
(*remoteProfiledEvents_).emplace_back(remoteProfiledEvents);
} else {
remoteProfiledEvents_ = {std::move(remoteProfiledEvents)};
}
if (config_.state == ProfilerState::NVTX) {
cuda_stubs()->nvtxMarkA(name.c_str());
} else {
LegacyEvent evt(
EventKind::Mark,
at::StringView(std::move(name)),
at::RecordFunction::currentThreadId(),
include_cuda && config_.state == ProfilerState::CUDA);
evt.setNodeId(at::RecordFunction::getDefaultNodeId());
getEventList().record(std::move(evt));
}
}
void pushRange(
const at::RecordFunction& fn,
const bool record_cuda,
const char* msg = "",
std::vector<std::vector<int64_t>>&& shapes = {}) {
if (config_.state == ProfilerState::Disabled) {
return;
}
if (config_.state == ProfilerState::NVTX) {
cuda_stubs->nvtxRangePushA(getNvtxStr(
fn.name(), msg, fn.seqNr(), shapes).c_str());
} else {
Event evt(
EventKind::PushRange,
fn.name(),
at::RecordFunction::currentThreadId(),
record_cuda,
fn.handle(),
std::move(shapes),
at::RecordFunction::getDefaultNodeId());
evt.setSequenceNr(fn.seqNr());
evt.setFwdThreadId(fn.forwardThreadId());
evt.setScope((uint8_t)fn.scope());
void ProfilerThreadLocalState::setOrAddRemoteProfiledEvents(
std::vector<LegacyEvent>&& remoteProfiledEvents) {
// Lock to serialize access from multiple callback threads.
std::lock_guard<std::mutex> guard(state_mutex_);
if (remoteProfiledEvents_) {
(*remoteProfiledEvents_).emplace_back(remoteProfiledEvents);
} else {
remoteProfiledEvents_ = {std::move(remoteProfiledEvents)};
}
}
void ProfilerThreadLocalState::pushRange(
const at::RecordFunction& fn,
const bool record_cuda,
const char* msg,
std::vector<std::vector<int64_t>>&& shapes) {
if (config_.state == ProfilerState::Disabled) {
return;
}
if (config_.state == ProfilerState::NVTX) {
cuda_stubs()->nvtxRangePushA(getNvtxStr(
fn.name(), msg, fn.seqNr(), shapes).c_str());
} else {
LegacyEvent evt(
EventKind::PushRange,
fn.name(),
at::RecordFunction::currentThreadId(),
record_cuda,
fn.handle(),
std::move(shapes),
at::RecordFunction::getDefaultNodeId());
evt.setSequenceNr(fn.seqNr());
evt.setFwdThreadId(fn.forwardThreadId());
evt.setScope((uint8_t)fn.scope());
#ifndef C10_MOBILE
// backward nodes source range corresponds to the forward node
// TODO: consider using C++ stack trace
if (config_.with_stack && fn.scope() != at::RecordScope::BACKWARD_FUNCTION) {
auto cs = prepareCallstack(jit::currentCallstack());
if (cs.empty()) {
cs = prepareCallstack(jit::tracer::pythonCallstack());
}
evt.setStack(callstackStr(cs));
// backward nodes source range corresponds to the forward node
// TODO: consider using C++ stack trace
if (config_.with_stack && fn.scope() != at::RecordScope::BACKWARD_FUNCTION) {
auto cs = prepareCallstack(jit::currentCallstack());
if (cs.empty()) {
cs = prepareCallstack(jit::tracer::pythonCallstack());
}
evt.setStack(callstackStr(cs));
}
#endif
getEventList().record(std::move(evt));
}
getEventList().record(std::move(evt));
}
}
void popRange(const at::RecordFunction& fn, const bool record_cuda) {
if (config_.state == ProfilerState::Disabled) {
return;
}
if (config_.state == ProfilerState::NVTX) {
cuda_stubs->nvtxRangePop();
} else {
// In some cases RecordFunction (and popRange) may be
// called on a different thread than pushRange
// As a convention, we put the async pop on the original
// thread and save current thread id in pop event
Event evt(
EventKind::PopRange,
at::StringView(""),
at::RecordFunction::currentThreadId(),
record_cuda,
fn.handle());
evt.setNodeId(at::RecordFunction::getDefaultNodeId());
getEventList(fn.threadId()).record(std::move(evt));
}
void ProfilerThreadLocalState::popRange(const at::RecordFunction& fn, const bool record_cuda) {
if (config_.state == ProfilerState::Disabled) {
return;
}
void setCallbackHandle(at::CallbackHandle handle) {
handle_ = handle;
if (config_.state == ProfilerState::NVTX) {
cuda_stubs()->nvtxRangePop();
} else {
// In some cases RecordFunction (and popRange) may be
// called on a different thread than pushRange
// As a convention, we put the async pop on the original
// thread and save current thread id in pop event
LegacyEvent evt(
EventKind::PopRange,
at::StringView(""),
at::RecordFunction::currentThreadId(),
record_cuda,
fn.handle());
evt.setNodeId(at::RecordFunction::getDefaultNodeId());
getEventList(fn.threadId()).record(std::move(evt));
}
}
at::CallbackHandle callbackHandle() const {
return handle_;
void ProfilerThreadLocalState::reportMemoryUsage(
void* /* unused */,
int64_t alloc_size,
c10::Device device) {
if (config_.profile_memory && config_.state != ProfilerState::Disabled) {
uint64_t thread_id = at::RecordFunction::currentThreadId();
LegacyEvent evt(
EventKind::MemoryAlloc,
at::StringView(""),
thread_id,
config_.state == ProfilerState::CUDA);
evt.updateMemoryStats(alloc_size, device);
getEventList(thread_id).record(std::move(evt));
}
}
void reportMemoryUsage(
void* /* unused */,
int64_t alloc_size,
c10::Device device) override {
if (config_.profile_memory && config_.state != ProfilerState::Disabled) {
uint64_t thread_id = at::RecordFunction::currentThreadId();
Event evt(
EventKind::MemoryAlloc,
at::StringView(""),
thread_id,
config_.state == ProfilerState::CUDA);
evt.updateMemoryStats(alloc_size, device);
getEventList(thread_id).record(std::move(evt));
}
}
bool ProfilerThreadLocalState::memoryProfilingEnabled() const {
return config_.profile_memory;
}
bool memoryProfilingEnabled() const override {
return config_.profile_memory;
}
private:
std::vector<FileLineFunc> prepareCallstack(const std::vector<jit::StackEntry>& cs) {
std::vector<FileLineFunc> entries;
entries.reserve(cs.size());
for (const auto& entry : cs) {
auto& range = entry.range;
if (range.source()) {
auto& src = range.source();
if (src && src->filename()) {
auto line = src->starting_line_no() +
src->lineno_for_offset(range.start());
entries.emplace_back(FileLineFunc{*(src->filename()), line, entry.filename});
}
}
}
return entries;
}
std::vector<std::string> callstackStr(const std::vector<FileLineFunc>& cs) {
std::vector<std::string> cs_str;
cs_str.reserve(cs.size());
for (const auto& entry : cs) {
std::stringstream loc;
loc << entry.filename << "(" << entry.line << "): " << entry.funcname;
cs_str.push_back(loc.str());
}
return cs_str;
}
std::string getNvtxStr(
const at::StringView& name,
const char* msg,
int64_t sequence_nr,
const std::vector<std::vector<int64_t>>& shapes) const {
if (sequence_nr >= 0 || shapes.size() > 0) {
std::stringstream s;
std::string ProfilerThreadLocalState::getNvtxStr(
const at::StringView& name,
const char* msg,
int64_t sequence_nr,
const std::vector<std::vector<int64_t>>& shapes) const {
if (sequence_nr >= 0 || shapes.size() > 0) {
std::stringstream s;
#ifdef __HIP_PLATFORM_HCC__
s << name.str();
s << name.str();
#endif
if (sequence_nr >= 0) {
if (sequence_nr >= 0) {
#ifdef __HIP_PLATFORM_HCC__
s << msg << sequence_nr;
s << msg << sequence_nr;
#else
s << name.str() << msg << sequence_nr;
s << name.str() << msg << sequence_nr;
#endif
}
if (shapes.size() > 0) {
s << ", sizes = [";
for (size_t idx = 0; idx < shapes.size(); ++idx) {
if (shapes[idx].size() > 0) {
s << "[";
for (size_t dim = 0; dim < shapes[idx].size(); ++dim) {
s << shapes[idx][dim];
if (dim < shapes[idx].size() - 1) {
s << ", ";
}
}
if (shapes.size() > 0) {
s << ", sizes = [";
for (size_t idx = 0; idx < shapes.size(); ++idx) {
if (shapes[idx].size() > 0) {
s << "[";
for (size_t dim = 0; dim < shapes[idx].size(); ++dim) {
s << shapes[idx][dim];
if (dim < shapes[idx].size() - 1) {
s << ", ";
}
s << "]";
} else {
s << "[]";
}
if (idx < shapes.size() - 1) {
s << ", ";
}
s << "]";
} else {
s << "[]";
}
if (idx < shapes.size() - 1) {
s << ", ";
}
s << "]";
}
return s.str();
s << "]";
}
return s.str();
} else {
return name.str();
}
}
RangeEventList& ProfilerThreadLocalState::getEventList(int64_t thread_id) {
if (thread_id < 0) {
thread_id = at::RecordFunction::currentThreadId();
}
RangeEventList* list_ptr = nullptr;
std::lock_guard<std::mutex> guard(state_mutex_);
auto it = event_lists_map_.find(thread_id);
if (it != event_lists_map_.end()) {
list_ptr = it->second.get();
} else {
auto event_list = std::make_shared<RangeEventList>();
event_lists_map_[thread_id] = event_list;
list_ptr = event_list.get();
}
return *list_ptr;
}
std::vector<std::vector<int64_t>> inputSizes(const at::RecordFunction& fn) {
std::vector<std::vector<int64_t>> sizes;
sizes.reserve(fn.inputs().size());
for (const c10::IValue& input : fn.inputs()) {
if (!input.isTensor()) {
sizes.emplace_back();
continue;
}
const at::Tensor& tensor = input.toTensor();
if (tensor.defined()) {
sizes.push_back(input.toTensor().sizes().vec());
} else {
return name.str();
sizes.emplace_back();
}
}
return sizes;
}
RangeEventList& getEventList(int64_t thread_id = -1) {
if (thread_id < 0) {
thread_id = at::RecordFunction::currentThreadId();
}
RangeEventList* list_ptr = nullptr;
std::lock_guard<std::mutex> guard(state_mutex_);
auto it = event_lists_map_.find(thread_id);
if (it != event_lists_map_.end()) {
list_ptr = it->second.get();
} else {
auto event_list = std::make_shared<RangeEventList>();
event_lists_map_[thread_id] = event_list;
list_ptr = event_list.get();
}
return *list_ptr;
}
namespace {
std::mutex state_mutex_;
std::unordered_map<uint64_t, std::shared_ptr<RangeEventList>>
event_lists_map_;
enum EventIValueIdx {
KIND = 0,
NAME,
THREAD_ID,
HANDLE,
NODE_ID,
CPU_MEM_USAGE,
CPU_NS,
CUDA_RECORDED,
CUDA_MEM_USAGE,
CUDA_DEVICE,
CUDA_US,
SHAPES,
NUM_EVENT_IVALUE_IDX // must be last in list
};
ProfilerConfig config_ = ProfilerConfig(ProfilerState::Disabled);
at::CallbackHandle handle_ = 0;
c10::optional<std::vector<std::vector<Event>>> remoteProfiledEvents_;
enum ProfilerIValueIdx {
STATE = 0,
REPORT_INPUT_SHAPES,
PROFILE_MEMORY,
NUM_PROFILER_CFG_IVALUE_IDX // must be last in list
};
const std::unordered_set<std::string> disable_cuda_profiling = {
"aten::view",
"aten::t",
"aten::transpose",
"aten::stride",
"aten::empty",
"aten::empty_like",
"aten::empty_strided",
"aten::as_strided",
"aten::expand",
"aten::resize_",
"aten::squeeze",
"aten::unsqueeze",
"aten::slice",
"aten::_unsafe_view",
"aten::size"
};
ProfilerThreadLocalState* getProfilerTLSState() {
@ -416,7 +410,7 @@ ProfilerThreadLocalState* getProfilerTLSState() {
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE));
}
void pushProfilingCallbacks() {
void pushProfilingCallbacksLegacy() {
auto state_ptr = getProfilerTLSState();
TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback(
@ -433,21 +427,8 @@ void pushProfilingCallbacks() {
auto* msg = (fn.seqNr() >= 0) ? ", seq = " : "";
if (state_ptr->config().report_input_shapes) {
std::vector<std::vector<int64_t>> inputSizes;
inputSizes.reserve(fn.inputs().size());
for (const c10::IValue& input : fn.inputs()) {
if (!input.isTensor()) {
inputSizes.emplace_back();
continue;
}
const at::Tensor& tensor = input.toTensor();
if (tensor.defined()) {
inputSizes.push_back(input.toTensor().sizes().vec());
} else {
inputSizes.emplace_back();
}
}
state_ptr->pushRange(fn, record_cuda, msg, std::move(inputSizes));
auto sizes = inputSizes(fn);
state_ptr->pushRange(fn, record_cuda, msg, std::move(sizes));
} else {
state_ptr->pushRange(fn, record_cuda, msg);
}
@ -474,11 +455,9 @@ const int kCUDAWarmupStart = 5;
} // namespace
void registerCUDAMethods(CUDAStubs* stubs) {
cuda_stubs = stubs;
cuda_stubs() = stubs;
}
ProfilerConfig::~ProfilerConfig() = default;
at::IValue ProfilerConfig::toIValue() const {
c10::impl::GenericList eventIValueList(at::AnyType::get());
eventIValueList.reserve(NUM_PROFILER_CFG_IVALUE_IDX);
@ -519,38 +498,40 @@ bool profilerEnabled() {
return state_ptr && state_ptr->config().state != ProfilerState::Disabled;
}
void enableProfiler(const ProfilerConfig& new_config) {
TORCH_CHECK(new_config.state != ProfilerState::NVTX || cuda_stubs->enabled(),
void enableProfilerLegacy(const ProfilerConfig& new_config) {
TORCH_CHECK(new_config.state != ProfilerState::NVTX || cuda_stubs()->enabled(),
"Can't use NVTX profiler - PyTorch was compiled without CUDA");
TORCH_CHECK(new_config.state != ProfilerState::KINETO);
auto state_ptr = getProfilerTLSState();
TORCH_CHECK(!state_ptr, "Profiler is already enabled on this thread");
auto state = std::make_shared<ProfilerThreadLocalState>(new_config);
c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state);
pushProfilingCallbacks();
pushProfilingCallbacksLegacy();
if (new_config.state == ProfilerState::CUDA) {
// event recording appears to have some startup overhead, so we need to
// to generate some dummy events first before recording synchronization events
for (int idx = 0; idx < kCUDAWarmupStart; ++idx) {
cuda_stubs->onEachDevice([state](int /* unused */) {
cuda_stubs()->onEachDevice([state](int /* unused */) {
state->mark("__cuda_startup");
cuda_stubs->synchronize();
cuda_stubs()->synchronize();
});
}
// cuda events must be on the same device, so we need a start event recorded
// for each gpu. we then use this event to synchronize time on the GPU
// with the CPU clock.
cuda_stubs->onEachDevice([state](int d) {
cuda_stubs()->onEachDevice([state](int d) {
state->mark("__cuda_start_event");
});
}
state->mark("__start_profile", false);
}
thread_event_lists disableProfiler(c10::optional<ProfilerDisableOptions> profilerDisableOptions) {
thread_event_lists disableProfilerLegacy(c10::optional<ProfilerDisableOptions> profilerDisableOptions) {
auto cleanupTLSState = profilerDisableOptions ? profilerDisableOptions->cleanupTLSState : true;
auto consolidate = profilerDisableOptions ? profilerDisableOptions->consolidate : true;
// all the DebugInfoBase objects are scope based and supposed to use DebugInfoGuard
@ -578,21 +559,21 @@ thread_event_lists disableProfiler(c10::optional<ProfilerDisableOptions> profile
return state_ptr->consolidate();
}
void addEventList(std::vector<Event>&& profiledEvents) {
void addEventList(std::vector<LegacyEvent>&& profiledEvents) {
auto state_ptr = getProfilerTLSState();
TORCH_CHECK(state_ptr, "Profiler must be enabled.");
state_ptr->setOrAddRemoteProfiledEvents(std::move(profiledEvents));
}
void Event::record(bool record_cuda) {
void LegacyEvent::record(bool record_cuda) {
if (record_cuda) {
cuda_stubs->record(&device_, &cuda_event, &cpu_ns_);
cuda_stubs()->record(&device_, &cuda_event, &cpu_ns_);
return;
}
cpu_ns_ = getTime();
}
/* static */ Event Event::fromIValue(const at::IValue& eventIValue) {
/* static */ LegacyEvent LegacyEvent::fromIValue(const at::IValue& eventIValue) {
TORCH_INTERNAL_ASSERT(
eventIValue.isList(),
"Expected IValue to contain type c10::impl::GenericList");
@ -601,7 +582,7 @@ void Event::record(bool record_cuda) {
ivalues.size() >= NUM_EVENT_IVALUE_IDX,
"Expected at least ",
NUM_EVENT_IVALUE_IDX,
" elements to reconstruct Event.");
" elements to reconstruct LegacyEvent.");
// Reconstruct input shapes from ivalues.
auto shapeListIValue = ivalues.get(EventIValueIdx::SHAPES);
@ -627,7 +608,7 @@ void Event::record(bool record_cuda) {
shapes.emplace_back(s);
}
Event evt(
LegacyEvent evt(
static_cast<EventKind>(
ivalues.get(EventIValueIdx::KIND).toInt()), // EventKind
at::StringView(ivalues.get(EventIValueIdx::NAME).toStringRef()), // name
@ -647,7 +628,7 @@ void Event::record(bool record_cuda) {
return evt;
}
at::IValue Event::toIValue() const {
at::IValue LegacyEvent::toIValue() const {
c10::impl::GenericList eventIValueList(at::AnyType::get());
eventIValueList.reserve(NUM_EVENT_IVALUE_IDX);
eventIValueList.emplace_back(static_cast<int64_t>(kind_));
@ -679,7 +660,7 @@ at::IValue Event::toIValue() const {
return at::IValue(eventIValueList);
}
double Event::cudaElapsedUs(const Event& e) const {
double LegacyEvent::cudaElapsedUs(const LegacyEvent& e) const {
TORCH_CHECK(e.hasCuda() && hasCuda(), "Events were not recorded for CUDA");
TORCH_CHECK(
e.device() == device(),
@ -690,13 +671,12 @@ double Event::cudaElapsedUs(const Event& e) const {
TORCH_INTERNAL_ASSERT(cuda_us_ >= 0 && e.cuda_us_ >= 0);
return static_cast<double>(e.cuda_us_ - cuda_us_);
}
return cuda_stubs->elapsed(&cuda_event, &e.cuda_event);
return cuda_stubs()->elapsed(&cuda_event, &e.cuda_event);
}
CUDAStubs::~CUDAStubs() = default;
static jit::CodeTemplate event_template(R"(
static const jit::CodeTemplate event_template(R"(
{
"name": "${name}",
"ph": "X",
@ -707,10 +687,10 @@ static jit::CodeTemplate event_template(R"(
"args": {}
})");
void writeProfilerEventsToStream(std::ostream& out, const std::vector<Event*>& events) {
void writeProfilerEventsToStream(std::ostream& out, const std::vector<LegacyEvent*>& events) {
TORCH_CHECK(out, "Could not open file");
Event* profiler_start = nullptr;
for (Event* e : events) {
LegacyEvent* profiler_start = nullptr;
for (LegacyEvent* e : events) {
if (0 == strcmp(e->name(), "__start_profile")) {
profiler_start = e;
break;
@ -724,20 +704,20 @@ void writeProfilerEventsToStream(std::ostream& out, const std::vector<Event*>& e
return std::hash<at::RecordFunctionHandle>()(p.first) ^ std::hash<int64_t>()(p.second);
}
};
std::unordered_map<std::pair<at::RecordFunctionHandle, int64_t>, Event*, PairHash> events_map;
std::unordered_map<std::pair<at::RecordFunctionHandle, int64_t>, LegacyEvent*, PairHash> events_map;
out << "[\n";
bool first = true;
for (Event* evt : events) {
if (evt->kind() == "push") {
for (LegacyEvent* evt : events) {
if (evt->kindStr() == "push") {
events_map[std::make_pair(evt->handle(), evt->nodeId())] = evt;
} else if (evt->kind() == "pop") {
} else if (evt->kindStr() == "pop") {
if (!first) {
out << ",\n";
}
first = false;
auto it = events_map.find(std::make_pair(evt->handle(), evt->nodeId()));
TORCH_CHECK(it != events_map.end(), "Unmatched pop event");
Event* evt_start = it->second;
LegacyEvent* evt_start = it->second;
events_map.erase(it);
jit::TemplateEnv env;
@ -751,7 +731,6 @@ void writeProfilerEventsToStream(std::ostream& out, const std::vector<Event*>& e
out << "]\n";
}
RecordProfile::RecordProfile(std::ostream& out)
: out_(out) {
init();
@ -763,24 +742,27 @@ RecordProfile::RecordProfile(const std::string& filename)
}
void RecordProfile::init() {
enableProfiler(ProfilerConfig(ProfilerState::CPU));
enableProfilerLegacy(ProfilerConfig(ProfilerState::CPU));
}
RecordProfile::~RecordProfile() {
thread_event_lists event_lists = disableProfiler();
std::vector<Event*> events;
for (auto& l : event_lists) {
for (auto& e : l) {
events.push_back(&e);
try {
thread_event_lists event_lists = disableProfilerLegacy();
std::vector<LegacyEvent*> events;
for (auto& l : event_lists) {
for (auto& e : l) {
events.push_back(&e);
}
}
}
processEvents(events);
if (file_){
file_->close();
processEvents(events);
} catch (const std::exception& e) {
LOG(ERROR) << e.what() << std::endl;
} catch (...) {
LOG(ERROR) << "Unknown error" << std::endl;
}
}
void RecordProfile::processEvents(const std::vector<Event*>& events) {
void RecordProfile::processEvents(const std::vector<LegacyEvent*>& events) {
writeProfilerEventsToStream(out_, events);
}

View File

@ -0,0 +1,544 @@
#pragma once
#include <iostream>
#include <mutex>
#include <memory>
#include <vector>
#include <cstdint>
#include <string>
#include <sstream>
#include <forward_list>
#include <tuple>
#include <ATen/ATen.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#ifndef _WIN32
#include <ctime>
#endif
#if defined(C10_IOS) && defined(C10_MOBILE)
#include <sys/time.h> // for gettimeofday()
#endif
#include <ATen/record_function.h>
#include <torch/csrc/jit/frontend/source_range.h>
struct CUevent_st;
typedef std::shared_ptr<CUevent_st> CUDAEventStub;
namespace torch { namespace autograd {
struct Node;
namespace profiler {
struct TORCH_API CUDAStubs {
virtual void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) const {
fail();
}
virtual float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) const {
fail();
return 0.f;
}
virtual void nvtxMarkA(const char* name) const {
fail();
}
virtual void nvtxRangePushA(const char* name) const {
fail();
}
virtual void nvtxRangePop() const {
fail();
}
virtual bool enabled() const {
return false;
}
virtual void onEachDevice(std::function<void(int)> op) const {
fail();
}
virtual void synchronize() const {
fail();
}
virtual ~CUDAStubs();
private:
void fail() const {
AT_ERROR("CUDA used in profiler but not enabled.");
}
};
TORCH_API void registerCUDAMethods(CUDAStubs* stubs);
constexpr inline size_t ceilToMultiple(size_t a, size_t b) {
return ((a + b - 1) / b) * b;
}
inline int64_t getTime() {
#if defined(C10_IOS) && defined(C10_MOBILE)
// clock_gettime is only available on iOS 10.0 or newer. Unlike OS X, iOS can't rely on
// CLOCK_REALTIME, as it is defined no matter if clock_gettime is implemented or not
struct timeval now;
gettimeofday(&now, NULL);
return static_cast<int64_t>(now.tv_sec) * 1000000000 + static_cast<int64_t>(now.tv_usec) * 1000;
#elif defined(_WIN32) || defined(__MACH__)
using namespace std::chrono;
using clock = std::conditional<high_resolution_clock::is_steady, high_resolution_clock, steady_clock>::type;
return duration_cast<nanoseconds>(clock::now().time_since_epoch()).count();
#else
// clock_gettime is *much* faster than std::chrono implementation on Linux
struct timespec t{};
clock_gettime(CLOCK_MONOTONIC, &t);
return static_cast<int64_t>(t.tv_sec) * 1000000000 + static_cast<int64_t>(t.tv_nsec);
#endif
}
enum class C10_API_ENUM EventKind : uint16_t {
Mark,
PushRange,
PopRange,
MemoryAlloc,
};
// To be deprecated, once we switch to Kineto profiling
struct TORCH_API LegacyEvent {
LegacyEvent(
EventKind kind,
at::StringView name,
uint16_t thread_id,
bool record_cuda,
at::RecordFunctionHandle handle = 0,
std::vector<std::vector<int64_t>>&& shapes = {},
int node_id = -1)
: name_(std::move(name)),
kind_(kind),
thread_id_(thread_id),
handle_(handle),
shapes_(shapes),
node_id_(node_id) {
record(record_cuda);
}
// Constructor to be used in conjunction with LegacyEvent::fromIValue.
LegacyEvent(
EventKind kind,
at::StringView name,
uint16_t thread_id,
at::RecordFunctionHandle handle,
std::vector<std::vector<int64_t>>&& shapes,
int node_id,
bool is_remote,
int64_t cpu_memory_usage,
int64_t cpu_ns,
bool cuda_recorded,
int64_t cuda_memory_usage = 0,
int device = -1,
double cuda_us = -1)
: cpu_ns_(cpu_ns),
name_(std::move(name)),
kind_(kind),
thread_id_(thread_id),
handle_(handle),
shapes_(shapes),
cpu_memory_usage_(cpu_memory_usage),
cuda_memory_usage_(cuda_memory_usage),
device_(device),
node_id_(node_id),
is_remote_(is_remote),
cuda_us_(cuda_us) {
// Sanity check values that were deserialized
TORCH_INTERNAL_ASSERT(cpu_ns_ > 0);
if (cuda_recorded) {
TORCH_INTERNAL_ASSERT(device_ >= 0);
TORCH_INTERNAL_ASSERT(cuda_us_ >= 0);
}
}
// Returns IValues corresponding to event structure, to be used for
// serialization.
at::IValue toIValue() const;
// Reconstructs an event from IValues given by toIValue.
static LegacyEvent fromIValue(const at::IValue& eventIValue);
void record(bool record_cuda);
std::string kindStr() const {
switch (kind_) {
case EventKind::Mark: return "mark";
case EventKind::PushRange: return "push";
case EventKind::PopRange: return "pop";
case EventKind::MemoryAlloc: return "memory_alloc";
}
throw std::runtime_error("unknown event kind");
}
const char* name() const {
return name_.str();
}
uint64_t threadId() const {
return thread_id_;
}
std::vector<std::vector<int64_t>> shapes() const {
return shapes_;
}
double cpuElapsedUs(const LegacyEvent& e) const {
return (e.cpu_ns_ - cpu_ns_)/(1000.0);
}
void setCpuUs(int64_t cpu_us) {
cpu_ns_ = cpu_us * 1000.0;
}
double cpuUs() const {
return cpu_ns_ / (1000.0);
}
double cudaElapsedUs(const LegacyEvent& e) const;
bool hasCuda() const {
return cuda_event != nullptr || (isRemote() && device_ != -1);
}
int device() const {
return device_;
}
void updateMemoryStats(int64_t alloc_size, c10::Device device) {
if (device.type() == c10::DeviceType::CUDA ||
device.type() == c10::DeviceType::HIP) {
cuda_memory_usage_ = alloc_size;
} else if (device.type() == c10::DeviceType::CPU ||
device.type() == c10::DeviceType::MKLDNN ||
device.type() == c10::DeviceType::IDEEP) {
cpu_memory_usage_ = alloc_size;
} else {
LOG(WARNING) << "Unsupported memory profiling device: " << device;
}
}
int64_t cpuMemoryUsage() const {
return cpu_memory_usage_;
}
int64_t cudaMemoryUsage() const {
return cuda_memory_usage_;
}
at::RecordFunctionHandle handle() const {
return handle_;
}
// Node ID corresponding to this event.
int nodeId( ) const {
return node_id_;
}
// Set Node ID on this event.
void setNodeId(int node_id) {
node_id_ = node_id;
}
void setName(at::StringView newName_) {
name_ = std::move(newName_);
}
bool isRemote() const {
return is_remote_;
}
void setCudaUs(int64_t cuda_us) {
cuda_us_ = cuda_us;
}
void setSequenceNr(int64_t sequence_nr) {
sequence_nr_ = sequence_nr;
}
int64_t sequenceNr() const {
return sequence_nr_;
}
void setCorrelationId(uint64_t correlation_id) {
correlation_id_ = correlation_id;
}
uint64_t correlationId() const {
return correlation_id_;
}
const std::vector<std::string>& stack() const {
return stack_;
}
void setStack(const std::vector<std::string>& stack) {
stack_ = stack;
}
uint64_t fwdThreadId() const {
return fwd_thread_id_;
}
void setFwdThreadId(uint64_t fwd_thread_id) {
fwd_thread_id_ = fwd_thread_id;
}
uint8_t scope() const {
return scope_;
}
void setScope(uint8_t scope) {
scope_ = scope;
}
private:
// signed to allow for negative intervals, initialized for safety.
int64_t cpu_ns_ = 0;
at::StringView name_;
EventKind kind_;
uint64_t thread_id_;
uint64_t fwd_thread_id_;
at::RecordFunctionHandle handle_ {0};
std::vector<std::vector<int64_t>> shapes_;
int64_t cpu_memory_usage_ = 0;
int64_t cuda_memory_usage_ = 0;
int device_ = -1;
CUDAEventStub cuda_event = nullptr;
int node_id_ = 0;
bool is_remote_ = false;
int64_t cuda_us_ = -1;
int64_t sequence_nr_ = -1;
std::vector<std::string> stack_;
uint8_t scope_;
uint64_t correlation_id_;
};
// a linked-list of fixed sized vectors, to avoid
// a std::vector resize from taking a large amount of time inside
// a profiling event
struct RangeEventList {
RangeEventList() {
events_.reserve(kReservedCapacity);
}
template<typename... Args>
void record(Args&&... args) {
std::lock_guard<std::mutex> guard(mutex_);
events_.emplace_back(std::forward<Args>(args)...);
}
std::vector<LegacyEvent> consolidate() {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<LegacyEvent> result;
result.insert(
result.begin(),
std::make_move_iterator(events_.begin()),
std::make_move_iterator(events_.end()));
events_.erase(events_.begin(), events_.end());
return result;
}
size_t size() {
std::lock_guard<std::mutex> lock(mutex_);
return events_.size();
}
private:
// This mutex is used to serialize access when different threads are writing
// to the same instance of RangeEventList.
std::mutex mutex_;
std::vector<LegacyEvent> events_;
static const size_t kReservedCapacity = 1024;
};
enum class C10_API_ENUM ProfilerState {
Disabled = 0,
CPU, // CPU-only profiling
CUDA, // CPU + CUDA events
NVTX, // only emit NVTX markers
KINETO, // use libkineto
NUM_PROFILER_STATES, // must be the last one
};
struct TORCH_API ProfilerConfig {
ProfilerConfig(
ProfilerState state,
bool report_input_shapes = false,
bool profile_memory = false,
bool with_stack = false)
: state(state),
report_input_shapes(report_input_shapes),
profile_memory(profile_memory),
with_stack(with_stack) {}
~ProfilerConfig() = default;
ProfilerState state;
bool report_input_shapes;
bool profile_memory;
bool with_stack;
// Returns IValues corresponding to ProfilerConfig struct, to be used for
// serialization.
at::IValue toIValue() const;
// Reconstructs a ProfilerConfig from IValues given by toIValue.
static ProfilerConfig fromIValue(const at::IValue& profilerConfigIValue);
};
// A struct to control settings of disableProfiler options.
struct TORCH_API ProfilerDisableOptions {
ProfilerDisableOptions() = default;
ProfilerDisableOptions(bool shouldCleanupTLSState, bool shouldConsolidate)
: cleanupTLSState(shouldCleanupTLSState),
consolidate(shouldConsolidate) {}
// Whether we should clean up profiler states that are thread local, such as
// ThreadLocalDebugInfo and thread local RecordFunction callbacks.
bool cleanupTLSState = true;
// Whether we should consolidate all currently recorded profiled events. If
// false, will not consolidate and other threads can continue to write to the
// event lists.
bool consolidate = true;
};
// NOTE: profiler mode is thread local, with automatic propagation
// across thread boundary (e.g. at::launch tasks)
TORCH_API void enableProfilerLegacy(const ProfilerConfig&);
using thread_event_lists = std::vector<std::vector<LegacyEvent>>;
TORCH_API thread_event_lists disableProfilerLegacy(c10::optional<ProfilerDisableOptions> profilerDisableOptions = c10::nullopt);
// adds profiledEvents to the current thread local recorded events. Each event
// will be marked with node ID given by fromNodeId.
TORCH_API void addEventList(std::vector<LegacyEvent>&& profiledEvents);
// Returns if the profiler is currently enabled in the current thread.
TORCH_API bool profilerEnabled();
// Retrieve the thread_local ProfilerConfig.
TORCH_API ProfilerConfig getProfilerConfig();
// Writes profiled events to a stream.
TORCH_API void writeProfilerEventsToStream(std::ostream& out, const std::vector<LegacyEvent*>& events);
// Usage:
// {
// RecordProfile guard("filename.trace");
// // code you want to profile
// }
// Then open filename.trace in chrome://tracing
struct TORCH_API RecordProfile {
RecordProfile(std::ostream& out);
RecordProfile(const std::string& filename);
~RecordProfile();
private:
void init();
std::unique_ptr<std::ofstream> file_;
std::ostream& out_;
void processEvents(const std::vector<LegacyEvent*>& events);
};
// A guard that enables the profiler, taking in an optional callback to process
// the results
// Usage:
// {
// TLSProfilerGuard g([](thread_event_lists profilerResults) {
// // process profilerResults
// });
// Code to profile
// }
struct TORCH_API TLSProfilerGuard {
explicit TLSProfilerGuard(
const ProfilerConfig& cfg,
c10::optional<std::function<void(const thread_event_lists&)>>
resultCallback = c10::nullopt,
c10::optional<ProfilerDisableOptions> profilerDisableOptions =
c10::nullopt)
: cb_(std::move(resultCallback)),
profilerDisableOptions_(std::move(profilerDisableOptions)) {
enableProfilerLegacy(cfg);
}
~TLSProfilerGuard() {
thread_event_lists event_lists = disableProfilerLegacy(profilerDisableOptions_);
if (cb_) {
try {
(*cb_)(event_lists);
} catch (const std::exception& e) {
LOG(ERROR) << "Got error processing profiler events: " << e.what();
}
}
}
private:
c10::optional<std::function<void(const thread_event_lists&)>> cb_;
const c10::optional<ProfilerDisableOptions> profilerDisableOptions_;
};
struct TORCH_API FileLineFunc {
std::string filename;
size_t line;
std::string funcname;
};
TORCH_API std::vector<FileLineFunc> prepareCallstack(const std::vector<jit::StackEntry>& cs);
TORCH_API std::vector<std::string> callstackStr(const std::vector<FileLineFunc>& cs);
TORCH_API std::vector<std::vector<int64_t>> inputSizes(const at::RecordFunction& fn);
struct TORCH_API ProfilerThreadLocalState : public c10::MemoryReportingInfoBase {
explicit ProfilerThreadLocalState(const ProfilerConfig& config)
: config_(config), remoteProfiledEvents_{c10::nullopt} {}
~ProfilerThreadLocalState() override = default;
const ProfilerConfig& config() const;
thread_event_lists consolidate();
void mark(std::string name, bool include_cuda = true);
void setOrAddRemoteProfiledEvents(
std::vector<LegacyEvent>&& remoteProfiledEvents);
void pushRange(
const at::RecordFunction& fn,
const bool record_cuda,
const char* msg = "",
std::vector<std::vector<int64_t>>&& shapes = {});
void popRange(const at::RecordFunction& fn, const bool record_cuda);
void setCallbackHandle(at::CallbackHandle handle) {
handle_ = handle;
}
at::CallbackHandle callbackHandle() const {
return handle_;
}
bool hasCallbackHandle() {
return handle_ > 0;
}
void reportMemoryUsage(
void* /* unused */,
int64_t alloc_size,
c10::Device device) override;
bool memoryProfilingEnabled() const override;
protected:
std::string getNvtxStr(
const at::StringView& name,
const char* msg,
int64_t sequence_nr,
const std::vector<std::vector<int64_t>>& shapes) const;
RangeEventList& getEventList(int64_t thread_id = -1);
std::mutex state_mutex_;
std::unordered_map<uint64_t, std::shared_ptr<RangeEventList>>
event_lists_map_;
ProfilerConfig config_ = ProfilerConfig(ProfilerState::Disabled);
at::CallbackHandle handle_ = 0;
c10::optional<std::vector<std::vector<LegacyEvent>>> remoteProfiledEvents_;
};
} // namespace profiler
}} // namespace torch::autograd

View File

@ -13,7 +13,7 @@ constexpr auto kProfileEventsStartIdx = 3;
RpcWithProfilingResp::RpcWithProfilingResp(
rpc::MessageType messageType,
rpc::Message&& wrappedMessage,
std::vector<torch::autograd::profiler::Event> profiledEvents,
std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents,
rpc::ProfilingId profilingId)
: messageType_(messageType),
wrappedMessage_(std::move(wrappedMessage)),
@ -32,7 +32,7 @@ RpcWithProfilingResp::RpcWithProfilingResp(
std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,
rpc::MessageType wrappedMessageType,
std::vector<torch::Tensor> tensors,
std::vector<torch::autograd::profiler::Event> profiledEvents,
std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents,
rpc::ProfilingId profilingId)
: messageType_(messageType),
wrappedRpc_(std::move(wrappedRpc)),
@ -52,7 +52,7 @@ rpc::MessageType RpcWithProfilingResp::wrappedMessageType() const {
return wrappedMessageType_;
}
std::vector<torch::autograd::profiler::Event> RpcWithProfilingResp::
std::vector<torch::autograd::profiler::LegacyEvent> RpcWithProfilingResp::
getProfiledEvents() const {
return profiledEvents_;
}
@ -119,15 +119,15 @@ std::unique_ptr<RpcWithProfilingResp> RpcWithProfilingResp::fromMessage(
static_cast<rpc::MessageType>(tupleElements[0].toInt());
rpc::ProfilingId profilingId = rpc::ProfilingId::fromIValue(tupleElements[1]);
int profiledEventsSize = tupleElements[2].toInt();
std::vector<torch::autograd::profiler::Event> remoteEvents;
std::vector<torch::autograd::profiler::LegacyEvent> remoteEvents;
remoteEvents.reserve(profiledEventsSize);
for (int i = kProfileEventsStartIdx;
i < kProfileEventsStartIdx + profiledEventsSize;
++i) {
TORCH_CHECK(i < tupleElements.size());
// Reconstruct remote event from the ivalues.
torch::autograd::profiler::Event fromIvalueEvent =
torch::autograd::profiler::Event::fromIValue(tupleElements[i]);
torch::autograd::profiler::LegacyEvent fromIvalueEvent =
torch::autograd::profiler::LegacyEvent::fromIValue(tupleElements[i]);
remoteEvents.push_back(std::move(fromIvalueEvent));
}

View File

@ -15,7 +15,7 @@ class TORCH_API RpcWithProfilingResp : public rpc::RpcCommandBase {
RpcWithProfilingResp(
rpc::MessageType messageType,
rpc::Message&& wrappedMessage,
std::vector<torch::autograd::profiler::Event> profiledEvents,
std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents,
rpc::ProfilingId profilingId);
// For receving RPCs. Used in from message when converting a message received
@ -25,13 +25,13 @@ class TORCH_API RpcWithProfilingResp : public rpc::RpcCommandBase {
std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,
rpc::MessageType wrappedMessageType,
std::vector<torch::Tensor> tensors,
std::vector<torch::autograd::profiler::Event> profiledEvents,
std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents,
rpc::ProfilingId profilingId);
rpc::Message toMessageImpl() && override;
static std::unique_ptr<RpcWithProfilingResp> fromMessage(
const rpc::Message& message);
// Retrieve remote Events
std::vector<torch::autograd::profiler::Event> getProfiledEvents() const;
std::vector<torch::autograd::profiler::LegacyEvent> getProfiledEvents() const;
// Retrieve the globally unique profiling ID corresponding to this command.
const rpc::ProfilingId& getProfilingId() const;
// Retrieve the original RPC which this ProfilingRPC wraps.
@ -51,7 +51,7 @@ class TORCH_API RpcWithProfilingResp : public rpc::RpcCommandBase {
std::unique_ptr<RpcCommandBase> wrappedRpc_;
rpc::MessageType wrappedMessageType_;
std::vector<torch::Tensor> tensors_;
const std::vector<torch::autograd::profiler::Event> profiledEvents_;
const std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents_;
const rpc::ProfilingId profilingId_;
};
} // namespace autograd

View File

@ -81,7 +81,7 @@ std::shared_ptr<FutureMessage> RequestCallbackNoPython::processMessage(
if (serverProcessGlobalProfilerStateStackEntryPtr) {
// Initialize thread-local profiler state from process-global
// profiler state.
::torch::autograd::profiler::enableProfiler(
::torch::autograd::profiler::enableProfilerLegacy(
serverProcessGlobalProfilerStateStackEntryPtr->statePtr()
->config());
}
@ -93,7 +93,7 @@ std::shared_ptr<FutureMessage> RequestCallbackNoPython::processMessage(
if (serverProcessGlobalProfilerStateStackEntryPtr) {
// Restore thread-local profiler state.
::torch::autograd::profiler::thread_event_lists event_lists =
::torch::autograd::profiler::disableProfiler();
::torch::autograd::profiler::disableProfilerLegacy();
// Put thread_local event_lists into the process-global profiler
// state.
profiler::processglobal::pushResultRecursive(
@ -509,7 +509,7 @@ void RequestCallbackNoPython::processRunWithProfilingReq(
responseFuture,
profilingKeyId,
profilingConfig] {
std::vector<torch::autograd::profiler::Event> profiledEvents;
std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents;
// Defer consolidation of profiler events until async work has
// completed (such as async UDF)
@ -521,7 +521,7 @@ void RequestCallbackNoPython::processRunWithProfilingReq(
// they will be cleaned up by main thread, and consolidate all
// events so we obtain asynchronously run events.
torch::autograd::profiler::ProfilerDisableOptions opts(false, true);
auto event_lists = torch::autograd::profiler::disableProfiler(opts);
auto event_lists = torch::autograd::profiler::disableProfilerLegacy(opts);
if (wrappedRpcResponseFuture->hasError()) {
// Propagate error
// No need to propagate remote events in the case of an error.

View File

@ -36,7 +36,7 @@ void processRemoteProfiledEvents(
"Profiler was expected to be enabled. This can happen in callback "
" continutations that run in different threads, and the TLS of the "
" profiler was not propagated.");
std::vector<torch::autograd::profiler::Event> events =
std::vector<torch::autograd::profiler::LegacyEvent> events =
rpcWithProfilingResp.getProfiledEvents();
const auto& profilingId = rpcWithProfilingResp.getProfilingId();
auto& remoteProfilerManager = RemoteProfilerManager::getInstance();
@ -46,7 +46,7 @@ void processRemoteProfiledEvents(
std::for_each(
events.begin(),
events.end(),
[&keyPrefixStr](torch::autograd::profiler::Event& event) {
[&keyPrefixStr](torch::autograd::profiler::LegacyEvent& event) {
std::string name = keyPrefixStr + std::string(event.name());
event.setName(at::StringView(name));
});
@ -511,9 +511,9 @@ std::vector<at::IValue> readWrappedPayload(
}
void populateRemoteProfiledEvents(
std::vector<torch::autograd::profiler::Event>& profiledEvents,
std::vector<torch::autograd::profiler::LegacyEvent>& profiledEvents,
const torch::autograd::profiler::ProfilerConfig& profilingConfig,
const std::vector<std::vector<torch::autograd::profiler::Event>>&
const std::vector<std::vector<torch::autograd::profiler::LegacyEvent>>&
eventLists) {
// Gather all events into a vector
for (auto& l : eventLists) {
@ -525,11 +525,11 @@ void populateRemoteProfiledEvents(
bool cudaProfilingEnabled =
profilingConfig.state == torch::autograd::profiler::ProfilerState::CUDA;
bool foundCpuStart = false;
const torch::autograd::profiler::Event* profilerStart = nullptr;
const torch::autograd::profiler::LegacyEvent* profilerStart = nullptr;
// Each device has its own cudaProfilerStart, so we must take
// care to use the correct one depending on the device the
// operation ran on.
std::unordered_map<int, const torch::autograd::profiler::Event*>
std::unordered_map<int, const torch::autograd::profiler::LegacyEvent*>
cudaProfilerStarts;
for (auto& e : profiledEvents) {
if (!foundCpuStart && 0 == strcmp(e.name(), "__start_profile")) {

View File

@ -82,9 +82,9 @@ TORCH_API std::vector<at::IValue> readWrappedPayload(
// Takes a list of events from autograd profiler and populates them into
// profiledEvents to be carried over RPC.
TORCH_API void populateRemoteProfiledEvents(
std::vector<torch::autograd::profiler::Event>& profiledEvents,
std::vector<torch::autograd::profiler::LegacyEvent>& profiledEvents,
const torch::autograd::profiler::ProfilerConfig& profilerConfig,
const std::vector<std::vector<torch::autograd::profiler::Event>>&
const std::vector<std::vector<torch::autograd::profiler::LegacyEvent>>&
eventLists);
} // namespace rpc

View File

@ -103,7 +103,7 @@ class _server_process_global_profile(profile):
if not self.enabled:
return
if self.entered:
if self.entered: # type: ignore[has-type]
raise RuntimeError("autograd profiler traces are not reentrant")
self.entered = True
@ -145,13 +145,13 @@ class _server_process_global_profile(profile):
process_global_function_events = []
for thread_local_events in process_global_events:
# Parse from ``Event``s to ``FunctionEvent``s.
thread_local_function_events = torch.autograd.profiler.parse_event_records(
thread_local_function_events = torch.autograd.profiler.parse_legacy_records(
thread_local_events
)
thread_local_function_events.sort(
key=lambda function_event: [
function_event.cpu_interval.start,
-(function_event.cpu_interval.end),
function_event.time_range.start,
-(function_event.time_range.end),
]
)
process_global_function_events.append(thread_local_function_events)
@ -164,6 +164,7 @@ class _server_process_global_profile(profile):
use_cuda=self.use_cuda,
profile_memory=self.profile_memory,
)
self.function_events._build_tree()
self.process_global_function_events = process_global_function_events

View File

@ -177,9 +177,9 @@ class AllreduceNCCLTest : public NCCLTest {
// Make sure enabling profile does not make any issue. Note, in single
// process multi-device mode we do not expect any events be populated for
// collective operations, since profiling for that mode is not supported.
enableProfiler({ProfilerState::CPU});
enableProfilerLegacy({ProfilerState::CPU});
auto results = pg_->allreduce(tensors_);
disableProfiler();
disableProfilerLegacy();
return results;
}
};

View File

@ -1567,8 +1567,8 @@ class RpcTest(RpcAgentTestFixture):
scope_event = get_function_event(events, "foo")
# Since RPC call is within the scope, its CPU interval should be
# contained within foo's interval.
self.assertTrue(scope_event.cpu_interval.start < rpc_event.cpu_interval.start)
self.assertTrue(scope_event.cpu_interval.end > rpc_event.cpu_interval.end)
self.assertTrue(scope_event.time_range.start < rpc_event.time_range.start)
self.assertTrue(scope_event.time_range.end > rpc_event.time_range.end)
# the sender, dest worker, function run, and type of RPC should all
# be recorded.
self_worker_name = worker_name(self.rank)
@ -1760,10 +1760,10 @@ class RpcTest(RpcAgentTestFixture):
last_end_time = 0
for event in thread_local_events:
event_name = event.name
cpu_interval = event.cpu_interval
if cpu_interval.start > last_end_time:
time_range = event.time_range
if time_range.start > last_end_time:
top_level_event_names.append(event_name)
last_end_time = cpu_interval.end
last_end_time = time_range.end
self.assertEqual(sorted(top_level_event_names), sorted(expected_top_level_event_names))
@dist_init