mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[RFC] enable oneMKL&oneDNN on-demands verbose functinality (#63212)
**RFC: Problem statement** Intel oneMKL and oneDNN are used to accelerate performance on Intel platforms. Both these 2 libraries provide verbose functionality to dump detailed operator execution information as well as execution time. These verbose messages are very helpful to performance profiling. However, the verbose functionality works for the entire execution. In many scenarios, though, we only would like to profile partial of the execution process. This feature is to expose PyTorch API functions to control oneDNN and oneMKL verbose functionality in runtime. **Additional context** The most used performance profiling steps are shown as the following code snippet: ``` def inference(model, inputs): # step0 (optional): jit model = torch.jit.trace(model, inputs) # step1: warmup for _ in range(100): model(inputs) # step2: performance profiling. We only care the profiling result, as well as oneDNN and oneMKL verbose messages, of this step model(inputs) # step3 (optional): benchmarking t0 = time.time() for _ in range(100): model(inputs) t1 = time.time() print(‘dur: {}’.format((t1-t0)/100)) return model(inputs) ``` Since environment variables MKL_VERBOSE and DNNL_VERBOSE will be effect to the entire progress, we will get a great number of verbose messages for all of 101 iterations (if step3 is not involved). However, we only care about the verbose messages dumped in step2. It is very difficult to filter unnecessary verbose messages out if we are running into a complicated usages scenario. Also, jit trace will also bring more undesired verbose messages. Furthermore, there are more complicated topologies or usages like cascaded topologies as below: ``` model1 = Model1() model2 = Model2() model3 = Model3() x1 = inference(model1, x) x2 = inference(model2, x1) y = inference(model3, x2) ``` There are many cases that it is very hard to split these child topologies out. In this scenario, it is not possible to investigate performance of each individual topology with `DNNL_VERBOSE` and `MKL_VERBOSE`. To solve this issue, oneDNN and oneMKL provide API functions to make it possible to control verbose functionality in runtime. ``` int mkl_verbose (int enable) status dnnl::set_verbose(int level) ``` oneDNN and oneMKL print verbose messages to stdout when oneMKL or oneDNN ops are executed. Sample verbose messages: ``` MKL_VERBOSE SGEMM(t,n,768,2048,3072,0x7fff64115800,0x7fa1aca58040,3072,0x1041f5c0,3072,0x7fff64115820,0x981f0c0,768) 8.52ms CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:44 dnnl_verbose,exec,cpu,inner_product,brgemm:avx512_core,forward_training,src_f32::blocked:ab:f0 wei_f32::blocked:AB16b64a:f0 bia_f32::blocked:a:f0 dst_f32::blocked:ab:f0,,,mb16ic768oc768,0.0839844 ``` **Design and implementation** The design is to make python-interfaced wrap functions to invoke mkl_verbose and dnnl::set_verbose functions. **Design concern** - Need to add wrapper C++ functions for mkl_verbose and dnnl::set_verbose functions in torch/csrc and aten/csrc. - Python API functions will be added to device-specific backends - with torch.backends.mkl.verbose(1): - with torch.backends.mkldnn.verbose(1): **Use cases** ``` def inference(model, inputs): # step0 (optional): jit model = torch.jit.trace(model, inputs) # step1: warmup for _ in range(100): model(inputs) # step2: performance profiling with torch.backends.mkl.verbose(1), torch.backends.mkldnn.verbose(1): model(inputs) # step3 (optional): benchmarking t0 = time.time() for _ in range(100): model(inputs) t1 = time.time() print(‘dur: {}’.format((t1-t0)/100)) return model(inputs) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/63212 Approved by: https://github.com/VitalyFedyunin, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
de9b3fb3e5
commit
0e95746580
@ -117,6 +117,10 @@ ideep::tensor itensor_from_tensor(const Tensor& tensor) {
|
||||
}
|
||||
}
|
||||
|
||||
int set_verbose(int level) {
|
||||
return ideep::utils::set_verbose(level);
|
||||
}
|
||||
|
||||
}}
|
||||
|
||||
#endif // AT_MKLDNN_ENABLED()
|
||||
|
@ -24,6 +24,9 @@ TORCH_API ideep::tensor itensor_view_from_dense(const Tensor& tensor);
|
||||
// Helper function for getting an ideep tensor out of an aten Tensor or MKL-DNN tensor.
|
||||
TORCH_API ideep::tensor itensor_from_tensor(const Tensor& tensor);
|
||||
|
||||
// Set MKLDNN verbose level
|
||||
TORCH_API int set_verbose(int level);
|
||||
|
||||
}}
|
||||
|
||||
#endif // AT_MKLDNN_ENABLED
|
||||
|
32
aten/src/ATen/native/verbose_wrapper.cpp
Normal file
32
aten/src/ATen/native/verbose_wrapper.cpp
Normal file
@ -0,0 +1,32 @@
|
||||
#include <ATen/Config.h>
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
#if AT_MKL_ENABLED()
|
||||
#include <mkl.h>
|
||||
#endif
|
||||
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
#include <ATen/native/mkldnn/MKLDNNCommon.h>
|
||||
#endif
|
||||
|
||||
namespace torch {
|
||||
namespace verbose {
|
||||
|
||||
TORCH_API int _mkl_set_verbose(int enable) {
|
||||
#if AT_MKL_ENABLED()
|
||||
return mkl_verbose(enable);
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_API int _mkldnn_set_verbose(int level) {
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
return at::native::set_verbose(level);
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace verbose
|
||||
} // namespace torch
|
11
aten/src/ATen/native/verbose_wrapper.h
Normal file
11
aten/src/ATen/native/verbose_wrapper.h
Normal file
@ -0,0 +1,11 @@
|
||||
#ifndef VERBOSE_WRAPPER_H
|
||||
#define VERBOSE_WRAPPER_H
|
||||
|
||||
namespace torch {
|
||||
namespace verbose {
|
||||
int _mkl_set_verbose(int enable);
|
||||
int _mkldnn_set_verbose(int level);
|
||||
} // namespace verbose
|
||||
} // namespace torch
|
||||
|
||||
#endif // VERBOSE_WRAPPER_H
|
@ -937,6 +937,7 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/utils/tensor_numpy.cpp",
|
||||
"torch/csrc/utils/tensor_types.cpp",
|
||||
"torch/csrc/utils/disable_torch_function.cpp",
|
||||
"torch/csrc/utils/verbose.cpp",
|
||||
] + lazy_tensor_core_python_sources
|
||||
|
||||
libtorch_python_distributed_core_sources = [
|
||||
@ -1096,6 +1097,7 @@ aten_cpu_source_non_codegen_list = [
|
||||
"aten/src/ATen/nnapi/nnapi_wrapper.cpp",
|
||||
"aten/src/ATen/nnapi/nnapi_model_loader.cpp",
|
||||
"aten/src/ATen/native/prim_native_functions.cpp",
|
||||
"aten/src/ATen/native/verbose_wrapper.cpp",
|
||||
]
|
||||
|
||||
aten_cpu_source_codegen_list = [
|
||||
|
@ -101,6 +101,8 @@ torch.backends.mkl
|
||||
|
||||
.. autofunction:: torch.backends.mkl.is_available
|
||||
|
||||
.. autoclass:: torch.backends.mkl.verbose
|
||||
|
||||
|
||||
torch.backends.mkldnn
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
@ -108,6 +110,8 @@ torch.backends.mkldnn
|
||||
|
||||
.. autofunction:: torch.backends.mkldnn.is_available
|
||||
|
||||
.. autoclass:: torch.backends.mkldnn.verbose
|
||||
|
||||
|
||||
torch.backends.openmp
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
17
test/mkl_verbose.py
Normal file
17
test/mkl_verbose.py
Normal file
@ -0,0 +1,17 @@
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
def run_model(level):
|
||||
m = torch.nn.Linear(20, 30)
|
||||
input = torch.randn(128, 20)
|
||||
with torch.backends.mkl.verbose(level):
|
||||
m(input)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--verbose-level", default=0, type=int)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
run_model(args.verbose_level)
|
||||
except Exception as e:
|
||||
print(e)
|
26
test/mkldnn_verbose.py
Normal file
26
test/mkldnn_verbose.py
Normal file
@ -0,0 +1,26 @@
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Module, self).__init__()
|
||||
self.conv = torch.nn.Conv2d(1, 10, 5, 1)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.conv(x)
|
||||
return y
|
||||
|
||||
def run_model(level):
|
||||
m = Module().eval()
|
||||
d = torch.rand(1, 1, 112, 112)
|
||||
with torch.backends.mkldnn.verbose(level):
|
||||
m(d)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--verbose-level", default=0, type=int)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
run_model(args.verbose_level)
|
||||
except Exception as e:
|
||||
print(e)
|
34
test/test_mkl_verbose.py
Normal file
34
test/test_mkl_verbose.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
class TestMKLVerbose(TestCase):
|
||||
def test_verbose_on(self):
|
||||
num = 0
|
||||
loc = os.path.dirname(os.path.abspath(__file__))
|
||||
with subprocess.Popen(f'{sys.executable} -u {loc}/mkl_verbose.py --verbose-level=1', shell=True,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as p:
|
||||
for line in p.stdout.readlines():
|
||||
line = str(line, 'utf-8').strip()
|
||||
if line.startswith("MKL_VERBOSE"):
|
||||
num = num + 1
|
||||
elif line == 'Failed to set MKL into verbose mode. Please consider to disable this verbose scope.':
|
||||
return
|
||||
self.assertTrue(num > 0, 'oneMKL verbose messages not found.')
|
||||
|
||||
def test_verbose_off(self):
|
||||
num = 0
|
||||
loc = os.path.dirname(os.path.abspath(__file__))
|
||||
with subprocess.Popen(f'{sys.executable} -u {loc}/mkl_verbose.py --verbose-level=0', shell=True,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as p:
|
||||
for line in p.stdout.readlines():
|
||||
line = str(line, 'utf-8').strip()
|
||||
if line.startswith("MKL_VERBOSE"):
|
||||
num = num + 1
|
||||
self.assertEqual(num, 0, 'unexpected oneMKL verbose messages found.')
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
34
test/test_mkldnn_verbose.py
Normal file
34
test/test_mkldnn_verbose.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
class TestMKLDNNVerbose(TestCase):
|
||||
def test_verbose_on(self):
|
||||
num = 0
|
||||
loc = os.path.dirname(os.path.abspath(__file__))
|
||||
with subprocess.Popen(f'{sys.executable} -u {loc}/mkldnn_verbose.py --verbose-level=1', shell=True,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as p:
|
||||
for line in p.stdout.readlines():
|
||||
line = str(line, 'utf-8').strip()
|
||||
if line.startswith("onednn_verbose"):
|
||||
num = num + 1
|
||||
elif line == 'Failed to set MKLDNN into verbose mode. Please consider to disable this verbose scope.':
|
||||
return
|
||||
self.assertTrue(num > 0, 'oneDNN verbose messages not found.')
|
||||
|
||||
def test_verbose_off(self):
|
||||
num = 0
|
||||
loc = os.path.dirname(os.path.abspath(__file__))
|
||||
with subprocess.Popen(f'{sys.executable} -u {loc}/mkldnn_verbose.py --verbose-level=0', shell=True,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as p:
|
||||
for line in p.stdout.readlines():
|
||||
line = str(line, 'utf-8').strip()
|
||||
if line.startswith("onednn_verbose"):
|
||||
num = num + 1
|
||||
self.assertEqual(num, 0, 'unexpected oneDNN verbose messages found.')
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
3
torch/_C/_verbose.pyi
Normal file
3
torch/_C/_verbose.pyi
Normal file
@ -0,0 +1,3 @@
|
||||
# Defined in torch/csrc/utils/verbose.cpp
|
||||
def mkl_set_verbose(enable: int) -> int: ...
|
||||
def mkldnn_set_verbose(level: int) -> int: ...
|
@ -1,6 +1,49 @@
|
||||
import torch
|
||||
|
||||
|
||||
def is_available():
|
||||
r"""Returns whether PyTorch is built with MKL support."""
|
||||
return torch._C.has_mkl
|
||||
|
||||
VERBOSE_OFF = 0
|
||||
VERBOSE_ON = 1
|
||||
class verbose(object):
|
||||
"""
|
||||
On-demand oneMKL verbosing functionality
|
||||
To make it easier to debug performance issues, oneMKL can dump verbose
|
||||
messages containing execution information like duration while executing
|
||||
the kernel. The verbosing functionality can be invoked via an environment
|
||||
variable named `MKL_VERBOSE`. However, this methodology dumps messages in
|
||||
all steps. Those are a large amount of verbose messages. Moreover, for
|
||||
investigating the performance issues, generally taking verbose messages
|
||||
for one single iteration is enough. This on-demand verbosing functionality
|
||||
makes it possible to control scope for verbose message dumping. In the
|
||||
following example, verbose messages will be dumped out for the second
|
||||
inference only.
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
model(data)
|
||||
with torch.backends.mkl.verbose(torch.backends.mkl.VERBOSE_ON):
|
||||
model(data)
|
||||
|
||||
Args:
|
||||
level: Verbose level
|
||||
- ``VERBOSE_OFF``: Disable verbosing
|
||||
- ``VERBOSE_ON``: Enable verbosing
|
||||
"""
|
||||
|
||||
def __init__(self, enable):
|
||||
self.enable = enable
|
||||
|
||||
def __enter__(self):
|
||||
if self.enable == VERBOSE_OFF:
|
||||
return
|
||||
st = torch._C._verbose.mkl_set_verbose(self.enable)
|
||||
assert st, "Failed to set MKL into verbose mode. Please consider to disable this verbose scope."
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
torch._C._verbose.mkl_set_verbose(VERBOSE_OFF)
|
||||
return False
|
||||
|
@ -7,6 +7,52 @@ def is_available():
|
||||
r"""Returns whether PyTorch is built with MKL-DNN support."""
|
||||
return torch._C.has_mkldnn
|
||||
|
||||
VERBOSE_OFF = 0
|
||||
VERBOSE_ON = 1
|
||||
VERBOSE_ON_CREATION = 2
|
||||
class verbose(object):
|
||||
"""
|
||||
On-demand oneDNN (former MKL-DNN) verbosing functionality
|
||||
To make it easier to debug performance issues, oneDNN can dump verbose
|
||||
messages containing information like kernel size, input data size and
|
||||
execution duration while executing the kernel. The verbosing functionality
|
||||
can be invoked via an environment variable named `DNNL_VERBOSE`. However,
|
||||
this methodology dumps messages in all steps. Those are a large amount of
|
||||
verbose messages. Moreover, for investigating the performance issues,
|
||||
generally taking verbose messages for one single iteration is enough.
|
||||
This on-demand verbosing functionality makes it possible to control scope
|
||||
for verbose message dumping. In the following example, verbose messages
|
||||
will be dumped out for the second inference only.
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
model(data)
|
||||
with torch.backends.mkldnn.verbose(torch.backends.mkldnn.VERBOSE_ON):
|
||||
model(data)
|
||||
|
||||
Args:
|
||||
level: Verbose level
|
||||
- ``VERBOSE_OFF``: Disable verbosing
|
||||
- ``VERBOSE_ON``: Enable verbosing
|
||||
- ``VERBOSE_ON_CREATION``: Enable verbosing, including oneDNN kernel creation
|
||||
"""
|
||||
|
||||
def __init__(self, level):
|
||||
self.level = level
|
||||
|
||||
def __enter__(self):
|
||||
if self.level == VERBOSE_OFF:
|
||||
return
|
||||
st = torch._C._verbose.mkldnn_set_verbose(self.level)
|
||||
assert st, "Failed to set MKLDNN into verbose mode. Please consider to disable this verbose scope."
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
torch._C._verbose.mkldnn_set_verbose(VERBOSE_OFF)
|
||||
return False
|
||||
|
||||
def set_flags(_enabled):
|
||||
orig_flags = (torch._C._get_mkldnn_enabled(),)
|
||||
torch._C._set_mkldnn_enabled(_enabled)
|
||||
|
@ -918,6 +918,10 @@ void initIttBindings(PyObject* module);
|
||||
} // namespace torch
|
||||
#endif
|
||||
|
||||
namespace torch {
|
||||
void initVerboseBindings(PyObject* module);
|
||||
} // namespace torch
|
||||
|
||||
static std::vector<PyMethodDef> methods;
|
||||
|
||||
// In Python we can't use the trick of C10_LOG_API_USAGE_ONCE
|
||||
@ -1022,6 +1026,7 @@ PyObject* initModule() {
|
||||
#ifdef USE_CUDA
|
||||
torch::cuda::initModule(module);
|
||||
#endif
|
||||
torch::initVerboseBindings(module);
|
||||
ASSERT_TRUE(THPStorage_init(module));
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
14
torch/csrc/utils/verbose.cpp
Normal file
14
torch/csrc/utils/verbose.cpp
Normal file
@ -0,0 +1,14 @@
|
||||
#include <ATen/native/verbose_wrapper.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch {
|
||||
|
||||
void initVerboseBindings(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
|
||||
auto verbose = m.def_submodule("_verbose", "MKL, MKLDNN verbose");
|
||||
verbose.def("mkl_set_verbose", torch::verbose::_mkl_set_verbose);
|
||||
verbose.def("mkldnn_set_verbose", torch::verbose::_mkldnn_set_verbose);
|
||||
}
|
||||
|
||||
} // namespace torch
|
Reference in New Issue
Block a user