mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Enable Intel® VTune™ Profiler's Instrumentation and Tracing Technology APIs (ITT) to PyTorch (#63289)
More detailed description of benefits can be found at #41001. This is Intel's counterpart of NVidia’s NVTX (https://pytorch.org/docs/stable/autograd.html#torch.autograd.profiler.emit_nvtx). ITT is a functionality for labeling trace data during application execution across different Intel tools. For integrating Intel(R) VTune Profiler into Kineto, ITT needs to be integrated into PyTorch first. It works with both standalone VTune Profiler [(https://www.intel.com/content/www/us/en/developer/tools/oneapi/vtune-profiler.html](https://www.intel.com/content/www/us/en/developer/tools/oneapi/vtune-profiler.html)) and Kineto-integrated VTune functionality in the future. It works for both Intel CPU and Intel XPU devices. Pitch Add VTune Profiler's ITT API function calls to annotate PyTorch ops, as well as developer customized code scopes on CPU, like NVTX for NVidia GPU. This PR rebases the code changes at https://github.com/pytorch/pytorch/pull/61335 to the latest master branch. Usage example: ``` with torch.autograd.profiler.emit_itt(): for i in range(10): torch.itt.range_push('step_{}'.format(i)) model(input) torch.itt.range_pop() ``` cc @ilia-cher @robieta @chaekit @gdankel @bitfort @ngimel @orionr @nbcsm @guotuofeng @guyang3532 @gaoteng-git Pull Request resolved: https://github.com/pytorch/pytorch/pull/63289 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
937ca69f15
commit
3c7044728b
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -139,6 +139,9 @@
|
||||
[submodule "third_party/pocketfft"]
|
||||
path = third_party/pocketfft
|
||||
url = https://github.com/mreineck/pocketfft
|
||||
[submodule "third_party/ittapi"]
|
||||
path = third_party/ittapi
|
||||
url = https://github.com/intel/ittapi.git
|
||||
[submodule "third_party/flatbuffers"]
|
||||
path = third_party/flatbuffers
|
||||
url = https://github.com/google/flatbuffers.git
|
||||
|
@ -295,6 +295,10 @@ if(NOT USE_XNNPACK AND CMAKE_VERSION VERSION_LESS ${XNNPACK_MIN_CMAKE_VER})
|
||||
endif()
|
||||
option(USE_ZMQ "Use ZMQ" OFF)
|
||||
option(USE_ZSTD "Use ZSTD" OFF)
|
||||
# Ensure that an ITT build is the default for x86 CPUs
|
||||
cmake_dependent_option(
|
||||
USE_ITT "Use Intel(R) VTune Profiler ITT functionality" ON
|
||||
"CPU_INTEL" OFF)
|
||||
# Ensure that an MKLDNN build is the default for x86 CPUs
|
||||
# but optional for AArch64 (dependent on -DUSE_MKLDNN).
|
||||
cmake_dependent_option(
|
||||
|
@ -132,6 +132,7 @@ libtorch_profiler_sources = [
|
||||
"torch/csrc/profiler/kineto_shim.cpp",
|
||||
"torch/csrc/profiler/nvtx_observer.cpp",
|
||||
"torch/csrc/profiler/kineto_client_interface.cpp",
|
||||
"torch/csrc/profiler/itt_observer.cpp",
|
||||
"torch/csrc/monitor/counters.cpp",
|
||||
"torch/csrc/monitor/events.cpp",
|
||||
]
|
||||
|
@ -609,6 +609,13 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
)
|
||||
endif()
|
||||
|
||||
if(${USE_ITT})
|
||||
list(APPEND TORCH_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/itt_wrapper.cpp
|
||||
${TORCH_SRC_DIR}/csrc/profiler/itt.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
if(NOT INTERN_BUILD_MOBILE AND NOT BUILD_LITE_INTERPRETER)
|
||||
list(APPEND TORCH_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
|
||||
|
@ -42,6 +42,7 @@ static_assert(
|
||||
#cmakedefine CAFFE2_USE_MKL
|
||||
#cmakedefine CAFFE2_USE_MKLDNN
|
||||
#cmakedefine CAFFE2_USE_NVTX
|
||||
#cmakedefine CAFFE2_USE_ITT
|
||||
#cmakedefine CAFFE2_USE_TRT
|
||||
|
||||
#ifndef EIGEN_MPL2_ONLY
|
||||
@ -82,5 +83,6 @@ static_assert(
|
||||
{"USE_MKL", "${CAFFE2_USE_MKL}"}, \
|
||||
{"USE_MKLDNN", "${CAFFE2_USE_MKLDNN}"}, \
|
||||
{"USE_NVTX", "${CAFFE2_USE_NVTX}"}, \
|
||||
{"USE_ITT", "${CAFFE2_USE_ITT}"}, \
|
||||
{"USE_TRT", "${CAFFE2_USE_TRT}"}, \
|
||||
}
|
||||
|
@ -962,6 +962,19 @@ if(USE_FFMPEG)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(USE_ITT)
|
||||
find_package(ITT)
|
||||
if(ITT_FOUND)
|
||||
include_directories(SYSTEM ${ITT_INCLUDE_DIR})
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS ${ITT_LIBRARIES})
|
||||
list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${ITT_LIBRARIES})
|
||||
else()
|
||||
message(WARNING "Not compiling with ITT. Suppress this warning with -DUSE_ITT=OFF")
|
||||
set(USE_ITT OFF CACHE BOOL "" FORCE)
|
||||
caffe2_update_option(USE_ITT OFF)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# ---[ Caffe2 depends on FP16 library for half-precision conversions
|
||||
if(NOT TARGET fp16 AND NOT USE_SYSTEM_FP16)
|
||||
set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party")
|
||||
|
21
cmake/Modules/FindITT.cmake
Normal file
21
cmake/Modules/FindITT.cmake
Normal file
@ -0,0 +1,21 @@
|
||||
# - Try to find ITT
|
||||
#
|
||||
# The following are set after configuration is done:
|
||||
# ITT_FOUND : set to true if ITT is found.
|
||||
# ITT_INCLUDE_DIR : path to ITT include dir.
|
||||
# ITT_LIBRARIES : list of libraries for ITT
|
||||
|
||||
IF (NOT ITT_FOUND)
|
||||
SET(ITT_FOUND OFF)
|
||||
|
||||
SET(ITT_INCLUDE_DIR)
|
||||
SET(ITT_LIBRARIES)
|
||||
|
||||
SET(ITT_ROOT "${PROJECT_SOURCE_DIR}/third_party/ittapi")
|
||||
FIND_PATH(ITT_INCLUDE_DIR ittnotify.h PATHS ${ITT_ROOT} PATH_SUFFIXES include)
|
||||
IF (ITT_INCLUDE_DIR)
|
||||
ADD_SUBDIRECTORY(${ITT_ROOT})
|
||||
SET(ITT_LIBRARIES ittnotify)
|
||||
SET(ITT_FOUND ON)
|
||||
ENDIF (ITT_INCLUDE_DIR)
|
||||
ENDIF(NOT ITT_FOUND)
|
@ -150,6 +150,7 @@ function(caffe2_print_configuration_summary)
|
||||
if(${USE_UCC})
|
||||
message(STATUS " USE_SYSTEM_UCC : ${USE_SYSTEM_UCC}")
|
||||
endif()
|
||||
message(STATUS " USE_ITT : ${USE_ITT}")
|
||||
message(STATUS " USE_NCCL : ${USE_NCCL}")
|
||||
if(${USE_NCCL})
|
||||
message(STATUS " USE_SYSTEM_NCCL : ${USE_SYSTEM_NCCL}")
|
||||
|
@ -223,10 +223,12 @@ Profiler
|
||||
^^^^^^^^
|
||||
|
||||
Autograd includes a profiler that lets you inspect the cost of different
|
||||
operators inside your model - both on the CPU and GPU. There are two modes
|
||||
operators inside your model - both on the CPU and GPU. There are three modes
|
||||
implemented at the moment - CPU-only using :class:`~torch.autograd.profiler.profile`.
|
||||
and nvprof based (registers both CPU and GPU activity) using
|
||||
nvprof based (registers both CPU and GPU activity) using
|
||||
:class:`~torch.autograd.profiler.emit_nvtx`.
|
||||
and vtune profiler based using
|
||||
:class:`~torch.autograd.profiler.emit_itt`.
|
||||
|
||||
.. autoclass:: torch.autograd.profiler.profile
|
||||
|
||||
@ -240,6 +242,7 @@ and nvprof based (registers both CPU and GPU activity) using
|
||||
profiler.profile.total_average
|
||||
|
||||
.. autoclass:: torch.autograd.profiler.emit_nvtx
|
||||
.. autoclass:: torch.autograd.profiler.emit_itt
|
||||
|
||||
|
||||
.. autosummary::
|
||||
|
@ -47,7 +47,9 @@ where [args] are any number of arguments to `script.py`, or run
|
||||
evaluating. If the profiler outputs don't help, you could try looking at
|
||||
the result of :func:`torch.autograd.profiler.emit_nvtx()` with ``nvprof``.
|
||||
However, please take into account that the NVTX overhead is very high and
|
||||
often gives a heavily skewed timeline.
|
||||
often gives a heavily skewed timeline. Similarly, Intel VTune Profiler helps
|
||||
to analyze performance on Intel platforms further with
|
||||
:func:`torch.autograd.profiler.emit_nvtx()`.
|
||||
|
||||
.. warning::
|
||||
If you are profiling CUDA code, the first profiler that ``bottleneck`` runs
|
||||
|
@ -135,6 +135,7 @@ else
|
||||
fi
|
||||
# Disable unused dependencies
|
||||
CMAKE_ARGS+=("-DUSE_CUDA=OFF")
|
||||
CMAKE_ARGS+=("-DUSE_ITT=OFF")
|
||||
CMAKE_ARGS+=("-DUSE_GFLAGS=OFF")
|
||||
CMAKE_ARGS+=("-DUSE_OPENCV=OFF")
|
||||
CMAKE_ARGS+=("-DUSE_LMDB=OFF")
|
||||
|
@ -104,6 +104,7 @@ CMAKE_ARGS+=("-DBUILD_PYTHON=OFF")
|
||||
|
||||
# Disable unused dependencies
|
||||
CMAKE_ARGS+=("-DUSE_CUDA=OFF")
|
||||
CMAKE_ARGS+=("-DUSE_ITT=OFF")
|
||||
CMAKE_ARGS+=("-DUSE_GFLAGS=OFF")
|
||||
CMAKE_ARGS+=("-DUSE_OPENCV=OFF")
|
||||
CMAKE_ARGS+=("-DUSE_LMDB=OFF")
|
||||
|
@ -38,6 +38,7 @@ fi
|
||||
# Disable unused dependencies
|
||||
CMAKE_ARGS+=("-DUSE_ROCM=OFF")
|
||||
CMAKE_ARGS+=("-DUSE_CUDA=OFF")
|
||||
CMAKE_ARGS+=("-DUSE_ITT=OFF")
|
||||
CMAKE_ARGS+=("-DUSE_GFLAGS=OFF")
|
||||
CMAKE_ARGS+=("-DUSE_OPENCV=OFF")
|
||||
CMAKE_ARGS+=("-DUSE_LMDB=OFF")
|
||||
|
@ -112,6 +112,7 @@ cd $BUILD_ROOT
|
||||
cmake "$CAFFE2_ROOT" \
|
||||
-DCMAKE_VERBOSE_MAKEFILE=1 \
|
||||
-DUSE_CUDA=OFF \
|
||||
-DUSE_ITT=OFF \
|
||||
-DUSE_OPENCV=OFF \
|
||||
-DUSE_LMDB=OFF \
|
||||
-DCAFFE2_CPU_FLAGS="-mfpu=neon -mfloat-abi=soft" \
|
||||
|
7
setup.py
7
setup.py
@ -52,6 +52,8 @@
|
||||
#
|
||||
# USE_STATIC_MKL
|
||||
# Prefer to link with MKL statically - Unix only
|
||||
# USE_ITT=0
|
||||
# disable use of Intel(R) VTune Profiler's ITT functionality
|
||||
#
|
||||
# USE_NNPACK=0
|
||||
# disables NNPACK build
|
||||
@ -541,6 +543,11 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
||||
if cmake_cache_vars['USE_LIGHTWEIGHT_DISPATCH']:
|
||||
report('-- Using lightweight dispatch')
|
||||
|
||||
if cmake_cache_vars['USE_ITT']:
|
||||
report('-- Using ITT')
|
||||
else:
|
||||
report('-- Not using ITT')
|
||||
|
||||
# Do not use clang to compile extensions if `-fstack-clash-protection` is defined
|
||||
# in system CFLAGS
|
||||
c_flags = str(os.getenv('CFLAGS', ''))
|
||||
|
1
third_party/ittapi
vendored
Submodule
1
third_party/ittapi
vendored
Submodule
Submodule third_party/ittapi added at 5b8a7d7422
@ -117,6 +117,13 @@ if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
|
||||
-Wno-writable-strings)
|
||||
endif()
|
||||
|
||||
if(USE_ITT)
|
||||
list(APPEND TORCH_PYTHON_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/itt.cpp
|
||||
)
|
||||
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_ITT)
|
||||
endif()
|
||||
|
||||
if(USE_CUDA)
|
||||
include(${TORCH_ROOT}/cmake/public/cuda.cmake)
|
||||
append_filelist("libtorch_python_cuda_core_sources" TORCH_PYTHON_SRCS)
|
||||
|
@ -10,6 +10,7 @@ class ProfilerState(Enum):
|
||||
CPU = ...
|
||||
CUDA = ...
|
||||
NVTX = ...
|
||||
ITT = ...
|
||||
KINETO = ...
|
||||
KINETO_GPU_FALLBACK = ...
|
||||
|
||||
|
4
torch/_C/_itt.pyi
Normal file
4
torch/_C/_itt.pyi
Normal file
@ -0,0 +1,4 @@
|
||||
# Defined in torch/csrc/itt.cpp
|
||||
def rangePush(message: str) -> None: ...
|
||||
def rangePop() -> None: ...
|
||||
def mark(message: str) -> None: ...
|
@ -479,6 +479,70 @@ class record_function(ContextDecorator):
|
||||
return profiled_future
|
||||
|
||||
|
||||
class emit_itt(object):
|
||||
"""Context manager that makes every autograd operation emit an ITT range.
|
||||
|
||||
It is useful when running the program under Intel(R) VTune Profiler::
|
||||
|
||||
vtune <--vtune_flags> <regular command here>
|
||||
|
||||
The Instrumentation and Tracing Technology (ITT) API enables your application to generate and
|
||||
control the collection of trace data during its execution across different Intel tools.
|
||||
This context manager is to annotate Intel(R) VTune Profiling trace. With help of this context manager,
|
||||
you will be able to see labled ranges in Intel(R) VTune Profiler GUI.
|
||||
|
||||
.. warning:
|
||||
This context manager should not be called recursively, i.e. at most one
|
||||
instance should be enabled at any given time.
|
||||
|
||||
Args:
|
||||
enabled (bool, optional, default=True): Setting ``enabled=False`` makes this context manager a no-op.
|
||||
Default: ``True``.
|
||||
record_shapes (bool, optional, default=False): If ``record_shapes=True``, the itt range wrapping
|
||||
each autograd op will append information about the sizes of Tensor arguments received
|
||||
by that op, in the following format:
|
||||
``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]``
|
||||
Non-tensor arguments will be represented by ``[]``.
|
||||
Arguments will be listed in the order they are received by the backend op.
|
||||
Please note that this order may not match the order in which those arguments were passed
|
||||
on the Python side. Also note that shape recording may increase the overhead of itt range creation.
|
||||
|
||||
Example:
|
||||
>>> with torch.autograd.profiler.emit_itt():
|
||||
... model(x)
|
||||
|
||||
"""
|
||||
def __init__(self, enabled=True, record_shapes=False):
|
||||
self.enabled = enabled
|
||||
self.entered = False
|
||||
self.record_shapes = record_shapes
|
||||
|
||||
def __enter__(self):
|
||||
if not self.enabled:
|
||||
return
|
||||
if self.entered:
|
||||
raise RuntimeError("ITT annotation context manager is not reentrant")
|
||||
self.entered = True
|
||||
_enable_profiler(
|
||||
ProfilerConfig(
|
||||
ProfilerState.ITT,
|
||||
self.record_shapes,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
_ExperimentalConfig()),
|
||||
set()
|
||||
)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if not self.enabled:
|
||||
return
|
||||
_disable_profiler()
|
||||
return False
|
||||
|
||||
|
||||
class emit_nvtx(object):
|
||||
"""Context manager that makes every autograd operation emit an NVTX range.
|
||||
|
||||
|
@ -910,6 +910,14 @@ void initModule(PyObject* module);
|
||||
} // namespace torch
|
||||
#endif
|
||||
|
||||
#ifdef USE_ITT
|
||||
namespace torch {
|
||||
namespace profiler {
|
||||
void initIttBindings(PyObject* module);
|
||||
} // namespace profiler
|
||||
} // namespace torch
|
||||
#endif
|
||||
|
||||
static std::vector<PyMethodDef> methods;
|
||||
|
||||
// In Python we can't use the trick of C10_LOG_API_USAGE_ONCE
|
||||
@ -1008,6 +1016,9 @@ PyObject* initModule() {
|
||||
torch::autograd::init_legacy_variable(module);
|
||||
torch::python::init_bindings(module);
|
||||
torch::lazy::initLazyBindings(module);
|
||||
#ifdef USE_ITT
|
||||
torch::profiler::initIttBindings(module);
|
||||
#endif
|
||||
#ifdef USE_CUDA
|
||||
torch::cuda::initModule(module);
|
||||
#endif
|
||||
|
@ -85,6 +85,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
||||
.value("CPU", ProfilerState::CPU)
|
||||
.value("CUDA", ProfilerState::CUDA)
|
||||
.value("NVTX", ProfilerState::NVTX)
|
||||
.value("ITT", ProfilerState::ITT)
|
||||
.value("KINETO", ProfilerState::KINETO)
|
||||
.value("KINETO_GPU_FALLBACK", ProfilerState::KINETO_GPU_FALLBACK);
|
||||
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include <torch/csrc/profiler/api.h>
|
||||
#include <torch/csrc/profiler/collection.h>
|
||||
#include <torch/csrc/profiler/containers.h>
|
||||
#include <torch/csrc/profiler/itt_observer.h>
|
||||
#include <torch/csrc/profiler/kineto_shim.h>
|
||||
#include <torch/csrc/profiler/nvtx_observer.h>
|
||||
|
||||
@ -626,7 +627,8 @@ void reportBackendEventToActiveKinetoProfiler(
|
||||
void prepareProfiler(
|
||||
const torch::profiler::impl::ProfilerConfig& config,
|
||||
const std::set<torch::profiler::impl::ActivityType>& activities) {
|
||||
if (config.state == ProfilerState::NVTX) {
|
||||
if (config.state == ProfilerState::NVTX ||
|
||||
config.state == ProfilerState::ITT) {
|
||||
return;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
@ -645,6 +647,9 @@ void enableProfilerWithEventPostProcess(
|
||||
TORCH_CHECK(
|
||||
config.state != ProfilerState::NVTX,
|
||||
"NVTX does not support post processing callback.");
|
||||
TORCH_CHECK(
|
||||
config.state != ProfilerState::ITT,
|
||||
"ITT does not support post processing callback.");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
GlobalStateManager::get() == nullptr,
|
||||
"On-demand profiling does not support post processing callback");
|
||||
@ -662,6 +667,9 @@ void enableProfiler(
|
||||
if (config.state == ProfilerState::NVTX) {
|
||||
torch::profiler::impl::pushNVTXCallbacks(config, scopes);
|
||||
return;
|
||||
} else if (config.state == ProfilerState::ITT) {
|
||||
torch::profiler::impl::pushITTCallbacks(config, scopes);
|
||||
return;
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
@ -705,7 +713,8 @@ std::unique_ptr<ProfilerResult> disableProfiler() {
|
||||
(config.state == ProfilerState::KINETO ||
|
||||
config.state == ProfilerState::KINETO_GPU_FALLBACK ||
|
||||
config.state == ProfilerState::KINETO_ONDEMAND ||
|
||||
config.state == ProfilerState::NVTX),
|
||||
config.state == ProfilerState::NVTX ||
|
||||
config.state == ProfilerState::ITT),
|
||||
"Can't disable Kineto profiler when it's not running");
|
||||
|
||||
if (state_ptr->hasCallbackHandle()) {
|
||||
|
@ -279,8 +279,8 @@ struct TORCH_API KinetoEvent {
|
||||
int64_t debug_handle_{-1};
|
||||
std::string backend_;
|
||||
|
||||
torch::profiler::impl::CUDAEventStub cuda_event_start_ = nullptr;
|
||||
torch::profiler::impl::CUDAEventStub cuda_event_end_ = nullptr;
|
||||
torch::profiler::impl::ProfilerEventStub cuda_event_start_ = nullptr;
|
||||
torch::profiler::impl::ProfilerEventStub cuda_event_end_ = nullptr;
|
||||
bool is_python_function_;
|
||||
};
|
||||
|
||||
|
@ -197,7 +197,7 @@ void ProfilerLegacyThreadLocalState::mark(std::string name, bool include_cuda) {
|
||||
return;
|
||||
}
|
||||
if (config_.state == torch::profiler::impl::ProfilerState::NVTX) {
|
||||
torch::profiler::impl::cudaStubs()->nvtxMarkA(name.c_str());
|
||||
torch::profiler::impl::cudaStubs()->mark(name.c_str());
|
||||
} else {
|
||||
LegacyEvent evt(
|
||||
EventKind::Mark,
|
||||
@ -229,7 +229,7 @@ void ProfilerLegacyThreadLocalState::pushRange(
|
||||
return;
|
||||
}
|
||||
if (config_.state == torch::profiler::impl::ProfilerState::NVTX) {
|
||||
torch::profiler::impl::cudaStubs()->nvtxRangePushA(
|
||||
torch::profiler::impl::cudaStubs()->rangePush(
|
||||
torch::profiler::impl::getNvtxStr(fn.name(), fn.seqNr(), shapes)
|
||||
.c_str());
|
||||
} else {
|
||||
@ -277,7 +277,7 @@ void ProfilerLegacyThreadLocalState::popRange(
|
||||
return;
|
||||
}
|
||||
if (config_.state == torch::profiler::impl::ProfilerState::NVTX) {
|
||||
torch::profiler::impl::cudaStubs()->nvtxRangePop();
|
||||
torch::profiler::impl::cudaStubs()->rangePop();
|
||||
} else {
|
||||
// In some cases RecordFunction (and popRange) may be
|
||||
// called on a different thread than pushRange
|
||||
|
@ -266,7 +266,7 @@ struct TORCH_API LegacyEvent {
|
||||
int64_t cpu_memory_usage_ = 0;
|
||||
int64_t cuda_memory_usage_ = 0;
|
||||
int device_ = -1;
|
||||
torch::profiler::impl::CUDAEventStub cuda_event = nullptr;
|
||||
torch::profiler::impl::ProfilerEventStub cuda_event = nullptr;
|
||||
int node_id_ = 0;
|
||||
bool is_remote_ = false;
|
||||
int64_t cuda_us_ = -1;
|
||||
|
15
torch/csrc/itt.cpp
Normal file
15
torch/csrc/itt.cpp
Normal file
@ -0,0 +1,15 @@
|
||||
#include <torch/csrc/itt_wrapper.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch {
|
||||
namespace profiler {
|
||||
void initIttBindings(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
|
||||
auto itt = m.def_submodule("_itt", "VTune ITT bindings");
|
||||
itt.def("rangePush", itt_range_push);
|
||||
itt.def("rangePop", itt_range_pop);
|
||||
itt.def("mark", itt_mark);
|
||||
}
|
||||
} // namespace profiler
|
||||
} // namespace torch
|
23
torch/csrc/itt_wrapper.cpp
Normal file
23
torch/csrc/itt_wrapper.cpp
Normal file
@ -0,0 +1,23 @@
|
||||
#include <c10/macros/Export.h>
|
||||
#include <ittnotify.h>
|
||||
|
||||
namespace torch {
|
||||
namespace profiler {
|
||||
__itt_domain* _itt_domain = __itt_domain_create("PyTorch");
|
||||
|
||||
TORCH_API void itt_range_push(const char* msg) {
|
||||
__itt_string_handle* hsMsg = __itt_string_handle_create(msg);
|
||||
__itt_task_begin(_itt_domain, __itt_null, __itt_null, hsMsg);
|
||||
}
|
||||
|
||||
TORCH_API void itt_range_pop() {
|
||||
__itt_task_end(_itt_domain);
|
||||
}
|
||||
|
||||
TORCH_API void itt_mark(const char* msg) {
|
||||
__itt_string_handle* hsMsg = __itt_string_handle_create(msg);
|
||||
__itt_task_begin(_itt_domain, __itt_null, __itt_null, hsMsg);
|
||||
__itt_task_end(_itt_domain);
|
||||
}
|
||||
} // namespace profiler
|
||||
} // namespace torch
|
12
torch/csrc/itt_wrapper.h
Normal file
12
torch/csrc/itt_wrapper.h
Normal file
@ -0,0 +1,12 @@
|
||||
#ifndef PROFILER_ITT_H
|
||||
#define PROFILER_ITT_H
|
||||
|
||||
namespace torch {
|
||||
namespace profiler {
|
||||
void itt_range_push(const char* msg);
|
||||
void itt_range_pop();
|
||||
void itt_mark(const char* msg);
|
||||
} // namespace profiler
|
||||
} // namespace torch
|
||||
|
||||
#endif // PROFILER_ITT_H
|
@ -61,26 +61,29 @@ torch::profiler::impl::ProfilerConfig getProfilerConfig() {
|
||||
return state_ptr->config();
|
||||
}
|
||||
|
||||
CUDAStubs::~CUDAStubs() = default;
|
||||
ProfilerStubs::~ProfilerStubs() = default;
|
||||
|
||||
namespace {
|
||||
struct DefaultCUDAStubs : public CUDAStubs {
|
||||
void record(int* /*device*/, CUDAEventStub* /*event*/, int64_t* /*cpu_ns*/)
|
||||
const override {
|
||||
struct DefaultCUDAStubs : public ProfilerStubs {
|
||||
void record(
|
||||
int* /*device*/,
|
||||
ProfilerEventStub* /*event*/,
|
||||
int64_t* /*cpu_ns*/) const override {
|
||||
fail();
|
||||
}
|
||||
float elapsed(const CUDAEventStub* /*event*/, const CUDAEventStub* /*event2*/)
|
||||
const override {
|
||||
float elapsed(
|
||||
const ProfilerEventStub* /*event*/,
|
||||
const ProfilerEventStub* /*event2*/) const override {
|
||||
fail();
|
||||
return 0.f;
|
||||
}
|
||||
void nvtxMarkA(const char* /*name*/) const override {
|
||||
void mark(const char* /*name*/) const override {
|
||||
fail();
|
||||
}
|
||||
void nvtxRangePushA(const char* /*name*/) const override {
|
||||
void rangePush(const char* /*name*/) const override {
|
||||
fail();
|
||||
}
|
||||
void nvtxRangePop() const override {
|
||||
void rangePop() const override {
|
||||
fail();
|
||||
}
|
||||
bool enabled() const override {
|
||||
@ -100,25 +103,82 @@ struct DefaultCUDAStubs : public CUDAStubs {
|
||||
}
|
||||
};
|
||||
|
||||
const DefaultCUDAStubs default_stubs;
|
||||
constexpr const DefaultCUDAStubs* default_stubs_addr = &default_stubs;
|
||||
const DefaultCUDAStubs default_cuda_stubs;
|
||||
constexpr const DefaultCUDAStubs* default_cuda_stubs_addr = &default_cuda_stubs;
|
||||
// Constant initialization, so it is guaranteed to be initialized before
|
||||
// static initialization calls which may invoke registerCUDAMethods
|
||||
inline const CUDAStubs*& cuda_stubs() {
|
||||
static const CUDAStubs* stubs_ =
|
||||
static_cast<const CUDAStubs*>(default_stubs_addr);
|
||||
inline const ProfilerStubs*& cuda_stubs() {
|
||||
static const ProfilerStubs* stubs_ =
|
||||
static_cast<const ProfilerStubs*>(default_cuda_stubs_addr);
|
||||
return stubs_;
|
||||
}
|
||||
|
||||
struct DefaultITTStubs : public ProfilerStubs {
|
||||
void record(
|
||||
int* /*device*/,
|
||||
ProfilerEventStub* /*event*/,
|
||||
int64_t* /*cpu_ns*/) const override {
|
||||
fail();
|
||||
}
|
||||
float elapsed(
|
||||
const ProfilerEventStub* /*event*/,
|
||||
const ProfilerEventStub* /*event2*/) const override {
|
||||
fail();
|
||||
return 0.f;
|
||||
}
|
||||
void mark(const char* /*name*/) const override {
|
||||
fail();
|
||||
}
|
||||
void rangePush(const char* /*name*/) const override {
|
||||
fail();
|
||||
}
|
||||
void rangePop() const override {
|
||||
fail();
|
||||
}
|
||||
bool enabled() const override {
|
||||
return false;
|
||||
}
|
||||
void onEachDevice(std::function<void(int)> /*op*/) const override {
|
||||
fail();
|
||||
}
|
||||
void synchronize() const override {
|
||||
fail();
|
||||
}
|
||||
~DefaultITTStubs() override = default;
|
||||
|
||||
private:
|
||||
void fail() const {
|
||||
AT_ERROR("ITT used in profiler but not enabled.");
|
||||
}
|
||||
};
|
||||
|
||||
const DefaultITTStubs default_itt_stubs;
|
||||
constexpr const DefaultITTStubs* default_itt_stubs_addr = &default_itt_stubs;
|
||||
// Constant initialization, so it is guaranteed to be initialized before
|
||||
// static initialization calls which may invoke registerITTMethods
|
||||
inline const ProfilerStubs*& itt_stubs() {
|
||||
static const ProfilerStubs* stubs_ =
|
||||
static_cast<const ProfilerStubs*>(default_itt_stubs_addr);
|
||||
return stubs_;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const CUDAStubs* cudaStubs() {
|
||||
const ProfilerStubs* cudaStubs() {
|
||||
return cuda_stubs();
|
||||
}
|
||||
|
||||
void registerCUDAMethods(CUDAStubs* stubs) {
|
||||
void registerCUDAMethods(ProfilerStubs* stubs) {
|
||||
cuda_stubs() = stubs;
|
||||
}
|
||||
|
||||
const ProfilerStubs* ittStubs() {
|
||||
return itt_stubs();
|
||||
}
|
||||
|
||||
void registerITTMethods(ProfilerStubs* stubs) {
|
||||
itt_stubs() = stubs;
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
} // namespace profiler
|
||||
} // namespace torch
|
||||
|
@ -23,13 +23,20 @@ enum class C10_API_ENUM ProfilerState {
|
||||
CPU, // CPU-only profiling
|
||||
CUDA, // CPU + CUDA events
|
||||
NVTX, // only emit NVTX markers
|
||||
ITT, // only emit ITT markers
|
||||
KINETO, // use libkineto
|
||||
KINETO_GPU_FALLBACK, // use CUDA events when CUPTI is not available
|
||||
KINETO_ONDEMAND, // run the profiler in on-demand mode
|
||||
NUM_PROFILER_STATES, // must be the last one
|
||||
};
|
||||
|
||||
enum class C10_API_ENUM ActiveProfilerType { NONE = 0, LEGACY, KINETO, NVTX };
|
||||
enum class C10_API_ENUM ActiveProfilerType {
|
||||
NONE = 0,
|
||||
LEGACY,
|
||||
KINETO,
|
||||
NVTX,
|
||||
ITT
|
||||
};
|
||||
|
||||
struct TORCH_API ExperimentalConfig {
|
||||
explicit ExperimentalConfig(
|
||||
@ -130,28 +137,31 @@ TORCH_API ActiveProfilerType profilerType();
|
||||
TORCH_API ProfilerConfig getProfilerConfig();
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// -- CUDA --------------------------------------------------------------------
|
||||
// -- Annotation --------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------
|
||||
using CUDAEventStub = std::shared_ptr<CUevent_st>;
|
||||
using ProfilerEventStub = std::shared_ptr<CUevent_st>;
|
||||
|
||||
struct TORCH_API CUDAStubs {
|
||||
virtual void record(int* device, CUDAEventStub* event, int64_t* cpu_ns)
|
||||
struct TORCH_API ProfilerStubs {
|
||||
virtual void record(int* device, ProfilerEventStub* event, int64_t* cpu_ns)
|
||||
const = 0;
|
||||
virtual float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2)
|
||||
const = 0;
|
||||
virtual void nvtxMarkA(const char* name) const = 0;
|
||||
virtual void nvtxRangePushA(const char* name) const = 0;
|
||||
virtual void nvtxRangePop() const = 0;
|
||||
virtual float elapsed(
|
||||
const ProfilerEventStub* event,
|
||||
const ProfilerEventStub* event2) const = 0;
|
||||
virtual void mark(const char* name) const = 0;
|
||||
virtual void rangePush(const char* name) const = 0;
|
||||
virtual void rangePop() const = 0;
|
||||
virtual bool enabled() const {
|
||||
return false;
|
||||
}
|
||||
virtual void onEachDevice(std::function<void(int)> op) const = 0;
|
||||
virtual void synchronize() const = 0;
|
||||
virtual ~CUDAStubs();
|
||||
virtual ~ProfilerStubs();
|
||||
};
|
||||
|
||||
TORCH_API void registerCUDAMethods(CUDAStubs* stubs);
|
||||
TORCH_API const CUDAStubs* cudaStubs();
|
||||
TORCH_API void registerCUDAMethods(ProfilerStubs* stubs);
|
||||
TORCH_API const ProfilerStubs* cudaStubs();
|
||||
TORCH_API void registerITTMethods(ProfilerStubs* stubs);
|
||||
TORCH_API const ProfilerStubs* ittStubs();
|
||||
|
||||
} // namespace impl
|
||||
} // namespace profiler
|
||||
|
@ -54,8 +54,8 @@ using jit_modules_t = std::vector<std::string>;
|
||||
using extra_args_t = std::unordered_map<std::string, c10::IValue>;
|
||||
|
||||
struct FallbackPair {
|
||||
CUDAEventStub cuda_event_start_ = nullptr;
|
||||
CUDAEventStub cuda_event_end_ = nullptr;
|
||||
ProfilerEventStub cuda_event_start_ = nullptr;
|
||||
ProfilerEventStub cuda_event_end_ = nullptr;
|
||||
};
|
||||
|
||||
template <>
|
||||
|
@ -33,8 +33,8 @@ static inline void cudaCheck(cudaError_t result, const char* file, int line) {
|
||||
}
|
||||
#define TORCH_CUDA_CHECK(result) cudaCheck(result, __FILE__, __LINE__);
|
||||
|
||||
struct CUDAMethods : public CUDAStubs {
|
||||
void record(int* device, CUDAEventStub* event, int64_t* cpu_ns)
|
||||
struct CUDAMethods : public ProfilerStubs {
|
||||
void record(int* device, ProfilerEventStub* event, int64_t* cpu_ns)
|
||||
const override {
|
||||
if (device) {
|
||||
TORCH_CUDA_CHECK(cudaGetDevice(device));
|
||||
@ -52,7 +52,7 @@ struct CUDAMethods : public CUDAStubs {
|
||||
TORCH_CUDA_CHECK(cudaEventRecord(cuda_event_ptr, stream));
|
||||
}
|
||||
|
||||
float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2)
|
||||
float elapsed(const ProfilerEventStub* event, const ProfilerEventStub* event2)
|
||||
const override {
|
||||
TORCH_CUDA_CHECK(cudaEventSynchronize(event->get()));
|
||||
TORCH_CUDA_CHECK(cudaEventSynchronize(event2->get()));
|
||||
@ -63,17 +63,17 @@ struct CUDAMethods : public CUDAStubs {
|
||||
return ms * 1000.0;
|
||||
}
|
||||
|
||||
void nvtxMarkA(const char* name) const override {
|
||||
void mark(const char* name) const override {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
::nvtxMark(name);
|
||||
}
|
||||
|
||||
void nvtxRangePushA(const char* name) const override {
|
||||
void rangePush(const char* name) const override {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
::nvtxRangePushA(name);
|
||||
}
|
||||
|
||||
void nvtxRangePop() const override {
|
||||
void rangePop() const override {
|
||||
::nvtxRangePop();
|
||||
}
|
||||
|
||||
|
55
torch/csrc/profiler/itt.cpp
Normal file
55
torch/csrc/profiler/itt.cpp
Normal file
@ -0,0 +1,55 @@
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/autograd/profiler.h>
|
||||
#include <torch/csrc/itt_wrapper.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace torch {
|
||||
namespace profiler {
|
||||
namespace impl {
|
||||
namespace {
|
||||
|
||||
struct ITTMethods : public ProfilerStubs {
|
||||
void record(int* device, ProfilerEventStub* event, int64_t* cpu_ns)
|
||||
const override {}
|
||||
|
||||
float elapsed(const ProfilerEventStub* event, const ProfilerEventStub* event2)
|
||||
const override {
|
||||
return 0;
|
||||
}
|
||||
|
||||
void mark(const char* name) const override {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
torch::profiler::itt_mark(name);
|
||||
}
|
||||
|
||||
void rangePush(const char* name) const override {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
torch::profiler::itt_range_push(name);
|
||||
}
|
||||
|
||||
void rangePop() const override {
|
||||
torch::profiler::itt_range_pop();
|
||||
}
|
||||
|
||||
void onEachDevice(std::function<void(int)> op) const override {}
|
||||
|
||||
void synchronize() const override {}
|
||||
|
||||
bool enabled() const override {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
struct RegisterITTMethods {
|
||||
RegisterITTMethods() {
|
||||
static ITTMethods methods;
|
||||
registerITTMethods(&methods);
|
||||
}
|
||||
};
|
||||
RegisterITTMethods reg;
|
||||
|
||||
} // namespace
|
||||
} // namespace impl
|
||||
} // namespace profiler
|
||||
} // namespace torch
|
72
torch/csrc/profiler/itt_observer.cpp
Normal file
72
torch/csrc/profiler/itt_observer.cpp
Normal file
@ -0,0 +1,72 @@
|
||||
#include <torch/csrc/profiler/itt_observer.h>
|
||||
|
||||
#include <torch/csrc/profiler/util.h>
|
||||
|
||||
namespace torch {
|
||||
namespace profiler {
|
||||
namespace impl {
|
||||
|
||||
struct ITTThreadLocalState : ProfilerThreadLocalStateBase {
|
||||
explicit ITTThreadLocalState(const ProfilerConfig& config)
|
||||
: ProfilerThreadLocalStateBase(config) {
|
||||
// Only `report_input_shapes` makes sense in this context.
|
||||
TORCH_CHECK(!config.profile_memory);
|
||||
TORCH_CHECK(!config.with_stack);
|
||||
TORCH_CHECK(!config.with_flops);
|
||||
TORCH_CHECK(!config.with_modules);
|
||||
}
|
||||
~ITTThreadLocalState() override = default;
|
||||
|
||||
ActiveProfilerType profilerType() override {
|
||||
return ActiveProfilerType::ITT;
|
||||
}
|
||||
|
||||
void reportMemoryUsage(void*, int64_t, int64_t, int64_t, c10::Device)
|
||||
override {}
|
||||
|
||||
static ITTThreadLocalState* getTLS() {
|
||||
auto tls = ProfilerThreadLocalStateBase::getTLS();
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
tls == nullptr || tls->profilerType() == ActiveProfilerType::ITT);
|
||||
return static_cast<ITTThreadLocalState*>(tls);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool report_input_shapes>
|
||||
std::unique_ptr<at::ObserverContext> enterITT(const at::RecordFunction& fn) {
|
||||
if (ITTThreadLocalState::getTLS() != nullptr) {
|
||||
torch::profiler::impl::ittStubs()->rangePush(fn.name());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void pushITTCallbacks(
|
||||
const ProfilerConfig& config,
|
||||
const std::unordered_set<at::RecordScope>& scopes) {
|
||||
TORCH_CHECK(
|
||||
torch::profiler::impl::ittStubs()->enabled(),
|
||||
"Can't use ITT profiler - PyTorch was compiled without ITT");
|
||||
|
||||
c10::ThreadLocalDebugInfo::_push(
|
||||
c10::DebugInfoKind::PROFILER_STATE,
|
||||
std::make_shared<ITTThreadLocalState>(config));
|
||||
|
||||
auto state_ptr = ITTThreadLocalState::getTLS();
|
||||
TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
|
||||
|
||||
auto handle = at::addThreadLocalCallback(
|
||||
at::RecordFunctionCallback(
|
||||
state_ptr->config().report_input_shapes
|
||||
? &enterITT</*report_input_shapes=*/true>
|
||||
: &enterITT</*report_input_shapes=*/false>,
|
||||
[](const at::RecordFunction&, at::ObserverContext*) {
|
||||
torch::profiler::impl::ittStubs()->rangePop();
|
||||
})
|
||||
.needsInputs(config.report_input_shapes)
|
||||
.scopes(scopes));
|
||||
state_ptr->setCallbackHandle(handle);
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
} // namespace profiler
|
||||
} // namespace torch
|
13
torch/csrc/profiler/itt_observer.h
Normal file
13
torch/csrc/profiler/itt_observer.h
Normal file
@ -0,0 +1,13 @@
|
||||
#include <torch/csrc/profiler/api.h>
|
||||
|
||||
namespace torch {
|
||||
namespace profiler {
|
||||
namespace impl {
|
||||
|
||||
void pushITTCallbacks(
|
||||
const ProfilerConfig& config,
|
||||
const std::unordered_set<at::RecordScope>& scopes);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace profiler
|
||||
} // namespace torch
|
@ -129,7 +129,7 @@ template <bool report_input_shapes>
|
||||
std::unique_ptr<at::ObserverContext> enterNVTX(const at::RecordFunction& fn) {
|
||||
if (NVTXThreadLocalState::getTLS() != nullptr) {
|
||||
auto input_op_ids = getInputTensorOpIds(fn);
|
||||
torch::profiler::impl::cudaStubs()->nvtxRangePushA(
|
||||
torch::profiler::impl::cudaStubs()->rangePush(
|
||||
torch::profiler::impl::getNvtxStr(
|
||||
fn.name(),
|
||||
fn.seqNr(),
|
||||
@ -164,7 +164,7 @@ void pushNVTXCallbacks(
|
||||
? &enterNVTX</*report_input_shapes=*/true>
|
||||
: &enterNVTX</*report_input_shapes=*/false>,
|
||||
[](const at::RecordFunction& fn, at::ObserverContext* ctx) {
|
||||
torch::profiler::impl::cudaStubs()->nvtxRangePop();
|
||||
torch::profiler::impl::cudaStubs()->rangePop();
|
||||
updateOutputTensorTracker(fn);
|
||||
})
|
||||
.needsInputs(config.report_input_shapes)
|
||||
|
@ -16,3 +16,5 @@ from torch.autograd.profiler import record_function
|
||||
__all__ = ['profile', 'schedule', 'supported_activities',
|
||||
'tensorboard_trace_handler', 'ProfilerAction', 'ProfilerActivity',
|
||||
'kineto_available', 'DeviceType', 'record_function', 'ExecutionGraphObserver']
|
||||
|
||||
from . import itt
|
||||
|
56
torch/profiler/itt.py
Normal file
56
torch/profiler/itt.py
Normal file
@ -0,0 +1,56 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
try:
|
||||
from torch._C import _itt
|
||||
except ImportError:
|
||||
class _ITTStub(object):
|
||||
@staticmethod
|
||||
def _fail(*args, **kwargs):
|
||||
raise RuntimeError("ITT functions not installed. Are you sure you have a ITT build?")
|
||||
|
||||
rangePush = _fail
|
||||
rangePop = _fail
|
||||
mark = _fail
|
||||
|
||||
_itt = _ITTStub() # type: ignore[assignment]
|
||||
|
||||
|
||||
__all__ = ['range_push', 'range_pop', 'mark', 'range']
|
||||
|
||||
|
||||
def range_push(msg):
|
||||
"""
|
||||
Arguments:
|
||||
msg (string): ASCII message to associate with range
|
||||
"""
|
||||
return _itt.rangePush(msg)
|
||||
|
||||
|
||||
def range_pop():
|
||||
"""
|
||||
"""
|
||||
return _itt.rangePop()
|
||||
|
||||
|
||||
def mark(msg):
|
||||
"""
|
||||
Describe an instantaneous event that occurred at some point.
|
||||
Arguments:
|
||||
msg (string): ASCII message to associate with the event.
|
||||
"""
|
||||
return _itt.mark(msg)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def range(msg, *args, **kwargs):
|
||||
"""
|
||||
Context manager / decorator that pushes an ITT range at the beginning
|
||||
of its scope, and pops it at the end. If extra arguments are given,
|
||||
they are passed as arguments to msg.format().
|
||||
|
||||
Args:
|
||||
msg (string): message to associate with the range
|
||||
"""
|
||||
range_push(msg.format(*args, **kwargs))
|
||||
yield
|
||||
range_pop()
|
Reference in New Issue
Block a user