mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Changes to support input sequence ID tracking (#70264)
Summary: in the NVTX markers. This feature adds additional information to the NVTX marker string eg seq_ids=[101, 102, 103]. This indicates the sequence id of the op which produced the input tensor based on its position index in the array. In the above example input tensor 0 was produced by the node with sequence id 101, input tensor 1 is from node 102, input tensor 2 is from node with sequence id 103. This is the same way the sizes array is organized. If you know the sequence id of the node and the sequence ids of the input edges, then you have enough information to construct the network graph. Fixes https://github.com/pytorch/pytorch/issues/66105 Pull Request resolved: https://github.com/pytorch/pytorch/pull/70264 Reviewed By: chaekit Differential Revision: D34792707 Pulled By: robieta fbshipit-source-id: 4407b853c929a737505803b0db77a8ecd966cce2 (cherry picked from commit cd3c0c8c9d4d63d7897f60521c407883240d1d5b)
This commit is contained in:
committed by
PyTorch MergeBot
parent
fd4ad5d72c
commit
c0a6add7ee
@ -64,6 +64,31 @@ class TestProfilerCUDA(TestCase):
|
||||
self.assertTrue(not (is_increasing and max_diff > 100 * 1024),
|
||||
msg='memory usage is increasing, {}'.format(str(last_rss)))
|
||||
|
||||
def test_custom_module_input_op_ids(self):
|
||||
class MyFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO):
|
||||
x, = ctx.saved_tensors
|
||||
return x
|
||||
|
||||
def custom_layer(input_ten):
|
||||
return MyFunc.apply(input_ten)
|
||||
|
||||
# Only testing that emit_nvtx runs when
|
||||
# record_shapes option is enabled.
|
||||
with torch.autograd.profiler.emit_nvtx(record_shapes=True) as prof:
|
||||
x = torch.randn(10, 10, requires_grad=True)
|
||||
y = torch.randn(10, 10, requires_grad=True)
|
||||
z = x + y
|
||||
s = custom_layer(z)
|
||||
q = s.sum()
|
||||
q.backward()
|
||||
|
||||
class TestRecordFunction(TestCase):
|
||||
def _record_function_with_param(self):
|
||||
u = torch.randn(3, 4, 5, requires_grad=True)
|
||||
|
@ -682,10 +682,19 @@ PyObject* THPFunction_name(PyObject *self, PyObject* noargs) {
|
||||
PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
|
||||
// save a local copy of seq_id before it gets incremented
|
||||
int seq_id = at::sequence_number::peek();
|
||||
auto info_pair = unpack_input<false>(inputs);
|
||||
UnpackedInput& unpacked_input = info_pair.first;
|
||||
InputFlags& input_info = info_pair.second;
|
||||
|
||||
// Call record function after all the inputs have been decoded, but
|
||||
// before context has been allocated.
|
||||
RECORD_FUNCTION(
|
||||
((PyTypeObject*)cls)->tp_name,
|
||||
std::vector<c10::IValue>(),
|
||||
at::sequence_number::peek());
|
||||
std::vector<c10::IValue>(unpacked_input.input_vars.begin(), unpacked_input.input_vars.end()),
|
||||
seq_id);
|
||||
|
||||
// Temporary hack to improve functorch UX. We'll find a better solution.
|
||||
const auto& functorch_tls = at::functorch::functorchTLSAccessor();
|
||||
@ -702,11 +711,6 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)
|
||||
auto cdata = std::shared_ptr<PyNode>(new PyNode(std::move(ctx_obj)), deleteNode);
|
||||
ctx->cdata = cdata;
|
||||
|
||||
// Prepare inputs and allocate context (grad fn)
|
||||
auto info_pair = unpack_input<false>(inputs);
|
||||
UnpackedInput& unpacked_input = info_pair.first;
|
||||
InputFlags& input_info = info_pair.second;
|
||||
|
||||
// Record input nodes if tracing
|
||||
auto* node = _trace_pre_record(cls, inputs, unpacked_input.input_vars);
|
||||
|
||||
@ -716,6 +720,7 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)
|
||||
ctx->needs_input_grad = input_info.needs_input_grad.release();
|
||||
ctx->is_variable_input = std::move(input_info.is_variable_input);
|
||||
|
||||
|
||||
// Prepend ctx to input_tuple, in preparation for static method call
|
||||
auto num_args = PyTuple_GET_SIZE(inputs);
|
||||
THPObjectPtr ctx_input_tuple(PyTuple_New(num_args + 1));
|
||||
|
@ -30,17 +30,104 @@ struct NVTXThreadLocalState : ProfilerThreadLocalStateBase {
|
||||
tls == nullptr || tls->profilerType() == ActiveProfilerType::NVTX);
|
||||
return static_cast<NVTXThreadLocalState*>(tls);
|
||||
}
|
||||
std::pair<at::RecordFunctionHandle, int> getOpIdFromInput(const at::Tensor& tensor);
|
||||
|
||||
void setProducerTensorMap(at::TensorImpl *tensor, at::RecordFunctionHandle op_id, int output_nr){
|
||||
producer_tensor_map_[(void*)tensor] = std::pair<at::RecordFunctionHandle, int> {op_id, output_nr};
|
||||
}
|
||||
|
||||
protected:
|
||||
// Maps the address of an output Tensor to a unique op id and output
|
||||
// index of the tensor.
|
||||
// at::TensorImpl* is the actual type of the key, but using void*
|
||||
// to indicate the pointer is just being used as a key
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
std::unordered_map<void*, std::pair<at::RecordFunctionHandle, int>> producer_tensor_map_;
|
||||
};
|
||||
|
||||
std::pair<at::RecordFunctionHandle, int> NVTXThreadLocalState::getOpIdFromInput(const at::Tensor& tensor) {
|
||||
std::pair<at::RecordFunctionHandle, int> producer_op_pair(0, -1);
|
||||
if (tensor.defined()) {
|
||||
at::TensorImpl *ten_addr = tensor.unsafeGetTensorImpl();
|
||||
// See if Address is in the map already
|
||||
if (producer_tensor_map_.count((void*)ten_addr) > 0) {
|
||||
producer_op_pair = producer_tensor_map_[(void*)ten_addr];
|
||||
}
|
||||
}
|
||||
return producer_op_pair;
|
||||
}
|
||||
|
||||
std::list<std::pair<at::RecordFunctionHandle, int>> flattenOpIdList(c10::List<c10::IValue> list, std::string fn_name) {
|
||||
std::list<std::pair<at::RecordFunctionHandle, int>> input_op_id_list;
|
||||
auto state_ptr = NVTXThreadLocalState::getTLS();
|
||||
TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
|
||||
for (const c10::IValue input : list) {
|
||||
if (input.isTensor()) {
|
||||
const at::Tensor& tensor = input.toTensor();
|
||||
auto producer_op_pair = state_ptr->getOpIdFromInput(tensor);
|
||||
input_op_id_list.push_back(producer_op_pair);
|
||||
}
|
||||
}
|
||||
return input_op_id_list;
|
||||
}
|
||||
|
||||
std::list<std::pair<at::RecordFunctionHandle, int>> getInputTensorOpIds(const at::RecordFunction& fn) {
|
||||
int num_inputs = fn.inputs().size();
|
||||
std::pair<at::RecordFunctionHandle, int> undefined_op_pair(0,-1);
|
||||
std::list<std::pair<at::RecordFunctionHandle, int>> input_producer_ops_;
|
||||
auto state_ptr = NVTXThreadLocalState::getTLS();
|
||||
TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
|
||||
for (const c10::IValue& input_item : fn.inputs()) {
|
||||
if(input_item.isTensor()) {
|
||||
const at::Tensor& tensor = input_item.toTensor();
|
||||
auto producer_pair = state_ptr->getOpIdFromInput(tensor);
|
||||
input_producer_ops_.push_back(producer_pair);
|
||||
} else {
|
||||
if (input_item.isList()) {
|
||||
std::list<std::pair<at::RecordFunctionHandle, int>> tmp_op_ids = flattenOpIdList(input_item.toList(), std::string(fn.name()));
|
||||
// Extend the current sizes array by the array returned from input sizes
|
||||
if (!tmp_op_ids.empty()) {
|
||||
input_producer_ops_.splice(input_producer_ops_.end(), tmp_op_ids);
|
||||
} else {
|
||||
input_producer_ops_.emplace_back(undefined_op_pair);
|
||||
}
|
||||
} else {
|
||||
input_producer_ops_.emplace_back(undefined_op_pair);
|
||||
}
|
||||
}
|
||||
}
|
||||
return input_producer_ops_;
|
||||
}
|
||||
|
||||
void updateOutputTensorTracker(const at::RecordFunction& fn) {
|
||||
int output_nr = 0;
|
||||
auto state_ptr = NVTXThreadLocalState::getTLS();
|
||||
TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
|
||||
for (const c10::IValue& s_tensor : fn.outputs()){
|
||||
if(s_tensor.isTensor()) {
|
||||
const at::Tensor& tensor = s_tensor.toTensor();
|
||||
if (tensor.defined()) {
|
||||
auto ten_addr = tensor.unsafeGetTensorImpl();
|
||||
state_ptr->setProducerTensorMap(ten_addr, fn.handle(), output_nr);
|
||||
}
|
||||
}
|
||||
output_nr++;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool report_input_shapes>
|
||||
std::unique_ptr<at::ObserverContext> enterNVTX(const at::RecordFunction& fn) {
|
||||
if (NVTXThreadLocalState::getTLS() != nullptr) {
|
||||
auto input_op_ids = getInputTensorOpIds(fn);
|
||||
torch::profiler::impl::cudaStubs()->nvtxRangePushA(
|
||||
torch::profiler::impl::getNvtxStr(
|
||||
fn.name(),
|
||||
fn.seqNr(),
|
||||
report_input_shapes ? torch::profiler::impl::inputSizes(fn)
|
||||
: std::vector<std::vector<int64_t>>())
|
||||
report_input_shapes ? torch::profiler::impl::inputSizes(fn, true)
|
||||
: std::vector<std::vector<int64_t>>(),
|
||||
fn.handle(),
|
||||
report_input_shapes ? input_op_ids
|
||||
: std::list<std::pair<at::RecordFunctionHandle, int>>())
|
||||
.c_str());
|
||||
}
|
||||
return nullptr;
|
||||
@ -65,10 +152,13 @@ void pushNVTXCallbacks(
|
||||
state_ptr->config().report_input_shapes
|
||||
? &enterNVTX</*report_input_shapes=*/true>
|
||||
: &enterNVTX</*report_input_shapes=*/false>,
|
||||
[](const at::RecordFunction&, at::ObserverContext*) {
|
||||
[](const at::RecordFunction& fn, at::ObserverContext *ctx) {
|
||||
torch::profiler::impl::cudaStubs()->nvtxRangePop();
|
||||
updateOutputTensorTracker(fn);
|
||||
})
|
||||
.needsInputs(config.report_input_shapes)
|
||||
.needsOutputs(config.report_input_shapes)
|
||||
.needsIds(true)
|
||||
.scopes(scopes));
|
||||
state_ptr->setCallbackHandle(handle);
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <torch/csrc/profiler/util.h>
|
||||
#include <torch/csrc/autograd/function.h>
|
||||
#include <torch/csrc/profiler/kineto_shim.h>
|
||||
|
||||
#include <c10/util/ArrayRef.h>
|
||||
@ -89,7 +90,9 @@ ApproximateClockToUnixTimeConverter::makeConverter() {
|
||||
std::string getNvtxStr(
|
||||
const char* name,
|
||||
int64_t sequence_nr,
|
||||
const std::vector<std::vector<int64_t>>& shapes) {
|
||||
const std::vector<std::vector<int64_t>>& shapes,
|
||||
at::RecordFunctionHandle op_id,
|
||||
const std::list<std::pair<at::RecordFunctionHandle, int>>& input_op_ids) {
|
||||
if (sequence_nr >= -1 || shapes.size() > 0) {
|
||||
std::string str;
|
||||
if (sequence_nr >= 0) {
|
||||
@ -102,31 +105,17 @@ std::string getNvtxStr(
|
||||
str = name;
|
||||
#endif
|
||||
}
|
||||
if (shapes.size() > 0) {
|
||||
std::stringstream s;
|
||||
s << str;
|
||||
s << ", sizes = [";
|
||||
for (const auto idx : c10::irange(shapes.size())) {
|
||||
if (shapes[idx].size() > 0) {
|
||||
s << "[";
|
||||
for (const auto dim : c10::irange(shapes[idx].size())) {
|
||||
s << shapes[idx][dim];
|
||||
if (dim < shapes[idx].size() - 1) {
|
||||
s << ", ";
|
||||
}
|
||||
}
|
||||
s << "]";
|
||||
} else {
|
||||
s << "[]";
|
||||
}
|
||||
if (idx < shapes.size() - 1) {
|
||||
s << ", ";
|
||||
}
|
||||
}
|
||||
s << "]";
|
||||
return s.str();
|
||||
if (op_id > 0) {
|
||||
str = fmt::format("{}, op_id = {}", str, op_id);
|
||||
}
|
||||
if (shapes.size() > 0) {
|
||||
str = fmt::format("{}, sizes = {}", str, shapesToStr(shapes));
|
||||
}
|
||||
// Include the op ids of the input edges so
|
||||
// you can build the network graph
|
||||
if (input_op_ids.size() > 0) {
|
||||
str = fmt::format("{}, input_op_ids = {}", str, inputOpIdsToStr(input_op_ids));
|
||||
}
|
||||
|
||||
return str;
|
||||
} else {
|
||||
return name;
|
||||
@ -185,17 +174,41 @@ std::string stacksToStr(
|
||||
return "\"" + rc + "\"";
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> inputSizes(const at::RecordFunction& fn) {
|
||||
std::vector<std::vector<int64_t>> flattenList(c10::List<c10::IValue> list, std::string fn_name) {
|
||||
std::vector<std::vector<int64_t>> tensor_dims;
|
||||
for (const c10::IValue input : list) {
|
||||
if (input.isTensor()) {
|
||||
const at::Tensor& tensor = input.toTensor();
|
||||
if (tensor.defined()) {
|
||||
tensor_dims.push_back(input.toTensor().sizes().vec());
|
||||
}
|
||||
}
|
||||
}
|
||||
return tensor_dims;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> inputSizes(const at::RecordFunction& fn, bool flatten_list_enabled) {
|
||||
std::vector<std::vector<int64_t>> sizes;
|
||||
sizes.reserve(fn.inputs().size());
|
||||
for (const c10::IValue& input : fn.inputs()) {
|
||||
if (!input.isTensor()) {
|
||||
sizes.emplace_back();
|
||||
continue;
|
||||
}
|
||||
const at::Tensor& tensor = input.toTensor();
|
||||
if (tensor.defined()) {
|
||||
sizes.push_back(input.toTensor().sizes().vec());
|
||||
if (input.isTensor()) {
|
||||
const at::Tensor& tensor = input.toTensor();
|
||||
if (tensor.defined()) {
|
||||
sizes.push_back(input.toTensor().sizes().vec());
|
||||
} else {
|
||||
sizes.emplace_back();
|
||||
}
|
||||
} else if (input.isList()) {
|
||||
std::vector<std::vector<int64_t>> tmp_sizes;
|
||||
if (flatten_list_enabled) {
|
||||
tmp_sizes = flattenList(input.toList(), std::string(fn.name()));
|
||||
}
|
||||
// Extend the current sizes array by the array returned from input sizes
|
||||
if (!tmp_sizes.empty()) {
|
||||
sizes.insert(sizes.end(), tmp_sizes.begin(), tmp_sizes.end());
|
||||
} else {
|
||||
sizes.emplace_back();
|
||||
}
|
||||
} else {
|
||||
sizes.emplace_back();
|
||||
}
|
||||
@ -204,23 +217,37 @@ std::vector<std::vector<int64_t>> inputSizes(const at::RecordFunction& fn) {
|
||||
}
|
||||
|
||||
std::string shapesToStr(const std::vector<std::vector<int64_t>>& shapes) {
|
||||
std::ostringstream oss;
|
||||
oss << "[";
|
||||
std::string str("[");
|
||||
for (const auto t_idx : c10::irange(shapes.size())) {
|
||||
if (t_idx > 0) {
|
||||
oss << ", ";
|
||||
str = fmt::format("{}, ", str);
|
||||
}
|
||||
oss << "[";
|
||||
str = fmt::format("{}[", str);
|
||||
for (const auto s_idx : c10::irange(shapes[t_idx].size())) {
|
||||
if (s_idx > 0) {
|
||||
oss << ", ";
|
||||
str = fmt::format("{}, ", str);
|
||||
}
|
||||
oss << shapes[t_idx][s_idx];
|
||||
str = fmt::format("{}{}", str, shapes[t_idx][s_idx]);
|
||||
}
|
||||
oss << "]";
|
||||
str = fmt::format("{}]", str);
|
||||
}
|
||||
oss << "]";
|
||||
return oss.str();
|
||||
str = fmt::format("{}]", str);
|
||||
return str;
|
||||
}
|
||||
|
||||
std::string inputOpIdsToStr(const std::list<std::pair<at::RecordFunctionHandle, int>>& input_op_ids) {
|
||||
std::string str("[");
|
||||
int idx = 0;
|
||||
|
||||
for (const auto& op_id_info_pair : input_op_ids) {
|
||||
if (idx++ > 0) {
|
||||
str = fmt::format("{}, ", str);
|
||||
}
|
||||
// (OpId,OutputNr)
|
||||
str = fmt::format("{}({},{})", str, op_id_info_pair.first, op_id_info_pair.second);
|
||||
}
|
||||
str = fmt::format("{}]", str);
|
||||
return str;
|
||||
}
|
||||
|
||||
std::string dtypesToStr(const std::vector<std::string>& types) {
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <list>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <ATen/record_function.h>
|
||||
@ -120,7 +121,9 @@ class ApproximateClockToUnixTimeConverter final {
|
||||
std::string getNvtxStr(
|
||||
const char* name,
|
||||
int64_t sequence_nr,
|
||||
const std::vector<std::vector<int64_t>>& shapes);
|
||||
const std::vector<std::vector<int64_t>>& shapes,
|
||||
at::RecordFunctionHandle op_id = 0,
|
||||
const std::list<std::pair<at::RecordFunctionHandle, int>>& input_op_ids = {});
|
||||
|
||||
struct TORCH_API FileLineFunc {
|
||||
std::string filename;
|
||||
@ -136,10 +139,12 @@ TORCH_API std::string stacksToStr(
|
||||
const std::vector<std::string>& stacks,
|
||||
const char* delim);
|
||||
TORCH_API std::vector<std::vector<int64_t>> inputSizes(
|
||||
const at::RecordFunction& fn);
|
||||
const at::RecordFunction& fn,
|
||||
const bool flatten_list_enabled=false);
|
||||
TORCH_API std::string shapesToStr(
|
||||
const std::vector<std::vector<int64_t>>& shapes);
|
||||
TORCH_API std::string dtypesToStr(const std::vector<std::string>& types);
|
||||
TORCH_API std::string inputOpIdsToStr(const std::list<std::pair<at::RecordFunctionHandle, int>>& input_op_ids);
|
||||
TORCH_API std::vector<std::string> inputTypes(const at::RecordFunction& fn);
|
||||
|
||||
std::unordered_map<std::string, c10::IValue> TORCH_API
|
||||
|
Reference in New Issue
Block a user