mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use libkineto in profiler (#46470)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46470 Adding ability to use Kineto (CUPTI) to profile CUDA kernels Test Plan: USE_KINETO=1 USE_CUDA=1 USE_MKLDNN=1 BLAS=MKL BUILD_BINARY=1 python setup.py develop install python test/test_profiler.py python test/test_autograd.py -k test_profile python test/test_autograd.py -k test_record ``` ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Memcpy HtoD (Pageable -> Device) 0.00% 0.000us 0.00% 0.000us 0.000us 2.000us 33.33% 2.000us 1.000us 2 sgemm_32x32x32_NN 0.00% 0.000us 0.00% 0.000us 0.000us 2.000us 33.33% 2.000us 2.000us 1 void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 16.67% 1.000us 1.000us 1 Memcpy DtoH (Device -> Pageable) 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 16.67% 1.000us 1.000us 1 aten::randn 5.17% 74.000us 6.71% 96.000us 48.000us 0.000us 0.00% 0.000us 0.000us 2 aten::empty 1.33% 19.000us 1.33% 19.000us 4.750us 0.000us 0.00% 0.000us 0.000us 4 aten::normal_ 1.05% 15.000us 1.05% 15.000us 7.500us 0.000us 0.00% 0.000us 0.000us 2 aten::to 77.90% 1.114ms 91.61% 1.310ms 436.667us 0.000us 0.00% 3.000us 1.000us 3 aten::empty_strided 2.52% 36.000us 2.52% 36.000us 12.000us 0.000us 0.00% 0.000us 0.000us 3 aten::copy_ 2.73% 39.000us 11.19% 160.000us 53.333us 0.000us 0.00% 3.000us 1.000us 3 cudaMemcpyAsync 4.34% 62.000us 4.34% 62.000us 20.667us 0.000us 0.00% 0.000us 0.000us 3 cudaStreamSynchronize 1.61% 23.000us 1.61% 23.000us 7.667us 0.000us 0.00% 0.000us 0.000us 3 aten::mm 0.21% 3.000us 7.20% 103.000us 103.000us 0.000us 0.00% 2.000us 2.000us 1 aten::stride 0.21% 3.000us 0.21% 3.000us 1.000us 0.000us 0.00% 0.000us 0.000us 3 cudaLaunchKernel 2.45% 35.000us 2.45% 35.000us 17.500us 0.000us 0.00% 0.000us 0.000us 2 aten::add 0.49% 7.000us 4.27% 61.000us 61.000us 0.000us 0.00% 1.000us 1.000us 1 ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ``` benchmark: https://gist.github.com/ilia-cher/a5a9eb6b68504542a3cad5150fc39b1a Reviewed By: Chillee Differential Revision: D25142223 Pulled By: ilia-cher fbshipit-source-id: b0dff46c28da5fb0a8e01cf548aa4f2b723fde80
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e9efd8df1b
commit
f7a8bf2855
@ -10,13 +10,13 @@ namespace {
|
||||
|
||||
// Used to generate unique callback handles
|
||||
CallbackHandle next_unique_callback_handle() {
|
||||
static std::atomic<uint64_t> unique_cb_id {0};
|
||||
return CallbackHandle(++unique_cb_id);
|
||||
static std::atomic<uint64_t> unique_cb_id {1};
|
||||
return CallbackHandle(unique_cb_id++);
|
||||
}
|
||||
|
||||
RecordFunctionHandle next_unique_record_function_handle() {
|
||||
static std::atomic<uint64_t> unique_rf_id {0};
|
||||
return RecordFunctionHandle(++unique_rf_id);
|
||||
static std::atomic<uint64_t> unique_rf_id {1};
|
||||
return RecordFunctionHandle(unique_rf_id++);
|
||||
}
|
||||
|
||||
thread_local RecordFunctionTLS rf_tls_;
|
||||
|
@ -1,10 +1,9 @@
|
||||
import argparse
|
||||
import statistics
|
||||
import sys
|
||||
import timeit
|
||||
import torch
|
||||
|
||||
from torch.utils._benchmark import Timer
|
||||
from torch.utils.benchmark import Timer
|
||||
|
||||
PARALLEL_TASKS_NUM = 4
|
||||
INTERNAL_ITER = None
|
||||
@ -34,12 +33,12 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--with_cuda', action='store_true')
|
||||
parser.add_argument('--with_stack', action='store_true')
|
||||
parser.add_argument('--use_script', action='store_true')
|
||||
parser.add_argument('--use_kineto', action='store_true')
|
||||
parser.add_argument('--profiling_tensor_size', default=1, type=int)
|
||||
parser.add_argument('--workload', default='loop', type=str)
|
||||
parser.add_argument('--internal_iter', default=256, type=int)
|
||||
parser.add_argument('--n', default=100, type=int)
|
||||
parser.add_argument('--use_timer', action='store_true')
|
||||
parser.add_argument('--timer_min_run_time', default=100, type=int)
|
||||
parser.add_argument('--timer_min_run_time', default=10, type=int)
|
||||
parser.add_argument('--cuda_only', action='store_true')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -47,16 +46,17 @@ if __name__ == '__main__':
|
||||
print("No CUDA available")
|
||||
sys.exit()
|
||||
|
||||
print("Payload: {}; {} iterations, N = {}\n".format(
|
||||
args.workload, args.internal_iter, args.n))
|
||||
print("Payload: {}, {} iterations; timer min. runtime = {}\n".format(
|
||||
args.workload, args.internal_iter, args.timer_min_run_time))
|
||||
INTERNAL_ITER = args.internal_iter
|
||||
|
||||
for profiling_enabled in [False, True]:
|
||||
print("Profiling {}, tensor size {}x{}, use cuda: {}, with stacks: {}, use script: {}".format(
|
||||
print("Profiling {}, tensor size {}x{}, use cuda: {}, use kineto: {}, with stacks: {}, use script: {}".format(
|
||||
"enabled" if profiling_enabled else "disabled",
|
||||
args.profiling_tensor_size,
|
||||
args.profiling_tensor_size,
|
||||
args.with_cuda,
|
||||
args.use_kineto,
|
||||
args.with_stack,
|
||||
args.use_script))
|
||||
|
||||
@ -83,27 +83,18 @@ if __name__ == '__main__':
|
||||
x = None
|
||||
with torch.autograd.profiler.profile(
|
||||
use_cuda=args.with_cuda,
|
||||
with_stack=args.with_stack) as prof:
|
||||
with_stack=args.with_stack,
|
||||
use_kineto=args.use_kineto,
|
||||
use_cpu=not args.cuda_only) as prof:
|
||||
x = workload(input_x)
|
||||
return x
|
||||
else:
|
||||
def payload():
|
||||
return workload(input_x)
|
||||
|
||||
if args.use_timer:
|
||||
t = Timer(
|
||||
"payload()",
|
||||
globals={"payload": payload},
|
||||
timer=timeit.default_timer,
|
||||
).blocked_autorange(min_run_time=args.timer_min_run_time)
|
||||
print(t)
|
||||
else:
|
||||
runtimes = timeit.repeat(payload, repeat=args.n, number=1)
|
||||
avg_time = statistics.mean(runtimes) * 1000.0
|
||||
stddev_time = statistics.stdev(runtimes) * 1000.0
|
||||
print("\tavg. time: {:.3f} ms, stddev: {:.3f} ms".format(
|
||||
avg_time, stddev_time))
|
||||
if args.workload == "loop":
|
||||
print("\ttime per iteration: {:.3f} ms".format(
|
||||
avg_time / args.internal_iter))
|
||||
print()
|
||||
t = Timer(
|
||||
"payload()",
|
||||
globals={"payload": payload},
|
||||
timer=timeit.default_timer,
|
||||
).blocked_autorange(min_run_time=args.timer_min_run_time)
|
||||
print(t)
|
||||
|
@ -1751,7 +1751,8 @@ endif()
|
||||
#
|
||||
# End ATen checks
|
||||
#
|
||||
|
||||
set(TEMP_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})
|
||||
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libs" FORCE)
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/fmt)
|
||||
|
||||
# Disable compiler feature checks for `fmt`.
|
||||
@ -1764,6 +1765,7 @@ add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/fmt)
|
||||
set_target_properties(fmt-header-only PROPERTIES INTERFACE_COMPILE_FEATURES "")
|
||||
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS fmt::fmt-header-only)
|
||||
set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS} CACHE BOOL "Build shared libs" FORCE)
|
||||
|
||||
# ---[ Kineto
|
||||
if(USE_KINETO)
|
||||
@ -1774,8 +1776,34 @@ if(USE_KINETO)
|
||||
set(KINETO_LIBRARY_TYPE "static" CACHE STRING "")
|
||||
set(CUDA_SOURCE_DIR "${CUDA_TOOLKIT_ROOT_DIR}" CACHE STRING "")
|
||||
|
||||
message(STATUS "Configuring Kineto dependency:")
|
||||
message(STATUS " KINETO_SOURCE_DIR = ${KINETO_SOURCE_DIR}")
|
||||
message(STATUS " KINETO_BUILD_TESTS = ${KINETO_BUILD_TESTS}")
|
||||
message(STATUS " KINETO_LIBRARY_TYPE = ${KINETO_LIBRARY_TYPE}")
|
||||
message(STATUS " CUDA_SOURCE_DIR = ${CUDA_SOURCE_DIR}")
|
||||
|
||||
if(EXISTS ${CUDA_SOURCE_DIR}/extras/CUPTI/include)
|
||||
set(CUPTI_INCLUDE_DIR "${CUDA_SOURCE_DIR}/extras/CUPTI/include")
|
||||
elseif(EXISTS ${CUDA_SOURCE_DIR}/include/cupti.h)
|
||||
set(CUPTI_INCLUDE_DIR "${CUDA_SOURCE_DIR}/include")
|
||||
endif()
|
||||
|
||||
if((NOT DEFINED CUDA_cupti_LIBRARY) OR (${CUDA_cupti_LIBRARY} STREQUAL "CUDA_cupti_LIBRARY-NOTFOUND"))
|
||||
if(EXISTS ${CUDA_SOURCE_DIR}/extras/CUPTI/lib64/libcupti_static.a)
|
||||
set(CUDA_cupti_LIBRARY "${CUDA_SOURCE_DIR}/extras/CUPTI/lib64/libcupti_static.a")
|
||||
elseif(EXISTS ${CUDA_SOURCE_DIR}/lib64/libcupti_static.a)
|
||||
set(CUDA_cupti_LIBRARY "${CUDA_SOURCE_DIR}/lib64/libcupti_static.a")
|
||||
elseif(EXISTS ${CUDA_SOURCE_DIR}/extras/CUPTI/lib64/libcupti.so)
|
||||
set(CUDA_cupti_LIBRARY "${CUDA_SOURCE_DIR}/extras/CUPTI/lib64/libcupti.so")
|
||||
elseif(EXISTS ${CUDA_SOURCE_DIR}/lib64/libcupti.so)
|
||||
set(CUDA_cupti_LIBRARY "${CUDA_SOURCE_DIR}/lib64/libcupti.so")
|
||||
endif()
|
||||
endif()
|
||||
message(STATUS " CUDA_cupti_LIBRARY = ${CUDA_cupti_LIBRARY}")
|
||||
message(STATUS " CUPTI_INCLUDE_DIR = ${CUPTI_INCLUDE_DIR}")
|
||||
|
||||
add_subdirectory("${KINETO_SOURCE_DIR}")
|
||||
message(STATUS "Configured libkineto as a dependency.")
|
||||
message(STATUS "Configured Kineto as a dependency.")
|
||||
endif()
|
||||
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS kineto)
|
||||
|
@ -2163,7 +2163,7 @@ TEST(TLSFutureCallbacksTest, Basic) {
|
||||
// test running callbacks with propagation of TLS state.
|
||||
{
|
||||
// Enable the profiler in this thread
|
||||
torch::autograd::profiler::enableProfiler(
|
||||
torch::autograd::profiler::enableProfilerLegacy(
|
||||
torch::autograd::profiler::ProfilerConfig(
|
||||
torch::autograd::profiler::ProfilerState::CPU, false, false));
|
||||
auto s1 = c10::make_intrusive<Future>(IntType::get());
|
||||
@ -2172,12 +2172,12 @@ TEST(TLSFutureCallbacksTest, Basic) {
|
||||
// Since we join here, we can ensure that all callbacks corresponding to
|
||||
// markCompleted() have finished.
|
||||
t.join();
|
||||
torch::autograd::profiler::disableProfiler();
|
||||
torch::autograd::profiler::disableProfilerLegacy();
|
||||
}
|
||||
// then() with TLS State
|
||||
{
|
||||
// Enable the profiler in this thread
|
||||
torch::autograd::profiler::enableProfiler(
|
||||
torch::autograd::profiler::enableProfilerLegacy(
|
||||
torch::autograd::profiler::ProfilerConfig(
|
||||
torch::autograd::profiler::ProfilerState::CPU, false, false));
|
||||
auto s1 = c10::make_intrusive<Future>(IntType::get());
|
||||
@ -2190,7 +2190,7 @@ TEST(TLSFutureCallbacksTest, Basic) {
|
||||
std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); });
|
||||
t.join();
|
||||
s2->wait();
|
||||
torch::autograd::profiler::disableProfiler();
|
||||
torch::autograd::profiler::disableProfilerLegacy();
|
||||
}
|
||||
}
|
||||
|
||||
@ -2199,7 +2199,7 @@ TEST(ProfilerDisableInCallbackTest, Basic) {
|
||||
auto profilerEnabledCb = []() {
|
||||
ASSERT_TRUE(torch::autograd::profiler::profilerEnabled());
|
||||
};
|
||||
torch::autograd::profiler::enableProfiler(
|
||||
torch::autograd::profiler::enableProfilerLegacy(
|
||||
torch::autograd::profiler::ProfilerConfig(
|
||||
torch::autograd::profiler::ProfilerState::CPU, false, false));
|
||||
auto s1 = c10::make_intrusive<Future>(IntType::get());
|
||||
@ -2212,10 +2212,10 @@ TEST(ProfilerDisableInCallbackTest, Basic) {
|
||||
// Don't cleanup TLSState, and just consolidate.
|
||||
auto opts = torch::autograd::profiler::ProfilerDisableOptions(false, true);
|
||||
auto thread_event_lists =
|
||||
torch::autograd::profiler::disableProfiler(std::move(opts));
|
||||
torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
|
||||
// Ensure that the events from this thread are still profiled and we obtain
|
||||
// the expected in events in our consolidated list when calling
|
||||
// disableProfiler().
|
||||
// disableProfilerLegacy().
|
||||
bool found_ones = false;
|
||||
bool found_add = false;
|
||||
for (const auto& li : thread_event_lists) {
|
||||
@ -2237,13 +2237,13 @@ TEST(ProfilerDisableInCallbackTest, Basic) {
|
||||
s1->addCallback(verifyProfilerCb);
|
||||
// Disable the profiler, but do not consolidate results in the main thread.
|
||||
auto opts = torch::autograd::profiler::ProfilerDisableOptions(true, false);
|
||||
torch::autograd::profiler::disableProfiler(std::move(opts));
|
||||
torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
|
||||
std::thread t([s1 = std::move(s1)]() { s1->markCompleted(at::IValue(1)); });
|
||||
t.join();
|
||||
|
||||
// Similar to above test, but verifies correctness in the case where
|
||||
// continuation runs on the main thread.
|
||||
torch::autograd::profiler::enableProfiler(
|
||||
torch::autograd::profiler::enableProfilerLegacy(
|
||||
torch::autograd::profiler::ProfilerConfig(
|
||||
torch::autograd::profiler::ProfilerState::CPU, false, false));
|
||||
s1 = c10::make_intrusive<Future>(IntType::get());
|
||||
@ -2251,7 +2251,7 @@ TEST(ProfilerDisableInCallbackTest, Basic) {
|
||||
// Runs callback inline
|
||||
s1->markCompleted(at::IValue(1));
|
||||
opts = torch::autograd::profiler::ProfilerDisableOptions(true, false);
|
||||
torch::autograd::profiler::disableProfiler(std::move(opts));
|
||||
torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
|
||||
}
|
||||
|
||||
TEST(IValueKWargsTest, Basic) {
|
||||
|
@ -33,7 +33,7 @@ from torch.testing._internal.common_utils import (TEST_MKL, TEST_WITH_ROCM, Test
|
||||
suppress_warnings, slowTest,
|
||||
load_tests, random_symmetric_matrix,
|
||||
IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck)
|
||||
from torch.autograd import Variable, Function, detect_anomaly
|
||||
from torch.autograd import Variable, Function, detect_anomaly, kineto_available
|
||||
from torch.autograd.function import InplaceFunction
|
||||
from torch.testing import randn_like
|
||||
from torch.testing._internal.common_methods_invocations import (method_tests,
|
||||
@ -2954,7 +2954,7 @@ class TestAutograd(TestCase):
|
||||
https://github.com/pytorch/pytorch/issues/34086""")
|
||||
def test_profiler_tracing(self):
|
||||
t1, t2 = torch.ones(1), torch.ones(1)
|
||||
with torch.autograd.profiler.profile() as prof:
|
||||
with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof:
|
||||
torch.add(t1, t2)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w+") as f:
|
||||
@ -2969,7 +2969,7 @@ class TestAutograd(TestCase):
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
t1, t2 = torch.ones(1, device=device), torch.ones(1, device=device)
|
||||
with torch.autograd.profiler.profile(use_cuda=True) as prof:
|
||||
with torch.autograd.profiler.profile(use_cuda=True, use_kineto=kineto_available()) as prof:
|
||||
torch.add(t1, t2)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w+") as f:
|
||||
@ -2980,7 +2980,7 @@ class TestAutograd(TestCase):
|
||||
def test_profiler(self):
|
||||
x = torch.randn(10, 10)
|
||||
|
||||
with profile() as p:
|
||||
with profile(use_kineto=kineto_available()) as p:
|
||||
self.assertTrue(torch.autograd._profiler_enabled())
|
||||
y = x * 2 + 4
|
||||
|
||||
@ -2991,22 +2991,21 @@ class TestAutograd(TestCase):
|
||||
'aten::empty', 'aten::add', 'aten::to', 'aten::empty_strided',
|
||||
'aten::copy_', 'aten::empty']
|
||||
top_level_names = ['aten::mul', 'aten::add']
|
||||
top_level_iter = iter(top_level_names)
|
||||
self.assertEqual(len(p.function_events), len(names))
|
||||
for info, expected_name in zip(p.function_events, names):
|
||||
if info.cpu_interval.start > last_end:
|
||||
top_level_name_expected = next(top_level_iter)
|
||||
self.assertEqual(info.name, top_level_name_expected)
|
||||
last_end = info.cpu_interval.end
|
||||
self.assertEqual(info.name, expected_name)
|
||||
for evt in p.function_events:
|
||||
if evt.time_range.start > last_end:
|
||||
self.assertTrue(evt.name in top_level_names)
|
||||
last_end = evt.time_range.end
|
||||
self.assertTrue(evt.name in names)
|
||||
|
||||
def test_profiler_seq_nr(self):
|
||||
with profile() as p:
|
||||
with profile(use_kineto=kineto_available()) as p:
|
||||
x = torch.randn(10, 10, requires_grad=True)
|
||||
y = torch.randn(10, 10, requires_grad=True)
|
||||
z = x + y
|
||||
s = z.sum()
|
||||
s.backward()
|
||||
print(p.key_averages().table(
|
||||
sort_by="self_cpu_time_total", row_limit=-1))
|
||||
# expecting aten::add, aten::sum to have the sequence numbers,
|
||||
# expecting the corresponding backward nodes to have the same numbers
|
||||
# as the forward ops
|
||||
@ -3049,7 +3048,7 @@ class TestAutograd(TestCase):
|
||||
def test_profiler_unboxed_only(self):
|
||||
x = torch.rand(3, 4)
|
||||
|
||||
with torch.autograd.profiler.profile() as prof:
|
||||
with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof:
|
||||
x.resize_([3, 2])
|
||||
|
||||
def test_profiler_propagation(self):
|
||||
@ -3074,7 +3073,7 @@ class TestAutograd(TestCase):
|
||||
|
||||
traced_bar = torch.jit.trace(bar, x)
|
||||
|
||||
with profile() as p:
|
||||
with profile(use_kineto=kineto_available()) as p:
|
||||
traced_bar(x)
|
||||
|
||||
found_foo = False
|
||||
@ -3096,7 +3095,7 @@ class TestAutograd(TestCase):
|
||||
|
||||
def test_record_function_callbacks(self):
|
||||
x = torch.randn(10, 10)
|
||||
with profile() as p:
|
||||
with profile(use_kineto=kineto_available()) as p:
|
||||
with record_function("foo"):
|
||||
y = x * 2 + 4
|
||||
|
||||
@ -3128,12 +3127,12 @@ class TestAutograd(TestCase):
|
||||
node_id=0,
|
||||
name="",
|
||||
thread=thread,
|
||||
cpu_start=range[0],
|
||||
cpu_end=range[1],
|
||||
start_us=range[0],
|
||||
end_us=range[1],
|
||||
)
|
||||
)
|
||||
|
||||
events.populate_cpu_children()
|
||||
events._populate_cpu_children()
|
||||
|
||||
# Note that [1, 3] pushes out [0, 2] first. Then we record [1, 2]
|
||||
# as a child of [1, 3]
|
||||
@ -3152,7 +3151,7 @@ class TestAutograd(TestCase):
|
||||
"""
|
||||
|
||||
x = torch.randn(1024)
|
||||
with torch.autograd.profiler.profile() as prof:
|
||||
with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof:
|
||||
torch.einsum("i->", x)
|
||||
|
||||
prof_str = str(prof)
|
||||
@ -3162,8 +3161,8 @@ class TestAutograd(TestCase):
|
||||
|
||||
def test_profiler_function_event_avg(self):
|
||||
avg = FunctionEventAvg()
|
||||
avg.add(FunctionEvent(id=0, node_id=0, name="foo", thread=0, cpu_start=10, cpu_end=15))
|
||||
avg.add(FunctionEvent(id=1, node_id=0, name="foo", thread=0, cpu_start=20, cpu_end=30))
|
||||
avg.add(FunctionEvent(id=0, node_id=0, name="foo", thread=0, start_us=10, end_us=15))
|
||||
avg.add(FunctionEvent(id=1, node_id=0, name="foo", thread=0, start_us=20, end_us=30))
|
||||
avg.add(avg)
|
||||
self.assertEqual(avg.key, "foo")
|
||||
|
||||
@ -3182,7 +3181,7 @@ class TestAutograd(TestCase):
|
||||
layer1 = torch.nn.Linear(20, 30)
|
||||
layer2 = torch.nn.Linear(30, 40)
|
||||
input = torch.randn(128, 20)
|
||||
with profile(record_shapes=True) as prof:
|
||||
with profile(record_shapes=True, use_kineto=kineto_available()) as prof:
|
||||
layer2(layer1(input))
|
||||
|
||||
print(prof.function_events)
|
||||
@ -3198,18 +3197,18 @@ class TestAutograd(TestCase):
|
||||
last_end = 0
|
||||
|
||||
for event in prof.function_events:
|
||||
if event.cpu_interval.start > last_end:
|
||||
if event.time_range.start > last_end:
|
||||
name_expected, input_shape_expected = next(expected_iter)
|
||||
if name_expected is not None:
|
||||
self.assertEqual(event.name, name_expected)
|
||||
self.assertEqual(event.input_shapes, input_shape_expected)
|
||||
last_end = event.cpu_interval.end
|
||||
last_end = event.time_range.end
|
||||
|
||||
def test_profiler_no_cuda(self):
|
||||
print("")
|
||||
layer = torch.nn.Linear(20, 30)
|
||||
x = torch.randn(128, 20)
|
||||
with profile(use_cuda=False) as prof:
|
||||
with profile(use_cuda=False, use_kineto=kineto_available()) as prof:
|
||||
layer(x)
|
||||
|
||||
prof_str = str(prof)
|
||||
@ -3221,7 +3220,7 @@ class TestAutograd(TestCase):
|
||||
print("")
|
||||
rnn = torch.nn.LSTM(10, 20, 2)
|
||||
total_time_s = 0
|
||||
with profile(record_shapes=True) as prof:
|
||||
with profile(record_shapes=True, use_kineto=kineto_available()) as prof:
|
||||
for i in range(20):
|
||||
input = torch.randn(5, 3, 10)
|
||||
h = torch.randn(2, 3, 20)
|
||||
@ -3258,7 +3257,7 @@ class TestAutograd(TestCase):
|
||||
def test_memory_profiler(self):
|
||||
def run_profiler(tensor_creation_fn, metric):
|
||||
# collecting allocs / deallocs
|
||||
with profile(profile_memory=True, record_shapes=True) as prof:
|
||||
with profile(profile_memory=True, record_shapes=True, use_kineto=kineto_available()) as prof:
|
||||
x = None
|
||||
with record_function("test_user_scope_alloc"):
|
||||
x = tensor_creation_fn()
|
||||
@ -3350,7 +3349,7 @@ class TestAutograd(TestCase):
|
||||
|
||||
# check partial overlap of tensor allocation with memory profiler
|
||||
x = torch.rand(10, 10)
|
||||
with profile(profile_memory=True, record_shapes=True) as prof:
|
||||
with profile(profile_memory=True, record_shapes=True, use_kineto=kineto_available()) as prof:
|
||||
del x
|
||||
x = torch.rand(10, 10)
|
||||
del x
|
||||
@ -3376,7 +3375,7 @@ class TestAutograd(TestCase):
|
||||
|
||||
forward(x)
|
||||
|
||||
with profile() as p:
|
||||
with profile(use_kineto=kineto_available()) as p:
|
||||
forward(x)
|
||||
|
||||
events = p.function_events
|
||||
@ -3401,7 +3400,7 @@ class TestAutograd(TestCase):
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
with profile() as p:
|
||||
with profile(use_kineto=kineto_available()) as p:
|
||||
f(1, 2)
|
||||
|
||||
self.assertTrue('my_func' in str(p))
|
||||
|
@ -44,6 +44,7 @@ from torch._six import PY37, StringIO
|
||||
from torch.autograd import Variable
|
||||
from torch.jit.annotations import BroadcastingList2, BroadcastingList3, Any # noqa: F401
|
||||
from torch.testing import FileCheck
|
||||
import torch.autograd.profiler
|
||||
import torch.cuda
|
||||
import torch.jit
|
||||
import torch.jit._logging
|
||||
@ -2552,10 +2553,10 @@ graph(%Ra, %Rb):
|
||||
for e in prof.function_events:
|
||||
if e.name == "aten::mul":
|
||||
self.assertTrue(e.thread not in mul_events)
|
||||
mul_events[e.thread] = e.cpu_interval.elapsed_us()
|
||||
mul_events[e.thread] = e.time_range.elapsed_us()
|
||||
elif e.name == "other_fn":
|
||||
self.assertTrue(e.thread not in other_fn_events)
|
||||
other_fn_events[e.thread] = e.cpu_interval.elapsed_us()
|
||||
other_fn_events[e.thread] = e.time_range.elapsed_us()
|
||||
|
||||
self.assertTrue(len(mul_events) == 2)
|
||||
self.assertTrue(len(other_fn_events) == 2)
|
||||
@ -8268,7 +8269,7 @@ dedent """
|
||||
|
||||
def _test_dtype_op_shape(self, ops, args, input_dims=1):
|
||||
if input_dims < 1:
|
||||
raise 'input dims must be at least 1'
|
||||
raise RuntimeError("input dims must be at least 1")
|
||||
dtypes = [torch.float32, torch.float64, torch.int64, torch.int32]
|
||||
str_args = ', '.join([str(arg) for arg in args]) + (', ' if len(args) else '')
|
||||
tensor_data = ('[' * input_dims) + '1, 2, 3' + (input_dims * ']')
|
||||
|
@ -7,6 +7,7 @@ import torch.nn as nn
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, TEST_WITH_ASAN, IS_WINDOWS)
|
||||
from torch.autograd.profiler import profile
|
||||
from torch.autograd import kineto_available
|
||||
|
||||
try:
|
||||
import psutil
|
||||
@ -73,7 +74,7 @@ class TestProfiler(TestCase):
|
||||
|
||||
mod = DummyModule()
|
||||
|
||||
with profile(with_stack=True) as p:
|
||||
with profile(with_stack=True, use_kineto=kineto_available()) as p:
|
||||
x = torch.randn(10, 10, requires_grad=True)
|
||||
y = torch.randn(10, 10, requires_grad=True)
|
||||
z = x + y
|
||||
@ -99,6 +100,34 @@ class TestProfiler(TestCase):
|
||||
|
||||
torch._C._set_graph_executor_optimize(prev_opt)
|
||||
|
||||
def payload(self):
|
||||
x = torch.randn(10, 10).cuda()
|
||||
y = torch.randn(10, 10).cuda()
|
||||
z = torch.mm(x, y)
|
||||
z = z + y
|
||||
z = z.cpu()
|
||||
|
||||
@unittest.skipIf(not kineto_available(), "Kineto is required")
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
|
||||
def test_kineto(self):
|
||||
with profile(use_cuda=True, use_kineto=True):
|
||||
self.payload()
|
||||
|
||||
# rerun to avoid initial start overhead
|
||||
with profile(use_cuda=True, use_kineto=True) as p:
|
||||
self.payload()
|
||||
print(p.key_averages().table(
|
||||
sort_by="self_cuda_time_total", row_limit=-1))
|
||||
found_gemm = False
|
||||
found_memcpy = False
|
||||
for e in p.function_events:
|
||||
if "gemm" in e.name:
|
||||
found_gemm = True
|
||||
if "Memcpy" in e.name or "memcpy" in e.name:
|
||||
found_memcpy = True
|
||||
self.assertTrue(found_gemm)
|
||||
self.assertTrue(found_memcpy)
|
||||
# p.export_chrome_trace("/tmp/test_trace.json")
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -74,7 +74,8 @@ jit_core_sources = [
|
||||
# list for the shared files.
|
||||
|
||||
core_sources_common = [
|
||||
"torch/csrc/autograd/profiler.cpp",
|
||||
"torch/csrc/autograd/profiler_legacy.cpp",
|
||||
"torch/csrc/autograd/profiler_kineto.cpp",
|
||||
"torch/csrc/jit/frontend/edit_distance.cpp",
|
||||
"torch/csrc/jit/frontend/string_to_type.cpp",
|
||||
"torch/csrc/jit/mobile/type_parser.cpp",
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
from typing import List, Set
|
||||
from enum import Enum
|
||||
|
||||
# Defined in tools/autograd/init.cpp
|
||||
@ -8,7 +8,16 @@ class ProfilerState(Enum):
|
||||
CPU = ...
|
||||
CUDA = ...
|
||||
NVTX = ...
|
||||
KINETO = ...
|
||||
|
||||
class ProfilerActivity(Enum):
|
||||
CPU = ...
|
||||
CUDA = ...
|
||||
|
||||
class DeviceType(Enum):
|
||||
CPU = ...
|
||||
CUDA = ...
|
||||
...
|
||||
|
||||
class ProfilerConfig:
|
||||
def __init__(
|
||||
@ -37,9 +46,25 @@ class ProfilerEvent:
|
||||
def thread_id(self) -> int: ...
|
||||
...
|
||||
|
||||
class KinetoEvent:
|
||||
def name(self) -> str: ...
|
||||
def device_index(self) -> int: ...
|
||||
def start_us(self) -> int: ...
|
||||
def duration_us(self) -> int: ...
|
||||
...
|
||||
|
||||
def _enable_profiler(config: ProfilerConfig) -> None: ...
|
||||
def _disable_profiler() -> List[List[ProfilerEvent]]: ...
|
||||
class ProfilerResult:
|
||||
def events(self) -> List[KinetoEvent]: ...
|
||||
def legacy_events(self) -> List[List[ProfilerEvent]]: ...
|
||||
def save(self, str) -> None: ...
|
||||
|
||||
def _enable_profiler(config: ProfilerConfig, activities: Set[ProfilerActivity]) -> None: ...
|
||||
def _prepare_profiler(config: ProfilerConfig, activities: Set[ProfilerActivity]) -> None: ...
|
||||
def _disable_profiler() -> ProfilerResult: ...
|
||||
def _profiler_enabled() -> bool: ...
|
||||
def kineto_available() -> bool: ...
|
||||
def _enable_record_function(enable: bool) -> None: ...
|
||||
def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
|
||||
|
||||
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
|
||||
def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ...
|
||||
|
@ -18,7 +18,6 @@ from .gradcheck import gradcheck, gradgradcheck
|
||||
from .grad_mode import no_grad, enable_grad, set_grad_enabled
|
||||
from .anomaly_mode import detect_anomaly, set_detect_anomaly
|
||||
from ..overrides import has_torch_function, handle_torch_function
|
||||
from . import profiler
|
||||
from . import functional
|
||||
|
||||
__all__ = ['Variable', 'Function', 'backward', 'grad_mode']
|
||||
@ -251,6 +250,10 @@ if not torch._C._autograd_init():
|
||||
raise RuntimeError("autograd initialization failed")
|
||||
|
||||
# Import all native method/classes
|
||||
from torch._C._autograd import (ProfilerState, ProfilerConfig, ProfilerEvent,
|
||||
_enable_profiler, _disable_profiler, _profiler_enabled,
|
||||
_enable_record_function, _set_empty_test_observer)
|
||||
from torch._C._autograd import (DeviceType, ProfilerActivity, ProfilerState, ProfilerConfig, ProfilerEvent,
|
||||
_enable_profiler_legacy, _disable_profiler_legacy, _profiler_enabled,
|
||||
_enable_record_function, _set_empty_test_observer, kineto_available)
|
||||
|
||||
if kineto_available():
|
||||
from torch._C._autograd import (ProfilerResult, KinetoEvent,
|
||||
_prepare_profiler, _enable_profiler, _disable_profiler)
|
||||
|
@ -1,12 +1,13 @@
|
||||
import itertools
|
||||
from typing import Any
|
||||
import torch
|
||||
from torch.autograd import DeviceType
|
||||
from torch.futures import Future
|
||||
|
||||
from collections import defaultdict, namedtuple
|
||||
from operator import attrgetter
|
||||
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
try:
|
||||
# Available in Python >= 3.2
|
||||
@ -37,14 +38,38 @@ class EventList(list):
|
||||
use_cuda = kwargs.pop('use_cuda', True)
|
||||
profile_memory = kwargs.pop('profile_memory', False)
|
||||
super(EventList, self).__init__(*args, **kwargs)
|
||||
self._cpu_children_populated = False
|
||||
self._use_cuda = use_cuda
|
||||
self._profile_memory = profile_memory
|
||||
self._tree_built = False
|
||||
|
||||
def _build_tree(self):
|
||||
self._populate_cpu_children()
|
||||
self._remove_dup_nodes()
|
||||
self._set_backward_stacktraces()
|
||||
self._tree_built = True
|
||||
|
||||
def __str__(self):
|
||||
return self.table()
|
||||
|
||||
def populate_cpu_children(self):
|
||||
def _remove_dup_nodes(self):
|
||||
while True:
|
||||
to_delete = []
|
||||
for idx in range(len(self)):
|
||||
if (self[idx].cpu_parent is not None and
|
||||
self[idx].cpu_parent.name == self[idx].name and
|
||||
len(self[idx].cpu_parent.cpu_children) == 1):
|
||||
self[idx].cpu_parent.cpu_children = self[idx].cpu_children
|
||||
self[idx].cpu_parent.kernels = self[idx].kernels # lift kernels up
|
||||
for ch in self[idx].cpu_children:
|
||||
ch.cpu_parent = self[idx].cpu_parent
|
||||
to_delete.append(idx)
|
||||
if len(to_delete) == 0:
|
||||
break
|
||||
new_evts = [ev for ind, ev in enumerate(self) if ind not in to_delete]
|
||||
self.clear()
|
||||
self.extend(new_evts)
|
||||
|
||||
def _populate_cpu_children(self):
|
||||
"""Populates child events into each underlying FunctionEvent object.
|
||||
One event is a child of another if [s1, e1) is inside [s2, e2). Where
|
||||
s1 and e1 would be start and end of the child event's interval. And
|
||||
@ -56,13 +81,11 @@ class EventList(list):
|
||||
If for any reason two intervals intersect only partially, this function
|
||||
will not record a parent child relationship between then.
|
||||
"""
|
||||
if self.cpu_children_populated:
|
||||
return
|
||||
|
||||
# Some events can be async (i.e. start and end on different threads),
|
||||
# since it's generally undefined how to attribute children ranges to
|
||||
# async ranges, we do not use them when calculating nested ranges and stats
|
||||
sync_events = [evt for evt in self if not evt.is_async]
|
||||
sync_events = [evt for evt in self if not evt.is_async and evt.device_type == DeviceType.CPU]
|
||||
events = sorted(
|
||||
sync_events,
|
||||
key=attrgetter("thread"),
|
||||
@ -89,15 +112,15 @@ class EventList(list):
|
||||
for thread_id, thread_events in threads:
|
||||
thread_events_ = sorted(
|
||||
thread_events,
|
||||
key=lambda event: [event.cpu_interval.start, -event.cpu_interval.end],
|
||||
key=lambda event: [event.time_range.start, -event.time_range.end],
|
||||
)
|
||||
current_events: List[FunctionEvent] = []
|
||||
cur_end = 0
|
||||
for event in thread_events_:
|
||||
while len(current_events) > 0:
|
||||
parent = current_events[-1]
|
||||
if event.cpu_interval.start >= parent.cpu_interval.end or \
|
||||
event.cpu_interval.end > parent.cpu_interval.end:
|
||||
if event.time_range.start >= parent.time_range.end or \
|
||||
event.time_range.end > parent.time_range.end:
|
||||
# this can't be a parent
|
||||
current_events.pop()
|
||||
else:
|
||||
@ -112,22 +135,18 @@ class EventList(list):
|
||||
|
||||
current_events.append(event)
|
||||
|
||||
self._cpu_children_populated = True
|
||||
|
||||
def set_backward_stacktraces(self):
|
||||
self.populate_cpu_children()
|
||||
|
||||
def _set_backward_stacktraces(self):
|
||||
def bw_parent(evt):
|
||||
if evt is None:
|
||||
return None
|
||||
elif evt.scope == 1:
|
||||
elif evt.scope == 1: # BACKWARD_FUNCTION
|
||||
return evt
|
||||
else:
|
||||
return bw_parent(evt.cpu_parent)
|
||||
|
||||
fwd_stacks = {}
|
||||
for evt in self:
|
||||
if bw_parent(evt) is None:
|
||||
if bw_parent(evt) is None and evt.stack is not None:
|
||||
t = (evt.sequence_nr, evt.thread)
|
||||
if t not in fwd_stacks:
|
||||
fwd_stacks[t] = evt.stack
|
||||
@ -142,15 +161,10 @@ class EventList(list):
|
||||
else:
|
||||
evt.stack = []
|
||||
|
||||
|
||||
@property
|
||||
def self_cpu_time_total(self):
|
||||
return sum([event.self_cpu_time_total for event in self])
|
||||
|
||||
@property
|
||||
def cpu_children_populated(self):
|
||||
return self._cpu_children_populated
|
||||
|
||||
def table(self, sort_by=None, row_limit=100, max_src_column_width=75, header=None, top_level_events_only=False):
|
||||
"""Prints an EventList as a nicely formatted table.
|
||||
|
||||
@ -205,8 +219,8 @@ class EventList(list):
|
||||
'"args": {}}, '
|
||||
% (
|
||||
evt.name,
|
||||
evt.cpu_interval.start,
|
||||
evt.cpu_interval.elapsed_us(),
|
||||
evt.time_range.start,
|
||||
evt.time_range.elapsed_us(),
|
||||
evt.thread
|
||||
if not evt.is_remote
|
||||
else f'" node_id:{evt.node_id}, thread_id:{evt.thread} "',
|
||||
@ -222,7 +236,7 @@ class EventList(list):
|
||||
'"pid": "CPU functions", '
|
||||
'"id": %s, '
|
||||
'"cat": "cpu_to_cuda", '
|
||||
'"args": {}}, ' % (evt.name, evt.cpu_interval.start,
|
||||
'"args": {}}, ' % (evt.name, evt.time_range.start,
|
||||
evt.thread, next_id))
|
||||
f.write('{"name": "%s", '
|
||||
'"ph": "f", '
|
||||
@ -262,11 +276,11 @@ class EventList(list):
|
||||
Returns:
|
||||
An EventList containing FunctionEventAvg objects.
|
||||
"""
|
||||
self.populate_cpu_children()
|
||||
stats: Dict[Tuple[int, Tuple[int, int]], FunctionEventAvg] = defaultdict(FunctionEventAvg)
|
||||
assert self._tree_built
|
||||
stats: Dict[Tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg)
|
||||
|
||||
def get_key(event, group_by_input_shapes, group_by_stack_n):
|
||||
key = [str(event.key), str(event.node_id)]
|
||||
def get_key(event, group_by_input_shapes, group_by_stack_n) -> Tuple[str, ...]:
|
||||
key = [str(event.key), str(event.node_id), str(event.device_type), str(event.is_legacy)]
|
||||
if group_by_input_shapes:
|
||||
key.append(str(event.input_shapes))
|
||||
if group_by_stack_n > 0:
|
||||
@ -326,6 +340,11 @@ class profile(object):
|
||||
|
||||
with_stack (bool, optional): record source information (file and line number) for the ops
|
||||
|
||||
use_kineto (bool, default False): experimental support for Kineto profiler
|
||||
|
||||
use_cpu (default True) - whether to profile CPU events; setting to False requires
|
||||
use_kineto=True and can be used to lower the overhead for GPU-only profiling
|
||||
|
||||
.. warning:
|
||||
Enabling memory profiling or source attribution incurs additional profiler
|
||||
overhead
|
||||
@ -365,44 +384,83 @@ class profile(object):
|
||||
use_cuda=False,
|
||||
record_shapes=False,
|
||||
profile_memory=False,
|
||||
with_stack=False):
|
||||
self.enabled = enabled
|
||||
self.use_cuda = use_cuda
|
||||
self.function_events = None
|
||||
with_stack=False,
|
||||
use_kineto=False,
|
||||
use_cpu=True):
|
||||
self.enabled: bool = enabled
|
||||
if not self.enabled:
|
||||
return
|
||||
self.use_cuda = use_cuda
|
||||
self.function_events = None
|
||||
self.entered = False
|
||||
self.record_shapes = record_shapes
|
||||
self.profile_memory = profile_memory
|
||||
self.with_stack = with_stack
|
||||
self.use_cpu = use_cpu
|
||||
self.kineto_results = None
|
||||
if not self.use_cpu:
|
||||
assert use_kineto, \
|
||||
"Device-only events supported only with Kineto (use_kineto=True)"
|
||||
|
||||
self.profiler_kind = None
|
||||
self.kineto_activities = set()
|
||||
if use_kineto:
|
||||
self.profiler_kind = torch.autograd.ProfilerState.KINETO
|
||||
if self.use_cpu:
|
||||
self.kineto_activities.add(torch.autograd.ProfilerActivity.CPU)
|
||||
if self.use_cuda:
|
||||
self.kineto_activities.add(
|
||||
# uses CUPTI
|
||||
torch.autograd.ProfilerActivity.CUDA)
|
||||
assert len(self.kineto_activities) > 0, \
|
||||
"No activities specified for Kineto profiler"
|
||||
elif self.use_cuda:
|
||||
# legacy CUDA mode
|
||||
self.profiler_kind = torch.autograd.ProfilerState.CUDA
|
||||
else:
|
||||
self.profiler_kind = torch.autograd.ProfilerState.CPU
|
||||
|
||||
if self.profiler_kind == torch.autograd.ProfilerState.KINETO:
|
||||
assert (
|
||||
torch.autograd.kineto_available()
|
||||
), """Requested Kineto profiling but Kineto is not available,
|
||||
make sure PyTorch is built with USE_KINETO=1"""
|
||||
|
||||
def config(self):
|
||||
assert self.profiler_kind is not None
|
||||
return torch.autograd.ProfilerConfig(
|
||||
self.profiler_kind,
|
||||
self.record_shapes,
|
||||
self.profile_memory,
|
||||
self.with_stack)
|
||||
|
||||
def __enter__(self):
|
||||
if not self.enabled:
|
||||
return
|
||||
if self.entered:
|
||||
raise RuntimeError("autograd profiler traces are not reentrant")
|
||||
raise RuntimeError("profiler context manager is not reentrant")
|
||||
self.entered = True
|
||||
profiler_kind = torch.autograd.ProfilerState.CUDA if self.use_cuda \
|
||||
else torch.autograd.ProfilerState.CPU
|
||||
|
||||
config = torch.autograd.ProfilerConfig(
|
||||
profiler_kind,
|
||||
self.record_shapes,
|
||||
self.profile_memory,
|
||||
self.with_stack)
|
||||
torch.autograd._enable_profiler(config)
|
||||
if self.kineto_activities:
|
||||
torch.autograd._prepare_profiler(self.config(), self.kineto_activities)
|
||||
torch.autograd._enable_profiler(self.config(), self.kineto_activities)
|
||||
else:
|
||||
torch.autograd._enable_profiler_legacy(self.config())
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if not self.enabled:
|
||||
return
|
||||
records = torch.autograd._disable_profiler()
|
||||
if self.kineto_activities:
|
||||
self.kineto_results = torch.autograd._disable_profiler()
|
||||
parsed_results = parse_kineto_results(self.kineto_results)
|
||||
else:
|
||||
records = torch.autograd._disable_profiler_legacy()
|
||||
parsed_results = parse_legacy_records(records)
|
||||
self.function_events = EventList(
|
||||
parse_event_records(records),
|
||||
parsed_results,
|
||||
use_cuda=self.use_cuda,
|
||||
profile_memory=self.profile_memory)
|
||||
if self.with_stack:
|
||||
self.function_events.set_backward_stacktraces()
|
||||
self.function_events._build_tree()
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
@ -413,13 +471,11 @@ class profile(object):
|
||||
def __str__(self):
|
||||
if self.function_events is None:
|
||||
return '<unfinished torch.autograd.profile>'
|
||||
self.function_events.populate_cpu_children()
|
||||
return str(self.function_events)
|
||||
|
||||
def _check_finish(self):
|
||||
if self.function_events is None:
|
||||
raise RuntimeError("can't export a trace that didn't finish running")
|
||||
self.function_events.populate_cpu_children()
|
||||
|
||||
def table(self, sort_by=None, row_limit=100, max_src_column_width=75, header=None, top_level_events_only=False):
|
||||
self._check_finish()
|
||||
@ -432,8 +488,11 @@ class profile(object):
|
||||
|
||||
def export_chrome_trace(self, path):
|
||||
self._check_finish()
|
||||
assert self.function_events is not None
|
||||
return self.function_events.export_chrome_trace(path)
|
||||
if self.kineto_results is not None:
|
||||
self.kineto_results.save(path)
|
||||
else:
|
||||
assert self.function_events is not None
|
||||
return self.function_events.export_chrome_trace(path)
|
||||
export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
|
||||
|
||||
def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
|
||||
@ -630,7 +689,7 @@ class emit_nvtx(object):
|
||||
raise RuntimeError("NVTX annotation context manager is not reentrant")
|
||||
self.entered = True
|
||||
torch.cuda.synchronize()
|
||||
torch.autograd._enable_profiler(
|
||||
torch.autograd._enable_profiler_legacy(
|
||||
torch.autograd.ProfilerConfig(
|
||||
torch.autograd.ProfilerState.NVTX,
|
||||
self.record_shapes,
|
||||
@ -643,7 +702,7 @@ class emit_nvtx(object):
|
||||
if not self.enabled:
|
||||
return
|
||||
torch.cuda.synchronize()
|
||||
torch.autograd._disable_profiler()
|
||||
torch.autograd._disable_profiler_legacy()
|
||||
return False
|
||||
|
||||
|
||||
@ -731,13 +790,14 @@ Kernel = namedtuple('Kernel', ['name', 'device', 'interval'])
|
||||
class FunctionEvent(FormattedTimesMixin):
|
||||
"""Profiling information about a single function."""
|
||||
def __init__(
|
||||
self, id, node_id, name, thread, cpu_start, cpu_end, fwd_thread=None, input_shapes=None,
|
||||
self, id, name, thread, start_us, end_us, fwd_thread=None, input_shapes=None,
|
||||
stack=None, scope=0, cpu_memory_usage=0, cuda_memory_usage=0, is_async=False,
|
||||
is_remote=True, sequence_nr=-1):
|
||||
is_remote=False, sequence_nr=-1, node_id=-1, device_type=DeviceType.CPU, device_index=0,
|
||||
is_legacy=False):
|
||||
self.id: int = id
|
||||
self.node_id: int = node_id
|
||||
self.name: str = name
|
||||
self.cpu_interval: Interval = Interval(cpu_start, cpu_end)
|
||||
self.time_range: Interval = Interval(start_us, end_us)
|
||||
self.thread: int = thread
|
||||
self.fwd_thread: Optional[int] = fwd_thread
|
||||
self.kernels: List[Kernel] = []
|
||||
@ -752,8 +812,12 @@ class FunctionEvent(FormattedTimesMixin):
|
||||
self.is_async: bool = is_async
|
||||
self.is_remote: bool = is_remote
|
||||
self.sequence_nr: int = sequence_nr
|
||||
self.device_type: DeviceType = device_type
|
||||
self.device_index: int = device_index
|
||||
self.is_legacy: bool = is_legacy
|
||||
|
||||
def append_kernel(self, name, device, start, end):
|
||||
assert self.device_type == DeviceType.CPU
|
||||
self.kernels.append(Kernel(name, device, Interval(start, end)))
|
||||
|
||||
def append_cpu_child(self, child):
|
||||
@ -762,7 +826,9 @@ class FunctionEvent(FormattedTimesMixin):
|
||||
One is supposed to append only direct children to the event to have
|
||||
correct self cpu time being reported.
|
||||
"""
|
||||
assert(self.device_type == DeviceType.CPU)
|
||||
assert(isinstance(child, FunctionEvent))
|
||||
assert(child.device_type == DeviceType.CPU)
|
||||
self.cpu_children.append(child)
|
||||
|
||||
def set_cpu_parent(self, parent):
|
||||
@ -772,14 +838,16 @@ class FunctionEvent(FormattedTimesMixin):
|
||||
the child's range interval is completely inside the parent's. We use
|
||||
this connection to determine the event is from top-level op or not.
|
||||
"""
|
||||
assert(self.device_type == DeviceType.CPU)
|
||||
assert(isinstance(parent, FunctionEvent))
|
||||
assert(parent.device_type == DeviceType.CPU)
|
||||
self.cpu_parent = parent
|
||||
|
||||
# Note: async events don't have children, are not used when computing 'self'
|
||||
# metrics of other events, have only total cpu time
|
||||
@property
|
||||
def self_cpu_memory_usage(self):
|
||||
if self.is_async:
|
||||
if self.is_async or self.device_type != DeviceType.CPU:
|
||||
return 0
|
||||
return self.cpu_memory_usage - sum(
|
||||
[child.cpu_memory_usage for child in self.cpu_children]
|
||||
@ -787,7 +855,7 @@ class FunctionEvent(FormattedTimesMixin):
|
||||
|
||||
@property
|
||||
def self_cuda_memory_usage(self):
|
||||
if self.is_async:
|
||||
if self.is_async or self.device_type != DeviceType.CPU:
|
||||
return 0
|
||||
return self.cuda_memory_usage - sum(
|
||||
[child.cuda_memory_usage for child in self.cpu_children]
|
||||
@ -795,7 +863,7 @@ class FunctionEvent(FormattedTimesMixin):
|
||||
|
||||
@property
|
||||
def self_cpu_time_total(self):
|
||||
if self.is_async:
|
||||
if self.is_async or self.device_type != DeviceType.CPU:
|
||||
return 0
|
||||
return self.cpu_time_total - sum(
|
||||
[child.cpu_time_total for child in self.cpu_children]
|
||||
@ -803,16 +871,37 @@ class FunctionEvent(FormattedTimesMixin):
|
||||
|
||||
@property
|
||||
def cuda_time_total(self):
|
||||
return sum(kinfo.interval.elapsed_us() for kinfo in self.kernels)
|
||||
if self.is_async:
|
||||
return 0
|
||||
if self.device_type == DeviceType.CPU:
|
||||
if not self.is_legacy:
|
||||
# account for the kernels in the children ops
|
||||
return (sum(kinfo.interval.elapsed_us() for kinfo in self.kernels) +
|
||||
sum(ch.cuda_time_total for ch in self.cpu_children))
|
||||
else:
|
||||
# each legacy cpu events has a single (fake) kernel
|
||||
return sum(kinfo.interval.elapsed_us() for kinfo in self.kernels)
|
||||
else:
|
||||
assert self.device_type == DeviceType.CUDA
|
||||
return self.time_range.elapsed_us()
|
||||
|
||||
@property
|
||||
def self_cuda_time_total(self):
|
||||
return sum(kinfo.interval.elapsed_us() for kinfo in self.kernels) - \
|
||||
sum([child.cuda_time_total for child in self.cpu_children])
|
||||
if self.is_async:
|
||||
return 0
|
||||
if self.device_type == DeviceType.CPU:
|
||||
return self.cuda_time_total - \
|
||||
sum([child.cuda_time_total for child in self.cpu_children])
|
||||
else:
|
||||
assert(self.device_type == DeviceType.CUDA)
|
||||
return self.cuda_time_total
|
||||
|
||||
@property
|
||||
def cpu_time_total(self):
|
||||
return self.cpu_interval.elapsed_us()
|
||||
if self.device_type == DeviceType.CPU:
|
||||
return self.time_range.elapsed_us()
|
||||
else:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
@ -820,14 +909,16 @@ class FunctionEvent(FormattedTimesMixin):
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
'<FunctionEvent id={} node_id={} cpu_time={} cpu_start={} cpu_end={} '
|
||||
'<FunctionEvent id={} name={} device_type={} node_id={} cpu_time={} start_us={} end_us={} '
|
||||
'cpu_children={} cuda_time={} name={} thread={} input_shapes={} '
|
||||
'cpu_memory_usage={} cuda_memory_usage={} is_async={} is_remote={} seq_nr={}>'.format(
|
||||
'cpu_memory_usage={} cuda_memory_usage={} is_async={} is_remote={} seq_nr={} is_legacy={}>'.format(
|
||||
self.id,
|
||||
self.name,
|
||||
self.device_type,
|
||||
self.node_id,
|
||||
self.cpu_time_str,
|
||||
self.cpu_interval.start,
|
||||
self.cpu_interval.end,
|
||||
self.time_range.start,
|
||||
self.time_range.end,
|
||||
str([child.id for child in self.cpu_children]),
|
||||
self.cuda_time_str,
|
||||
self.name,
|
||||
@ -838,6 +929,7 @@ class FunctionEvent(FormattedTimesMixin):
|
||||
self.is_async,
|
||||
self.is_remote,
|
||||
self.sequence_nr,
|
||||
self.is_legacy,
|
||||
)
|
||||
)
|
||||
|
||||
@ -863,6 +955,8 @@ class FunctionEventAvg(FormattedTimesMixin):
|
||||
self.self_cuda_memory_usage: int = 0
|
||||
self.cpu_children: Optional[List[FunctionEvent]] = None
|
||||
self.cpu_parent: Optional[FunctionEvent] = None
|
||||
self.device_type: DeviceType = DeviceType.CPU
|
||||
self.is_legacy: bool = False
|
||||
|
||||
def add(self, other):
|
||||
if self.key is None:
|
||||
@ -878,6 +972,8 @@ class FunctionEventAvg(FormattedTimesMixin):
|
||||
self.input_shapes = other.input_shapes
|
||||
self.stack = other.stack
|
||||
self.scope = other.scope
|
||||
self.device_type = other.device_type
|
||||
self.is_legacy = other.is_legacy
|
||||
|
||||
assert isinstance(other, (FunctionEvent, FunctionEventAvg))
|
||||
assert other.key == self.key
|
||||
@ -923,10 +1019,111 @@ class StringTable(defaultdict):
|
||||
self[key] = torch._C._demangle(key) if len(key) > 1 else key
|
||||
return self[key]
|
||||
|
||||
def parse_event_records(thread_records):
|
||||
def filter_stack_entry(entry):
|
||||
filtered_entries = [
|
||||
("autograd/__init__", "_make_grads"),
|
||||
("autograd/__init__", "backward"),
|
||||
("torch/tensor", "backward"),
|
||||
("_internal/common_utils", "prof_callable"),
|
||||
("_internal/common_utils", "prof_func_call"),
|
||||
("_internal/common_utils", "prof_meth_call"),
|
||||
]
|
||||
return all([not (f[0] in entry and f[1] in entry) for f in filtered_entries])
|
||||
|
||||
def filter_name(name):
|
||||
# ignoring the following utility ops
|
||||
filtered_out_names = [
|
||||
"profiler::_record_function_enter",
|
||||
"profiler::_record_function_exit",
|
||||
"aten::is_leaf",
|
||||
"aten::output_nr",
|
||||
"aten::_version",
|
||||
]
|
||||
return name in filtered_out_names
|
||||
|
||||
# Parsing of kineto profiler events
|
||||
def parse_kineto_results(result):
|
||||
# result.events() has most of the events - PyTorch op-level and device-level events
|
||||
# result.legacy_events() has events not yet ported to kineto
|
||||
# (e.g. start/stop marks, tensor memory allocator events)
|
||||
|
||||
# First, find __start_profile mark to get the absolute time of the start of the trace;
|
||||
# save memory allocation records
|
||||
start_record = None
|
||||
mem_records = []
|
||||
for record in itertools.chain(*result.legacy_events()):
|
||||
if record.kind() == 'mark' and record.name() == '__start_profile':
|
||||
assert start_record is None
|
||||
start_record = record
|
||||
if record.kind() == 'memory_alloc':
|
||||
mem_records.append(record)
|
||||
assert start_record is not None, "Invalid profiler output, __start_profile is missing"
|
||||
|
||||
# Create and return FunctionEvent list
|
||||
string_table = StringTable()
|
||||
function_events = []
|
||||
cuda_corr_map: Dict[int, List[torch.autograd.KinetoEvent]] = {}
|
||||
for kineto_event in result.events():
|
||||
if filter_name(kineto_event.name()):
|
||||
continue
|
||||
rel_start_us = kineto_event.start_us() - start_record.start_us()
|
||||
rel_end_us = rel_start_us + kineto_event.duration_us()
|
||||
abs_end_us = kineto_event.start_us() + kineto_event.duration_us()
|
||||
|
||||
cpu_memory_usage = 0
|
||||
cuda_memory_usage = 0
|
||||
if kineto_event.device_type() == DeviceType.CPU:
|
||||
# find the corresponding memory allocation events
|
||||
for mem_record in mem_records:
|
||||
if (mem_record.start_us() >= kineto_event.start_us() and
|
||||
mem_record.start_us() <= abs_end_us):
|
||||
cpu_memory_usage += mem_record.cpu_memory_usage()
|
||||
cuda_memory_usage += mem_record.cuda_memory_usage()
|
||||
is_async = kineto_event.start_thread_id() != kineto_event.end_thread_id()
|
||||
fe = FunctionEvent(
|
||||
id=kineto_event.correlation_id(),
|
||||
name=string_table[kineto_event.name()],
|
||||
thread=kineto_event.start_thread_id(),
|
||||
start_us=rel_start_us,
|
||||
end_us=rel_end_us,
|
||||
fwd_thread=kineto_event.fwd_thread_id(),
|
||||
input_shapes=kineto_event.shapes(),
|
||||
stack=[entry for entry in kineto_event.stack() if filter_stack_entry(entry)],
|
||||
scope=kineto_event.scope(),
|
||||
cpu_memory_usage=cpu_memory_usage,
|
||||
cuda_memory_usage=cuda_memory_usage,
|
||||
is_async=is_async,
|
||||
sequence_nr=kineto_event.sequence_nr(),
|
||||
device_type=kineto_event.device_type(),
|
||||
device_index=kineto_event.device_index(),
|
||||
)
|
||||
function_events.append(fe)
|
||||
if kineto_event.device_type() == DeviceType.CUDA:
|
||||
corr_id = kineto_event.linked_correlation_id()
|
||||
if corr_id > 0:
|
||||
if corr_id not in cuda_corr_map:
|
||||
cuda_corr_map[corr_id] = []
|
||||
cuda_corr_map[corr_id].append(kineto_event)
|
||||
|
||||
# associate CUDA kernels with CPU events
|
||||
for fe in function_events:
|
||||
if (fe.device_type == DeviceType.CPU and not fe.is_async and
|
||||
fe.id in cuda_corr_map):
|
||||
for k_evt in cuda_corr_map[fe.id]:
|
||||
fe.append_kernel(
|
||||
k_evt.name(),
|
||||
k_evt.device_index(),
|
||||
k_evt.start_us(),
|
||||
k_evt.start_us() + k_evt.duration_us())
|
||||
|
||||
function_events.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end])
|
||||
return function_events
|
||||
|
||||
# Parsing of legacy profiler events
|
||||
def parse_legacy_records(thread_records):
|
||||
def get_record_key(record):
|
||||
"""
|
||||
Returns a tuple to be used by parse_event_records for correlating start and
|
||||
Returns a tuple to be used by parse_legacy_records for correlating start and
|
||||
end records.
|
||||
"""
|
||||
return (record.handle(), record.node_id())
|
||||
@ -938,26 +1135,6 @@ def parse_event_records(thread_records):
|
||||
record_stack = []
|
||||
string_table = StringTable()
|
||||
|
||||
# ignoring the following utility ops
|
||||
filtered_out_names = [
|
||||
"profiler::_record_function_enter",
|
||||
"profiler::_record_function_exit",
|
||||
"aten::is_leaf",
|
||||
"aten::output_nr",
|
||||
"aten::_version",
|
||||
]
|
||||
|
||||
def filter_stack_entry(entry):
|
||||
filtered_entries = [
|
||||
("autograd/__init__", "_make_grads"),
|
||||
("autograd/__init__", "backward"),
|
||||
("torch/tensor", "backward"),
|
||||
("_internal/common_utils", "prof_callable"),
|
||||
("_internal/common_utils", "prof_func_call"),
|
||||
("_internal/common_utils", "prof_meth_call"),
|
||||
]
|
||||
return all([not (f[0] in entry and f[1] in entry) for f in filtered_entries])
|
||||
|
||||
# cuda start events and the overall profiler start event don't happen
|
||||
# at exactly the same time because we need to record an event on each device
|
||||
# and each record takes ~4us. So we adjust here by the difference
|
||||
@ -994,7 +1171,7 @@ def parse_event_records(thread_records):
|
||||
prev_record = None
|
||||
for record in thread_record_list:
|
||||
record_key = get_record_key(record)
|
||||
if (record.name() in filtered_out_names or
|
||||
if (filter_name(record.name()) or
|
||||
record_key in filtered_handles):
|
||||
filtered_handles.add(record_key)
|
||||
continue
|
||||
@ -1035,8 +1212,8 @@ def parse_event_records(thread_records):
|
||||
node_id=record.node_id(),
|
||||
name=string_table[start.name()],
|
||||
thread=start.thread_id(),
|
||||
cpu_start=start_record.cpu_elapsed_us(start),
|
||||
cpu_end=start_record.cpu_elapsed_us(record),
|
||||
start_us=start_record.cpu_elapsed_us(start),
|
||||
end_us=start_record.cpu_elapsed_us(record),
|
||||
fwd_thread=start.fwd_thread_id(),
|
||||
input_shapes=start.shapes(),
|
||||
stack=[entry for entry in start.stack() if filter_stack_entry(entry)],
|
||||
@ -1046,6 +1223,8 @@ def parse_event_records(thread_records):
|
||||
is_async=is_async,
|
||||
is_remote=is_remote_event,
|
||||
sequence_nr=start.sequence_nr(),
|
||||
device_type=DeviceType.CPU,
|
||||
is_legacy=True,
|
||||
)
|
||||
# note: async events have only cpu total time
|
||||
if not is_async and start.has_cuda():
|
||||
@ -1074,7 +1253,7 @@ def parse_event_records(thread_records):
|
||||
# granularity of the given clock tick)--we always show
|
||||
# the outermost nested call first. This adds stability
|
||||
# in how FunctionEvents appear
|
||||
functions.sort(key=lambda evt: [evt.cpu_interval.start, -evt.cpu_interval.end])
|
||||
functions.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end])
|
||||
return functions
|
||||
|
||||
|
||||
@ -1121,8 +1300,8 @@ def parse_nvprof_trace(path):
|
||||
node_id=0, # missing a node_id when calling FunctionEvent. This is just to ensure
|
||||
# that pytorch doesn't crash when creating a FunctionEvent() object
|
||||
name=strings[row['name']],
|
||||
cpu_start=row['start_time'],
|
||||
cpu_end=row['end_time'],
|
||||
start_us=row['start_time'],
|
||||
end_us=row['end_time'],
|
||||
thread=0) # TODO: find in sqlite database
|
||||
functions.append(evt)
|
||||
functions_map[evt.id] = evt
|
||||
@ -1153,7 +1332,7 @@ def parse_nvprof_trace(path):
|
||||
row['kernel_start'],
|
||||
row['kernel_end'])
|
||||
|
||||
functions.sort(key=lambda evt: evt.cpu_interval.start)
|
||||
functions.sort(key=lambda evt: evt.time_range.start)
|
||||
return functions
|
||||
|
||||
|
||||
@ -1182,7 +1361,9 @@ def build_table(
|
||||
has_input_shapes = any(
|
||||
[(event.input_shapes is not None and len(event.input_shapes) > 0) for event in events])
|
||||
|
||||
MAX_NAME_COLUMN_WIDTH = 55
|
||||
name_column_width = max([len(evt.key) for evt in events]) + 4
|
||||
name_column_width = min(name_column_width, MAX_NAME_COLUMN_WIDTH)
|
||||
|
||||
DEFAULT_COLUMN_WIDTH = 12
|
||||
|
||||
@ -1269,7 +1450,16 @@ def build_table(
|
||||
result.append('\n') # Yes, newline after the end as well
|
||||
|
||||
self_cpu_time_total = sum([event.self_cpu_time_total for event in events])
|
||||
cuda_time_total = sum([evt.self_cuda_time_total for evt in events])
|
||||
cuda_time_total = 0
|
||||
for evt in events:
|
||||
if evt.device_type == DeviceType.CPU:
|
||||
# in legacy profiler, kernel info is stored in cpu events
|
||||
if evt.is_legacy:
|
||||
cuda_time_total += evt.self_cuda_time_total
|
||||
elif evt.device_type == DeviceType.CUDA:
|
||||
# in kineto mode, there're events with the correct device type (e.g. CUDA)
|
||||
cuda_time_total += evt.self_cuda_time_total
|
||||
|
||||
# Actual printing
|
||||
if header is not None:
|
||||
append('=' * line_length)
|
||||
@ -1290,8 +1480,11 @@ def build_table(
|
||||
continue
|
||||
else:
|
||||
event_limit += 1
|
||||
name = evt.key
|
||||
if len(name) >= MAX_NAME_COLUMN_WIDTH - 3:
|
||||
name = name[:(MAX_NAME_COLUMN_WIDTH - 3)] + "..."
|
||||
row_values = [
|
||||
evt.key, # Name
|
||||
name,
|
||||
# Self CPU total, 0 for async events. %
|
||||
format_time_share(evt.self_cpu_time_total,
|
||||
self_cpu_time_total),
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <torch/csrc/Exceptions.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/autograd/grad_mode.h>
|
||||
@ -39,37 +40,132 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
|
||||
.value("Disabled", ProfilerState::Disabled)
|
||||
.value("CPU", ProfilerState::CPU)
|
||||
.value("CUDA", ProfilerState::CUDA)
|
||||
.value("NVTX", ProfilerState::NVTX);
|
||||
.value("NVTX", ProfilerState::NVTX)
|
||||
.value("KINETO", ProfilerState::KINETO);
|
||||
|
||||
py::enum_<ActivityType>(m, "ProfilerActivity")
|
||||
.value("CPU", ActivityType::CPU)
|
||||
.value("CUDA", ActivityType::CUDA);
|
||||
|
||||
py::class_<ProfilerConfig>(m, "ProfilerConfig")
|
||||
.def(py::init<ProfilerState, bool, bool, bool>());
|
||||
|
||||
py::class_<Event>(m, "ProfilerEvent")
|
||||
.def("kind", &Event::kind)
|
||||
.def("name", [](const Event& e) { return e.name(); })
|
||||
.def("thread_id", &Event::threadId)
|
||||
.def("fwd_thread_id", &Event::fwdThreadId)
|
||||
.def("device", &Event::device)
|
||||
.def("cpu_elapsed_us", &Event::cpuElapsedUs)
|
||||
.def("cuda_elapsed_us", &Event::cudaElapsedUs)
|
||||
.def("has_cuda", &Event::hasCuda)
|
||||
.def("shapes", &Event::shapes)
|
||||
.def("cpu_memory_usage", &Event::cpuMemoryUsage)
|
||||
.def("cuda_memory_usage", &Event::cudaMemoryUsage)
|
||||
.def("handle", &Event::handle)
|
||||
.def("node_id", &Event::nodeId)
|
||||
.def("is_remote", &Event::isRemote)
|
||||
.def("sequence_nr", &Event::sequenceNr)
|
||||
.def("stack", &Event::stack)
|
||||
.def("scope", &Event::scope);
|
||||
py::class_<LegacyEvent>(m, "ProfilerEvent")
|
||||
.def("kind", &LegacyEvent::kindStr)
|
||||
.def("name", [](const LegacyEvent& e) { return e.name(); })
|
||||
.def("thread_id", &LegacyEvent::threadId)
|
||||
.def("fwd_thread_id", &LegacyEvent::fwdThreadId)
|
||||
.def("device", &LegacyEvent::device)
|
||||
.def("cpu_elapsed_us", &LegacyEvent::cpuElapsedUs)
|
||||
.def("cuda_elapsed_us", &LegacyEvent::cudaElapsedUs)
|
||||
.def("has_cuda", &LegacyEvent::hasCuda)
|
||||
.def("shapes", &LegacyEvent::shapes)
|
||||
.def("cpu_memory_usage", &LegacyEvent::cpuMemoryUsage)
|
||||
.def("cuda_memory_usage", &LegacyEvent::cudaMemoryUsage)
|
||||
.def("handle", &LegacyEvent::handle)
|
||||
.def("node_id", &LegacyEvent::nodeId)
|
||||
.def("is_remote", &LegacyEvent::isRemote)
|
||||
.def("sequence_nr", &LegacyEvent::sequenceNr)
|
||||
.def("stack", &LegacyEvent::stack)
|
||||
.def("scope", &LegacyEvent::scope)
|
||||
.def("correlation_id", &LegacyEvent::correlationId)
|
||||
.def("start_us", &LegacyEvent::cpuUs);
|
||||
|
||||
py::class_<ProfilerDisableOptions>(m, "_ProfilerDisableOptions")
|
||||
.def(py::init<bool, bool>());
|
||||
py::enum_<c10::DeviceType>(m, "DeviceType")
|
||||
.value("CPU", c10::DeviceType::CPU)
|
||||
.value("CUDA", c10::DeviceType::CUDA)
|
||||
.value("MKLDNN", c10::DeviceType::MKLDNN)
|
||||
.value("OPENGL", c10::DeviceType::OPENGL)
|
||||
.value("OPENCL", c10::DeviceType::OPENCL)
|
||||
.value("IDEEP", c10::DeviceType::IDEEP)
|
||||
.value("HIP", c10::DeviceType::HIP)
|
||||
.value("FPGA", c10::DeviceType::FPGA)
|
||||
.value("MSNPU", c10::DeviceType::MSNPU)
|
||||
.value("XLA", c10::DeviceType::XLA)
|
||||
.value("Vulkan", c10::DeviceType::Vulkan)
|
||||
.value("Metal", c10::DeviceType::Metal);
|
||||
|
||||
#ifdef USE_KINETO
|
||||
py::class_<KinetoEvent>(m, "KinetoEvent")
|
||||
// name of the event
|
||||
.def("name", &KinetoEvent::name)
|
||||
// PyTorch thread id of the start callback
|
||||
.def("start_thread_id", [](const KinetoEvent& e) {
|
||||
return e.startThreadId();
|
||||
})
|
||||
// PyTorch thread id of the end callback
|
||||
.def("end_thread_id", [](const KinetoEvent& e) {
|
||||
return e.endThreadId();
|
||||
})
|
||||
// for events of scope BACKWARD_FUNCTION - PyTorch thread id
|
||||
// of the corresponding forward op
|
||||
.def("fwd_thread_id", [](const KinetoEvent& e) {
|
||||
return e.fwdThreadId();
|
||||
})
|
||||
// together with fwd_thread_id, used to uniquely identify
|
||||
// the forward op
|
||||
.def("sequence_nr", [](const KinetoEvent& e) {
|
||||
return e.sequenceNr();
|
||||
})
|
||||
// absolute start time (since unix epoch) in us
|
||||
.def("start_us", &KinetoEvent::startUs)
|
||||
// duration in us
|
||||
.def("duration_us", &KinetoEvent::durationUs)
|
||||
// used for correlation between high-level PyTorch events
|
||||
// and low-level device events
|
||||
.def("correlation_id", [](const KinetoEvent& e) {
|
||||
return e.correlationId();
|
||||
})
|
||||
// shapes of input tensors
|
||||
.def("shapes", [](const KinetoEvent& e) {
|
||||
if (e.hasShapes()) {
|
||||
return e.shapes();
|
||||
} else {
|
||||
return std::vector<std::vector<int64_t>>();
|
||||
}
|
||||
})
|
||||
// stack traces of the PyTorch CPU events
|
||||
.def("stack", [](const KinetoEvent& e) {
|
||||
if (e.hasStack()) {
|
||||
return e.stack();
|
||||
} else {
|
||||
return std::vector<std::string>();
|
||||
}
|
||||
})
|
||||
// type of the RecordFunction that generated a PyTorch CPU event
|
||||
// (op, torchscript function, user label, etc)
|
||||
.def("scope", [](const KinetoEvent& e) {
|
||||
return e.scope();
|
||||
})
|
||||
// device number, for CPU - process id
|
||||
.def("device_index", &KinetoEvent::deviceIndex)
|
||||
// for CUDA - stream id, for CPU - start thread id
|
||||
.def("device_resource_id", &KinetoEvent::deviceResourceId)
|
||||
// device type
|
||||
.def("device_type", [](const KinetoEvent& e) {
|
||||
return e.deviceType();
|
||||
})
|
||||
// correlation id of a linked event
|
||||
.def("linked_correlation_id", &KinetoEvent::linkedCorrelationId);
|
||||
|
||||
py::class_<ProfilerResult>(m, "ProfilerResult")
|
||||
.def("events", &ProfilerResult::events)
|
||||
.def("legacy_events", &ProfilerResult::legacy_events)
|
||||
.def("save", &ProfilerResult::save);
|
||||
|
||||
m.def("_enable_profiler", enableProfiler);
|
||||
m.def("_disable_profiler", disableProfiler);
|
||||
m.def("_prepare_profiler", prepareProfiler);
|
||||
#endif
|
||||
|
||||
m.def("kineto_available", kinetoAvailable);
|
||||
|
||||
m.def("_enable_profiler_legacy", enableProfilerLegacy);
|
||||
py::class_<ProfilerDisableOptions>(m, "_ProfilerDisableOptions")
|
||||
.def(py::init<bool, bool>());
|
||||
m.def(
|
||||
"_disable_profiler",
|
||||
disableProfiler,
|
||||
"_disable_profiler_legacy",
|
||||
disableProfilerLegacy,
|
||||
py::arg("profiler_disable_options") = ProfilerDisableOptions());
|
||||
m.def("_profiler_enabled", profilerEnabled);
|
||||
m.def("_enable_record_function", [](bool enable) {
|
||||
|
@ -1,461 +1,4 @@
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <mutex>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <forward_list>
|
||||
#include <tuple>
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#ifndef _WIN32
|
||||
#include <ctime>
|
||||
#endif
|
||||
#if defined(C10_IOS) && defined(C10_MOBILE)
|
||||
#include <sys/time.h> // for gettimeofday()
|
||||
#endif
|
||||
|
||||
#include <ATen/record_function.h>
|
||||
|
||||
struct CUevent_st;
|
||||
typedef std::shared_ptr<CUevent_st> CUDAEventStub;
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
struct Node;
|
||||
|
||||
namespace profiler {
|
||||
|
||||
struct TORCH_API CUDAStubs {
|
||||
virtual void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) {
|
||||
fail();
|
||||
}
|
||||
virtual float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) {
|
||||
fail();
|
||||
return 0.f;
|
||||
}
|
||||
virtual void nvtxMarkA(const char* name) {
|
||||
fail();
|
||||
}
|
||||
virtual void nvtxRangePushA(const char* name) {
|
||||
fail();
|
||||
}
|
||||
virtual void nvtxRangePop() {
|
||||
fail();
|
||||
}
|
||||
virtual bool enabled() {
|
||||
return false;
|
||||
}
|
||||
virtual void onEachDevice(std::function<void(int)> op) {
|
||||
fail();
|
||||
}
|
||||
virtual void synchronize() {
|
||||
fail();
|
||||
}
|
||||
virtual ~CUDAStubs();
|
||||
|
||||
private:
|
||||
void fail() {
|
||||
AT_ERROR("CUDA used in profiler but not enabled.");
|
||||
}
|
||||
};
|
||||
|
||||
TORCH_API void registerCUDAMethods(CUDAStubs* stubs);
|
||||
|
||||
constexpr inline size_t ceilToMultiple(size_t a, size_t b) {
|
||||
return ((a + b - 1) / b) * b;
|
||||
}
|
||||
|
||||
inline int64_t getTime() {
|
||||
#if defined(C10_IOS) && defined(C10_MOBILE)
|
||||
// clock_gettime is only available on iOS 10.0 or newer. Unlike OS X, iOS can't rely on
|
||||
// CLOCK_REALTIME, as it is defined no matter if clock_gettime is implemented or not
|
||||
struct timeval now;
|
||||
gettimeofday(&now, NULL);
|
||||
return static_cast<int64_t>(now.tv_sec) * 1000000000 + static_cast<int64_t>(now.tv_usec) * 1000;
|
||||
#elif defined(_WIN32) || defined(__MACH__)
|
||||
using namespace std::chrono;
|
||||
using clock = std::conditional<high_resolution_clock::is_steady, high_resolution_clock, steady_clock>::type;
|
||||
return duration_cast<nanoseconds>(clock::now().time_since_epoch()).count();
|
||||
#else
|
||||
// clock_gettime is *much* faster than std::chrono implementation on Linux
|
||||
struct timespec t{};
|
||||
clock_gettime(CLOCK_MONOTONIC, &t);
|
||||
return static_cast<int64_t>(t.tv_sec) * 1000000000 + static_cast<int64_t>(t.tv_nsec);
|
||||
#endif
|
||||
}
|
||||
|
||||
// A struct to control settings of disableProfiler options.
|
||||
struct TORCH_API ProfilerDisableOptions {
|
||||
ProfilerDisableOptions() = default;
|
||||
ProfilerDisableOptions(bool shouldCleanupTLSState, bool shouldConsolidate)
|
||||
: cleanupTLSState(shouldCleanupTLSState),
|
||||
consolidate(shouldConsolidate) {}
|
||||
// Whether we should clean up profiler states that are thread local, such as
|
||||
// ThreadLocalDebugInfo and thread local RecordFunction callbacks.
|
||||
bool cleanupTLSState = true;
|
||||
// Whether we should consolidate all currently recorded profiled events. If
|
||||
// false, will not consolidate and other threads can continue to write to the
|
||||
// event lists.
|
||||
bool consolidate = true;
|
||||
};
|
||||
|
||||
enum class C10_API_ENUM ProfilerState {
|
||||
Disabled,
|
||||
CPU, // CPU-only profiling
|
||||
CUDA, // CPU + CUDA events
|
||||
NVTX, // only emit NVTX markers
|
||||
};
|
||||
|
||||
struct TORCH_API ProfilerConfig {
|
||||
ProfilerConfig(
|
||||
ProfilerState state,
|
||||
bool report_input_shapes = false,
|
||||
bool profile_memory = false,
|
||||
bool with_stack = false)
|
||||
: state(state),
|
||||
report_input_shapes(report_input_shapes),
|
||||
profile_memory(profile_memory),
|
||||
with_stack(with_stack) {}
|
||||
~ProfilerConfig();
|
||||
ProfilerState state;
|
||||
bool report_input_shapes;
|
||||
bool profile_memory;
|
||||
bool with_stack;
|
||||
|
||||
// Returns IValues corresponding to ProfilerConfig struct, to be used for
|
||||
// serialization.
|
||||
at::IValue toIValue() const;
|
||||
|
||||
// Reconstructs a ProfilerConfig from IValues given by toIValue.
|
||||
static ProfilerConfig fromIValue(const at::IValue& profilerConfigIValue);
|
||||
|
||||
};
|
||||
|
||||
enum class C10_API_ENUM EventKind : uint16_t {
|
||||
Mark,
|
||||
PushRange,
|
||||
PopRange,
|
||||
MemoryAlloc,
|
||||
};
|
||||
|
||||
struct TORCH_API Event final {
|
||||
Event(
|
||||
EventKind kind,
|
||||
at::StringView name,
|
||||
uint16_t thread_id,
|
||||
bool record_cuda,
|
||||
at::RecordFunctionHandle handle = 0,
|
||||
std::vector<std::vector<int64_t>>&& shapes = {},
|
||||
int node_id = -1)
|
||||
: name_(std::move(name)),
|
||||
kind_(kind),
|
||||
thread_id_(thread_id),
|
||||
handle_(handle),
|
||||
shapes_(shapes),
|
||||
node_id_(node_id) {
|
||||
record(record_cuda);
|
||||
}
|
||||
|
||||
// Constructor to be used in conjunction with Event::fromIValue.
|
||||
Event(
|
||||
EventKind kind,
|
||||
at::StringView name,
|
||||
uint16_t thread_id,
|
||||
at::RecordFunctionHandle handle,
|
||||
std::vector<std::vector<int64_t>>&& shapes,
|
||||
int node_id,
|
||||
bool is_remote,
|
||||
int64_t cpu_memory_usage,
|
||||
int64_t cpu_ns,
|
||||
bool cuda_recorded,
|
||||
int64_t cuda_memory_usage = 0,
|
||||
int device = -1,
|
||||
double cuda_us = -1)
|
||||
: cpu_ns_(cpu_ns),
|
||||
name_(std::move(name)),
|
||||
kind_(kind),
|
||||
thread_id_(thread_id),
|
||||
handle_(handle),
|
||||
shapes_(shapes),
|
||||
cpu_memory_usage_(cpu_memory_usage),
|
||||
cuda_memory_usage_(cuda_memory_usage),
|
||||
device_(device),
|
||||
node_id_(node_id),
|
||||
is_remote_(is_remote),
|
||||
cuda_us_(cuda_us) {
|
||||
// Sanity check values that were deserialized
|
||||
TORCH_INTERNAL_ASSERT(cpu_ns_ > 0);
|
||||
if (cuda_recorded) {
|
||||
TORCH_INTERNAL_ASSERT(device_ >= 0);
|
||||
TORCH_INTERNAL_ASSERT(cuda_us_ >= 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns IValues corresponding to event structure, to be used for
|
||||
// serialization.
|
||||
at::IValue toIValue() const;
|
||||
|
||||
// Reconstructs an event from IValues given by toIValue.
|
||||
static Event fromIValue(const at::IValue& eventIValue);
|
||||
|
||||
void record(bool record_cuda);
|
||||
std::string kind() const {
|
||||
switch(kind_) {
|
||||
case EventKind::Mark: return "mark";
|
||||
case EventKind::PushRange: return "push";
|
||||
case EventKind::PopRange: return "pop";
|
||||
case EventKind::MemoryAlloc: return "memory_alloc";
|
||||
}
|
||||
throw std::runtime_error("unknown EventKind");
|
||||
}
|
||||
|
||||
// Get enum kind of this event.
|
||||
EventKind eventKind() const {
|
||||
return kind_;
|
||||
}
|
||||
|
||||
const char* name() const {
|
||||
return name_.str();
|
||||
}
|
||||
|
||||
uint64_t threadId() const {
|
||||
return thread_id_;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> shapes() const {
|
||||
return shapes_;
|
||||
}
|
||||
|
||||
double cpuElapsedUs(const Event& e) const {
|
||||
return (e.cpu_ns_ - cpu_ns_)/(1000.0);
|
||||
}
|
||||
|
||||
double cpuUs() const {
|
||||
return cpu_ns_ / (1000.0);
|
||||
}
|
||||
|
||||
double cudaElapsedUs(const Event& e) const;
|
||||
|
||||
bool hasCuda() const {
|
||||
return cuda_event != nullptr || (isRemote() && device_ != -1);
|
||||
}
|
||||
|
||||
int device() const {
|
||||
return device_;
|
||||
}
|
||||
|
||||
void updateMemoryStats(int64_t alloc_size, c10::Device device) {
|
||||
if (device.type() == c10::DeviceType::CUDA ||
|
||||
device.type() == c10::DeviceType::HIP) {
|
||||
cuda_memory_usage_ = alloc_size;
|
||||
} else if (device.type() == c10::DeviceType::CPU ||
|
||||
device.type() == c10::DeviceType::MKLDNN ||
|
||||
device.type() == c10::DeviceType::IDEEP) {
|
||||
cpu_memory_usage_ = alloc_size;
|
||||
} else {
|
||||
LOG(WARNING) << "Unsupported memory profiling device: " << device;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t cpuMemoryUsage() const {
|
||||
return cpu_memory_usage_;
|
||||
}
|
||||
|
||||
int64_t cudaMemoryUsage() const {
|
||||
return cuda_memory_usage_;
|
||||
}
|
||||
|
||||
at::RecordFunctionHandle handle() const {
|
||||
return handle_;
|
||||
}
|
||||
|
||||
// Node ID corresponding to this event.
|
||||
int nodeId( ) const {
|
||||
return node_id_;
|
||||
}
|
||||
|
||||
// Set Node ID on this event.
|
||||
void setNodeId(int node_id) {
|
||||
node_id_ = node_id;
|
||||
}
|
||||
|
||||
void setName(at::StringView newName_) {
|
||||
name_ = std::move(newName_);
|
||||
}
|
||||
|
||||
bool isRemote() const {
|
||||
return is_remote_;
|
||||
}
|
||||
|
||||
void setCudaUs(int64_t cuda_us) {
|
||||
cuda_us_ = cuda_us;
|
||||
}
|
||||
|
||||
void setSequenceNr(int64_t sequence_nr) {
|
||||
sequence_nr_ = sequence_nr;
|
||||
}
|
||||
|
||||
int64_t sequenceNr() const {
|
||||
return sequence_nr_;
|
||||
}
|
||||
|
||||
const std::vector<std::string>& stack() const {
|
||||
return stack_;
|
||||
}
|
||||
|
||||
void setStack(const std::vector<std::string>& stack) {
|
||||
stack_ = stack;
|
||||
}
|
||||
|
||||
uint64_t fwdThreadId() const {
|
||||
return fwd_thread_id_;
|
||||
}
|
||||
|
||||
void setFwdThreadId(uint64_t fwd_thread_id) {
|
||||
fwd_thread_id_ = fwd_thread_id;
|
||||
}
|
||||
|
||||
uint8_t scope() const {
|
||||
return scope_;
|
||||
}
|
||||
|
||||
void setScope(uint8_t scope) {
|
||||
scope_ = scope;
|
||||
}
|
||||
|
||||
private:
|
||||
// signed to allow for negative intervals, initialized for safety.
|
||||
int64_t cpu_ns_ = 0;
|
||||
at::StringView name_;
|
||||
EventKind kind_;
|
||||
uint64_t thread_id_;
|
||||
uint64_t fwd_thread_id_;
|
||||
at::RecordFunctionHandle handle_ {0};
|
||||
std::vector<std::vector<int64_t>> shapes_;
|
||||
int64_t cpu_memory_usage_ = 0;
|
||||
int64_t cuda_memory_usage_ = 0;
|
||||
int device_ = -1;
|
||||
CUDAEventStub cuda_event = nullptr;
|
||||
int node_id_ = 0;
|
||||
bool is_remote_ = false;
|
||||
int64_t cuda_us_ = -1;
|
||||
int64_t sequence_nr_ = -1;
|
||||
|
||||
std::vector<std::string> stack_;
|
||||
uint8_t scope_;
|
||||
};
|
||||
|
||||
// a linked-list of fixed sized vectors, to avoid
|
||||
// a std::vector resize from taking a large amount of time inside
|
||||
// a profiling event
|
||||
struct RangeEventList {
|
||||
RangeEventList() {
|
||||
events_.reserve(kReservedCapacity);
|
||||
}
|
||||
|
||||
template<typename... Args>
|
||||
void record(Args&&... args) {
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
events_.emplace_back(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
std::vector<Event> consolidate() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
std::vector<Event> result;
|
||||
result.insert(
|
||||
result.begin(),
|
||||
std::make_move_iterator(events_.begin()),
|
||||
std::make_move_iterator(events_.end()));
|
||||
events_.erase(events_.begin(), events_.end());
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t size() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return events_.size();
|
||||
}
|
||||
|
||||
private:
|
||||
// This mutex is used to serialize access when different threads are writing
|
||||
// to the same instance of RangeEventList.
|
||||
std::mutex mutex_;
|
||||
std::vector<Event> events_;
|
||||
|
||||
static const size_t kReservedCapacity = 1024;
|
||||
};
|
||||
|
||||
using thread_event_lists = std::vector<std::vector<Event>>;
|
||||
// NOTE: profiler mode is thread local, with automatic propagation
|
||||
// across thread boundary (e.g. at::launch tasks)
|
||||
TORCH_API void enableProfiler(const ProfilerConfig&);
|
||||
TORCH_API thread_event_lists disableProfiler(c10::optional<ProfilerDisableOptions> profilerDisableOptions = c10::nullopt);
|
||||
// adds profiledEvents to the current thread local recorded events. Each event
|
||||
// will be marked with node ID given by fromNodeId.
|
||||
TORCH_API void addEventList(std::vector<Event>&& profiledEvents);
|
||||
// Returns if the profiler is currently enabled in the current thread.
|
||||
TORCH_API bool profilerEnabled();
|
||||
// Retrieve the thread_local ProfilerConfig.
|
||||
TORCH_API ProfilerConfig getProfilerConfig();
|
||||
// Writes profiled events to a stream.
|
||||
TORCH_API void writeProfilerEventsToStream(std::ostream& out, const std::vector<Event*>& events);
|
||||
|
||||
// Usage:
|
||||
// {
|
||||
// RecordProfile guard("filename.trace");
|
||||
// // code you want to profile
|
||||
// }
|
||||
// Then open filename.trace in chrome://tracing
|
||||
struct TORCH_API RecordProfile {
|
||||
RecordProfile(std::ostream& out);
|
||||
RecordProfile(const std::string& filename);
|
||||
|
||||
~RecordProfile();
|
||||
private:
|
||||
void init();
|
||||
std::unique_ptr<std::ofstream> file_;
|
||||
std::ostream& out_;
|
||||
void processEvents(const std::vector<Event*>& events);
|
||||
};
|
||||
|
||||
// A guard that enables the profiler, taking in an optional callback to process
|
||||
// the results
|
||||
// Usage:
|
||||
// {
|
||||
// TLSProfilerGuard g([](thread_event_lists profilerResults) {
|
||||
// // process profilerResults
|
||||
// });
|
||||
// Code to profile
|
||||
// }
|
||||
struct TORCH_API TLSProfilerGuard {
|
||||
explicit TLSProfilerGuard(
|
||||
const ProfilerConfig& cfg,
|
||||
c10::optional<std::function<void(const thread_event_lists&)>>
|
||||
resultCallback = c10::nullopt,
|
||||
c10::optional<ProfilerDisableOptions> profilerDisableOptions =
|
||||
c10::nullopt)
|
||||
: cb_(std::move(resultCallback)),
|
||||
profilerDisableOptions_(std::move(profilerDisableOptions)) {
|
||||
enableProfiler(cfg);
|
||||
}
|
||||
~TLSProfilerGuard() {
|
||||
thread_event_lists event_lists = disableProfiler(profilerDisableOptions_);
|
||||
if (cb_) {
|
||||
try {
|
||||
(*cb_)(event_lists);
|
||||
} catch (const std::exception& e) {
|
||||
LOG(ERROR) << "Got error processing profiler events: " << e.what();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
c10::optional<std::function<void(const thread_event_lists&)>> cb_;
|
||||
const c10::optional<ProfilerDisableOptions> profilerDisableOptions_;
|
||||
};
|
||||
|
||||
} // namespace profiler
|
||||
}} // namespace torch::autograd
|
||||
#include <torch/csrc/autograd/profiler_legacy.h>
|
||||
#include <torch/csrc/autograd/profiler_kineto.h>
|
||||
|
@ -32,7 +32,7 @@ static inline void cudaCheck(cudaError_t result, const char * file, int line) {
|
||||
#define TORCH_CUDA_CHECK(result) cudaCheck(result,__FILE__,__LINE__);
|
||||
|
||||
struct CUDAMethods : public CUDAStubs {
|
||||
void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) override {
|
||||
void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) const override {
|
||||
TORCH_CUDA_CHECK(cudaGetDevice(device));
|
||||
CUevent_st* cuda_event_ptr;
|
||||
TORCH_CUDA_CHECK(cudaEventCreate(&cuda_event_ptr));
|
||||
@ -43,23 +43,28 @@ struct CUDAMethods : public CUDAStubs {
|
||||
*cpu_ns = getTime();
|
||||
TORCH_CUDA_CHECK(cudaEventRecord(cuda_event_ptr, stream));
|
||||
}
|
||||
float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) override {
|
||||
|
||||
float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) const override{
|
||||
TORCH_CUDA_CHECK(cudaEventSynchronize(event->get()));
|
||||
TORCH_CUDA_CHECK(cudaEventSynchronize(event2->get()));
|
||||
float ms;
|
||||
TORCH_CUDA_CHECK(cudaEventElapsedTime(&ms, event->get(), event2->get()));
|
||||
return ms*1000.0;
|
||||
}
|
||||
void nvtxMarkA(const char* name) override {
|
||||
|
||||
void nvtxMarkA(const char* name) const override {
|
||||
::nvtxMark(name);
|
||||
}
|
||||
void nvtxRangePushA(const char* name) override {
|
||||
|
||||
void nvtxRangePushA(const char* name) const override {
|
||||
::nvtxRangePushA(name);
|
||||
}
|
||||
void nvtxRangePop() override {
|
||||
|
||||
void nvtxRangePop() const override {
|
||||
::nvtxRangePop();
|
||||
}
|
||||
void onEachDevice(std::function<void(int)> op) override {
|
||||
|
||||
void onEachDevice(std::function<void(int)> op) const override {
|
||||
at::cuda::OptionalCUDAGuard device_guard;
|
||||
int count = at::cuda::device_count();
|
||||
for(int i = 0; i < count; i++) {
|
||||
@ -67,13 +72,14 @@ struct CUDAMethods : public CUDAStubs {
|
||||
op(i);
|
||||
}
|
||||
}
|
||||
void synchronize() override {
|
||||
|
||||
void synchronize() const override {
|
||||
cudaDeviceSynchronize();
|
||||
}
|
||||
bool enabled() override {
|
||||
|
||||
bool enabled() const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct RegisterCUDAMethods {
|
||||
|
368
torch/csrc/autograd/profiler_kineto.cpp
Normal file
368
torch/csrc/autograd/profiler_kineto.cpp
Normal file
@ -0,0 +1,368 @@
|
||||
#include <torch/csrc/autograd/profiler_kineto.h>
|
||||
|
||||
#include <torch/csrc/jit/frontend/tracer.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#ifdef USE_KINETO
|
||||
#include <pthread.h>
|
||||
#include <libkineto.h>
|
||||
#endif
|
||||
|
||||
namespace torch { namespace autograd { namespace profiler {
|
||||
|
||||
#ifdef USE_KINETO
|
||||
namespace {
|
||||
// TODO: consider TLS (tid + tls counter)
|
||||
uint64_t next_correlation_id() {
|
||||
static std::atomic<uint64_t> corr_id_ {1};
|
||||
return corr_id_++;
|
||||
}
|
||||
|
||||
inline int64_t getTimeUs() {
|
||||
using namespace std::chrono;
|
||||
return duration_cast<microseconds>(high_resolution_clock::now().time_since_epoch()).count();
|
||||
}
|
||||
|
||||
std::string shapesToStr(const std::vector<std::vector<int64_t>>& shapes);
|
||||
|
||||
struct TORCH_API KinetoThreadLocalState : public ProfilerThreadLocalState {
|
||||
using ProfilerThreadLocalState::ProfilerThreadLocalState;
|
||||
virtual ~KinetoThreadLocalState() override = default;
|
||||
|
||||
void reportClientActivity(
|
||||
const at::RecordFunction& fn,
|
||||
const KinetoObserverContext* ctx) {
|
||||
if (!ctx) {
|
||||
return;
|
||||
}
|
||||
libkineto::ClientTraceActivity op;
|
||||
op.startTime = ctx->startUs;
|
||||
op.endTime = getTimeUs();
|
||||
op.opType = std::string(fn.name().str());
|
||||
op.device = 0;
|
||||
op.threadId = ctx->startThreadId;
|
||||
op.correlation = ctx->correlationId;
|
||||
// optimization - postpone shapesToStr till finalizeCPUTrace
|
||||
// is called from disableProfiler
|
||||
// if (ctx->shapes && !ctx->shapes->empty()) {
|
||||
// op.inputDims = shapesToStr(*ctx->shapes);
|
||||
// }
|
||||
|
||||
// Not setting atm
|
||||
op.inputTypes = "[]";
|
||||
op.arguments = "[]";
|
||||
op.outputDims = "[]";
|
||||
op.outputTypes = "[]";
|
||||
op.inputNames = "[]";
|
||||
op.outputNames = "[]";
|
||||
|
||||
//
|
||||
op.threadId = pthread_self();
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(state_mutex_);
|
||||
kineto_events_.emplace_back();
|
||||
kineto_events_.back()
|
||||
.activity(op)
|
||||
.startThreadId(ctx->startThreadId)
|
||||
.endThreadId(ctx->endThreadId)
|
||||
.sequenceNr(ctx->sequenceNr)
|
||||
.fwdThreadId(ctx->fwdThreadId)
|
||||
.scope(ctx->recFunScope);
|
||||
if (ctx->shapes && !ctx->shapes->empty()) {
|
||||
kineto_events_.back().shapes(*ctx->shapes);
|
||||
}
|
||||
if (ctx->stack && !ctx->stack->empty()) {
|
||||
kineto_events_.back().stack(*ctx->stack);
|
||||
}
|
||||
cpu_trace->activities.emplace_back(std::move(op));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: use kineto
|
||||
void reportMemoryUsage(
|
||||
void* /* unused */,
|
||||
int64_t alloc_size,
|
||||
c10::Device device) override {
|
||||
if (config_.profile_memory && config_.state != ProfilerState::Disabled) {
|
||||
uint64_t thread_id = at::RecordFunction::currentThreadId();
|
||||
LegacyEvent evt(
|
||||
EventKind::MemoryAlloc,
|
||||
at::StringView(""),
|
||||
thread_id,
|
||||
config_.state == ProfilerState::CUDA);
|
||||
evt.setCpuUs(getTimeUs()); // upd. time using Kineto's clock
|
||||
evt.updateMemoryStats(alloc_size, device);
|
||||
getEventList(thread_id).record(std::move(evt));
|
||||
}
|
||||
}
|
||||
|
||||
void addTraceEvents(libkineto::ActivityTraceInterface& trace) {
|
||||
const auto& events = *(trace.activities());
|
||||
for (const auto& ev_ptr : events) {
|
||||
// ClientTraceActivity events are already processed
|
||||
if (ev_ptr->type() != libkineto::ActivityType::CPU_OP) {
|
||||
kineto_events_.emplace_back();
|
||||
kineto_events_.back()
|
||||
.activity(*ev_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void finalizeCPUTrace() {
|
||||
TORCH_INTERNAL_ASSERT(cpu_trace->activities.size() == kineto_events_.size());
|
||||
for (auto idx = 0; idx < cpu_trace->activities.size(); ++idx) {
|
||||
if (kineto_events_[idx].hasShapes()) {
|
||||
cpu_trace->activities[idx].inputDims = shapesToStr(kineto_events_[idx].shapes());
|
||||
} else {
|
||||
cpu_trace->activities[idx].inputDims = "[]";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<KinetoEvent> kineto_events_;
|
||||
std::unique_ptr<libkineto::CpuTraceBuffer> cpu_trace =
|
||||
std::make_unique<libkineto::CpuTraceBuffer>();
|
||||
};
|
||||
|
||||
KinetoThreadLocalState* getProfilerTLSState() {
|
||||
const auto& state = c10::ThreadLocalDebugInfo::get(
|
||||
c10::DebugInfoKind::PROFILER_STATE);
|
||||
return static_cast<KinetoThreadLocalState*>(state);
|
||||
}
|
||||
|
||||
void pushProfilingCallbacks() {
|
||||
auto state_ptr = getProfilerTLSState();
|
||||
TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
|
||||
auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback(
|
||||
[](const at::RecordFunction& fn) {
|
||||
auto state_ptr = getProfilerTLSState();
|
||||
if (!state_ptr || state_ptr->config().state != ProfilerState::KINETO) {
|
||||
return std::make_unique<KinetoObserverContext>();
|
||||
}
|
||||
|
||||
auto corr_id = next_correlation_id();
|
||||
libkineto::api().activityProfiler().pushCorrelationId(corr_id);
|
||||
|
||||
auto ctx_ptr = std::make_unique<KinetoObserverContext>();
|
||||
ctx_ptr->startUs = getTimeUs();
|
||||
ctx_ptr->correlationId = corr_id;
|
||||
ctx_ptr->startThreadId = at::RecordFunction::currentThreadId();
|
||||
|
||||
if (state_ptr->config().report_input_shapes) {
|
||||
ctx_ptr->shapes = inputSizes(fn);
|
||||
}
|
||||
|
||||
ctx_ptr->sequenceNr = fn.seqNr();
|
||||
ctx_ptr->fwdThreadId = fn.forwardThreadId();
|
||||
ctx_ptr->recFunScope = (uint8_t)fn.scope();
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
// backward nodes source range corresponds to the forward node
|
||||
// TODO: consider using C++ stack trace
|
||||
if (state_ptr->config().with_stack &&
|
||||
fn.scope() != at::RecordScope::BACKWARD_FUNCTION) {
|
||||
auto cs = prepareCallstack(jit::currentCallstack());
|
||||
if (cs.empty()) {
|
||||
cs = prepareCallstack(jit::tracer::pythonCallstack());
|
||||
}
|
||||
ctx_ptr->stack = callstackStr(cs);
|
||||
}
|
||||
#endif
|
||||
return ctx_ptr;
|
||||
},
|
||||
[](const at::RecordFunction& fn, at::ObserverContext* ctx_ptr) {
|
||||
auto state_ptr = getProfilerTLSState();
|
||||
if (!state_ptr || state_ptr->config().state != ProfilerState::KINETO) {
|
||||
return;
|
||||
}
|
||||
auto* kineto_ctx_ptr = static_cast<KinetoObserverContext*>(ctx_ptr);
|
||||
TORCH_INTERNAL_ASSERT(kineto_ctx_ptr != nullptr);
|
||||
|
||||
kineto_ctx_ptr->endThreadId = at::RecordFunction::currentThreadId();
|
||||
|
||||
state_ptr->reportClientActivity(fn, kineto_ctx_ptr);
|
||||
libkineto::api().activityProfiler().popCorrelationId();
|
||||
})
|
||||
.needsInputs(state_ptr->config().report_input_shapes)
|
||||
.needsIds(true));
|
||||
state_ptr->setCallbackHandle(handle);
|
||||
}
|
||||
|
||||
std::string shapesToStr(const std::vector<std::vector<int64_t>>& shapes) {
|
||||
std::ostringstream oss;
|
||||
oss << "[";
|
||||
for (auto t_idx = 0; t_idx < shapes.size(); ++t_idx) {
|
||||
if (t_idx > 0) {
|
||||
oss << ", ";
|
||||
}
|
||||
oss << "[";
|
||||
for (auto s_idx = 0; s_idx < shapes[t_idx].size(); ++s_idx) {
|
||||
if (s_idx > 0) {
|
||||
oss << ", ";
|
||||
}
|
||||
oss << shapes[t_idx][s_idx];
|
||||
}
|
||||
oss << "]";
|
||||
}
|
||||
oss << "]";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void prepareProfiler(
|
||||
const ProfilerConfig& config,
|
||||
const std::set<ActivityType>& activities) {
|
||||
TORCH_CHECK(config.state == ProfilerState::KINETO,
|
||||
"Supported only in Kineto profiler");
|
||||
|
||||
std::set<libkineto::ActivityType> cpuTypes = {
|
||||
libkineto::ActivityType::CPU_OP,
|
||||
libkineto::ActivityType::EXTERNAL_CORRELATION,
|
||||
libkineto::ActivityType::CUDA_RUNTIME,
|
||||
};
|
||||
|
||||
std::set<libkineto::ActivityType> cudaTypes = {
|
||||
libkineto::ActivityType::GPU_MEMCPY,
|
||||
libkineto::ActivityType::GPU_MEMSET,
|
||||
libkineto::ActivityType::CONCURRENT_KERNEL,
|
||||
// also including CUDA_RUNTIME
|
||||
libkineto::ActivityType::CUDA_RUNTIME,
|
||||
};
|
||||
|
||||
std::set<libkineto::ActivityType> k_activities;
|
||||
if (activities.count(ActivityType::CPU)) {
|
||||
k_activities.insert(cpuTypes.begin(), cpuTypes.end());
|
||||
}
|
||||
if (activities.count(ActivityType::CUDA)) {
|
||||
k_activities.insert(cudaTypes.begin(), cudaTypes.end());
|
||||
}
|
||||
|
||||
if (!libkineto::api().isProfilerRegistered()) {
|
||||
libkineto_init();
|
||||
}
|
||||
|
||||
if (!libkineto::api().isProfilerInitialized()) {
|
||||
libkineto::api().initProfilerIfRegistered();
|
||||
}
|
||||
|
||||
libkineto::api().activityProfiler().prepareTrace(k_activities);
|
||||
}
|
||||
|
||||
void enableProfiler(
|
||||
const ProfilerConfig& config,
|
||||
const std::set<ActivityType>& activities) {
|
||||
TORCH_CHECK(config.state == ProfilerState::KINETO);
|
||||
TORCH_CHECK(!activities.empty(), "No activities specified for Kineto profiler");
|
||||
|
||||
auto state_ptr = getProfilerTLSState();
|
||||
TORCH_CHECK(!state_ptr, "Profiler is already enabled on this thread");
|
||||
auto state = std::make_shared<KinetoThreadLocalState>(config);
|
||||
c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state);
|
||||
|
||||
state->cpu_trace = std::make_unique<libkineto::CpuTraceBuffer>();
|
||||
state->cpu_trace->span.startTime = getTimeUs();
|
||||
// TODO: number of GPU ops
|
||||
state->cpu_trace->gpuOpCount = -1;
|
||||
state->cpu_trace->span.name = "PyTorch Profiler";
|
||||
|
||||
if (activities.count(ActivityType::CPU)) {
|
||||
pushProfilingCallbacks();
|
||||
}
|
||||
|
||||
libkineto::api().activityProfiler().startTrace();
|
||||
|
||||
state->mark("__start_profile", false);
|
||||
}
|
||||
|
||||
std::unique_ptr<ProfilerResult> disableProfiler() {
|
||||
// all the DebugInfoBase objects are scope based and supposed to use DebugInfoGuard
|
||||
auto state = c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE);
|
||||
|
||||
auto state_ptr = static_cast<KinetoThreadLocalState*>(state.get());
|
||||
TORCH_CHECK(state_ptr && state_ptr->config().state == ProfilerState::KINETO,
|
||||
"Can't disable Kineto profiler when it's not running");
|
||||
|
||||
if (state_ptr->hasCallbackHandle()) {
|
||||
at::removeCallback(state_ptr->callbackHandle());
|
||||
}
|
||||
|
||||
state_ptr->mark("__stop_profile");
|
||||
|
||||
state_ptr->cpu_trace->span.endTime = getTimeUs();
|
||||
|
||||
state_ptr->finalizeCPUTrace();
|
||||
libkineto::api().activityProfiler().transferCpuTrace(std::move(state_ptr->cpu_trace));
|
||||
|
||||
auto trace = std::move(libkineto::api().activityProfiler().stopTrace());
|
||||
TORCH_CHECK(trace);
|
||||
state_ptr->addTraceEvents(*trace);
|
||||
return std::make_unique<ProfilerResult>(
|
||||
std::move(state_ptr->kineto_events_),
|
||||
std::move(state_ptr->consolidate()),
|
||||
std::move(trace));
|
||||
}
|
||||
|
||||
KinetoEvent& KinetoEvent::activity(const libkineto::TraceActivity& activity) {
|
||||
name_ = activity.name();
|
||||
device_index_ = activity.deviceId();
|
||||
device_resource_id_ = activity.resourceId();
|
||||
start_us_ = activity.timestamp();
|
||||
duration_us_ = activity.duration();
|
||||
correlation_id_ = activity.correlationId();
|
||||
activity_type_ = (uint8_t)activity.type();
|
||||
if (activity.linkedActivity()) {
|
||||
linked_correlation_id_ = activity.linkedActivity()->correlationId();
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
c10::DeviceType KinetoEvent::deviceType() const {
|
||||
switch (activity_type_) {
|
||||
case (uint8_t)libkineto::ActivityType::CPU_OP:
|
||||
return c10::DeviceType::CPU;
|
||||
case (uint8_t)libkineto::ActivityType::GPU_MEMCPY:
|
||||
return c10::DeviceType::CUDA;
|
||||
case (uint8_t)libkineto::ActivityType::GPU_MEMSET:
|
||||
return c10::DeviceType::CUDA;
|
||||
case (uint8_t)libkineto::ActivityType::CONCURRENT_KERNEL:
|
||||
return c10::DeviceType::CUDA;
|
||||
case (uint8_t)libkineto::ActivityType::EXTERNAL_CORRELATION:
|
||||
return c10::DeviceType::CPU;
|
||||
case (uint8_t)libkineto::ActivityType::CUDA_RUNTIME:
|
||||
return c10::DeviceType::CPU;
|
||||
}
|
||||
TORCH_CHECK(false, "Unknown activity type");
|
||||
}
|
||||
|
||||
KinetoEvent::KinetoEvent() : activity_type_((uint8_t)libkineto::ActivityType::CPU_OP) {}
|
||||
|
||||
ProfilerResult::ProfilerResult(
|
||||
std::vector<KinetoEvent> events,
|
||||
thread_event_lists legacy_events,
|
||||
std::unique_ptr<libkineto::ActivityTraceInterface> trace)
|
||||
: events_(std::move(events)),
|
||||
legacy_events_(std::move(legacy_events)),
|
||||
trace_(std::move(trace)) {}
|
||||
ProfilerResult::~ProfilerResult() {}
|
||||
|
||||
void ProfilerResult::save(const std::string& path) {
|
||||
// Kineto's save is destructive
|
||||
TORCH_CHECK(!saved_, "Trace is already saved");
|
||||
trace_->save(path);
|
||||
saved_ = true;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
bool kinetoAvailable() {
|
||||
#ifdef USE_KINETO
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
}}}
|
213
torch/csrc/autograd/profiler_kineto.h
Normal file
213
torch/csrc/autograd/profiler_kineto.h
Normal file
@ -0,0 +1,213 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/autograd/profiler_legacy.h>
|
||||
|
||||
#ifdef USE_KINETO
|
||||
namespace libkineto {
|
||||
class TraceActivity;
|
||||
class ActivityTraceInterface;
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
namespace profiler {
|
||||
|
||||
enum class C10_API_ENUM ActivityType {
|
||||
CPU = 0,
|
||||
CUDA, // CUDA kernels, runtime
|
||||
NUM_KINETO_ACTIVITIES, // must be the last one
|
||||
};
|
||||
|
||||
#ifdef USE_KINETO
|
||||
|
||||
struct KinetoObserverContext : public at::ObserverContext {
|
||||
int64_t startUs;
|
||||
uint64_t correlationId;
|
||||
uint64_t startThreadId;
|
||||
uint64_t endThreadId;
|
||||
c10::optional<std::vector<std::vector<int64_t>>> shapes;
|
||||
int64_t sequenceNr;
|
||||
uint64_t fwdThreadId;
|
||||
uint8_t recFunScope;
|
||||
c10::optional<std::vector<std::string>> stack;
|
||||
};
|
||||
|
||||
struct TORCH_API KinetoEvent {
|
||||
KinetoEvent();
|
||||
|
||||
uint64_t startThreadId() const {
|
||||
return start_thread_id_;
|
||||
}
|
||||
|
||||
uint64_t endThreadId() const {
|
||||
return end_thread_id_;
|
||||
}
|
||||
|
||||
uint8_t activityType() const {
|
||||
return activity_type_;
|
||||
}
|
||||
|
||||
uint64_t fwdThreadId() const {
|
||||
return fwd_thread_id_;
|
||||
}
|
||||
|
||||
bool hasShapes() const {
|
||||
return shapes_ != c10::nullopt;
|
||||
}
|
||||
|
||||
const std::vector<std::vector<int64_t>>& shapes() const {
|
||||
return *shapes_;
|
||||
}
|
||||
|
||||
int64_t sequenceNr() const {
|
||||
return sequence_nr_;
|
||||
}
|
||||
|
||||
bool hasStack() const {
|
||||
return stack_ != c10::nullopt;
|
||||
}
|
||||
|
||||
const std::vector<std::string>& stack() const {
|
||||
return *stack_;
|
||||
}
|
||||
|
||||
uint8_t scope() const {
|
||||
return scope_;
|
||||
}
|
||||
|
||||
KinetoEvent& startThreadId(uint64_t start_thread_id) {
|
||||
start_thread_id_ = start_thread_id;
|
||||
return *this;
|
||||
}
|
||||
|
||||
KinetoEvent& endThreadId(uint64_t end_thread_id) {
|
||||
end_thread_id_ = end_thread_id;
|
||||
return *this;
|
||||
}
|
||||
|
||||
KinetoEvent& fwdThreadId(uint64_t fwd_thread_id) {
|
||||
fwd_thread_id_ = fwd_thread_id;
|
||||
return *this;
|
||||
}
|
||||
|
||||
KinetoEvent& shapes(const std::vector<std::vector<int64_t>>& shapes) {
|
||||
shapes_ = shapes;
|
||||
return *this;
|
||||
}
|
||||
|
||||
KinetoEvent& sequenceNr(int64_t sequence_nr) {
|
||||
sequence_nr_ = sequence_nr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
KinetoEvent& stack(const std::vector<std::string>& st) {
|
||||
stack_ = st;
|
||||
return *this;
|
||||
}
|
||||
|
||||
KinetoEvent& scope(uint8_t scope) {
|
||||
scope_ = scope;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Kineto fields
|
||||
|
||||
KinetoEvent& activity(const libkineto::TraceActivity& activity);
|
||||
|
||||
std::string name() const {
|
||||
return name_;
|
||||
}
|
||||
|
||||
uint64_t deviceIndex() const {
|
||||
return device_index_;
|
||||
}
|
||||
|
||||
uint64_t startUs() const {
|
||||
return start_us_;
|
||||
}
|
||||
|
||||
uint64_t durationUs() const {
|
||||
return duration_us_;
|
||||
}
|
||||
|
||||
uint64_t correlationId() const {
|
||||
return correlation_id_;
|
||||
}
|
||||
|
||||
KinetoEvent& correlationId(uint64_t correlation_id) {
|
||||
correlation_id_ = correlation_id;
|
||||
return *this;
|
||||
}
|
||||
|
||||
uint64_t linkedCorrelationId() const {
|
||||
return linked_correlation_id_;
|
||||
}
|
||||
|
||||
int64_t deviceResourceId() const {
|
||||
return device_resource_id_;
|
||||
}
|
||||
|
||||
c10::DeviceType deviceType() const;
|
||||
|
||||
uint64_t start_thread_id_ = 0;
|
||||
uint64_t end_thread_id_ = 0;
|
||||
uint64_t fwd_thread_id_ = 0;
|
||||
int64_t sequence_nr_ = -1;
|
||||
uint8_t scope_ = 0;
|
||||
|
||||
uint8_t activity_type_;
|
||||
c10::optional<std::vector<std::vector<int64_t>>> shapes_;
|
||||
c10::optional<std::vector<std::string>> stack_;
|
||||
|
||||
std::string name_;
|
||||
uint64_t device_index_ = 0;
|
||||
uint64_t start_us_ = 0;
|
||||
uint64_t duration_us_ = 0;
|
||||
uint64_t correlation_id_ = 0;
|
||||
uint64_t linked_correlation_id_ = 0;
|
||||
int64_t device_resource_id_ = 0;
|
||||
};
|
||||
|
||||
// Consolidating events returned directly from Kineto
|
||||
// with events manually created by us (e.g. start/stop marks,
|
||||
// memory allocation events)
|
||||
struct TORCH_API ProfilerResult {
|
||||
ProfilerResult(
|
||||
std::vector<KinetoEvent> events,
|
||||
thread_event_lists legacy_events,
|
||||
std::unique_ptr<libkineto::ActivityTraceInterface> trace);
|
||||
~ProfilerResult();
|
||||
|
||||
const std::vector<KinetoEvent>& events() const {
|
||||
return events_;
|
||||
}
|
||||
|
||||
const thread_event_lists& legacy_events() const {
|
||||
return legacy_events_;
|
||||
}
|
||||
|
||||
void save(const std::string& path);
|
||||
|
||||
private:
|
||||
bool saved_ = false;
|
||||
std::vector<KinetoEvent> events_;
|
||||
thread_event_lists legacy_events_;
|
||||
std::unique_ptr<libkineto::ActivityTraceInterface> trace_;
|
||||
};
|
||||
|
||||
TORCH_API void enableProfiler(
|
||||
const ProfilerConfig& config,
|
||||
const std::set<ActivityType>& activities);
|
||||
|
||||
TORCH_API std::unique_ptr<ProfilerResult> disableProfiler();
|
||||
|
||||
TORCH_API void prepareProfiler(
|
||||
const ProfilerConfig& config,
|
||||
const std::set<ActivityType>& activities);
|
||||
#endif // USE_KINETO
|
||||
|
||||
TORCH_API bool kinetoAvailable();
|
||||
|
||||
} // namespace profiler
|
||||
}} // namespace torch::autograd
|
@ -23,54 +23,33 @@
|
||||
|
||||
namespace torch { namespace autograd { namespace profiler {
|
||||
|
||||
namespace {
|
||||
std::vector<FileLineFunc> prepareCallstack(const std::vector<jit::StackEntry>& cs) {
|
||||
std::vector<FileLineFunc> entries;
|
||||
entries.reserve(cs.size());
|
||||
for (const auto& entry : cs) {
|
||||
auto& range = entry.range;
|
||||
if (range.source()) {
|
||||
auto& src = range.source();
|
||||
if (src && src->filename()) {
|
||||
auto line = src->starting_line_no() +
|
||||
src->lineno_for_offset(range.start());
|
||||
entries.emplace_back(FileLineFunc{*(src->filename()), line, entry.filename});
|
||||
}
|
||||
}
|
||||
}
|
||||
return entries;
|
||||
}
|
||||
|
||||
enum EventIValueIdx {
|
||||
KIND = 0,
|
||||
NAME,
|
||||
THREAD_ID,
|
||||
HANDLE,
|
||||
NODE_ID,
|
||||
CPU_MEM_USAGE,
|
||||
CPU_NS,
|
||||
CUDA_RECORDED,
|
||||
CUDA_MEM_USAGE,
|
||||
CUDA_DEVICE,
|
||||
CUDA_US,
|
||||
SHAPES,
|
||||
NUM_EVENT_IVALUE_IDX // must be last in list
|
||||
};
|
||||
|
||||
enum ProfilerIValueIdx {
|
||||
STATE = 0,
|
||||
REPORT_INPUT_SHAPES,
|
||||
PROFILE_MEMORY,
|
||||
NUM_PROFILER_CFG_IVALUE_IDX // must be last in list
|
||||
};
|
||||
|
||||
const std::unordered_set<std::string> disable_cuda_profiling = {
|
||||
"aten::view",
|
||||
"aten::t",
|
||||
"aten::transpose",
|
||||
"aten::stride",
|
||||
"aten::empty",
|
||||
"aten::empty_like",
|
||||
"aten::empty_strided",
|
||||
"aten::as_strided",
|
||||
"aten::expand",
|
||||
"aten::resize_",
|
||||
"aten::squeeze",
|
||||
"aten::unsqueeze",
|
||||
"aten::slice",
|
||||
"aten::_unsafe_view",
|
||||
"aten::size"
|
||||
};
|
||||
|
||||
CUDAStubs default_stubs;
|
||||
constexpr CUDAStubs* default_stubs_addr = &default_stubs;
|
||||
// Constant initialization, so it is guaranteed to be initialized before
|
||||
// static initialization calls which may invoke registerCUDAMethods
|
||||
static CUDAStubs* cuda_stubs = default_stubs_addr;
|
||||
std::vector<std::string> callstackStr(const std::vector<FileLineFunc>& cs) {
|
||||
std::vector<std::string> cs_str;
|
||||
cs_str.reserve(cs.size());
|
||||
for (const auto& entry : cs) {
|
||||
std::stringstream loc;
|
||||
loc << entry.filename << "(" << entry.line << "): " << entry.funcname;
|
||||
cs_str.push_back(loc.str());
|
||||
}
|
||||
return cs_str;
|
||||
}
|
||||
|
||||
// We decompose the profiler logic into the following components:
|
||||
//
|
||||
@ -163,252 +142,267 @@ static CUDAStubs* cuda_stubs = default_stubs_addr;
|
||||
// - save profiling events into the profiling state
|
||||
//
|
||||
|
||||
struct FileLineFunc {
|
||||
std::string filename;
|
||||
size_t line;
|
||||
std::string funcname;
|
||||
};
|
||||
namespace {
|
||||
const CUDAStubs default_stubs;
|
||||
constexpr const CUDAStubs* default_stubs_addr = &default_stubs;
|
||||
// Constant initialization, so it is guaranteed to be initialized before
|
||||
// static initialization calls which may invoke registerCUDAMethods
|
||||
inline const CUDAStubs*& cuda_stubs() {
|
||||
static const CUDAStubs* stubs_ = default_stubs_addr;
|
||||
return stubs_;
|
||||
}
|
||||
}
|
||||
|
||||
// Profiler state
|
||||
struct ProfilerThreadLocalState : public c10::MemoryReportingInfoBase {
|
||||
explicit ProfilerThreadLocalState(const ProfilerConfig& config)
|
||||
: config_(config), remoteProfiledEvents_{c10::nullopt} {}
|
||||
~ProfilerThreadLocalState() override = default;
|
||||
const ProfilerConfig& ProfilerThreadLocalState::config() const {
|
||||
return config_;
|
||||
}
|
||||
|
||||
inline const ProfilerConfig& config() const {
|
||||
return config_;
|
||||
thread_event_lists ProfilerThreadLocalState::consolidate() {
|
||||
std::lock_guard<std::mutex> g(state_mutex_);
|
||||
thread_event_lists result;
|
||||
for (auto& kv : event_lists_map_) {
|
||||
auto& list = kv.second;
|
||||
result.emplace_back(list->consolidate());
|
||||
}
|
||||
|
||||
thread_event_lists consolidate() {
|
||||
std::lock_guard<std::mutex> g(state_mutex_);
|
||||
thread_event_lists result;
|
||||
for (auto& kv : event_lists_map_) {
|
||||
auto& list = kv.second;
|
||||
result.emplace_back(list->consolidate());
|
||||
}
|
||||
// Consolidate remote events if applicable as well.
|
||||
if (remoteProfiledEvents_) {
|
||||
result.insert(
|
||||
result.end(),
|
||||
std::make_move_iterator(remoteProfiledEvents_->begin()),
|
||||
std::make_move_iterator(remoteProfiledEvents_->end()));
|
||||
}
|
||||
return result;
|
||||
// Consolidate remote events if applicable as well.
|
||||
if (remoteProfiledEvents_) {
|
||||
result.insert(
|
||||
result.end(),
|
||||
std::make_move_iterator(remoteProfiledEvents_->begin()),
|
||||
std::make_move_iterator(remoteProfiledEvents_->end()));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void mark(std::string name, bool include_cuda = true) {
|
||||
if (config_.state == ProfilerState::Disabled) {
|
||||
return;
|
||||
}
|
||||
if (config_.state == ProfilerState::NVTX) {
|
||||
cuda_stubs->nvtxMarkA(name.c_str());
|
||||
} else {
|
||||
Event evt(
|
||||
EventKind::Mark,
|
||||
at::StringView(std::move(name)),
|
||||
at::RecordFunction::currentThreadId(),
|
||||
include_cuda && config_.state == ProfilerState::CUDA);
|
||||
evt.setNodeId(at::RecordFunction::getDefaultNodeId());
|
||||
getEventList().record(std::move(evt));
|
||||
}
|
||||
void ProfilerThreadLocalState::mark(std::string name, bool include_cuda) {
|
||||
if (config_.state == ProfilerState::Disabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
void setOrAddRemoteProfiledEvents(
|
||||
std::vector<Event>&& remoteProfiledEvents) {
|
||||
// Lock to serialize access from multiple callback threads.
|
||||
std::lock_guard<std::mutex> guard(state_mutex_);
|
||||
if (remoteProfiledEvents_) {
|
||||
(*remoteProfiledEvents_).emplace_back(remoteProfiledEvents);
|
||||
} else {
|
||||
remoteProfiledEvents_ = {std::move(remoteProfiledEvents)};
|
||||
}
|
||||
if (config_.state == ProfilerState::NVTX) {
|
||||
cuda_stubs()->nvtxMarkA(name.c_str());
|
||||
} else {
|
||||
LegacyEvent evt(
|
||||
EventKind::Mark,
|
||||
at::StringView(std::move(name)),
|
||||
at::RecordFunction::currentThreadId(),
|
||||
include_cuda && config_.state == ProfilerState::CUDA);
|
||||
evt.setNodeId(at::RecordFunction::getDefaultNodeId());
|
||||
getEventList().record(std::move(evt));
|
||||
}
|
||||
}
|
||||
|
||||
void pushRange(
|
||||
const at::RecordFunction& fn,
|
||||
const bool record_cuda,
|
||||
const char* msg = "",
|
||||
std::vector<std::vector<int64_t>>&& shapes = {}) {
|
||||
if (config_.state == ProfilerState::Disabled) {
|
||||
return;
|
||||
}
|
||||
if (config_.state == ProfilerState::NVTX) {
|
||||
cuda_stubs->nvtxRangePushA(getNvtxStr(
|
||||
fn.name(), msg, fn.seqNr(), shapes).c_str());
|
||||
} else {
|
||||
Event evt(
|
||||
EventKind::PushRange,
|
||||
fn.name(),
|
||||
at::RecordFunction::currentThreadId(),
|
||||
record_cuda,
|
||||
fn.handle(),
|
||||
std::move(shapes),
|
||||
at::RecordFunction::getDefaultNodeId());
|
||||
evt.setSequenceNr(fn.seqNr());
|
||||
evt.setFwdThreadId(fn.forwardThreadId());
|
||||
evt.setScope((uint8_t)fn.scope());
|
||||
void ProfilerThreadLocalState::setOrAddRemoteProfiledEvents(
|
||||
std::vector<LegacyEvent>&& remoteProfiledEvents) {
|
||||
// Lock to serialize access from multiple callback threads.
|
||||
std::lock_guard<std::mutex> guard(state_mutex_);
|
||||
if (remoteProfiledEvents_) {
|
||||
(*remoteProfiledEvents_).emplace_back(remoteProfiledEvents);
|
||||
} else {
|
||||
remoteProfiledEvents_ = {std::move(remoteProfiledEvents)};
|
||||
}
|
||||
}
|
||||
|
||||
void ProfilerThreadLocalState::pushRange(
|
||||
const at::RecordFunction& fn,
|
||||
const bool record_cuda,
|
||||
const char* msg,
|
||||
std::vector<std::vector<int64_t>>&& shapes) {
|
||||
if (config_.state == ProfilerState::Disabled) {
|
||||
return;
|
||||
}
|
||||
if (config_.state == ProfilerState::NVTX) {
|
||||
cuda_stubs()->nvtxRangePushA(getNvtxStr(
|
||||
fn.name(), msg, fn.seqNr(), shapes).c_str());
|
||||
} else {
|
||||
LegacyEvent evt(
|
||||
EventKind::PushRange,
|
||||
fn.name(),
|
||||
at::RecordFunction::currentThreadId(),
|
||||
record_cuda,
|
||||
fn.handle(),
|
||||
std::move(shapes),
|
||||
at::RecordFunction::getDefaultNodeId());
|
||||
evt.setSequenceNr(fn.seqNr());
|
||||
evt.setFwdThreadId(fn.forwardThreadId());
|
||||
evt.setScope((uint8_t)fn.scope());
|
||||
#ifndef C10_MOBILE
|
||||
// backward nodes source range corresponds to the forward node
|
||||
// TODO: consider using C++ stack trace
|
||||
if (config_.with_stack && fn.scope() != at::RecordScope::BACKWARD_FUNCTION) {
|
||||
auto cs = prepareCallstack(jit::currentCallstack());
|
||||
if (cs.empty()) {
|
||||
cs = prepareCallstack(jit::tracer::pythonCallstack());
|
||||
}
|
||||
evt.setStack(callstackStr(cs));
|
||||
// backward nodes source range corresponds to the forward node
|
||||
// TODO: consider using C++ stack trace
|
||||
if (config_.with_stack && fn.scope() != at::RecordScope::BACKWARD_FUNCTION) {
|
||||
auto cs = prepareCallstack(jit::currentCallstack());
|
||||
if (cs.empty()) {
|
||||
cs = prepareCallstack(jit::tracer::pythonCallstack());
|
||||
}
|
||||
evt.setStack(callstackStr(cs));
|
||||
}
|
||||
#endif
|
||||
getEventList().record(std::move(evt));
|
||||
}
|
||||
getEventList().record(std::move(evt));
|
||||
}
|
||||
}
|
||||
|
||||
void popRange(const at::RecordFunction& fn, const bool record_cuda) {
|
||||
if (config_.state == ProfilerState::Disabled) {
|
||||
return;
|
||||
}
|
||||
if (config_.state == ProfilerState::NVTX) {
|
||||
cuda_stubs->nvtxRangePop();
|
||||
} else {
|
||||
// In some cases RecordFunction (and popRange) may be
|
||||
// called on a different thread than pushRange
|
||||
// As a convention, we put the async pop on the original
|
||||
// thread and save current thread id in pop event
|
||||
Event evt(
|
||||
EventKind::PopRange,
|
||||
at::StringView(""),
|
||||
at::RecordFunction::currentThreadId(),
|
||||
record_cuda,
|
||||
fn.handle());
|
||||
evt.setNodeId(at::RecordFunction::getDefaultNodeId());
|
||||
getEventList(fn.threadId()).record(std::move(evt));
|
||||
}
|
||||
void ProfilerThreadLocalState::popRange(const at::RecordFunction& fn, const bool record_cuda) {
|
||||
if (config_.state == ProfilerState::Disabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
void setCallbackHandle(at::CallbackHandle handle) {
|
||||
handle_ = handle;
|
||||
if (config_.state == ProfilerState::NVTX) {
|
||||
cuda_stubs()->nvtxRangePop();
|
||||
} else {
|
||||
// In some cases RecordFunction (and popRange) may be
|
||||
// called on a different thread than pushRange
|
||||
// As a convention, we put the async pop on the original
|
||||
// thread and save current thread id in pop event
|
||||
LegacyEvent evt(
|
||||
EventKind::PopRange,
|
||||
at::StringView(""),
|
||||
at::RecordFunction::currentThreadId(),
|
||||
record_cuda,
|
||||
fn.handle());
|
||||
evt.setNodeId(at::RecordFunction::getDefaultNodeId());
|
||||
getEventList(fn.threadId()).record(std::move(evt));
|
||||
}
|
||||
}
|
||||
|
||||
at::CallbackHandle callbackHandle() const {
|
||||
return handle_;
|
||||
void ProfilerThreadLocalState::reportMemoryUsage(
|
||||
void* /* unused */,
|
||||
int64_t alloc_size,
|
||||
c10::Device device) {
|
||||
if (config_.profile_memory && config_.state != ProfilerState::Disabled) {
|
||||
uint64_t thread_id = at::RecordFunction::currentThreadId();
|
||||
LegacyEvent evt(
|
||||
EventKind::MemoryAlloc,
|
||||
at::StringView(""),
|
||||
thread_id,
|
||||
config_.state == ProfilerState::CUDA);
|
||||
evt.updateMemoryStats(alloc_size, device);
|
||||
getEventList(thread_id).record(std::move(evt));
|
||||
}
|
||||
}
|
||||
|
||||
void reportMemoryUsage(
|
||||
void* /* unused */,
|
||||
int64_t alloc_size,
|
||||
c10::Device device) override {
|
||||
if (config_.profile_memory && config_.state != ProfilerState::Disabled) {
|
||||
uint64_t thread_id = at::RecordFunction::currentThreadId();
|
||||
Event evt(
|
||||
EventKind::MemoryAlloc,
|
||||
at::StringView(""),
|
||||
thread_id,
|
||||
config_.state == ProfilerState::CUDA);
|
||||
evt.updateMemoryStats(alloc_size, device);
|
||||
getEventList(thread_id).record(std::move(evt));
|
||||
}
|
||||
}
|
||||
bool ProfilerThreadLocalState::memoryProfilingEnabled() const {
|
||||
return config_.profile_memory;
|
||||
}
|
||||
|
||||
bool memoryProfilingEnabled() const override {
|
||||
return config_.profile_memory;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<FileLineFunc> prepareCallstack(const std::vector<jit::StackEntry>& cs) {
|
||||
std::vector<FileLineFunc> entries;
|
||||
entries.reserve(cs.size());
|
||||
for (const auto& entry : cs) {
|
||||
auto& range = entry.range;
|
||||
if (range.source()) {
|
||||
auto& src = range.source();
|
||||
if (src && src->filename()) {
|
||||
auto line = src->starting_line_no() +
|
||||
src->lineno_for_offset(range.start());
|
||||
entries.emplace_back(FileLineFunc{*(src->filename()), line, entry.filename});
|
||||
}
|
||||
}
|
||||
}
|
||||
return entries;
|
||||
}
|
||||
|
||||
std::vector<std::string> callstackStr(const std::vector<FileLineFunc>& cs) {
|
||||
std::vector<std::string> cs_str;
|
||||
cs_str.reserve(cs.size());
|
||||
for (const auto& entry : cs) {
|
||||
std::stringstream loc;
|
||||
loc << entry.filename << "(" << entry.line << "): " << entry.funcname;
|
||||
cs_str.push_back(loc.str());
|
||||
}
|
||||
return cs_str;
|
||||
}
|
||||
|
||||
std::string getNvtxStr(
|
||||
const at::StringView& name,
|
||||
const char* msg,
|
||||
int64_t sequence_nr,
|
||||
const std::vector<std::vector<int64_t>>& shapes) const {
|
||||
if (sequence_nr >= 0 || shapes.size() > 0) {
|
||||
std::stringstream s;
|
||||
std::string ProfilerThreadLocalState::getNvtxStr(
|
||||
const at::StringView& name,
|
||||
const char* msg,
|
||||
int64_t sequence_nr,
|
||||
const std::vector<std::vector<int64_t>>& shapes) const {
|
||||
if (sequence_nr >= 0 || shapes.size() > 0) {
|
||||
std::stringstream s;
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
s << name.str();
|
||||
s << name.str();
|
||||
#endif
|
||||
if (sequence_nr >= 0) {
|
||||
if (sequence_nr >= 0) {
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
s << msg << sequence_nr;
|
||||
s << msg << sequence_nr;
|
||||
#else
|
||||
s << name.str() << msg << sequence_nr;
|
||||
s << name.str() << msg << sequence_nr;
|
||||
#endif
|
||||
}
|
||||
if (shapes.size() > 0) {
|
||||
s << ", sizes = [";
|
||||
for (size_t idx = 0; idx < shapes.size(); ++idx) {
|
||||
if (shapes[idx].size() > 0) {
|
||||
s << "[";
|
||||
for (size_t dim = 0; dim < shapes[idx].size(); ++dim) {
|
||||
s << shapes[idx][dim];
|
||||
if (dim < shapes[idx].size() - 1) {
|
||||
s << ", ";
|
||||
}
|
||||
}
|
||||
if (shapes.size() > 0) {
|
||||
s << ", sizes = [";
|
||||
for (size_t idx = 0; idx < shapes.size(); ++idx) {
|
||||
if (shapes[idx].size() > 0) {
|
||||
s << "[";
|
||||
for (size_t dim = 0; dim < shapes[idx].size(); ++dim) {
|
||||
s << shapes[idx][dim];
|
||||
if (dim < shapes[idx].size() - 1) {
|
||||
s << ", ";
|
||||
}
|
||||
s << "]";
|
||||
} else {
|
||||
s << "[]";
|
||||
}
|
||||
if (idx < shapes.size() - 1) {
|
||||
s << ", ";
|
||||
}
|
||||
s << "]";
|
||||
} else {
|
||||
s << "[]";
|
||||
}
|
||||
if (idx < shapes.size() - 1) {
|
||||
s << ", ";
|
||||
}
|
||||
s << "]";
|
||||
}
|
||||
return s.str();
|
||||
s << "]";
|
||||
}
|
||||
return s.str();
|
||||
} else {
|
||||
return name.str();
|
||||
}
|
||||
}
|
||||
|
||||
RangeEventList& ProfilerThreadLocalState::getEventList(int64_t thread_id) {
|
||||
if (thread_id < 0) {
|
||||
thread_id = at::RecordFunction::currentThreadId();
|
||||
}
|
||||
RangeEventList* list_ptr = nullptr;
|
||||
std::lock_guard<std::mutex> guard(state_mutex_);
|
||||
auto it = event_lists_map_.find(thread_id);
|
||||
if (it != event_lists_map_.end()) {
|
||||
list_ptr = it->second.get();
|
||||
} else {
|
||||
auto event_list = std::make_shared<RangeEventList>();
|
||||
event_lists_map_[thread_id] = event_list;
|
||||
list_ptr = event_list.get();
|
||||
}
|
||||
return *list_ptr;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> inputSizes(const at::RecordFunction& fn) {
|
||||
std::vector<std::vector<int64_t>> sizes;
|
||||
sizes.reserve(fn.inputs().size());
|
||||
for (const c10::IValue& input : fn.inputs()) {
|
||||
if (!input.isTensor()) {
|
||||
sizes.emplace_back();
|
||||
continue;
|
||||
}
|
||||
const at::Tensor& tensor = input.toTensor();
|
||||
if (tensor.defined()) {
|
||||
sizes.push_back(input.toTensor().sizes().vec());
|
||||
} else {
|
||||
return name.str();
|
||||
sizes.emplace_back();
|
||||
}
|
||||
}
|
||||
return sizes;
|
||||
}
|
||||
|
||||
RangeEventList& getEventList(int64_t thread_id = -1) {
|
||||
if (thread_id < 0) {
|
||||
thread_id = at::RecordFunction::currentThreadId();
|
||||
}
|
||||
RangeEventList* list_ptr = nullptr;
|
||||
std::lock_guard<std::mutex> guard(state_mutex_);
|
||||
auto it = event_lists_map_.find(thread_id);
|
||||
if (it != event_lists_map_.end()) {
|
||||
list_ptr = it->second.get();
|
||||
} else {
|
||||
auto event_list = std::make_shared<RangeEventList>();
|
||||
event_lists_map_[thread_id] = event_list;
|
||||
list_ptr = event_list.get();
|
||||
}
|
||||
return *list_ptr;
|
||||
}
|
||||
namespace {
|
||||
|
||||
std::mutex state_mutex_;
|
||||
std::unordered_map<uint64_t, std::shared_ptr<RangeEventList>>
|
||||
event_lists_map_;
|
||||
enum EventIValueIdx {
|
||||
KIND = 0,
|
||||
NAME,
|
||||
THREAD_ID,
|
||||
HANDLE,
|
||||
NODE_ID,
|
||||
CPU_MEM_USAGE,
|
||||
CPU_NS,
|
||||
CUDA_RECORDED,
|
||||
CUDA_MEM_USAGE,
|
||||
CUDA_DEVICE,
|
||||
CUDA_US,
|
||||
SHAPES,
|
||||
NUM_EVENT_IVALUE_IDX // must be last in list
|
||||
};
|
||||
|
||||
ProfilerConfig config_ = ProfilerConfig(ProfilerState::Disabled);
|
||||
at::CallbackHandle handle_ = 0;
|
||||
c10::optional<std::vector<std::vector<Event>>> remoteProfiledEvents_;
|
||||
enum ProfilerIValueIdx {
|
||||
STATE = 0,
|
||||
REPORT_INPUT_SHAPES,
|
||||
PROFILE_MEMORY,
|
||||
NUM_PROFILER_CFG_IVALUE_IDX // must be last in list
|
||||
};
|
||||
|
||||
const std::unordered_set<std::string> disable_cuda_profiling = {
|
||||
"aten::view",
|
||||
"aten::t",
|
||||
"aten::transpose",
|
||||
"aten::stride",
|
||||
"aten::empty",
|
||||
"aten::empty_like",
|
||||
"aten::empty_strided",
|
||||
"aten::as_strided",
|
||||
"aten::expand",
|
||||
"aten::resize_",
|
||||
"aten::squeeze",
|
||||
"aten::unsqueeze",
|
||||
"aten::slice",
|
||||
"aten::_unsafe_view",
|
||||
"aten::size"
|
||||
};
|
||||
|
||||
ProfilerThreadLocalState* getProfilerTLSState() {
|
||||
@ -416,7 +410,7 @@ ProfilerThreadLocalState* getProfilerTLSState() {
|
||||
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE));
|
||||
}
|
||||
|
||||
void pushProfilingCallbacks() {
|
||||
void pushProfilingCallbacksLegacy() {
|
||||
auto state_ptr = getProfilerTLSState();
|
||||
TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
|
||||
auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback(
|
||||
@ -433,21 +427,8 @@ void pushProfilingCallbacks() {
|
||||
|
||||
auto* msg = (fn.seqNr() >= 0) ? ", seq = " : "";
|
||||
if (state_ptr->config().report_input_shapes) {
|
||||
std::vector<std::vector<int64_t>> inputSizes;
|
||||
inputSizes.reserve(fn.inputs().size());
|
||||
for (const c10::IValue& input : fn.inputs()) {
|
||||
if (!input.isTensor()) {
|
||||
inputSizes.emplace_back();
|
||||
continue;
|
||||
}
|
||||
const at::Tensor& tensor = input.toTensor();
|
||||
if (tensor.defined()) {
|
||||
inputSizes.push_back(input.toTensor().sizes().vec());
|
||||
} else {
|
||||
inputSizes.emplace_back();
|
||||
}
|
||||
}
|
||||
state_ptr->pushRange(fn, record_cuda, msg, std::move(inputSizes));
|
||||
auto sizes = inputSizes(fn);
|
||||
state_ptr->pushRange(fn, record_cuda, msg, std::move(sizes));
|
||||
} else {
|
||||
state_ptr->pushRange(fn, record_cuda, msg);
|
||||
}
|
||||
@ -474,11 +455,9 @@ const int kCUDAWarmupStart = 5;
|
||||
} // namespace
|
||||
|
||||
void registerCUDAMethods(CUDAStubs* stubs) {
|
||||
cuda_stubs = stubs;
|
||||
cuda_stubs() = stubs;
|
||||
}
|
||||
|
||||
ProfilerConfig::~ProfilerConfig() = default;
|
||||
|
||||
at::IValue ProfilerConfig::toIValue() const {
|
||||
c10::impl::GenericList eventIValueList(at::AnyType::get());
|
||||
eventIValueList.reserve(NUM_PROFILER_CFG_IVALUE_IDX);
|
||||
@ -519,38 +498,40 @@ bool profilerEnabled() {
|
||||
return state_ptr && state_ptr->config().state != ProfilerState::Disabled;
|
||||
}
|
||||
|
||||
void enableProfiler(const ProfilerConfig& new_config) {
|
||||
TORCH_CHECK(new_config.state != ProfilerState::NVTX || cuda_stubs->enabled(),
|
||||
void enableProfilerLegacy(const ProfilerConfig& new_config) {
|
||||
TORCH_CHECK(new_config.state != ProfilerState::NVTX || cuda_stubs()->enabled(),
|
||||
"Can't use NVTX profiler - PyTorch was compiled without CUDA");
|
||||
|
||||
TORCH_CHECK(new_config.state != ProfilerState::KINETO);
|
||||
|
||||
auto state_ptr = getProfilerTLSState();
|
||||
TORCH_CHECK(!state_ptr, "Profiler is already enabled on this thread");
|
||||
auto state = std::make_shared<ProfilerThreadLocalState>(new_config);
|
||||
c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state);
|
||||
|
||||
pushProfilingCallbacks();
|
||||
pushProfilingCallbacksLegacy();
|
||||
|
||||
if (new_config.state == ProfilerState::CUDA) {
|
||||
// event recording appears to have some startup overhead, so we need to
|
||||
// to generate some dummy events first before recording synchronization events
|
||||
for (int idx = 0; idx < kCUDAWarmupStart; ++idx) {
|
||||
cuda_stubs->onEachDevice([state](int /* unused */) {
|
||||
cuda_stubs()->onEachDevice([state](int /* unused */) {
|
||||
state->mark("__cuda_startup");
|
||||
cuda_stubs->synchronize();
|
||||
cuda_stubs()->synchronize();
|
||||
});
|
||||
}
|
||||
|
||||
// cuda events must be on the same device, so we need a start event recorded
|
||||
// for each gpu. we then use this event to synchronize time on the GPU
|
||||
// with the CPU clock.
|
||||
cuda_stubs->onEachDevice([state](int d) {
|
||||
cuda_stubs()->onEachDevice([state](int d) {
|
||||
state->mark("__cuda_start_event");
|
||||
});
|
||||
}
|
||||
state->mark("__start_profile", false);
|
||||
}
|
||||
|
||||
thread_event_lists disableProfiler(c10::optional<ProfilerDisableOptions> profilerDisableOptions) {
|
||||
thread_event_lists disableProfilerLegacy(c10::optional<ProfilerDisableOptions> profilerDisableOptions) {
|
||||
auto cleanupTLSState = profilerDisableOptions ? profilerDisableOptions->cleanupTLSState : true;
|
||||
auto consolidate = profilerDisableOptions ? profilerDisableOptions->consolidate : true;
|
||||
// all the DebugInfoBase objects are scope based and supposed to use DebugInfoGuard
|
||||
@ -578,21 +559,21 @@ thread_event_lists disableProfiler(c10::optional<ProfilerDisableOptions> profile
|
||||
return state_ptr->consolidate();
|
||||
}
|
||||
|
||||
void addEventList(std::vector<Event>&& profiledEvents) {
|
||||
void addEventList(std::vector<LegacyEvent>&& profiledEvents) {
|
||||
auto state_ptr = getProfilerTLSState();
|
||||
TORCH_CHECK(state_ptr, "Profiler must be enabled.");
|
||||
state_ptr->setOrAddRemoteProfiledEvents(std::move(profiledEvents));
|
||||
}
|
||||
|
||||
void Event::record(bool record_cuda) {
|
||||
void LegacyEvent::record(bool record_cuda) {
|
||||
if (record_cuda) {
|
||||
cuda_stubs->record(&device_, &cuda_event, &cpu_ns_);
|
||||
cuda_stubs()->record(&device_, &cuda_event, &cpu_ns_);
|
||||
return;
|
||||
}
|
||||
cpu_ns_ = getTime();
|
||||
}
|
||||
|
||||
/* static */ Event Event::fromIValue(const at::IValue& eventIValue) {
|
||||
/* static */ LegacyEvent LegacyEvent::fromIValue(const at::IValue& eventIValue) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
eventIValue.isList(),
|
||||
"Expected IValue to contain type c10::impl::GenericList");
|
||||
@ -601,7 +582,7 @@ void Event::record(bool record_cuda) {
|
||||
ivalues.size() >= NUM_EVENT_IVALUE_IDX,
|
||||
"Expected at least ",
|
||||
NUM_EVENT_IVALUE_IDX,
|
||||
" elements to reconstruct Event.");
|
||||
" elements to reconstruct LegacyEvent.");
|
||||
|
||||
// Reconstruct input shapes from ivalues.
|
||||
auto shapeListIValue = ivalues.get(EventIValueIdx::SHAPES);
|
||||
@ -627,7 +608,7 @@ void Event::record(bool record_cuda) {
|
||||
shapes.emplace_back(s);
|
||||
}
|
||||
|
||||
Event evt(
|
||||
LegacyEvent evt(
|
||||
static_cast<EventKind>(
|
||||
ivalues.get(EventIValueIdx::KIND).toInt()), // EventKind
|
||||
at::StringView(ivalues.get(EventIValueIdx::NAME).toStringRef()), // name
|
||||
@ -647,7 +628,7 @@ void Event::record(bool record_cuda) {
|
||||
return evt;
|
||||
}
|
||||
|
||||
at::IValue Event::toIValue() const {
|
||||
at::IValue LegacyEvent::toIValue() const {
|
||||
c10::impl::GenericList eventIValueList(at::AnyType::get());
|
||||
eventIValueList.reserve(NUM_EVENT_IVALUE_IDX);
|
||||
eventIValueList.emplace_back(static_cast<int64_t>(kind_));
|
||||
@ -679,7 +660,7 @@ at::IValue Event::toIValue() const {
|
||||
return at::IValue(eventIValueList);
|
||||
}
|
||||
|
||||
double Event::cudaElapsedUs(const Event& e) const {
|
||||
double LegacyEvent::cudaElapsedUs(const LegacyEvent& e) const {
|
||||
TORCH_CHECK(e.hasCuda() && hasCuda(), "Events were not recorded for CUDA");
|
||||
TORCH_CHECK(
|
||||
e.device() == device(),
|
||||
@ -690,13 +671,12 @@ double Event::cudaElapsedUs(const Event& e) const {
|
||||
TORCH_INTERNAL_ASSERT(cuda_us_ >= 0 && e.cuda_us_ >= 0);
|
||||
return static_cast<double>(e.cuda_us_ - cuda_us_);
|
||||
}
|
||||
return cuda_stubs->elapsed(&cuda_event, &e.cuda_event);
|
||||
return cuda_stubs()->elapsed(&cuda_event, &e.cuda_event);
|
||||
}
|
||||
|
||||
CUDAStubs::~CUDAStubs() = default;
|
||||
|
||||
|
||||
static jit::CodeTemplate event_template(R"(
|
||||
static const jit::CodeTemplate event_template(R"(
|
||||
{
|
||||
"name": "${name}",
|
||||
"ph": "X",
|
||||
@ -707,10 +687,10 @@ static jit::CodeTemplate event_template(R"(
|
||||
"args": {}
|
||||
})");
|
||||
|
||||
void writeProfilerEventsToStream(std::ostream& out, const std::vector<Event*>& events) {
|
||||
void writeProfilerEventsToStream(std::ostream& out, const std::vector<LegacyEvent*>& events) {
|
||||
TORCH_CHECK(out, "Could not open file");
|
||||
Event* profiler_start = nullptr;
|
||||
for (Event* e : events) {
|
||||
LegacyEvent* profiler_start = nullptr;
|
||||
for (LegacyEvent* e : events) {
|
||||
if (0 == strcmp(e->name(), "__start_profile")) {
|
||||
profiler_start = e;
|
||||
break;
|
||||
@ -724,20 +704,20 @@ void writeProfilerEventsToStream(std::ostream& out, const std::vector<Event*>& e
|
||||
return std::hash<at::RecordFunctionHandle>()(p.first) ^ std::hash<int64_t>()(p.second);
|
||||
}
|
||||
};
|
||||
std::unordered_map<std::pair<at::RecordFunctionHandle, int64_t>, Event*, PairHash> events_map;
|
||||
std::unordered_map<std::pair<at::RecordFunctionHandle, int64_t>, LegacyEvent*, PairHash> events_map;
|
||||
out << "[\n";
|
||||
bool first = true;
|
||||
for (Event* evt : events) {
|
||||
if (evt->kind() == "push") {
|
||||
for (LegacyEvent* evt : events) {
|
||||
if (evt->kindStr() == "push") {
|
||||
events_map[std::make_pair(evt->handle(), evt->nodeId())] = evt;
|
||||
} else if (evt->kind() == "pop") {
|
||||
} else if (evt->kindStr() == "pop") {
|
||||
if (!first) {
|
||||
out << ",\n";
|
||||
}
|
||||
first = false;
|
||||
auto it = events_map.find(std::make_pair(evt->handle(), evt->nodeId()));
|
||||
TORCH_CHECK(it != events_map.end(), "Unmatched pop event");
|
||||
Event* evt_start = it->second;
|
||||
LegacyEvent* evt_start = it->second;
|
||||
events_map.erase(it);
|
||||
|
||||
jit::TemplateEnv env;
|
||||
@ -751,7 +731,6 @@ void writeProfilerEventsToStream(std::ostream& out, const std::vector<Event*>& e
|
||||
out << "]\n";
|
||||
}
|
||||
|
||||
|
||||
RecordProfile::RecordProfile(std::ostream& out)
|
||||
: out_(out) {
|
||||
init();
|
||||
@ -763,24 +742,27 @@ RecordProfile::RecordProfile(const std::string& filename)
|
||||
}
|
||||
|
||||
void RecordProfile::init() {
|
||||
enableProfiler(ProfilerConfig(ProfilerState::CPU));
|
||||
enableProfilerLegacy(ProfilerConfig(ProfilerState::CPU));
|
||||
}
|
||||
|
||||
RecordProfile::~RecordProfile() {
|
||||
thread_event_lists event_lists = disableProfiler();
|
||||
std::vector<Event*> events;
|
||||
for (auto& l : event_lists) {
|
||||
for (auto& e : l) {
|
||||
events.push_back(&e);
|
||||
try {
|
||||
thread_event_lists event_lists = disableProfilerLegacy();
|
||||
std::vector<LegacyEvent*> events;
|
||||
for (auto& l : event_lists) {
|
||||
for (auto& e : l) {
|
||||
events.push_back(&e);
|
||||
}
|
||||
}
|
||||
}
|
||||
processEvents(events);
|
||||
if (file_){
|
||||
file_->close();
|
||||
processEvents(events);
|
||||
} catch (const std::exception& e) {
|
||||
LOG(ERROR) << e.what() << std::endl;
|
||||
} catch (...) {
|
||||
LOG(ERROR) << "Unknown error" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void RecordProfile::processEvents(const std::vector<Event*>& events) {
|
||||
void RecordProfile::processEvents(const std::vector<LegacyEvent*>& events) {
|
||||
writeProfilerEventsToStream(out_, events);
|
||||
}
|
||||
|
544
torch/csrc/autograd/profiler_legacy.h
Normal file
544
torch/csrc/autograd/profiler_legacy.h
Normal file
@ -0,0 +1,544 @@
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <mutex>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <forward_list>
|
||||
#include <tuple>
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#ifndef _WIN32
|
||||
#include <ctime>
|
||||
#endif
|
||||
#if defined(C10_IOS) && defined(C10_MOBILE)
|
||||
#include <sys/time.h> // for gettimeofday()
|
||||
#endif
|
||||
|
||||
#include <ATen/record_function.h>
|
||||
|
||||
#include <torch/csrc/jit/frontend/source_range.h>
|
||||
|
||||
struct CUevent_st;
|
||||
typedef std::shared_ptr<CUevent_st> CUDAEventStub;
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
struct Node;
|
||||
|
||||
namespace profiler {
|
||||
|
||||
struct TORCH_API CUDAStubs {
|
||||
virtual void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) const {
|
||||
fail();
|
||||
}
|
||||
virtual float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) const {
|
||||
fail();
|
||||
return 0.f;
|
||||
}
|
||||
virtual void nvtxMarkA(const char* name) const {
|
||||
fail();
|
||||
}
|
||||
virtual void nvtxRangePushA(const char* name) const {
|
||||
fail();
|
||||
}
|
||||
virtual void nvtxRangePop() const {
|
||||
fail();
|
||||
}
|
||||
virtual bool enabled() const {
|
||||
return false;
|
||||
}
|
||||
virtual void onEachDevice(std::function<void(int)> op) const {
|
||||
fail();
|
||||
}
|
||||
virtual void synchronize() const {
|
||||
fail();
|
||||
}
|
||||
virtual ~CUDAStubs();
|
||||
|
||||
private:
|
||||
void fail() const {
|
||||
AT_ERROR("CUDA used in profiler but not enabled.");
|
||||
}
|
||||
};
|
||||
|
||||
TORCH_API void registerCUDAMethods(CUDAStubs* stubs);
|
||||
|
||||
constexpr inline size_t ceilToMultiple(size_t a, size_t b) {
|
||||
return ((a + b - 1) / b) * b;
|
||||
}
|
||||
|
||||
inline int64_t getTime() {
|
||||
#if defined(C10_IOS) && defined(C10_MOBILE)
|
||||
// clock_gettime is only available on iOS 10.0 or newer. Unlike OS X, iOS can't rely on
|
||||
// CLOCK_REALTIME, as it is defined no matter if clock_gettime is implemented or not
|
||||
struct timeval now;
|
||||
gettimeofday(&now, NULL);
|
||||
return static_cast<int64_t>(now.tv_sec) * 1000000000 + static_cast<int64_t>(now.tv_usec) * 1000;
|
||||
#elif defined(_WIN32) || defined(__MACH__)
|
||||
using namespace std::chrono;
|
||||
using clock = std::conditional<high_resolution_clock::is_steady, high_resolution_clock, steady_clock>::type;
|
||||
return duration_cast<nanoseconds>(clock::now().time_since_epoch()).count();
|
||||
#else
|
||||
// clock_gettime is *much* faster than std::chrono implementation on Linux
|
||||
struct timespec t{};
|
||||
clock_gettime(CLOCK_MONOTONIC, &t);
|
||||
return static_cast<int64_t>(t.tv_sec) * 1000000000 + static_cast<int64_t>(t.tv_nsec);
|
||||
#endif
|
||||
}
|
||||
|
||||
enum class C10_API_ENUM EventKind : uint16_t {
|
||||
Mark,
|
||||
PushRange,
|
||||
PopRange,
|
||||
MemoryAlloc,
|
||||
};
|
||||
|
||||
// To be deprecated, once we switch to Kineto profiling
|
||||
struct TORCH_API LegacyEvent {
|
||||
LegacyEvent(
|
||||
EventKind kind,
|
||||
at::StringView name,
|
||||
uint16_t thread_id,
|
||||
bool record_cuda,
|
||||
at::RecordFunctionHandle handle = 0,
|
||||
std::vector<std::vector<int64_t>>&& shapes = {},
|
||||
int node_id = -1)
|
||||
: name_(std::move(name)),
|
||||
kind_(kind),
|
||||
thread_id_(thread_id),
|
||||
handle_(handle),
|
||||
shapes_(shapes),
|
||||
node_id_(node_id) {
|
||||
record(record_cuda);
|
||||
}
|
||||
|
||||
// Constructor to be used in conjunction with LegacyEvent::fromIValue.
|
||||
LegacyEvent(
|
||||
EventKind kind,
|
||||
at::StringView name,
|
||||
uint16_t thread_id,
|
||||
at::RecordFunctionHandle handle,
|
||||
std::vector<std::vector<int64_t>>&& shapes,
|
||||
int node_id,
|
||||
bool is_remote,
|
||||
int64_t cpu_memory_usage,
|
||||
int64_t cpu_ns,
|
||||
bool cuda_recorded,
|
||||
int64_t cuda_memory_usage = 0,
|
||||
int device = -1,
|
||||
double cuda_us = -1)
|
||||
: cpu_ns_(cpu_ns),
|
||||
name_(std::move(name)),
|
||||
kind_(kind),
|
||||
thread_id_(thread_id),
|
||||
handle_(handle),
|
||||
shapes_(shapes),
|
||||
cpu_memory_usage_(cpu_memory_usage),
|
||||
cuda_memory_usage_(cuda_memory_usage),
|
||||
device_(device),
|
||||
node_id_(node_id),
|
||||
is_remote_(is_remote),
|
||||
cuda_us_(cuda_us) {
|
||||
// Sanity check values that were deserialized
|
||||
TORCH_INTERNAL_ASSERT(cpu_ns_ > 0);
|
||||
if (cuda_recorded) {
|
||||
TORCH_INTERNAL_ASSERT(device_ >= 0);
|
||||
TORCH_INTERNAL_ASSERT(cuda_us_ >= 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns IValues corresponding to event structure, to be used for
|
||||
// serialization.
|
||||
at::IValue toIValue() const;
|
||||
|
||||
// Reconstructs an event from IValues given by toIValue.
|
||||
static LegacyEvent fromIValue(const at::IValue& eventIValue);
|
||||
|
||||
void record(bool record_cuda);
|
||||
|
||||
std::string kindStr() const {
|
||||
switch (kind_) {
|
||||
case EventKind::Mark: return "mark";
|
||||
case EventKind::PushRange: return "push";
|
||||
case EventKind::PopRange: return "pop";
|
||||
case EventKind::MemoryAlloc: return "memory_alloc";
|
||||
}
|
||||
throw std::runtime_error("unknown event kind");
|
||||
}
|
||||
|
||||
const char* name() const {
|
||||
return name_.str();
|
||||
}
|
||||
|
||||
uint64_t threadId() const {
|
||||
return thread_id_;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> shapes() const {
|
||||
return shapes_;
|
||||
}
|
||||
|
||||
double cpuElapsedUs(const LegacyEvent& e) const {
|
||||
return (e.cpu_ns_ - cpu_ns_)/(1000.0);
|
||||
}
|
||||
|
||||
void setCpuUs(int64_t cpu_us) {
|
||||
cpu_ns_ = cpu_us * 1000.0;
|
||||
}
|
||||
|
||||
double cpuUs() const {
|
||||
return cpu_ns_ / (1000.0);
|
||||
}
|
||||
|
||||
double cudaElapsedUs(const LegacyEvent& e) const;
|
||||
|
||||
bool hasCuda() const {
|
||||
return cuda_event != nullptr || (isRemote() && device_ != -1);
|
||||
}
|
||||
|
||||
int device() const {
|
||||
return device_;
|
||||
}
|
||||
|
||||
void updateMemoryStats(int64_t alloc_size, c10::Device device) {
|
||||
if (device.type() == c10::DeviceType::CUDA ||
|
||||
device.type() == c10::DeviceType::HIP) {
|
||||
cuda_memory_usage_ = alloc_size;
|
||||
} else if (device.type() == c10::DeviceType::CPU ||
|
||||
device.type() == c10::DeviceType::MKLDNN ||
|
||||
device.type() == c10::DeviceType::IDEEP) {
|
||||
cpu_memory_usage_ = alloc_size;
|
||||
} else {
|
||||
LOG(WARNING) << "Unsupported memory profiling device: " << device;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t cpuMemoryUsage() const {
|
||||
return cpu_memory_usage_;
|
||||
}
|
||||
|
||||
int64_t cudaMemoryUsage() const {
|
||||
return cuda_memory_usage_;
|
||||
}
|
||||
|
||||
at::RecordFunctionHandle handle() const {
|
||||
return handle_;
|
||||
}
|
||||
|
||||
// Node ID corresponding to this event.
|
||||
int nodeId( ) const {
|
||||
return node_id_;
|
||||
}
|
||||
|
||||
// Set Node ID on this event.
|
||||
void setNodeId(int node_id) {
|
||||
node_id_ = node_id;
|
||||
}
|
||||
|
||||
void setName(at::StringView newName_) {
|
||||
name_ = std::move(newName_);
|
||||
}
|
||||
|
||||
bool isRemote() const {
|
||||
return is_remote_;
|
||||
}
|
||||
|
||||
void setCudaUs(int64_t cuda_us) {
|
||||
cuda_us_ = cuda_us;
|
||||
}
|
||||
|
||||
void setSequenceNr(int64_t sequence_nr) {
|
||||
sequence_nr_ = sequence_nr;
|
||||
}
|
||||
|
||||
int64_t sequenceNr() const {
|
||||
return sequence_nr_;
|
||||
}
|
||||
|
||||
void setCorrelationId(uint64_t correlation_id) {
|
||||
correlation_id_ = correlation_id;
|
||||
}
|
||||
|
||||
uint64_t correlationId() const {
|
||||
return correlation_id_;
|
||||
}
|
||||
|
||||
const std::vector<std::string>& stack() const {
|
||||
return stack_;
|
||||
}
|
||||
|
||||
void setStack(const std::vector<std::string>& stack) {
|
||||
stack_ = stack;
|
||||
}
|
||||
|
||||
uint64_t fwdThreadId() const {
|
||||
return fwd_thread_id_;
|
||||
}
|
||||
|
||||
void setFwdThreadId(uint64_t fwd_thread_id) {
|
||||
fwd_thread_id_ = fwd_thread_id;
|
||||
}
|
||||
|
||||
uint8_t scope() const {
|
||||
return scope_;
|
||||
}
|
||||
|
||||
void setScope(uint8_t scope) {
|
||||
scope_ = scope;
|
||||
}
|
||||
|
||||
private:
|
||||
// signed to allow for negative intervals, initialized for safety.
|
||||
int64_t cpu_ns_ = 0;
|
||||
at::StringView name_;
|
||||
EventKind kind_;
|
||||
uint64_t thread_id_;
|
||||
uint64_t fwd_thread_id_;
|
||||
at::RecordFunctionHandle handle_ {0};
|
||||
std::vector<std::vector<int64_t>> shapes_;
|
||||
int64_t cpu_memory_usage_ = 0;
|
||||
int64_t cuda_memory_usage_ = 0;
|
||||
int device_ = -1;
|
||||
CUDAEventStub cuda_event = nullptr;
|
||||
int node_id_ = 0;
|
||||
bool is_remote_ = false;
|
||||
int64_t cuda_us_ = -1;
|
||||
int64_t sequence_nr_ = -1;
|
||||
|
||||
std::vector<std::string> stack_;
|
||||
uint8_t scope_;
|
||||
uint64_t correlation_id_;
|
||||
};
|
||||
|
||||
// a linked-list of fixed sized vectors, to avoid
|
||||
// a std::vector resize from taking a large amount of time inside
|
||||
// a profiling event
|
||||
struct RangeEventList {
|
||||
RangeEventList() {
|
||||
events_.reserve(kReservedCapacity);
|
||||
}
|
||||
|
||||
template<typename... Args>
|
||||
void record(Args&&... args) {
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
events_.emplace_back(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
std::vector<LegacyEvent> consolidate() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
std::vector<LegacyEvent> result;
|
||||
result.insert(
|
||||
result.begin(),
|
||||
std::make_move_iterator(events_.begin()),
|
||||
std::make_move_iterator(events_.end()));
|
||||
events_.erase(events_.begin(), events_.end());
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t size() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return events_.size();
|
||||
}
|
||||
|
||||
private:
|
||||
// This mutex is used to serialize access when different threads are writing
|
||||
// to the same instance of RangeEventList.
|
||||
std::mutex mutex_;
|
||||
std::vector<LegacyEvent> events_;
|
||||
|
||||
static const size_t kReservedCapacity = 1024;
|
||||
};
|
||||
|
||||
enum class C10_API_ENUM ProfilerState {
|
||||
Disabled = 0,
|
||||
CPU, // CPU-only profiling
|
||||
CUDA, // CPU + CUDA events
|
||||
NVTX, // only emit NVTX markers
|
||||
KINETO, // use libkineto
|
||||
NUM_PROFILER_STATES, // must be the last one
|
||||
};
|
||||
|
||||
struct TORCH_API ProfilerConfig {
|
||||
ProfilerConfig(
|
||||
ProfilerState state,
|
||||
bool report_input_shapes = false,
|
||||
bool profile_memory = false,
|
||||
bool with_stack = false)
|
||||
: state(state),
|
||||
report_input_shapes(report_input_shapes),
|
||||
profile_memory(profile_memory),
|
||||
with_stack(with_stack) {}
|
||||
~ProfilerConfig() = default;
|
||||
ProfilerState state;
|
||||
bool report_input_shapes;
|
||||
bool profile_memory;
|
||||
bool with_stack;
|
||||
|
||||
// Returns IValues corresponding to ProfilerConfig struct, to be used for
|
||||
// serialization.
|
||||
at::IValue toIValue() const;
|
||||
|
||||
// Reconstructs a ProfilerConfig from IValues given by toIValue.
|
||||
static ProfilerConfig fromIValue(const at::IValue& profilerConfigIValue);
|
||||
};
|
||||
|
||||
// A struct to control settings of disableProfiler options.
|
||||
struct TORCH_API ProfilerDisableOptions {
|
||||
ProfilerDisableOptions() = default;
|
||||
ProfilerDisableOptions(bool shouldCleanupTLSState, bool shouldConsolidate)
|
||||
: cleanupTLSState(shouldCleanupTLSState),
|
||||
consolidate(shouldConsolidate) {}
|
||||
// Whether we should clean up profiler states that are thread local, such as
|
||||
// ThreadLocalDebugInfo and thread local RecordFunction callbacks.
|
||||
bool cleanupTLSState = true;
|
||||
// Whether we should consolidate all currently recorded profiled events. If
|
||||
// false, will not consolidate and other threads can continue to write to the
|
||||
// event lists.
|
||||
bool consolidate = true;
|
||||
};
|
||||
|
||||
// NOTE: profiler mode is thread local, with automatic propagation
|
||||
// across thread boundary (e.g. at::launch tasks)
|
||||
TORCH_API void enableProfilerLegacy(const ProfilerConfig&);
|
||||
using thread_event_lists = std::vector<std::vector<LegacyEvent>>;
|
||||
TORCH_API thread_event_lists disableProfilerLegacy(c10::optional<ProfilerDisableOptions> profilerDisableOptions = c10::nullopt);
|
||||
|
||||
// adds profiledEvents to the current thread local recorded events. Each event
|
||||
// will be marked with node ID given by fromNodeId.
|
||||
TORCH_API void addEventList(std::vector<LegacyEvent>&& profiledEvents);
|
||||
// Returns if the profiler is currently enabled in the current thread.
|
||||
TORCH_API bool profilerEnabled();
|
||||
// Retrieve the thread_local ProfilerConfig.
|
||||
TORCH_API ProfilerConfig getProfilerConfig();
|
||||
// Writes profiled events to a stream.
|
||||
TORCH_API void writeProfilerEventsToStream(std::ostream& out, const std::vector<LegacyEvent*>& events);
|
||||
|
||||
// Usage:
|
||||
// {
|
||||
// RecordProfile guard("filename.trace");
|
||||
// // code you want to profile
|
||||
// }
|
||||
// Then open filename.trace in chrome://tracing
|
||||
struct TORCH_API RecordProfile {
|
||||
RecordProfile(std::ostream& out);
|
||||
RecordProfile(const std::string& filename);
|
||||
|
||||
~RecordProfile();
|
||||
private:
|
||||
void init();
|
||||
std::unique_ptr<std::ofstream> file_;
|
||||
std::ostream& out_;
|
||||
void processEvents(const std::vector<LegacyEvent*>& events);
|
||||
};
|
||||
|
||||
// A guard that enables the profiler, taking in an optional callback to process
|
||||
// the results
|
||||
// Usage:
|
||||
// {
|
||||
// TLSProfilerGuard g([](thread_event_lists profilerResults) {
|
||||
// // process profilerResults
|
||||
// });
|
||||
// Code to profile
|
||||
// }
|
||||
struct TORCH_API TLSProfilerGuard {
|
||||
explicit TLSProfilerGuard(
|
||||
const ProfilerConfig& cfg,
|
||||
c10::optional<std::function<void(const thread_event_lists&)>>
|
||||
resultCallback = c10::nullopt,
|
||||
c10::optional<ProfilerDisableOptions> profilerDisableOptions =
|
||||
c10::nullopt)
|
||||
: cb_(std::move(resultCallback)),
|
||||
profilerDisableOptions_(std::move(profilerDisableOptions)) {
|
||||
enableProfilerLegacy(cfg);
|
||||
}
|
||||
~TLSProfilerGuard() {
|
||||
thread_event_lists event_lists = disableProfilerLegacy(profilerDisableOptions_);
|
||||
if (cb_) {
|
||||
try {
|
||||
(*cb_)(event_lists);
|
||||
} catch (const std::exception& e) {
|
||||
LOG(ERROR) << "Got error processing profiler events: " << e.what();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
c10::optional<std::function<void(const thread_event_lists&)>> cb_;
|
||||
const c10::optional<ProfilerDisableOptions> profilerDisableOptions_;
|
||||
};
|
||||
|
||||
struct TORCH_API FileLineFunc {
|
||||
std::string filename;
|
||||
size_t line;
|
||||
std::string funcname;
|
||||
};
|
||||
TORCH_API std::vector<FileLineFunc> prepareCallstack(const std::vector<jit::StackEntry>& cs);
|
||||
TORCH_API std::vector<std::string> callstackStr(const std::vector<FileLineFunc>& cs);
|
||||
TORCH_API std::vector<std::vector<int64_t>> inputSizes(const at::RecordFunction& fn);
|
||||
|
||||
struct TORCH_API ProfilerThreadLocalState : public c10::MemoryReportingInfoBase {
|
||||
explicit ProfilerThreadLocalState(const ProfilerConfig& config)
|
||||
: config_(config), remoteProfiledEvents_{c10::nullopt} {}
|
||||
~ProfilerThreadLocalState() override = default;
|
||||
|
||||
const ProfilerConfig& config() const;
|
||||
|
||||
thread_event_lists consolidate();
|
||||
|
||||
void mark(std::string name, bool include_cuda = true);
|
||||
|
||||
void setOrAddRemoteProfiledEvents(
|
||||
std::vector<LegacyEvent>&& remoteProfiledEvents);
|
||||
|
||||
void pushRange(
|
||||
const at::RecordFunction& fn,
|
||||
const bool record_cuda,
|
||||
const char* msg = "",
|
||||
std::vector<std::vector<int64_t>>&& shapes = {});
|
||||
|
||||
void popRange(const at::RecordFunction& fn, const bool record_cuda);
|
||||
|
||||
void setCallbackHandle(at::CallbackHandle handle) {
|
||||
handle_ = handle;
|
||||
}
|
||||
|
||||
at::CallbackHandle callbackHandle() const {
|
||||
return handle_;
|
||||
}
|
||||
|
||||
bool hasCallbackHandle() {
|
||||
return handle_ > 0;
|
||||
}
|
||||
|
||||
void reportMemoryUsage(
|
||||
void* /* unused */,
|
||||
int64_t alloc_size,
|
||||
c10::Device device) override;
|
||||
|
||||
bool memoryProfilingEnabled() const override;
|
||||
|
||||
protected:
|
||||
std::string getNvtxStr(
|
||||
const at::StringView& name,
|
||||
const char* msg,
|
||||
int64_t sequence_nr,
|
||||
const std::vector<std::vector<int64_t>>& shapes) const;
|
||||
|
||||
RangeEventList& getEventList(int64_t thread_id = -1);
|
||||
|
||||
std::mutex state_mutex_;
|
||||
std::unordered_map<uint64_t, std::shared_ptr<RangeEventList>>
|
||||
event_lists_map_;
|
||||
|
||||
ProfilerConfig config_ = ProfilerConfig(ProfilerState::Disabled);
|
||||
at::CallbackHandle handle_ = 0;
|
||||
c10::optional<std::vector<std::vector<LegacyEvent>>> remoteProfiledEvents_;
|
||||
};
|
||||
|
||||
|
||||
} // namespace profiler
|
||||
}} // namespace torch::autograd
|
@ -13,7 +13,7 @@ constexpr auto kProfileEventsStartIdx = 3;
|
||||
RpcWithProfilingResp::RpcWithProfilingResp(
|
||||
rpc::MessageType messageType,
|
||||
rpc::Message&& wrappedMessage,
|
||||
std::vector<torch::autograd::profiler::Event> profiledEvents,
|
||||
std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents,
|
||||
rpc::ProfilingId profilingId)
|
||||
: messageType_(messageType),
|
||||
wrappedMessage_(std::move(wrappedMessage)),
|
||||
@ -32,7 +32,7 @@ RpcWithProfilingResp::RpcWithProfilingResp(
|
||||
std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,
|
||||
rpc::MessageType wrappedMessageType,
|
||||
std::vector<torch::Tensor> tensors,
|
||||
std::vector<torch::autograd::profiler::Event> profiledEvents,
|
||||
std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents,
|
||||
rpc::ProfilingId profilingId)
|
||||
: messageType_(messageType),
|
||||
wrappedRpc_(std::move(wrappedRpc)),
|
||||
@ -52,7 +52,7 @@ rpc::MessageType RpcWithProfilingResp::wrappedMessageType() const {
|
||||
return wrappedMessageType_;
|
||||
}
|
||||
|
||||
std::vector<torch::autograd::profiler::Event> RpcWithProfilingResp::
|
||||
std::vector<torch::autograd::profiler::LegacyEvent> RpcWithProfilingResp::
|
||||
getProfiledEvents() const {
|
||||
return profiledEvents_;
|
||||
}
|
||||
@ -119,15 +119,15 @@ std::unique_ptr<RpcWithProfilingResp> RpcWithProfilingResp::fromMessage(
|
||||
static_cast<rpc::MessageType>(tupleElements[0].toInt());
|
||||
rpc::ProfilingId profilingId = rpc::ProfilingId::fromIValue(tupleElements[1]);
|
||||
int profiledEventsSize = tupleElements[2].toInt();
|
||||
std::vector<torch::autograd::profiler::Event> remoteEvents;
|
||||
std::vector<torch::autograd::profiler::LegacyEvent> remoteEvents;
|
||||
remoteEvents.reserve(profiledEventsSize);
|
||||
for (int i = kProfileEventsStartIdx;
|
||||
i < kProfileEventsStartIdx + profiledEventsSize;
|
||||
++i) {
|
||||
TORCH_CHECK(i < tupleElements.size());
|
||||
// Reconstruct remote event from the ivalues.
|
||||
torch::autograd::profiler::Event fromIvalueEvent =
|
||||
torch::autograd::profiler::Event::fromIValue(tupleElements[i]);
|
||||
torch::autograd::profiler::LegacyEvent fromIvalueEvent =
|
||||
torch::autograd::profiler::LegacyEvent::fromIValue(tupleElements[i]);
|
||||
remoteEvents.push_back(std::move(fromIvalueEvent));
|
||||
}
|
||||
|
||||
|
@ -15,7 +15,7 @@ class TORCH_API RpcWithProfilingResp : public rpc::RpcCommandBase {
|
||||
RpcWithProfilingResp(
|
||||
rpc::MessageType messageType,
|
||||
rpc::Message&& wrappedMessage,
|
||||
std::vector<torch::autograd::profiler::Event> profiledEvents,
|
||||
std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents,
|
||||
rpc::ProfilingId profilingId);
|
||||
|
||||
// For receving RPCs. Used in from message when converting a message received
|
||||
@ -25,13 +25,13 @@ class TORCH_API RpcWithProfilingResp : public rpc::RpcCommandBase {
|
||||
std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,
|
||||
rpc::MessageType wrappedMessageType,
|
||||
std::vector<torch::Tensor> tensors,
|
||||
std::vector<torch::autograd::profiler::Event> profiledEvents,
|
||||
std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents,
|
||||
rpc::ProfilingId profilingId);
|
||||
rpc::Message toMessageImpl() && override;
|
||||
static std::unique_ptr<RpcWithProfilingResp> fromMessage(
|
||||
const rpc::Message& message);
|
||||
// Retrieve remote Events
|
||||
std::vector<torch::autograd::profiler::Event> getProfiledEvents() const;
|
||||
std::vector<torch::autograd::profiler::LegacyEvent> getProfiledEvents() const;
|
||||
// Retrieve the globally unique profiling ID corresponding to this command.
|
||||
const rpc::ProfilingId& getProfilingId() const;
|
||||
// Retrieve the original RPC which this ProfilingRPC wraps.
|
||||
@ -51,7 +51,7 @@ class TORCH_API RpcWithProfilingResp : public rpc::RpcCommandBase {
|
||||
std::unique_ptr<RpcCommandBase> wrappedRpc_;
|
||||
rpc::MessageType wrappedMessageType_;
|
||||
std::vector<torch::Tensor> tensors_;
|
||||
const std::vector<torch::autograd::profiler::Event> profiledEvents_;
|
||||
const std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents_;
|
||||
const rpc::ProfilingId profilingId_;
|
||||
};
|
||||
} // namespace autograd
|
||||
|
@ -81,7 +81,7 @@ std::shared_ptr<FutureMessage> RequestCallbackNoPython::processMessage(
|
||||
if (serverProcessGlobalProfilerStateStackEntryPtr) {
|
||||
// Initialize thread-local profiler state from process-global
|
||||
// profiler state.
|
||||
::torch::autograd::profiler::enableProfiler(
|
||||
::torch::autograd::profiler::enableProfilerLegacy(
|
||||
serverProcessGlobalProfilerStateStackEntryPtr->statePtr()
|
||||
->config());
|
||||
}
|
||||
@ -93,7 +93,7 @@ std::shared_ptr<FutureMessage> RequestCallbackNoPython::processMessage(
|
||||
if (serverProcessGlobalProfilerStateStackEntryPtr) {
|
||||
// Restore thread-local profiler state.
|
||||
::torch::autograd::profiler::thread_event_lists event_lists =
|
||||
::torch::autograd::profiler::disableProfiler();
|
||||
::torch::autograd::profiler::disableProfilerLegacy();
|
||||
// Put thread_local event_lists into the process-global profiler
|
||||
// state.
|
||||
profiler::processglobal::pushResultRecursive(
|
||||
@ -509,7 +509,7 @@ void RequestCallbackNoPython::processRunWithProfilingReq(
|
||||
responseFuture,
|
||||
profilingKeyId,
|
||||
profilingConfig] {
|
||||
std::vector<torch::autograd::profiler::Event> profiledEvents;
|
||||
std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents;
|
||||
// Defer consolidation of profiler events until async work has
|
||||
// completed (such as async UDF)
|
||||
|
||||
@ -521,7 +521,7 @@ void RequestCallbackNoPython::processRunWithProfilingReq(
|
||||
// they will be cleaned up by main thread, and consolidate all
|
||||
// events so we obtain asynchronously run events.
|
||||
torch::autograd::profiler::ProfilerDisableOptions opts(false, true);
|
||||
auto event_lists = torch::autograd::profiler::disableProfiler(opts);
|
||||
auto event_lists = torch::autograd::profiler::disableProfilerLegacy(opts);
|
||||
if (wrappedRpcResponseFuture->hasError()) {
|
||||
// Propagate error
|
||||
// No need to propagate remote events in the case of an error.
|
||||
|
@ -36,7 +36,7 @@ void processRemoteProfiledEvents(
|
||||
"Profiler was expected to be enabled. This can happen in callback "
|
||||
" continutations that run in different threads, and the TLS of the "
|
||||
" profiler was not propagated.");
|
||||
std::vector<torch::autograd::profiler::Event> events =
|
||||
std::vector<torch::autograd::profiler::LegacyEvent> events =
|
||||
rpcWithProfilingResp.getProfiledEvents();
|
||||
const auto& profilingId = rpcWithProfilingResp.getProfilingId();
|
||||
auto& remoteProfilerManager = RemoteProfilerManager::getInstance();
|
||||
@ -46,7 +46,7 @@ void processRemoteProfiledEvents(
|
||||
std::for_each(
|
||||
events.begin(),
|
||||
events.end(),
|
||||
[&keyPrefixStr](torch::autograd::profiler::Event& event) {
|
||||
[&keyPrefixStr](torch::autograd::profiler::LegacyEvent& event) {
|
||||
std::string name = keyPrefixStr + std::string(event.name());
|
||||
event.setName(at::StringView(name));
|
||||
});
|
||||
@ -511,9 +511,9 @@ std::vector<at::IValue> readWrappedPayload(
|
||||
}
|
||||
|
||||
void populateRemoteProfiledEvents(
|
||||
std::vector<torch::autograd::profiler::Event>& profiledEvents,
|
||||
std::vector<torch::autograd::profiler::LegacyEvent>& profiledEvents,
|
||||
const torch::autograd::profiler::ProfilerConfig& profilingConfig,
|
||||
const std::vector<std::vector<torch::autograd::profiler::Event>>&
|
||||
const std::vector<std::vector<torch::autograd::profiler::LegacyEvent>>&
|
||||
eventLists) {
|
||||
// Gather all events into a vector
|
||||
for (auto& l : eventLists) {
|
||||
@ -525,11 +525,11 @@ void populateRemoteProfiledEvents(
|
||||
bool cudaProfilingEnabled =
|
||||
profilingConfig.state == torch::autograd::profiler::ProfilerState::CUDA;
|
||||
bool foundCpuStart = false;
|
||||
const torch::autograd::profiler::Event* profilerStart = nullptr;
|
||||
const torch::autograd::profiler::LegacyEvent* profilerStart = nullptr;
|
||||
// Each device has its own cudaProfilerStart, so we must take
|
||||
// care to use the correct one depending on the device the
|
||||
// operation ran on.
|
||||
std::unordered_map<int, const torch::autograd::profiler::Event*>
|
||||
std::unordered_map<int, const torch::autograd::profiler::LegacyEvent*>
|
||||
cudaProfilerStarts;
|
||||
for (auto& e : profiledEvents) {
|
||||
if (!foundCpuStart && 0 == strcmp(e.name(), "__start_profile")) {
|
||||
|
@ -82,9 +82,9 @@ TORCH_API std::vector<at::IValue> readWrappedPayload(
|
||||
// Takes a list of events from autograd profiler and populates them into
|
||||
// profiledEvents to be carried over RPC.
|
||||
TORCH_API void populateRemoteProfiledEvents(
|
||||
std::vector<torch::autograd::profiler::Event>& profiledEvents,
|
||||
std::vector<torch::autograd::profiler::LegacyEvent>& profiledEvents,
|
||||
const torch::autograd::profiler::ProfilerConfig& profilerConfig,
|
||||
const std::vector<std::vector<torch::autograd::profiler::Event>>&
|
||||
const std::vector<std::vector<torch::autograd::profiler::LegacyEvent>>&
|
||||
eventLists);
|
||||
|
||||
} // namespace rpc
|
||||
|
@ -103,7 +103,7 @@ class _server_process_global_profile(profile):
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
if self.entered:
|
||||
if self.entered: # type: ignore[has-type]
|
||||
raise RuntimeError("autograd profiler traces are not reentrant")
|
||||
self.entered = True
|
||||
|
||||
@ -145,13 +145,13 @@ class _server_process_global_profile(profile):
|
||||
process_global_function_events = []
|
||||
for thread_local_events in process_global_events:
|
||||
# Parse from ``Event``s to ``FunctionEvent``s.
|
||||
thread_local_function_events = torch.autograd.profiler.parse_event_records(
|
||||
thread_local_function_events = torch.autograd.profiler.parse_legacy_records(
|
||||
thread_local_events
|
||||
)
|
||||
thread_local_function_events.sort(
|
||||
key=lambda function_event: [
|
||||
function_event.cpu_interval.start,
|
||||
-(function_event.cpu_interval.end),
|
||||
function_event.time_range.start,
|
||||
-(function_event.time_range.end),
|
||||
]
|
||||
)
|
||||
process_global_function_events.append(thread_local_function_events)
|
||||
@ -164,6 +164,7 @@ class _server_process_global_profile(profile):
|
||||
use_cuda=self.use_cuda,
|
||||
profile_memory=self.profile_memory,
|
||||
)
|
||||
self.function_events._build_tree()
|
||||
|
||||
self.process_global_function_events = process_global_function_events
|
||||
|
||||
|
@ -177,9 +177,9 @@ class AllreduceNCCLTest : public NCCLTest {
|
||||
// Make sure enabling profile does not make any issue. Note, in single
|
||||
// process multi-device mode we do not expect any events be populated for
|
||||
// collective operations, since profiling for that mode is not supported.
|
||||
enableProfiler({ProfilerState::CPU});
|
||||
enableProfilerLegacy({ProfilerState::CPU});
|
||||
auto results = pg_->allreduce(tensors_);
|
||||
disableProfiler();
|
||||
disableProfilerLegacy();
|
||||
return results;
|
||||
}
|
||||
};
|
||||
|
@ -1567,8 +1567,8 @@ class RpcTest(RpcAgentTestFixture):
|
||||
scope_event = get_function_event(events, "foo")
|
||||
# Since RPC call is within the scope, its CPU interval should be
|
||||
# contained within foo's interval.
|
||||
self.assertTrue(scope_event.cpu_interval.start < rpc_event.cpu_interval.start)
|
||||
self.assertTrue(scope_event.cpu_interval.end > rpc_event.cpu_interval.end)
|
||||
self.assertTrue(scope_event.time_range.start < rpc_event.time_range.start)
|
||||
self.assertTrue(scope_event.time_range.end > rpc_event.time_range.end)
|
||||
# the sender, dest worker, function run, and type of RPC should all
|
||||
# be recorded.
|
||||
self_worker_name = worker_name(self.rank)
|
||||
@ -1760,10 +1760,10 @@ class RpcTest(RpcAgentTestFixture):
|
||||
last_end_time = 0
|
||||
for event in thread_local_events:
|
||||
event_name = event.name
|
||||
cpu_interval = event.cpu_interval
|
||||
if cpu_interval.start > last_end_time:
|
||||
time_range = event.time_range
|
||||
if time_range.start > last_end_time:
|
||||
top_level_event_names.append(event_name)
|
||||
last_end_time = cpu_interval.end
|
||||
last_end_time = time_range.end
|
||||
self.assertEqual(sorted(top_level_event_names), sorted(expected_top_level_event_names))
|
||||
|
||||
@dist_init
|
||||
|
Reference in New Issue
Block a user