Files
pytorch/torch/csrc/autograd/profiler_kineto.cpp
Mihai Polceanu 6fa3715c12 Expose Kineto event metadata in PyTorch Profiler events (#161624)
## Overview
This PR allows the profiler users to access `Kineto` and `TorchOp` metadata in JSON string format through a new `metadata_json` attribute in `FunctionEvent` objects, which is triggered through a new `expose_kineto_event_metadata` flag in `ExperimentalConfig`.

## Testing
A unit test was added to validate functionality.

## Documentation
Added/updated function doc strings where appropriate.

## Example output
```python
import torch
from torch.profiler import profile

with profile(experimental_config=torch._C._profiler._ExperimentalConfig(expose_kineto_event_metadata=True)) as prof:
    res = torch.mm(torch.rand(1024, 1024), torch.rand(1024, 1024))

for event in prof.events():
    print(f'name: {event.key}, metadata: {event.metadata_json}')
```

```
name: aten::rand, metadata: "Ev Idx": 0
name: aten::empty, metadata: "Ev Idx": 1
name: aten::uniform_, metadata: "Ev Idx": 2
name: aten::rand, metadata: "Ev Idx": 3
name: aten::empty, metadata: "Ev Idx": 4
name: aten::uniform_, metadata: "Ev Idx": 5
name: aten::mm, metadata: "Ev Idx": 6
name: aten::resolve_conj, metadata: "Ev Idx": 7
name: aten::resolve_conj, metadata: "Ev Idx": 8
name: aten::resolve_conj, metadata: "Ev Idx": 9
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161624
Approved by: https://github.com/sraikund16
2025-09-25 14:58:30 +00:00

1171 lines
40 KiB
C++

#include <cstring>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <torch/csrc/autograd/profiler_kineto.h>
#include <c10/macros/Export.h>
#include <c10/util/ApproximateClock.h>
#include <c10/util/Exception.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/irange.h>
#include <c10/util/overloaded.h>
#include <torch/csrc/profiler/api.h>
#include <torch/csrc/profiler/collection.h>
#include <torch/csrc/profiler/containers.h>
#include <torch/csrc/profiler/events.h>
#include <torch/csrc/profiler/kineto_shim.h>
#include <torch/csrc/profiler/orchestration/observer.h>
#include <torch/csrc/profiler/perf.h>
#include <torch/csrc/profiler/standalone/itt_observer.h>
#include <torch/csrc/profiler/standalone/nvtx_observer.h>
#include <torch/csrc/profiler/standalone/privateuse1_observer.h>
#include <torch/csrc/profiler/util.h>
#include <stdexcept>
#include <utility>
#ifdef USE_KINETO
#include <ApproximateClock.h>
#include <libkineto.h>
#include <time_since_epoch.h>
#ifndef _MSC_VER
// TODO: TO be removed, once this properly works from libkineto
// Literal copy-n-paste from third_party/kineto/libkineto/src/WeakSymbols.cpp
extern "C" {
// This function is needed to avoid superfluous dependency on GNU OpenMP library
// when cuPTI is linked statically For more details see
// https://github.com/pytorch/pytorch/issues/51026
__attribute__((weak)) int acc_get_device_type();
__attribute__((weak)) int acc_get_device_type() {
TORCH_CHECK(
false,
"Dummy implementation of acc_get_device_type is not supposed to be called!");
}
} // extern "C"
#endif // _MSC_VER
#endif // USE_KINETO
namespace torch {
namespace autograd::profiler {
namespace {
inline int64_t getTimeNs() {
#ifdef USE_KINETO
return libkineto::timeSinceEpoch(std::chrono::system_clock::now());
#else
return c10::getTime();
#endif // USE_KINETO
}
using torch::profiler::impl::ActiveProfilerType;
using torch::profiler::impl::EventType;
using torch::profiler::impl::ExtraFields;
using torch::profiler::impl::get_record_concrete_inputs_enabled;
using torch::profiler::impl::ivalueListToStr;
using torch::profiler::impl::ivalueToStr;
using torch::profiler::impl::op_input_t;
using torch::profiler::impl::ProfilerStateBase;
using torch::profiler::impl::PyExtraFieldsBase;
using torch::profiler::impl::Result;
using torch::profiler::impl::shape;
using torch::profiler::impl::shapesToStr;
using torch::profiler::impl::stacksToStr;
using torch::profiler::impl::strListToStr;
using torch::profiler::impl::TensorMetadata;
using torch::profiler::impl::variantShapesToStr;
struct OpArgData {
bool hasData;
std::vector<shape> shapes;
std::vector<std::string> dtypes;
std::vector<c10::IValue> concreteInputs;
std::vector<std::vector<int64_t>> shapesForKinetoEvent;
std::vector<shape> strides;
};
auto parseArgData(
const std::vector<op_input_t>& input_shapes,
const std::vector<op_input_t>& concreteInputs) {
if (input_shapes.empty()) {
return OpArgData{false, {}, {}, {}, {}, {}};
}
std::vector<shape> shapes(input_shapes.size());
std::vector<shape> strides(input_shapes.size());
std::vector<std::vector<int64_t>> shapesForKinetoEvent(input_shapes.size());
std::vector<std::string> dtypes(input_shapes.size());
std::vector<c10::IValue> concrete_inputs_list;
for (const auto& i : c10::irange(input_shapes.size())) {
std::visit(
c10::overloaded(
[&](const TensorMetadata& t) {
shapes[i] = t.sizes_;
shapesForKinetoEvent[i] = t.sizes_;
dtypes[i] = std::string(scalarTypeToTypeMeta(t.dtype_).name());
strides[i] = t.strides_;
},
[&](const std::vector<TensorMetadata>& l) {
std::vector<std::vector<int64_t>> shape;
shape.reserve(l.size());
std::vector<std::vector<int64_t>> stride;
stride.reserve(l.size());
for (const auto& t : l) {
shape.emplace_back(t.sizes_);
stride.emplace_back(t.strides_);
}
shapes[i] = shape;
strides[i] = stride;
dtypes[i] = "TensorList";
},
[&](const c10::IValue&) { dtypes[i] = "Scalar"; },
[&](const auto&) {}),
input_shapes[i]);
}
// If we recorded concrete inputs, then parse them
if (input_shapes.size() == concreteInputs.size() && !concreteInputs.empty()) {
concrete_inputs_list.resize(input_shapes.size());
for (const auto& i : c10::irange(input_shapes.size())) {
std::visit(
c10::overloaded(
[&](const c10::IValue& val) { concrete_inputs_list[i] = val; },
[&](const auto&) {}),
input_shapes[i]);
std::visit(
c10::overloaded(
[&](const c10::IValue& val) {
concrete_inputs_list[i] = val;
dtypes[i] = "ScalarList";
},
[&](const auto&) {}),
concreteInputs[i]);
}
}
return OpArgData{
true,
shapes,
dtypes,
concrete_inputs_list,
shapesForKinetoEvent,
strides};
}
struct MetadataBase {
/* implicit */ MetadataBase(const std::shared_ptr<Result>& result)
: kinetoActivity_{result->kineto_activity_} {
if (std::holds_alternative<ExtraFields<EventType::Kineto>>(
result->extra_fields_)) {
// In order to add metadata we have to downcast from
// `libkineto::ITraceActivity` to `libkineto::GenericTraceActivity`. We
// know that all activities provided by PyTorch are of the correct type,
// however Kineto profilers can (and do) add events that inherit directly
// from ITraceActivity. As a result, any Result which was constructed from
// an event that Kineto provided is unsafe to cast.
if (!(SOFT_ASSERT(!hasKinetoActivity()))) {
result->kineto_activity_ = nullptr;
}
kinetoActivity_ = result->kineto_activity_;
}
}
void addMetadata(const std::string& key, const std::string& value) {
if (kinetoActivity_ && !value.empty() && value != "\"\"") {
torch::profiler::impl::kineto::addMetadata(
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<torch::profiler::impl::kineto::activity_t*>(
kinetoActivity_),
key,
value);
}
}
bool hasKinetoActivity() const {
return kinetoActivity_ != nullptr;
}
private:
const torch::profiler::impl::kineto::activity_t* kinetoActivity_{nullptr};
};
struct AddTensorboardFields : public MetadataBase {
AddTensorboardFields(
const std::shared_ptr<Result>& result,
KinetoEvent& kineto_event)
: MetadataBase(result) {
result->visit(*this);
const auto module_hierarchy = kineto_event.moduleHierarchy();
addMetadata("Module Hierarchy", stacksToStr(module_hierarchy.vec(), "."));
addMetadata("Call stack", stacksToStr(kineto_event.stack().vec(), ";"));
result->visit_if_base<PyExtraFieldsBase>([&, this](const auto& i) -> void {
this->addMetadata("Python id", std::to_string(i.id_));
std::optional<std::string> parent_id;
std::shared_ptr<Result> parent = result->parent_.lock();
while (parent && !parent_id.has_value()) {
parent->visit_if_base<PyExtraFieldsBase>(
[&](const auto& j) { parent_id = std::to_string(j.id_); });
parent = parent->parent_.lock();
}
this->addMetadata("Python parent id", parent_id.value_or("null"));
});
}
void operator()(const ExtraFields<EventType::PyCall>& py_call) {
if (py_call.module_.has_value()) {
addMetadata("Python module id", std::to_string(py_call.module_->id_));
}
}
template <typename T>
void operator()(const T&) {}
};
struct AddGenericMetadata : public MetadataBase {
AddGenericMetadata(
std::shared_ptr<Result>& result,
const torch::profiler::impl::ProfilerConfig* config)
: MetadataBase(result), config_(config) {
result->visit(*this);
if (config->experimental_config.verbose) {
result->visit_if_base<PyExtraFieldsBase>(
[&, this](const auto& i) -> void {
this->addMetadata("Python thread", std::to_string(i.python_tid_));
});
}
}
void operator()(ExtraFields<EventType::TorchOp>& op_event) {
const auto arg_data =
parseArgData(op_event.inputs_, op_event.concrete_inputs_);
if (arg_data.hasData) {
if (get_record_concrete_inputs_enabled()) {
addMetadata("Input Dims", variantShapesToStr(arg_data.shapes));
addMetadata("Input Strides", variantShapesToStr(arg_data.strides));
} else {
addMetadata("Input Dims", shapesToStr(arg_data.shapesForKinetoEvent));
}
addMetadata("Input type", strListToStr(arg_data.dtypes));
if (!arg_data.concreteInputs.empty()) {
addMetadata(
"Concrete Inputs", ivalueListToStr(arg_data.concreteInputs));
}
}
// Add metadata for kwinputs if exist
for (const auto& [key, val] : op_event.kwinputs_) {
if (key == "stream" && !val.isInt()) {
LOG(WARNING) << "Inputted stream is not an int for op: "
<< op_event.name_ << " skipping";
continue;
}
// Until needed, lets limit the kwargs to only ints, doubles, strings,
// bools, and list of strings
bool isValidType =
val.isInt() || val.isDouble() || val.isString() || val.isBool();
bool isStringList = false;
if (!isValidType && val.isList()) {
// Check if it's a list of strings
auto list = val.toListRef();
isStringList =
std::all_of(list.begin(), list.end(), [](const c10::IValue& item) {
return item.isString();
});
}
if (!isValidType && !isStringList) {
LOG(WARNING)
<< "Inputted kwarg: " << key
<< " is not an int, double, string, bool, or list of strings for op: "
<< op_event.name_ << " skipping";
continue;
}
if (isStringList) {
// For list of strings, use ivalueListToStr
auto list = val.toListRef();
std::vector<c10::IValue> stringList(list.begin(), list.end());
addMetadata(key, ivalueListToStr(stringList));
} else {
bool isString = val.isString();
addMetadata(key, ivalueToStr(val, isString));
}
}
// Add extra metadata if any
for (const auto& [key, val] : op_event.extra_meta_) {
addMetadata(key, val);
}
if (config_ && !config_->experimental_config.performance_events.empty()) {
auto& event_names = config_->experimental_config.performance_events;
for (const auto i : c10::irange(op_event.perf_event_counters_->size())) {
addMetadata(
event_names[i],
std::to_string((*op_event.perf_event_counters_)[i]));
}
}
// add information about an associated forward op, if a sequence number
// is available (e.g. during training)
if (op_event.sequence_number_ >= 0) {
addMetadata("Fwd thread id", std::to_string(op_event.forward_tid_));
addMetadata("Sequence number", std::to_string(op_event.sequence_number_));
}
addMetadata(
"Record function id", std::to_string(op_event.record_function_id_));
}
void operator()(ExtraFields<EventType::Backend>& backend_event) {
if (!backend_event.backend_.empty()) {
addMetadata("Backend", "\"" + backend_event.backend_ + "\"");
}
}
void operator()(const ExtraFields<EventType::Allocation>& alloc) {
addMetadata("Device Type", std::to_string((int8_t)alloc.device_type_));
addMetadata("Device Id", std::to_string(alloc.device_index_));
addMetadata("Addr", std::to_string(reinterpret_cast<intptr_t>(alloc.ptr_)));
addMetadata("Bytes", std::to_string(alloc.alloc_size_));
addMetadata("Total Allocated", std::to_string(alloc.total_allocated_));
addMetadata("Total Reserved", std::to_string(alloc.total_reserved_));
}
void operator()(const ExtraFields<EventType::OutOfMemory>& alloc) {
addMetadata("Device Type", std::to_string((int8_t)alloc.device_type_));
addMetadata("Device Id", std::to_string(alloc.device_index_));
addMetadata("Bytes", std::to_string(alloc.alloc_size_));
addMetadata("Total Allocated", std::to_string(alloc.total_allocated_));
addMetadata("Total Reserved", std::to_string(alloc.total_reserved_));
}
template <typename T>
void operator()(const T&) {}
private:
/* To get names of the performance events */
const torch::profiler::impl::ProfilerConfig* config_;
};
struct KinetoThreadLocalState : public ProfilerStateBase {
explicit KinetoThreadLocalState(
const ProfilerConfig& config,
std::set<torch::profiler::impl::ActivityType> activities)
: ProfilerStateBase(config),
startTime(getTimeNs()),
recordQueue(config, std::move(activities)) {}
~KinetoThreadLocalState() override = default;
static KinetoThreadLocalState* get(bool global) {
auto* state = ProfilerStateBase::get(/*global=*/global);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
state == nullptr ||
state->profilerType() == ActiveProfilerType::KINETO);
return static_cast<KinetoThreadLocalState*>(state);
}
ActiveProfilerType profilerType() override {
return ActiveProfilerType::KINETO;
}
void reportVulkanEventToProfiler(torch::profiler::impl::vulkan_id_t id) {
if (!config_.disabled()) {
recordQueue.getSubqueue()->emplace_vulkan_event(
c10::getApproximateTime(), id);
}
}
void reportMemoryUsage(
void* ptr,
int64_t alloc_size,
size_t total_allocated,
size_t total_reserved,
c10::Device device) override {
if (config_.profile_memory && !config_.disabled()) {
recordQueue.getSubqueue()->emplace_allocation_event(
c10::getApproximateTime(),
ptr,
alloc_size,
total_allocated,
total_reserved,
device.type(),
device.index());
}
}
void reportOutOfMemory(
int64_t alloc_size,
size_t total_allocated,
size_t total_reserved,
c10::Device device) override {
if (config_.profile_memory && !config_.disabled()) {
recordQueue.getSubqueue()->emplace_ooms_event(
c10::getApproximateTime(),
alloc_size,
total_allocated,
total_reserved,
device.type(),
device.index());
}
}
void setEventPostProcessingCallback(post_process_t&& cb) {
eventPostProcessCb = std::move(cb);
}
void pausePython() {
recordQueue.stop();
}
void resumePython() {
recordQueue.restart();
}
std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>
finalizeTrace() {
auto end_time = getTimeNs();
recordQueue.stop();
std::lock_guard<std::mutex> guard(state_mutex_);
auto converter = clockConverter.makeConverter();
#ifdef USE_KINETO
libkineto::get_time_converter() = converter;
#endif
auto records_and_trace =
recordQueue.getRecords(std::move(converter), startTime, end_time);
materializeOpEvents(records_and_trace.first);
// `kinetoEvents` does not include Python events. Instead it exposes them
// via the `stacks` property.
kinetoEvents.erase(
std::remove_if(
kinetoEvents.begin(),
kinetoEvents.end(),
[](const auto& i) { return i.isPythonFunction(); }),
kinetoEvents.end());
return std::move(records_and_trace.second);
}
template <typename T>
void invokeCallback(T& t) {
if (eventPostProcessCb) {
eventPostProcessCb(t.debug_handle_, t.jit_stack_, t.jit_modules_);
}
}
void materializeOpEvents(std::vector<std::shared_ptr<Result>>& events) {
for (auto& e : events) {
if (e->parent_.expired() && e->deviceType() == c10::DeviceType::CPU) {
eventTree.push_back(e);
}
if (e->finished_) {
e->visit(c10::overloaded(
[this](ExtraFields<EventType::TorchOp>& i) { invokeCallback(i); },
[this](ExtraFields<EventType::Backend>& i) { invokeCallback(i); },
[](auto&) {}));
kinetoEvents.emplace_back(e, config_.experimental_config.verbose);
AddTensorboardFields add_tb(e, kinetoEvents.back());
AddGenericMetadata add_generic(e, &config_);
// It is not safe to use the activity after post processing.
e->kineto_activity_ = nullptr;
}
}
}
uint64_t startTime;
c10::ApproximateClockToUnixTimeConverter clockConverter;
torch::profiler::impl::RecordQueue recordQueue;
std::vector<KinetoEvent> kinetoEvents;
std::vector<experimental_event_t> eventTree;
// Optional, if event post-processing is enabled.
post_process_t eventPostProcessCb;
};
template <bool use_global_state_ptr = false>
std::unique_ptr<at::ObserverContext> onFunctionEnter(
const at::RecordFunction& fn) {
auto state_ptr = KinetoThreadLocalState::get(use_global_state_ptr);
if (!state_ptr) {
return nullptr;
}
return state_ptr->recordQueue.getSubqueue()->begin_op(fn);
}
// @lint-ignore CLANGTIDY clang-diagnostic-unused-parameter
template <bool use_global_state_ptr = false>
void onFunctionExit(
const at::RecordFunction& fn,
at::ObserverContext* ctx_ptr) {
auto state_ptr = KinetoThreadLocalState::get(use_global_state_ptr);
if (!state_ptr) {
return;
}
const auto& config = state_ptr->config();
auto* kineto_ctx_ptr =
static_cast<torch::profiler::impl::KinetoObserverContext*>(ctx_ptr);
TORCH_INTERNAL_ASSERT(kineto_ctx_ptr != nullptr);
kineto_ctx_ptr->event_->end_time_ = c10::getApproximateTime();
if (!config.experimental_config.performance_events.empty()) {
state_ptr->recordQueue.getSubqueue()->disable_perf_profiler(
*kineto_ctx_ptr->event_->counters_);
}
kineto_ctx_ptr->event_->basic_fields_.end_tid_ =
at::RecordFunction::currentThreadId();
if (fn.isNcclMeta()) {
auto& extra_meta = *(kineto_ctx_ptr->event_->extra_nccl_meta_);
// Record only the outputs in this exit callback of the record function
torch::profiler::impl::SaveNcclMetaConfig ncclMetaConfig{
true, false, false, true};
auto additonal_nccl_meta =
torch::profiler::impl::saveNcclMeta(fn, ncclMetaConfig);
extra_meta.insert(additonal_nccl_meta.begin(), additonal_nccl_meta.end());
}
if (config.state == ProfilerState::KINETO_GPU_FALLBACK) {
try {
auto fallback = kineto_ctx_ptr->fallback_;
TORCH_INTERNAL_ASSERT(fallback != nullptr);
torch::profiler::impl::cudaStubs()->record(
nullptr, &fallback->device_event_end_, nullptr);
} catch (const std::exception& e) {
LOG(WARNING) << "Failed to record CUDA event. " << e.what();
}
} else if (config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK) {
auto fallback = kineto_ctx_ptr->fallback_;
TORCH_INTERNAL_ASSERT(fallback != nullptr);
torch::profiler::impl::privateuse1Stubs()->record(
nullptr, &fallback->device_event_end_, nullptr);
}
if (!config.experimental_config.disable_external_correlation) {
if (fn.scope() == at::RecordScope::USER_SCOPE) {
torch::profiler::impl::kineto::popUserCorrelationId();
} else {
torch::profiler::impl::kineto::popCorrelationId();
}
}
}
template <bool use_global_callback = false>
void pushProfilingCallbacks(const std::unordered_set<at::RecordScope>& scopes) {
auto registration_state_ptr =
KinetoThreadLocalState::get(use_global_callback);
TORCH_INTERNAL_ASSERT(registration_state_ptr, "Expected profiler state set");
auto recordFunctionCallback =
at::RecordFunctionCallback(
onFunctionEnter<use_global_callback>,
onFunctionExit<use_global_callback>)
.needsInputs(registration_state_ptr->config().report_input_shapes)
.scopes(scopes);
if constexpr (use_global_callback) {
registration_state_ptr->setCallbackHandle(
at::addGlobalCallback(recordFunctionCallback));
} else {
registration_state_ptr->setCallbackHandle(
at::addThreadLocalCallback(recordFunctionCallback));
}
}
struct ProfilerStateInfo {
std::shared_ptr<KinetoThreadLocalState> state_ptr;
std::unordered_set<at::RecordScope> scopes;
};
std::shared_ptr<ProfilerStateInfo> profiler_state_info_ptr{nullptr};
} // namespace
void reportBackendEventToActiveKinetoProfiler(
const int64_t start_time_us,
const int64_t end_time_us,
const int64_t debug_handle,
const at::RecordScope scope,
const std::string& event_name,
const std::string& backend_name) {
TORCH_INTERNAL_ASSERT(
KinetoThreadLocalState::get(/*global=*/true) == nullptr,
"On-demand profiling does not support post processing callback");
auto state_ptr = KinetoThreadLocalState::get(/*global=*/false);
if (!state_ptr) {
return;
}
state_ptr->recordQueue.getSubqueue()->emplace_backend_event(
start_time_us,
end_time_us,
debug_handle,
scope,
event_name,
backend_name);
/* no support for input shapes now?
if (config.report_input_shapes) {
ctx_ptr->shapes = inputSizes(fn);
ctx_ptr->dtypes = inputTypes(fn);
}
*/
}
void prepareProfiler(
const torch::profiler::impl::ProfilerConfig& config,
const std::set<torch::profiler::impl::ActivityType>& activities) {
if (config.state == ProfilerState::NVTX ||
config.state == ProfilerState::ITT) {
return;
}
TORCH_CHECK(
config.state == ProfilerState::KINETO ||
config.state == ProfilerState::KINETO_GPU_FALLBACK ||
config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK,
"Supported only in Kineto profiler");
torch::profiler::impl::kineto::prepareTrace(
/*cpuOnly=*/!(
at::hasCUDA() || at::hasXPU() || at::hasMTIA() ||
c10::get_privateuse1_backend() != "privateuseone"),
activities,
config.experimental_config,
config.trace_id);
if (!config.experimental_config.performance_events.empty()) {
/* For now only CPU activity is supported */
TORCH_CHECK(
activities.count(torch::autograd::profiler::ActivityType::CPU),
"Cannot run cpu hardware profiler without CPU activities, please only use CPU activity type");
/*
* Sending a warning and passing the non-standard event to the backend
* Backend can abort if the event is not supported.
* TODO Should we gracefully drop the invalid event if we have at least one
* valid?
*/
auto is_standard_event = [](const std::string& event) -> bool {
for (auto e : torch::profiler::ProfilerPerfEvents) {
if (!std::strcmp(event.c_str(), e)) {
return true;
}
}
return false;
};
for (const auto& e : config.experimental_config.performance_events) {
if (!is_standard_event(e)) {
TORCH_WARN("Forwarding a non-standard CPU performance event : ", e);
}
}
}
}
static void toggleTorchOpCollectionDynamic(bool enable) {
auto state_ptr = ProfilerStateBase::get();
if (state_ptr) {
const auto& config = state_ptr->config();
if (enable) {
auto scopes = profiler_state_info_ptr->scopes;
config.global() ? pushProfilingCallbacks</*global=*/true>(scopes)
: pushProfilingCallbacks</*global=*/false>(scopes);
} else {
state_ptr->removeCallback();
}
}
}
// Set this function to be unused as profiler implementation needs more
// refactoring to support Python ops collection dynamic toggling
#ifdef _MSC_VER
#define UNUSED
#else
#define UNUSED __attribute__((unused))
#endif
static UNUSED void togglePythonCollectionDynamic(bool enable) {
auto state_ptr = ProfilerStateBase::get();
if (state_ptr) {
auto global = state_ptr->config().global();
KinetoThreadLocalState* kineto_thread_local_state_ptr =
KinetoThreadLocalState::get(global);
if (enable) {
kineto_thread_local_state_ptr->resumePython();
} else {
kineto_thread_local_state_ptr->pausePython();
}
}
}
static void toggleCPUCollectionDynamic(bool enable) {
toggleTorchOpCollectionDynamic(enable);
// For now we only support Torch Op collection dynamic toggling as
// implementing Python ops would require not only string parsing to get rid of
// the toggling events as well as other unfinished events as well as changes
// in stack logic
// togglePythonCollectionDynamic(enable);
}
void toggleCollectionDynamic(
const bool enable,
const std::set<torch::profiler::impl::ActivityType>& activities) {
if (activities.count(torch::autograd::profiler::ActivityType::CPU) > 0 &&
(activities.count(torch::autograd::profiler::ActivityType::CUDA) == 0 ||
activities.count(torch::autograd::profiler::ActivityType::XPU) == 0)) {
LOG(WARNING)
<< "Toggling CPU activity with GPU activity on may result in traces with GPU events on artibrary tracks";
} else if (
(activities.count(torch::autograd::profiler::ActivityType::CUDA) > 0 ||
activities.count(torch::autograd::profiler::ActivityType::XPU) > 0) &&
activities.count(torch::autograd::profiler::ActivityType::CPU) == 0) {
LOG(WARNING)
<< "Toggling GPU activity with CPU activity on may result in traces with incorrect correlation between CPU and GPU events";
}
for (auto act : activities) {
if (act == torch::autograd::profiler::ActivityType::CUDA ||
act == torch::autograd::profiler::ActivityType::XPU) {
torch::profiler::impl::kineto::toggleCollectionDynamic(enable);
} else if (act == torch::autograd::profiler::ActivityType::CPU) {
toggleCPUCollectionDynamic(enable);
} else {
LOG(WARNING)
<< "Dynamic toggle is only supported for CPU/GPU activity, skipping toggling of "
<< actToString(act);
continue;
}
}
}
void enableProfilerWithEventPostProcess(
const torch::profiler::impl::ProfilerConfig& config,
const std::set<torch::profiler::impl::ActivityType>& activities,
post_process_t&& cb,
const std::unordered_set<at::RecordScope>& scopes) {
TORCH_CHECK(
config.state != ProfilerState::NVTX,
"NVTX does not support post processing callback.");
TORCH_CHECK(
config.state != ProfilerState::ITT,
"ITT does not support post processing callback.");
TORCH_INTERNAL_ASSERT(
KinetoThreadLocalState::get(/*global=*/true) == nullptr,
"On-demand profiling does not support post processing callback");
enableProfiler(config, activities, scopes);
auto state_ptr = KinetoThreadLocalState::get(config.global());
state_ptr->setEventPostProcessingCallback(std::move(cb));
}
void enableProfiler(
const torch::profiler::impl::ProfilerConfig& config,
const std::set<torch::profiler::impl::ActivityType>& activities,
const std::unordered_set<at::RecordScope>& scopes) {
const auto has_cpu = activities.count(ActivityType::CPU);
TORCH_CHECK(
KinetoThreadLocalState::get(/*global=*/config.global()) == nullptr,
"Profiler is already enabled",
(config.global() ? "." : " on this thread."));
if (config.state == ProfilerState::NVTX) {
torch::profiler::impl::pushNVTXCallbacks(config, scopes);
return;
} else if (config.state == ProfilerState::ITT) {
torch::profiler::impl::pushITTCallbacks(config, scopes);
return;
} else if (config.state == ProfilerState::PRIVATEUSE1) {
torch::profiler::impl::pushPRIVATEUSE1CallbacksStub(config, scopes);
return;
}
TORCH_CHECK(
config.state == ProfilerState::KINETO ||
config.state == ProfilerState::KINETO_GPU_FALLBACK ||
config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK ||
config.global());
TORCH_CHECK(!activities.empty(), "No activities specified.");
TORCH_INTERNAL_ASSERT(
has_cpu || !config.global(),
"Ondemand profiling must enable CPU tracing");
auto state_ptr = std::make_shared<KinetoThreadLocalState>(config, activities);
KinetoThreadLocalState::push(state_ptr);
if (has_cpu) {
config.pushGlobalCallbacks()
? pushProfilingCallbacks</*global=*/true>(scopes)
: pushProfilingCallbacks</*global=*/false>(scopes);
}
if (!config.global()) {
torch::profiler::impl::kineto::startTrace();
}
if (has_cpu) {
auto state_info_ptr = std::make_shared<ProfilerStateInfo>();
state_info_ptr->state_ptr = state_ptr;
state_info_ptr->scopes = scopes;
profiler_state_info_ptr = state_info_ptr;
}
}
bool isProfilerEnabledInMainThread() {
return profiler_state_info_ptr != nullptr;
}
void enableProfilerInChildThread() {
auto state_info_ptr = profiler_state_info_ptr;
TORCH_CHECK(state_info_ptr, "Profiler is not enabled in main thread.");
TORCH_CHECK(
KinetoThreadLocalState::get(/*global=*/false) == nullptr,
"Profiler is already enabled in this thread.");
KinetoThreadLocalState::push(state_info_ptr->state_ptr);
pushProfilingCallbacks</*global=*/false>(state_info_ptr->scopes);
}
void disableProfilerInChildThread() {
auto state_ptr = ProfilerStateBase::pop();
TORCH_CHECK(
state_ptr,
"Can't disable Kineto profiler when it's not running in this thread");
state_ptr->removeCallback();
}
std::unique_ptr<ProfilerResult> disableProfiler() {
// releasing to inform child threads to stop profiling
profiler_state_info_ptr = nullptr;
auto state_ptr = ProfilerStateBase::pop();
const auto& config = state_ptr->config();
TORCH_CHECK(
state_ptr &&
(config.state == ProfilerState::KINETO ||
config.state == ProfilerState::KINETO_GPU_FALLBACK ||
config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK ||
config.state == ProfilerState::KINETO_ONDEMAND ||
config.state == ProfilerState::NVTX ||
config.state == ProfilerState::ITT ||
config.state == ProfilerState::PRIVATEUSE1),
"Can't disable Kineto profiler when it's not running");
state_ptr->removeCallback();
// Traces are converged via libkineto automatically for ondemand flow
if (state_ptr->config().global()) {
(void)std::static_pointer_cast<KinetoThreadLocalState>(state_ptr)
->finalizeTrace();
return std::make_unique<ProfilerResult>();
}
// Shared among NVTX, PRIVATEUSE1, KINETO, KINETO_GPU_FALLBACK,
// KINETO_PRIVATEUSE1_FALLBACK
std::unique_ptr<ProfilerResult> result;
if (state_ptr->config().state == ProfilerState::NVTX ||
state_ptr->config().state == ProfilerState::PRIVATEUSE1) {
result = std::make_unique<ProfilerResult>();
}
if (config.state == ProfilerState::KINETO ||
config.state == ProfilerState::KINETO_GPU_FALLBACK ||
config.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK) {
auto kineto_state_ptr =
std::static_pointer_cast<KinetoThreadLocalState>(state_ptr);
auto trace = kineto_state_ptr->finalizeTrace();
result = std::make_unique<ProfilerResult>(
kineto_state_ptr->startTime,
std::move(kineto_state_ptr->kinetoEvents),
std::move(trace),
std::move(kineto_state_ptr->eventTree));
}
return result;
}
namespace tracer = torch::profiler::impl::python_tracer;
static std::unique_ptr<tracer::PythonMemoryTracerBase> memory_tracer;
void startMemoryProfile() {
if (memory_tracer == nullptr) {
memory_tracer = tracer::PythonMemoryTracerBase::make();
}
memory_tracer->start();
}
void stopMemoryProfile() {
memory_tracer->stop();
}
void exportMemoryProfile(const std::string& filename) {
memory_tracer->export_memory_history(filename);
}
KinetoEvent::KinetoEvent(
const std::shared_ptr<const torch::profiler::impl::Result>& result,
const bool verbose)
: result_{result} {
TORCH_INTERNAL_ASSERT(result != nullptr);
if (verbose) {
// Populate Python stack
auto parent = result_->parent_.lock();
while (parent != nullptr) {
parent->visit_if_base<PyExtraFieldsBase>(
[&](const auto&) { python_stack_.push_back(parent->name()); });
parent = parent->parent_.lock();
}
}
result->visit_if_base<ExtraFields<EventType::TorchOp>>([&](const auto& op) {
auto arg_data = parseArgData(op.inputs_, op.concrete_inputs_);
shapes_ = std::move(arg_data.shapesForKinetoEvent);
dtypes_ = std::move(arg_data.dtypes);
concrete_inputs_ = std::move(arg_data.concreteInputs);
kwinputs_ = std::move(op.kwinputs_);
});
}
bool KinetoEvent::isPythonFunction() const {
bool out{false};
result_->visit_if_base<PyExtraFieldsBase>([&](const auto&) { out = true; });
return out;
}
bool KinetoEvent::hasShapes() const {
return !shapes_.empty();
}
const c10::ArrayRef<std::vector<int64_t>> KinetoEvent::shapes() const {
return shapes_;
}
bool KinetoEvent::hasTypes() const {
return !dtypes_.empty();
}
const c10::ArrayRef<std::string> KinetoEvent::dtypes() const {
return dtypes_;
}
bool KinetoEvent::hasConcreteInputs() const {
return !concrete_inputs_.empty();
}
const c10::ArrayRef<c10::IValue> KinetoEvent::concreteInputs() const {
return concrete_inputs_;
}
bool KinetoEvent::hasKwinputs() const {
return !kwinputs_.empty();
}
bool KinetoEvent::isHiddenEvent() const {
return result_ && result_->hidden_;
}
const std::unordered_map<std::string, c10::IValue> KinetoEvent::kwinputs()
const {
return kwinputs_;
}
const c10::ArrayRef<std::string> KinetoEvent::stack() const {
auto get = [&](const auto& i) -> auto& {
return !i.jit_stack_.empty() ? i.jit_stack_ : python_stack_;
};
auto const& extra_fields = result_->extra_fields_;
if (auto p = std::get_if<ExtraFields<EventType::TorchOp>>(&extra_fields)) {
return get(*p);
}
if (auto p = std::get_if<ExtraFields<EventType::Backend>>(&extra_fields)) {
return get(*p);
}
return python_stack_;
}
const c10::ArrayRef<std::string> KinetoEvent::moduleHierarchy() const {
auto const& extra_fields = result_->extra_fields_;
if (auto p = std::get_if<ExtraFields<EventType::TorchOp>>(&extra_fields)) {
return p->jit_modules_;
}
if (auto p = std::get_if<ExtraFields<EventType::Backend>>(&extra_fields)) {
return p->jit_modules_;
}
return {};
}
uint64_t KinetoEvent::endNs() const {
return result_->endTimeNS();
}
uint64_t KinetoEvent::durationNs() const {
return (result_->endTimeNS() - result_->start_time_ns_);
}
int64_t KinetoEvent::debugHandle() const {
return result_->visit(c10::overloaded(
[](const ExtraFields<EventType::TorchOp>& i) { return i.debug_handle_; },
[](const ExtraFields<EventType::Backend>& i) { return i.debug_handle_; },
[](const auto&) -> int64_t { return -1; }));
}
int KinetoEvent::deviceIndex() const {
return result_->visit(c10::overloaded(
[](const ExtraFields<EventType::Allocation>& i) {
return static_cast<int>(i.device_index_);
},
[](const ExtraFields<EventType::OutOfMemory>& i) {
return static_cast<int>(i.device_index_);
},
[&](const auto&) {
return static_cast<int>(result_->kineto_info_.device);
}));
}
bool KinetoEvent::hasStack() const {
return !stack().empty();
}
int64_t KinetoEvent::cudaElapsedUs() const {
auto cuda_event_start = fallbackStart();
auto cuda_event_end = fallbackEnd();
if (!cuda_event_start || !cuda_event_end) {
return -1;
}
try {
return (int64_t)torch::profiler::impl::cudaStubs()->elapsed(
&cuda_event_start, &cuda_event_end);
} catch (std::exception& e) {
LOG(WARNING) << "Failed to measure time between two CUDA events. "
<< e.what();
}
return -1;
}
int64_t KinetoEvent::privateuse1ElapsedUs() const {
auto privateuse1_event_start = fallbackStart();
auto privateuse1_event_end = fallbackEnd();
if (!privateuse1_event_start || !privateuse1_event_end) {
return -1;
}
return (int64_t)torch::profiler::impl::privateuse1Stubs()->elapsed(
&privateuse1_event_start, &privateuse1_event_end);
return -1;
}
void KinetoEvent::getPerfEventCounters(std::vector<uint64_t>& in) const {
return result_->visit(c10::overloaded(
[&in](const ExtraFields<EventType::TorchOp>& e) -> void {
const size_t n = e.perf_event_counters_->size();
// should be rare
if (in.size() < n) {
in.resize(n, 0);
}
for (size_t i = 0; i < n; ++i) {
in[i] = (*e.perf_event_counters_)[i];
}
},
[](const auto&) -> void { return; }));
}
std::string KinetoEvent::metadataJson() const {
return result_->visit(c10::overloaded(
[](const ExtraFields<EventType::TorchOp>& op) -> std::string {
return op.metadata_json_;
},
[](const ExtraFields<EventType::Kineto>& op) -> std::string {
return op.metadata_json_;
},
[](const auto&) -> std::string { return std::string(""); }));
}
#define FORWARD_FROM_RESULT(method_name, result_expr) \
decltype(std::declval<KinetoEvent>().method_name()) \
KinetoEvent::method_name() const { \
return static_cast<decltype(std::declval<KinetoEvent>().method_name())>( \
result_->result_expr); \
}
FORWARD_FROM_RESULT(startThreadId, start_tid_)
FORWARD_FROM_RESULT(endThreadId, endTID())
FORWARD_FROM_RESULT(activityType, kinetoType())
FORWARD_FROM_RESULT(name, name())
FORWARD_FROM_RESULT(overload_name, overload_name())
FORWARD_FROM_RESULT(deviceType, deviceType())
FORWARD_FROM_RESULT(startNs, start_time_ns_)
FORWARD_FROM_RESULT(correlationId, correlationID())
FORWARD_FROM_RESULT(deviceResourceId, kineto_info_.resource)
#undef FORWARD_FROM_RESULT
// Most of the fields in `KinetoEvent` only make sense for a single event type.
// (Generally TorchOp.) For all other types they simply return the default
// value. This macro provides a succinct way of expressing this behavior.
#define TYPED_ATTR_WITH_DEFAULT( \
event_type, method_name, expression, default_value) \
decltype(std::declval<KinetoEvent>().method_name()) \
KinetoEvent::method_name() const { \
using out_t = decltype(std::declval<KinetoEvent>().method_name()); \
return result_->visit(c10::overloaded( \
[](const ExtraFields<EventType::event_type>& e) -> out_t { \
return expression; \
}, \
[](const auto&) -> out_t { return default_value; })); \
}
#define TYPED_ATTR(event_type, method_name, expression) \
TYPED_ATTR_WITH_DEFAULT(event_type, method_name, expression, {})
TYPED_ATTR_WITH_DEFAULT(TorchOp, sequenceNr, e.sequence_number_, -1)
TYPED_ATTR(TorchOp, fwdThreadId, e.sequence_number_ >= 0 ? e.forward_tid_ : 0)
TYPED_ATTR(TorchOp, scope, static_cast<uint8_t>(e.scope_))
TYPED_ATTR(TorchOp, hasModuleHierarchy, !e.jit_modules_.empty())
TYPED_ATTR(TorchOp, isAsync, e.is_async_)
TYPED_ATTR(TorchOp, extraMeta, e.extra_meta_)
TYPED_ATTR(TorchOp, fallbackStart, e.device_fallback_.device_event_start_)
TYPED_ATTR(TorchOp, fallbackEnd, e.device_fallback_.device_event_end_)
TYPED_ATTR(
TorchOp,
flops,
!e.extra_args_.empty()
? torch::profiler::impl::computeFlops(e.name_, e.extra_args_)
: 0)
TYPED_ATTR(Backend, backend, e.backend_)
TYPED_ATTR(Allocation, nBytes, e.alloc_size_)
TYPED_ATTR(Kineto, linkedCorrelationId, [&]() {
const auto linked = e.linked_activity_.lock();
return linked ? linked->correlationID() : 0;
}())
#undef TYPED_ATTR
#undef TYPED_ATTR_WITH_DEFAULT
ProfilerResult::ProfilerResult(
uint64_t start_time,
std::vector<KinetoEvent> events,
std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>&&
trace,
std::vector<experimental_event_t>&& event_tree)
: trace_start_ns_(start_time),
events_(std::move(events)),
trace_(std::move(trace)),
event_tree_(std::move(event_tree)) {}
ProfilerResult::ProfilerResult() = default;
ProfilerResult::~ProfilerResult() = default;
void ProfilerResult::save(const std::string& path) {
trace_->save(path);
}
} // namespace autograd::profiler
namespace profiler::impl {
void _reportVulkanEventToProfiler(vulkan_id_t id) {
auto state_ptr = ::torch::autograd::profiler::KinetoThreadLocalState::get(
/*global=*/false);
if (state_ptr) {
state_ptr->reportVulkanEventToProfiler(id);
}
}
} // namespace profiler::impl
} // namespace torch