Files
pytorch/torch/csrc/autograd/profiler_python.cpp
Shivam Raikundalia 3373b074f5 [Profiler] Add GC Events to Python Stack Tracer (#161209)
Summary:
Adds Python Garbage Collection to Kineto Traces and Profiler FunctionEvents. Create custom cpp callback in profiler_python.cpp. Then define a python function with cpp and register that callback for all python garbage collection. We don't worry about thread safety in this case because we are only doing init/teardown for main thread while holding GIL.

Currently we are hiding this behind experimental config because python tracing tends to be unstable especially when adding any new feature. If this is found to not add too much overhead we can set this to on by default. NOTE: To enable this you need both with_stack=True and the experimental config on!

Test Plan:
Ran trace with GC induced and saw it on trace

Also added a test

Rollback Plan:

Differential Revision: D80491146

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161209
Approved by: https://github.com/ngimel
2025-08-22 22:11:25 +00:00

1600 lines
52 KiB
C++

#include <torch/csrc/autograd/profiler_python.h>
#include <atomic>
#include <cstdint>
#include <deque>
#include <limits>
#include <memory>
#include <queue>
#include <string>
#include <utility>
#include <vector>
#include <Python.h>
#include <frameobject.h>
#include <ATen/core/TensorBase.h>
#include <c10/macros/Macros.h>
#include <c10/util/ApproximateClock.h>
#include <c10/util/Exception.h>
#include <c10/util/Logging.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/irange.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/profiler/collection.h>
#include <torch/csrc/profiler/containers.h>
#include <torch/csrc/profiler/orchestration/python_tracer.h>
#include <torch/csrc/profiler/util.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_compat.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_strings.h>
#include <optional>
namespace py = pybind11;
namespace torch::profiler::impl {
namespace {
enum CallType { PyCall = 0, PyModuleCall, PyCCall, PyOptimizerCall };
static constexpr size_t CallTypeSize = 4;
using no_ephemeral_t = std::tuple<>;
static constexpr uint64_t NoTID = std::numeric_limits<uint64_t>::max();
// ============================================================================
// == Miscellaneous structs and utils =========================================
// ============================================================================
struct CodeLocation {
CodeLocation() = default;
explicit CodeLocation(PyFrameObject* frame)
: line_number_{PyFrame_GetLineNumber(frame)} {
auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
filename_ = THPUtils_unpackStringView(code->co_filename).data();
name_ = THPUtils_unpackStringView(code->co_name).data();
}
bool operator==(const CodeLocation& other) const {
return filename_ == other.filename_ && name_ == other.name_ &&
line_number_ == other.line_number_;
}
const char* filename_{nullptr};
const char* name_{nullptr};
int line_number_{0};
};
template <CallType C>
PyCodeObject* getCode();
template <>
PyCodeObject* getCode<CallType::PyModuleCall>() {
static auto module_call_code = []() {
pybind11::gil_scoped_acquire gil;
auto res = py::module::import("torch.nn")
.attr("Module")
.attr("__call__")
.attr("__code__")
.ptr();
TORCH_INTERNAL_ASSERT(PyCode_Check(res));
return (PyCodeObject*)res;
}();
return module_call_code;
}
template <>
PyCodeObject* getCode<CallType::PyOptimizerCall>() {
static auto optimizer_step_code = []() {
pybind11::gil_scoped_acquire gil;
auto res = py::module::import("torch.optim")
.attr("Optimizer")
.attr("_optimizer_step_code")
.attr("__code__")
.ptr();
TORCH_INTERNAL_ASSERT(PyCode_Check(res));
return (PyCodeObject*)res;
}();
return optimizer_step_code;
}
} // namespace
} // namespace torch::profiler::impl
template <>
struct std::hash<torch::profiler::impl::CodeLocation> {
size_t operator()(const torch::profiler::impl::CodeLocation& x) {
return c10::get_hash(x.filename_, x.name_, x.line_number_);
}
};
namespace torch::profiler::impl {
namespace {
// ============================================================================
// == CallTypeHelper: Tools for generic programming on specializations. =======
// ============================================================================
template <template <CallType> class ClassT>
class CallTypeHelper final {
private:
static_assert(
CallType::PyCall == 0,
"CallTypeHelper uses integer math which depends on a zero start.");
static constexpr size_t End = CallTypeSize;
template <size_t... I>
static constexpr std::tuple<ClassT<(CallType)I>...> make_tuple_impl(
std::index_sequence<I...>);
template <size_t C, typename T, typename FunctorT, typename... Args>
static void map(T& t, FunctorT& f, Args&&... args) {
f(std::get<C>(t), args...);
if constexpr (C + 1 < End) {
map<C + 1>(t, f, std::forward<Args>(args)...);
}
}
public:
using tuple_type = decltype(make_tuple_impl(std::make_index_sequence<End>{}));
template <typename FunctorT, typename... Args>
static void map(tuple_type& t, FunctorT& f, Args&&... args) {
map<0>(t, f, std::forward<Args>(args)...);
}
};
// ============================================================================
// == Event type definitions. =================================================
// ============================================================================
// When we are tracing a Python program, the general procedure is to record
// every time we enter or exit a function and later replay these events during
// post processing. Thus, during the profiling phase we want to do the MINIMAL
// amount of work to capture all of the information that we need; otherwise we
// will distort the profile. (While we don't wish to be terribly inefficient
// during post processing, we are willing to do extra fixup work in post if it
// reduces overhead in the profiling phase.)
//
// When the tracer first enters a frame, it constructs a CallKey for that
// location. The contents of the key vary by context. For a python function
// the key is the (PyCodeObject*, int) pair that defines the bytecode of the
// function. For an `nn.Module` the key is a (non-owning) pointer to `self`.
// For a bound C function it is a (non-owning) pointer to the bound function.
// A CallKey should be small, inexpensive, and POD.
//
// We then collect a CallKey<CallType::PyCall> for the calling frame for better
// source tracking. This pair is a `Callsite`, and serves as a first level key
// during tracing. We lookup the Callsite in a thread local cache which maps
// Callsite to a unique integer `TraceKey`. On a cache hit, we simply store the
// TraceKey and return. On a cache miss, we use a global value cache to store
// whatever fields we need from the two CallKeys, generate a new TraceKey, and
// update the local cache.
//
// During post processing we:
// 1) Determine the type represented by a TraceKey by checking which
// sub-cache it appears in in the thread local cache.
// 2) Look up the pair of CallKeys from the thread local cache.
// 3) Look up the expanded values of each CallKey from the global value cache.
//
// To add a new event type to the cache:
// 1) Add an entry to the `CallType` enum.
// 2) Add a specialization of Config which defined key_t, ephemeral_t and
// cache_t.
// 3) Add a specialization of ValueCache::store and ValueCache::load.
//
// -------------------------
// -- Ephemeral arguments --
// -------------------------
// The value cache mechanism assumes that `key_t` is enough to specify the
// correct value. However it may not be possible to materialize a value using
// only an instance of `key_t`. As a result, the cache also accepts "ephemeral"
// inputs which can be used to populate the value cache. Ephemeral inputs come
// with two caveats:
// 1) They are NOT safe to save, and cannot be used after `ValueCache::store`.
// 2) They should be used to access data that is not expect to change from
// call to call, such as the name of a function.
template <CallType>
struct Config;
template <>
struct Config<CallType::PyCall> {
using key_t = CodeLocation;
using ephemeral_t = no_ephemeral_t;
using cache_t = ska::flat_hash_map<key_t, PyFrameState>;
static constexpr EventType event_type = EventType::PyCall;
};
template <typename Key, typename Cls, typename ParameterInfo>
struct ExtendedPyCallConfig {
using key_t = Key;
using cls_t = Cls;
using ephemeral_t = PyFrameObject*;
struct ClsAndParameters {
cls_t cls_;
std::vector<ParameterInfo> parameters_;
};
struct Cache {
// `nn.Module.forward` or `optim.Optimizer._optimizer_step_code`
std::optional<CodeLocation> location_;
ska::flat_hash_map<key_t, ClsAndParameters> cls_and_parameters_;
ska::flat_hash_map<cls_t, at::StringView> cls_names_;
};
using cache_t = Cache;
static constexpr EventType event_type = EventType::PyCall;
};
template <>
struct Config<CallType::PyModuleCall> : ExtendedPyCallConfig<
PyModuleSelf,
PyModuleCls,
NNModuleInfo::ParameterInfo> {};
template <>
struct Config<CallType::PyOptimizerCall> : ExtendedPyCallConfig<
PyOptimizerSelf,
PyOptimizerCls,
OptimizerInfo::ParameterInfo> {};
template <>
struct Config<CallType::PyCCall> {
using key_t = PyMethod;
using ephemeral_t = PyObject*;
using cache_t = ska::flat_hash_map<key_t, at::StringView>;
static constexpr EventType event_type = EventType::PyCCall;
};
// ============================================================================
// == Callsite & ValueCache: Storage during profiling =========================
// ============================================================================
template <CallType C>
class Callsite {
public:
static constexpr CallType call_type = C;
using key_t = typename Config<C>::key_t;
static_assert(
std::is_trivially_copyable_v<key_t>,
"Key should be trivial, as it is passed by value.");
template <typename U>
Callsite(U value, PyFrameObject* f_back) : value_(value), caller_(f_back) {}
bool operator==(const Callsite<C>& other) const {
return value_ == other.value_ && caller_ == other.caller_;
}
key_t value_;
Config<CallType::PyCall>::key_t caller_;
};
// ============================================================================
// == Type specific store and load implementations. ===========================
// ============================================================================
using PyCallKey = Config<CallType::PyCall>::key_t;
using PyModuleCallKey = Config<CallType::PyModuleCall>::key_t;
using PyCCallKey = Config<CallType::PyCCall>::key_t;
using PyOptimizerCallKey = Config<CallType::PyOptimizerCall>::key_t;
class ValueCache {
public:
ValueCache() = default;
ValueCache(const ValueCache&) = delete;
ValueCache& operator==(const ValueCache&) = delete;
ValueCache(ValueCache&&) = default;
ValueCache& operator==(ValueCache&&) = delete;
~ValueCache() = default;
template <CallType C>
void store(const typename Config<C>::key_t&, typename Config<C>::ephemeral_t);
template <CallType C>
auto load(const Callsite<C>& callsite, size_t python_tid) const {
auto caller = load<CallType::PyCall>(callsite.caller_);
TORCH_INTERNAL_ASSERT(!caller.module_info_.has_value());
return ExtraFields<Config<C>::event_type>{
/*end_time_ns=*/std::numeric_limits<c10::time_t>::min(),
python_tid,
caller.frame_state_,
load<C>(callsite.value_)};
}
std::optional<TensorMetadata> recordIfTensor(py::handle p);
std::vector<std::pair<std::string, TensorMetadata>> unpackTensorMap(
const py::dict& tensor_map);
void trimPrefixes();
private:
template <CallType C>
typename ExtraFields<Config<C>::event_type>::args_t load(
const typename Config<C>::key_t&) const;
template <CallType C>
using State = typename Config<C>::cache_t;
CallTypeHelper<State>::tuple_type state_;
};
template <CallType C>
typename Config<C>::cls_t set_class(
ValueCache* value_cache,
typename Config<C>::cache_t& cache,
const typename Config<C>::key_t& key,
const typename Config<C>::ephemeral_t& frame) {
if (C10_UNLIKELY(!cache.location_.has_value())) {
auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
TORCH_INTERNAL_ASSERT(code.get() == getCode<C>());
cache.location_ = PyCallKey(frame);
value_cache->store<CallType::PyCall>(*cache.location_, no_ephemeral_t());
}
auto cls_handle = py::handle((PyObject*)key).attr("__class__");
auto cls = typename Config<C>::cls_t(cls_handle.ptr());
if (cache.cls_names_.find(cls) == cache.cls_names_.end()) {
cache.cls_names_[cls] =
at::StringView(py::str(cls_handle.attr("__name__")));
}
return cls;
}
TensorMetadata toTensorMetadata(PyObject* self) {
TORCH_INTERNAL_ASSERT(THPVariable_CheckExact(self));
const auto& t = THPVariable_Unpack(self);
RawTensorMetadata m{t};
return TensorMetadata{
m,
t.sizes().vec(),
m.layout_ == at::kStrided ? t.strides().vec() : std::vector<int64_t>()};
}
std::optional<TensorMetadata> ValueCache::recordIfTensor(py::handle p) {
return THPVariable_CheckExact(p.ptr())
? std::optional<TensorMetadata>{toTensorMetadata(p.ptr())}
: std::nullopt;
}
std::vector<std::pair<std::string, TensorMetadata>> ValueCache::unpackTensorMap(
const py::dict& tensor_map) {
std::vector<std::pair<std::string, TensorMetadata>> out;
for (auto& it : tensor_map) {
auto* value = it.second.ptr();
if (py::isinstance<py::str>(it.first) && THPVariable_CheckExact(value)) {
out.emplace_back(
py::cast<std::string>(it.first), toTensorMetadata(value));
}
}
return out;
}
template <>
void ValueCache::store<CallType::PyCall>(const PyCallKey& key, no_ephemeral_t) {
auto& locations = std::get<CallType::PyCall>(state_);
if (C10_UNLIKELY(locations.find(key) == locations.end())) {
locations[key] = {
key.line_number_,
at::StringView(key.filename_),
at::StringView(key.name_)};
}
}
template <>
ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyCall>(
const PyCallKey& key) const {
return {std::get<CallType::PyCall>(state_).at(key), std::nullopt};
}
template <>
void ValueCache::store<CallType::PyModuleCall>(
const PyModuleCallKey& key,
Config<CallType::PyModuleCall>::ephemeral_t frame) {
auto& cache = std::get<CallType::PyModuleCall>(state_);
if (C10_UNLIKELY(
cache.cls_and_parameters_.find(key) ==
cache.cls_and_parameters_.end())) {
auto cls = set_class<CallType::PyModuleCall>(this, cache, key, frame);
py::dict params = py::handle((PyObject*)key).attr("_parameters");
std::vector<NNModuleInfo::ParameterInfo> params_;
for (auto& it : params) {
auto* p = it.second.ptr();
if (py::isinstance<py::str>(it.first) && THPVariable_CheckExact(p)) {
params_.push_back(
{it.first.cast<std::string>(),
toTensorMetadata(p),
recordIfTensor(py::getattr(it.second, "grad", py::none()))});
}
}
cache.cls_and_parameters_[key] = {cls, std::move(params_)};
}
}
template <>
ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyModuleCall>(
const PyModuleCallKey& key) const {
auto& cache = std::get<CallType::PyModuleCall>(state_);
TORCH_INTERNAL_ASSERT(cache.location_.has_value());
const auto& cls_and_parameters = cache.cls_and_parameters_.at(key);
const auto& cls = cls_and_parameters.cls_;
NNModuleInfo info{
key, cls, cache.cls_names_.at(cls), cls_and_parameters.parameters_};
return {
/*frame_state_=*/std::get<CallType::PyCall>(state_).at(*cache.location_),
/*module_info_=*/std::move(info),
/*optimizer_info_=*/std::nullopt};
}
template <>
void ValueCache::store<CallType::PyOptimizerCall>(
const PyOptimizerCallKey& key,
Config<CallType::PyOptimizerCall>::ephemeral_t frame) {
auto& cache = std::get<CallType::PyOptimizerCall>(state_);
if (C10_UNLIKELY(
cache.cls_and_parameters_.find(key) ==
cache.cls_and_parameters_.end())) {
auto cls = set_class<CallType::PyOptimizerCall>(this, cache, key, frame);
const py::handle self{(PyObject*)key};
std::vector<OptimizerInfo::ParameterInfo> params;
for (const auto& i : (py::list)self.attr("param_groups")) {
for (auto& param : py::cast<py::dict>(i).attr("get")("params")) {
if (THPVariable_CheckExact(param.ptr())) {
// While `self.state` is permitted to store data in an arbitrary way,
// all generic optimizers (SGD, Adam, etc) use param as the key since
// the state in question is tied to particular parameters. We can
// relax this assumption if the need arises.
params.push_back(
{toTensorMetadata(param.ptr()),
recordIfTensor(py::getattr(param, "grad", py::none())),
unpackTensorMap(py::cast<py::dict>(self.attr("state"))
.attr("get")(param, py::dict()))});
}
}
}
cache.cls_and_parameters_[key] = {cls, std::move(params)};
}
}
template <>
ExtraFields<EventType::PyCall>::args_t ValueCache::load<
CallType::PyOptimizerCall>(const PyOptimizerCallKey& key) const {
auto& cache = std::get<CallType::PyOptimizerCall>(state_);
const auto& cls_and_parameters = cache.cls_and_parameters_.at(key);
auto cls = cls_and_parameters.cls_;
OptimizerInfo info{
key, cls, cache.cls_names_.at(cls), cls_and_parameters.parameters_};
return {
/*frame_state_=*/std::get<CallType::PyCall>(state_).at(
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
cache.location_.value()),
/*module_info_=*/std::nullopt,
/*optimizer_info_=*/std::move(info)};
}
template <>
void ValueCache::store<CallType::PyCCall>(
const PyCCallKey& key,
Config<CallType::PyCCall>::ephemeral_t arg) {
auto& names = std::get<CallType::PyCCall>(state_);
if (C10_UNLIKELY(names.find(key) == names.end())) {
names[key] = at::StringView(py::repr(arg));
}
}
template <>
ExtraFields<EventType::PyCCall>::args_t ValueCache::load<CallType::PyCCall>(
const PyCCallKey& key) const {
return std::get<CallType::PyCCall>(state_).at(key);
}
// TODO: Use re2.
void ValueCache::trimPrefixes() {
static const auto prefixes = []() {
pybind11::gil_scoped_acquire gil;
return py::module::import("torch.profiler.python_tracer")
.attr("_prefix_regex")()
.cast<std::vector<std::string>>();
}();
for (auto& it : std::get<CallType::PyCall>(state_)) {
std::string filename = it.second.filename_.str();
for (const auto& p : prefixes) {
if (filename.compare(0, p.size(), p) == 0) {
filename.erase(0, p.size());
it.second.filename_ = at::StringView(filename);
break;
}
}
}
}
// ============================================================================
// == TraceKey cache ==========================================================
// ============================================================================
using python_tracer::TraceKey;
TraceKey nextKey() {
static std::atomic<uint64_t> key{0};
return TraceKey{++key};
}
template <CallType C>
struct TraceKeyCacheState {
struct Hash {
size_t operator()(const Callsite<C>& key) {
return c10::get_hash(key.value_, key.caller_);
}
};
TraceKey intern(
Callsite<C> callsite,
typename Config<C>::ephemeral_t ephemeral,
ValueCache& value_cache) {
auto it = state_.find(callsite);
if (C10_UNLIKELY(it == state_.end())) {
value_cache.store<C>(callsite.value_, ephemeral);
value_cache.store<CallType::PyCall>(callsite.caller_, no_ephemeral_t());
it = state_.insert({callsite, nextKey()}).first;
}
return it->second;
}
auto lookup(Callsite<C>& callsite, ValueCache& value_cache) const {
return std::make_pair(
value_cache.load<C>(callsite.value_),
value_cache.load<CallType::PyCall>(callsite.caller_));
}
ska::flat_hash_map<Callsite<C>, TraceKey, Hash> state_;
};
// ============================================================================
// == Core CPython data types =================================================
// ============================================================================
// PyObject that allows different threads to record events without colliding.
// It is passed as the second argument when enabling tracing via
// `PyEval_SetProfile`.
struct ThreadLocalResults;
struct TraceContext {
PyObject_HEAD
ThreadLocalResults* thread_local_results_;
};
// CPython boilerplate to define `TraceContext` as a proper python object.
static PyTypeObject TraceContextType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"TraceContext", /* tp_name */
sizeof(TraceContext), /* tp_basicsize */
0, /* tp_itemsize */
nullptr, /* tp_dealloc */
0,
/* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
"Python tracer TLS", /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
nullptr, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
nullptr, /* tp_init */
nullptr, /* tp_alloc */
PyType_GenericNew, /* tp_new */
nullptr /* tp_free */
};
class gil_and_restore_thread {
public:
gil_and_restore_thread() : initial_thread_state_{PyThreadState_Get()} {}
~gil_and_restore_thread() {
PyThreadState_Swap(initial_thread_state_);
// `gil_scoped_acquire` is a bit fragile in on-demand mode:
// https://github.com/pytorch/pytorch/pull/91684#issuecomment-1413154458
if (!Py_IsInitialized()) {
gil_.disarm();
}
}
PyThreadState* initial_thread_state() const {
return initial_thread_state_;
}
private:
pybind11::gil_scoped_acquire gil_;
PyThreadState* initial_thread_state_;
};
// ============================================================================
// == Thread local cache ======================================================
// ============================================================================
class PythonTracer;
struct ThreadLocalResults {
ThreadLocalResults(
PyThreadState* thread_state,
ValueCache* value_cache,
PythonTracer* active_tracer)
: thread_state_{thread_state},
ctx_{(TraceContext*)TraceContextType.tp_alloc(&TraceContextType, 0)},
value_cache_{value_cache},
active_tracer_{active_tracer} {
ctx_->thread_local_results_ = this;
}
ThreadLocalResults() = delete;
ThreadLocalResults(const ThreadLocalResults&) = delete;
ThreadLocalResults(ThreadLocalResults&&) = delete;
ThreadLocalResults& operator=(const ThreadLocalResults&) = delete;
ThreadLocalResults& operator=(const ThreadLocalResults&&) = delete;
~ThreadLocalResults() {
// Currently, there is a bug in Profiler when using Python 3.12 that causes
// a segfault when decrementing the refcount of a TraceContext during
// on-demand. We are purposefully allowing for a small leak in this
// situation to avoid the segfault. This should be fixed in the future.
#if PY_MAJOR_VERSION < 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 12)
Py_DECREF((PyObject*)ctx_);
#endif
}
template <CallType C, EventType E, typename Ephemeral, typename... Args>
TraceKey intern(Ephemeral ephemeral, Args... args) {
static_assert(
Config<C>::event_type == E,
"ThreadLocalResults.intern called from the wrong typed context.");
auto callsite = Callsite<C>(std::forward<Args>(args)...);
return std::get<C>(trace_keys_).intern(callsite, ephemeral, *value_cache_);
}
static constexpr size_t BLOCK_SIZE = 1024;
PyThreadState* thread_state_;
TraceContext* ctx_;
ValueCache* value_cache_;
PythonTracer* active_tracer_;
CallTypeHelper<TraceKeyCacheState>::tuple_type trace_keys_;
AppendOnlyList<c10::approx_time_t, BLOCK_SIZE> exit_times_;
AppendOnlyList<c10::approx_time_t, BLOCK_SIZE> c_exit_times_;
int active_frames_{0};
int remaining_start_frames_{0};
};
// ============================================================================
// == Tracing implementation ==================================================
// ============================================================================
#define IS_PYTHON_3_12 (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION == 12)
#if IS_PYTHON_3_12
// forward declarations
struct _PyEventHandler;
static PyObject* c_call_callback(
_PyEventHandler* self,
PyObject* const* args,
size_t nargsf,
PyObject* kwnames);
#endif
class PythonTracer final : public python_tracer::PythonTracerBase {
public:
PythonTracer(torch::profiler::impl::RecordQueue* queue);
// NOLINTNEXTLINE(bugprone-exception-escape)
~PythonTracer() override;
static int pyProfileFn(
PyObject* obj,
PyFrameObject* frame,
int what,
PyObject* arg);
void register_gc_callback() override;
void stop() override;
void restart() override;
std::vector<std::shared_ptr<Result>> getEvents(
std::function<c10::time_t(c10::approx_time_t)> time_converter,
std::vector<python_tracer::CompressedEvent>& enters,
c10::time_t end_time_ns) override;
struct StartFrame {
TraceKey trace_key_;
c10::approx_time_t start_time{};
};
private:
void recordPyCall(
ThreadLocalResults& tls,
PyFrameObject* frame,
bool is_startup_frame);
static PyObject* gc_event_callback(PyObject* self, PyObject* args);
void recordCCall(
ThreadLocalResults& tls,
PyFrameObject* frame,
PyObject* arg,
bool start_frame = false);
const std::vector<PyThreadState*> interpreterThreads() const;
std::atomic<bool> active_lock_{false};
bool active_{false};
bool gc_callback_registered_{false};
torch::profiler::impl::RecordQueue* queue_;
PyInterpreterState* interpreter_{nullptr};
PyCodeObject* module_call_code_;
PyCodeObject* optimizer_hook_;
std::vector<StartFrame> start_frames_;
std::deque<ThreadLocalResults> thread_local_results_;
ValueCache value_cache_;
#if IS_PYTHON_3_12
friend PyObject* c_call_callback(
_PyEventHandler* self,
PyObject* const* args,
size_t nargsf,
PyObject* kwnames);
#endif
};
#if IS_PYTHON_3_12
#define PROFILER_ID 2
#define PY_MONITORING_EVENT_CALL 4
static bool should_compensate_c_call_events() {
static const bool result = []() {
const char* version = Py_GetVersion();
const char micro = version[5];
return micro == '0' || (micro <= '4' && version[6] == ' ');
}();
return result;
}
struct _PyEventHandler {
PyObject_HEAD
vectorcallfunc vectorcall;
};
static PyTypeObject _PyEventHandler_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0) /* ob_base */
"torch.profiler.python_tracer_event_handler", /* tp_name */
sizeof(_PyEventHandler), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)PyObject_Free, /* tp_dealloc */
offsetof(_PyEventHandler, vectorcall), /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
PyVectorcall_Call, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_VECTORCALL |
Py_TPFLAGS_DISALLOW_INSTANTIATION, /* tp_flags */
};
static PyObject* c_call_callback(
_PyEventHandler* self,
PyObject* const* args,
size_t nargsf,
PyObject* kwnames) {
// The logic of this function is based on sys_defile_call_or_return defined
// in https://github.com/python/cpython/blob/v3.12.5/Python/legacy_tracing.c
PyThreadState* tstate = PyThreadState_GET();
if (tstate->c_profilefunc != PythonTracer::pyProfileFn) {
// We don't care this case if tstate->c_profilefunc is not pyProfileFn,
// just return normally.
Py_RETURN_NONE;
}
PyObject* callable = args[2];
if (Py_TYPE(callable) == &PyMethod_Type) {
// The call event of a method with c function is missing on 3.12.0-3.12.4.
// See
// https://github.com/python/cpython/commit/257c413cd16ddabcedde413288d0bb93bf872da7
// Other cases have already be handled by the legacy_tracing, so we only
// need to handle this case.
// The exception branches keep the same behavior as CPython.
PyObject* func = PyMethod_GET_FUNCTION(callable);
if (!func) {
return NULL;
}
if (PyCFunction_Check(func)) {
PyFrameObject* frame = PyEval_GetFrame();
if (!frame) {
PyErr_SetString(
PyExc_SystemError, "Missing frame when calling profile function.");
return NULL;
}
Py_INCREF(frame);
auto& local_results =
*reinterpret_cast<TraceContext*>(tstate->c_profileobj)
->thread_local_results_;
local_results.active_tracer_->recordCCall(local_results, frame, func);
Py_DECREF(frame);
}
}
Py_RETURN_NONE;
}
static void registerMonitoringCallback() {
if (!should_compensate_c_call_events()) {
return;
}
auto sys_module = THPObjectPtr(PyImport_ImportModule("sys"));
if (!sys_module) {
TORCH_WARN("Failed to import sys module.");
PyErr_Clear();
return;
}
auto monitoring =
THPObjectPtr(PyObject_GetAttrString(sys_module, "monitoring"));
if (!monitoring) {
TORCH_WARN("Failed to get monitoring from sys module.");
PyErr_Clear();
return;
}
auto result = THPObjectPtr(PyObject_CallMethod(
monitoring, "use_tool_id", "is", PROFILER_ID, "PyTorch Profiler"));
if (!result) {
TORCH_WARN("Failed to call sys.monitoring.use_tool_id");
PyErr_Clear();
return;
}
auto handler = THPObjectPtr(PyObject_NEW(PyObject, &_PyEventHandler_Type));
if (!handler) {
TORCH_WARN("Failed to create _PyEventHandler object.");
PyErr_Clear();
return;
}
reinterpret_cast<_PyEventHandler*>(handler.get())->vectorcall =
(vectorcallfunc)c_call_callback;
result = THPObjectPtr(PyObject_CallMethod(
monitoring,
"register_callback",
"iiO",
PROFILER_ID,
1 << PY_MONITORING_EVENT_CALL,
handler.get()));
if (!result) {
TORCH_WARN("Failed to call sys.monitoring.register_callback.");
PyErr_Clear();
return;
}
result = THPObjectPtr(PyObject_CallMethod(
monitoring,
"set_events",
"ii",
PROFILER_ID,
1 << PY_MONITORING_EVENT_CALL));
if (!result) {
TORCH_WARN("Failed to call sys.monitoring.set_events.");
PyErr_Clear();
return;
}
}
static void unregisterMonitoringCallback() {
if (!should_compensate_c_call_events()) {
return;
}
auto sys_module = THPObjectPtr(PyImport_ImportModule("sys"));
if (!sys_module) {
TORCH_WARN("Failed to import sys module.");
PyErr_Clear();
return;
}
auto monitoring =
THPObjectPtr(PyObject_GetAttrString(sys_module, "monitoring"));
if (!monitoring) {
TORCH_WARN("Failed to get monitoring from sys module.");
PyErr_Clear();
return;
}
auto tool_name = THPObjectPtr(
PyObject_CallMethod(monitoring, "get_tool", "i", PROFILER_ID));
if (!tool_name) {
TORCH_WARN("Failed to call sys.monitoring.use_tool_id");
PyErr_Clear();
return;
}
if (!THPUtils_checkString(tool_name)) {
return;
}
const char* str = THPUtils_unpackStringView(tool_name).data();
if (strcmp(str, "PyTorch Profiler") != 0) {
return;
}
auto none = THPObjectPtr(Py_None);
Py_INCREF(Py_None);
auto result = THPObjectPtr(PyObject_CallMethod(
monitoring,
"register_callback",
"iiO",
PROFILER_ID,
1 << PY_MONITORING_EVENT_CALL,
none.get()));
if (!result) {
TORCH_WARN("Failed to call sys.monitoring.register_callback.");
PyErr_Clear();
return;
}
result = THPObjectPtr(
PyObject_CallMethod(monitoring, "set_events", "ii", PROFILER_ID, 0));
if (!result) {
TORCH_WARN("Failed to call sys.monitoring.set_events.");
PyErr_Clear();
return;
}
result = THPObjectPtr(
PyObject_CallMethod(monitoring, "free_tool_id", "i", PROFILER_ID));
if (!result) {
TORCH_WARN("Failed to call sys.monitoring.free_tool_id.");
PyErr_Clear();
return;
}
}
#endif
const std::vector<PyThreadState*> PythonTracer::interpreterThreads() const {
pybind11::gil_scoped_acquire gil;
std::vector<PyThreadState*> out;
if (SOFT_ASSERT(interpreter_)) {
auto* thread_state = PyInterpreterState_ThreadHead(interpreter_);
while (thread_state != nullptr) {
out.push_back(thread_state);
thread_state = PyThreadState_Next(thread_state);
}
}
return out;
}
// we are only registering on main thread while holding GIL so this should be
// safe
static PyObject* py_gc_callback = nullptr;
// The C function to be called by Python's GC
PyObject* PythonTracer::gc_event_callback(PyObject* self, PyObject* args) {
const char* phase;
PyObject* info;
if (!PyArg_ParseTuple(args, "sO", &phase, &info)) {
return nullptr;
}
PythonTracer* instance =
reinterpret_cast<PythonTracer*>(PyCapsule_GetPointer(self, nullptr));
if (!instance) {
PyErr_SetString(PyExc_RuntimeError, "Invalid tracer instance");
return nullptr;
}
instance->queue_->getSubqueue()->emplace_gc_call(
phase, c10::getApproximateTime());
Py_RETURN_NONE;
}
PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
: queue_(queue),
module_call_code_(getCode<CallType::PyModuleCall>()),
optimizer_hook_(getCode<CallType::PyOptimizerCall>()) {
TORCH_CHECK(queue_ != nullptr);
bool expected{false};
active_ = active_lock_.compare_exchange_strong(expected, true);
if (!active_) {
TORCH_WARN(
"There is already an active Python tracer. "
"Refusing to register profile functions.");
return;
}
gil_and_restore_thread gil;
interpreter_ = PyInterpreterState_Get();
if (!gil.initial_thread_state()) {
TORCH_WARN("PyThreadState_Get returned NULL");
return;
}
// Register the tracer in each thread.
for (const auto thread_state : interpreterThreads()) {
PyThreadState_Swap(thread_state);
thread_local_results_.emplace_back(thread_state, &value_cache_, this);
auto& tls = thread_local_results_.back();
auto* ctx = tls.ctx_;
// When we begin profiling there are already frames on the Python
// interpreter stack. To ensure a complete trace, we must push calls
// to all the prior frames onto our event stack. (We stop at depth=128)
std::vector<THPFrameObjectPtr> current_stack;
auto frame = PyEval_GetFrame();
Py_XINCREF(frame);
size_t depth = 0; // Make sure we can't infinite loop.
while (frame != nullptr) {
current_stack.emplace_back(frame);
if (++depth == 128) {
break;
}
// NB: `PyFrame_GetBack` returns a strong reference.
frame = PyFrame_GetBack(frame);
}
for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) {
recordPyCall(tls, it->get(), true);
auto frame_refcount = Py_REFCNT(it->get());
// We hold one reference in `current_stack`, and the interpreter holds
// another.
TORCH_INTERNAL_ASSERT(frame_refcount >= 2, frame_refcount);
}
tls.remaining_start_frames_ = tls.active_frames_;
// Note:
// This profile will not compose with other CPython profilers, and
// cannot be round tripped via `sys.settrace(sys.gettrace())`
PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx);
}
#if IS_PYTHON_3_12
registerMonitoringCallback();
#endif
}
void unregister_gc_callback() {
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* gc_module = PyImport_ImportModule("gc");
if (!gc_module) {
PyErr_Print();
PyGILState_Release(gstate);
return;
}
PyObject* callbacks = PyObject_GetAttrString(gc_module, "callbacks");
if (!callbacks || !PyList_Check(callbacks)) {
PyErr_Print();
Py_XDECREF(gc_module);
Py_XDECREF(callbacks);
PyGILState_Release(gstate);
return;
}
Py_ssize_t idx = PySequence_Index(callbacks, py_gc_callback);
if (idx >= 0) {
PySequence_DelItem(callbacks, idx);
} else {
// Not found, maybe already removed
}
Py_DECREF(callbacks);
Py_DECREF(gc_module);
Py_XDECREF(py_gc_callback);
py_gc_callback = nullptr;
PyGILState_Release(gstate);
}
void PythonTracer::register_gc_callback() {
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* gc_module = PyImport_ImportModule("gc");
if (!gc_module) {
PyErr_Print();
PyGILState_Release(gstate);
return;
}
PyObject* callbacks = PyObject_GetAttrString(gc_module, "callbacks");
if (!callbacks || !PyList_Check(callbacks)) {
PyErr_Print();
Py_XDECREF(gc_module);
Py_XDECREF(callbacks);
PyGILState_Release(gstate);
return;
}
static PyMethodDef method_def = {
"gc_event_callback",
(PyCFunction)gc_event_callback,
METH_VARARGS,
nullptr};
PyObject* capsule = PyCapsule_New(this, nullptr, nullptr);
py_gc_callback = PyCFunction_New(&method_def, capsule);
Py_DECREF(capsule); // PyCFunction_New increments refcount
if (PyList_Append(callbacks, py_gc_callback) < 0) {
PyErr_Print();
}
gc_callback_registered_ = true;
Py_DECREF(callbacks);
Py_DECREF(gc_module);
PyGILState_Release(gstate);
}
void PythonTracer::stop() {
gil_and_restore_thread gil;
if (gc_callback_registered_) {
unregister_gc_callback();
gc_callback_registered_ = false;
}
if (active_) {
for (const auto thread_state : interpreterThreads()) {
if (thread_state->c_profilefunc == &PythonTracer::pyProfileFn) {
PyThreadState_Swap(thread_state);
PyEval_SetProfile(nullptr, nullptr);
}
}
#if IS_PYTHON_3_12
unregisterMonitoringCallback();
#endif
auto lock_returned = active_lock_.compare_exchange_strong(active_, false);
active_ = false;
SOFT_ASSERT(lock_returned, "Failed to return python tracer lock.");
}
}
void PythonTracer::restart() {
gil_and_restore_thread gil;
active_ = active_lock_.compare_exchange_strong(active_, true);
if (!active_) {
TORCH_WARN(
"There is already an active Python tracer. "
"Refusing to register profile functions.");
return;
}
int cur_thread = 0;
for (const auto thread_state : interpreterThreads()) {
if (thread_state->c_profilefunc == nullptr) {
auto* ctx = thread_local_results_[cur_thread].ctx_;
PyThreadState_Swap(thread_state);
PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx);
}
}
#if IS_PYTHON_3_12
registerMonitoringCallback();
#endif
}
// NOLINTNEXTLINE(bugprone-exception-escape)
PythonTracer::~PythonTracer() {
if (active_) {
TORCH_WARN("`PythonTracer::stop()` was not called.");
stop();
}
}
void PythonTracer::recordPyCall(
ThreadLocalResults& tls,
PyFrameObject* frame,
bool is_startup_frame) {
static constexpr auto E = EventType::PyCall;
const auto key = [&]() -> TraceKey {
auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
if (code.get() == module_call_code_) {
// By default, CPython stores locals in a "fast" format, with an array
// of names and an array of values. Consequently, frame->f_locals is
// NULL since the interpreter has no need to populate it.
//
// If these arrays were part of the public API then we could very
// quickly access `self`. Unfortunately they are not, and moreover are
// not stable across versions. As a result, we are forced to call
// `PyFrame_FastToLocals` which forces the interpreter to materialize
// the full dict of locals.
auto locals = THPObjectPtr(PyFrame_GetLocals(frame));
#if PY_MAJOR_VERSION < 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 13)
auto self = THPObjectPtr(PyDict_GetItemString(locals, "self"));
#else
// In Python-3.13+ `PyFrame_GetLocals()` returns instance of
// PyFrameLocalsProxy_Type See PEP 667 for more info
auto self = THPObjectPtr(PyMapping_GetItemString(locals, "self"));
#endif
Py_INCREF(self.get());
auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
TORCH_INTERNAL_ASSERT(back != nullptr);
return tls.intern<CallType::PyModuleCall, E>(
frame, self.get(), back.get());
} else if (code.get() == optimizer_hook_) {
auto locals = THPObjectPtr(PyFrame_GetLocals(frame));
#if PY_MAJOR_VERSION < 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 13)
auto self = THPObjectPtr(PyDict_GetItemString(locals, "self"));
#else
auto self = THPObjectPtr(PyMapping_GetItemString(locals, "self"));
#endif
Py_INCREF(self.get());
auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
TORCH_INTERNAL_ASSERT(back != nullptr);
return tls.intern<CallType::PyOptimizerCall, E>(
frame, self.get(), back.get());
} else {
auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
auto f_back = (back.get() != nullptr) ? back.get() : frame;
return tls.intern<CallType::PyCall, E>(no_ephemeral_t(), frame, f_back);
}
}();
const auto time = c10::getApproximateTime();
is_startup_frame ? start_frames_.push_back({key, time})
: queue_->getSubqueue()->emplace_py_call(key, time);
++tls.active_frames_;
}
void PythonTracer::recordCCall(
ThreadLocalResults& tls,
PyFrameObject* frame,
PyObject* arg,
bool start_frame) {
// for starting frames we duplicate callable python functions to avoid having
// empty C frames in trace when exiting
if (!start_frame) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(PyCFunction_Check(arg));
}
auto fn = reinterpret_cast<PyCFunctionObject*>(arg);
// NB: For C calls a new frame is not created, so we use `frame` rather than
// `frame->f_back`.
auto key = tls.intern<CallType::PyCCall, EventType::PyCCall>(
arg, (void*)(fn->m_ml), frame);
queue_->getSubqueue()->emplace_py_call(key, c10::getApproximateTime());
++tls.active_frames_;
}
// ============================================================================
// == Post processing =========================================================
// ============================================================================
struct Exit {
bool operator>(const Exit& other) const {
return t_ > other.t_;
}
c10::time_t t_;
size_t python_tid_;
};
class PostProcess {
public:
PostProcess(
std::function<c10::time_t(c10::approx_time_t)> time_converter,
std::deque<ThreadLocalResults>& tls,
const ValueCache& value_cache,
c10::time_t end_time_ns)
: end_time_{end_time_ns}, time_converter_{std::move(time_converter)} {
for (size_t python_tid : c10::irange(tls.size())) {
CallTypeHelper<TraceKeyCacheState>::map(
tls[python_tid].trace_keys_, *this, value_cache, python_tid);
addExits<EventType::PyCall>(tls[python_tid].exit_times_, python_tid);
addExits<EventType::PyCCall>(tls[python_tid].c_exit_times_, python_tid);
}
}
void set_start_frames(
const std::vector<PythonTracer::StartFrame>& start_frames,
std::vector<python_tracer::CompressedEvent>& enters) {
for (const auto& frame : start_frames) {
enters.push_back(
{frame.trace_key_,
NoTID, // Allows us to detect unhandled start frames
{},
time_converter_(frame.start_time)});
}
}
template <CallType C>
void operator()(
const TraceKeyCacheState<C>& trace_cache,
const ValueCache& value_cache,
size_t python_tid) {
for (const auto& it : trace_cache.state_) {
const auto inserted = get_state<Config<C>::event_type>().fields_.insert(
{it.second, value_cache.load(it.first, python_tid)});
TORCH_INTERNAL_ASSERT(inserted.second, "Duplicate key: ", it.second);
}
}
template <EventType E, size_t N>
void addExits(
AppendOnlyList<c10::approx_time_t, N>& exits,
size_t python_tid) {
for (const auto i : exits) {
get_state<E>().exits_.push({time_converter_(i), python_tid});
}
}
std::vector<std::shared_ptr<Result>> run(
std::vector<python_tracer::CompressedEvent>& enters) {
std::stable_sort(
enters.begin(), enters.end(), [](const auto a, const auto b) {
return a.enter_t_ < b.enter_t_;
});
std::vector<std::shared_ptr<Result>> out;
populate<EventType::PyCall>(enters, out);
populate<EventType::PyCCall>(enters, out);
return out;
}
private:
template <EventType E>
void populate(
std::vector<python_tracer::CompressedEvent>& enters,
std::vector<std::shared_ptr<Result>>& out) {
using stack_t = std::vector<std::shared_ptr<Result>>;
const auto initial_size = out.size();
auto pop = [](stack_t& stack, c10::time_t t) {
if (!stack.empty()) {
std::get<ExtraFields<E>>(stack.back()->extra_fields_).end_time_ns_ = t;
stack.pop_back();
} else {
TORCH_WARN_ONCE(
"Python replay stack is empty during pop operation! May result in incorrect stack tracing.");
}
};
ska::flat_hash_map<size_t, stack_t> stacks;
auto& state = get_state<E>();
// We already own the GIL at this point
for (const auto& enter : enters) {
auto fields_it = state.fields_.find(enter.key_);
if (fields_it != state.fields_.end()) {
while (!state.exits_.empty() &&
state.exits_.top().t_ < enter.enter_t_) {
auto& exit = state.exits_.top();
auto& tstack = stacks[exit.python_tid_];
pop(tstack, exit.t_);
state.exits_.pop();
}
out.push_back(Result::create(
enter.enter_t_,
enter.system_tid_,
enter.kineto_info_,
fields_it->second));
stacks[fields_it->second.python_tid_].push_back(out.back());
}
}
// Handle events which were still running when profiling ended.
for (auto& i : stacks) {
while (!i.second.empty()) {
pop(i.second, end_time_);
}
}
// Assign system TIDs to start events based on the system TID of the next
// observed event with the same Python TID.
ska::flat_hash_map<size_t, std::pair<size_t, kineto::DeviceAndResource>>
tid_map;
auto it = out.rbegin();
for ([[maybe_unused]] auto _ : c10::irange(initial_size, out.size())) {
const auto python_tid =
std::get<ExtraFields<E>>((*it)->extra_fields_).python_tid_;
if ((*it)->start_tid_ == NoTID && SOFT_ASSERT(E == EventType::PyCall)) {
const auto& tid_info =
tid_map.insert({python_tid, {NoTID, kineto::DeviceAndResource()}})
.first->second;
(*it)->start_tid_ = tid_info.first;
(*it)->kineto_info_ = tid_info.second;
}
tid_map[python_tid] = {(*it)->start_tid_, (*it)->kineto_info_};
++it;
}
}
template <EventType E>
struct State {
ska::flat_hash_map<TraceKey, ExtraFields<E>> fields_;
std::priority_queue<Exit, std::vector<Exit>, std::greater<>> exits_;
};
template <EventType E>
auto& get_state() {
return std::get < E == EventType::PyCall ? 0 : 1 > (state_);
}
c10::time_t end_time_;
std::function<c10::time_t(c10::approx_time_t)> time_converter_;
std::tuple<State<EventType::PyCall>, State<EventType::PyCCall>> state_;
};
struct PythonIDVisitor {
void operator()(ExtraFields<EventType::PyCall>& py_call) {
py_call.id_ = ++current_python_id_;
if (py_call.module_.has_value()) {
auto& m = py_call.module_;
auto& module_ids = module_ids_[m->cls_];
m->id_ = module_ids.insert({m->self_, module_ids.size()}).first->second;
}
}
void operator()(ExtraFields<EventType::PyCCall>& py_call) {
py_call.id_ = ++current_python_id_;
}
template <typename T>
void operator()(T&) {}
size_t current_python_id_{0};
ska::flat_hash_map<PyModuleCls, ska::flat_hash_map<PyModuleSelf, size_t>>
module_ids_;
};
std::vector<std::shared_ptr<Result>> PythonTracer::getEvents(
std::function<c10::time_t(c10::approx_time_t)> time_converter,
std::vector<python_tracer::CompressedEvent>& enters,
c10::time_t end_time_ns) {
value_cache_.trimPrefixes();
PostProcess post_process(
std::move(time_converter),
thread_local_results_,
value_cache_,
end_time_ns);
post_process.set_start_frames(start_frames_, enters);
auto out = post_process.run(enters);
std::stable_sort(out.begin(), out.end(), [](const auto& a, const auto& b) {
return a->start_time_ns_ < b->start_time_ns_;
});
PythonIDVisitor id_visitor;
for (auto& i : out) {
std::visit(id_visitor, i->extra_fields_);
}
return out;
}
// ============================================================================
// == Memory Tracer ======================================================
// ============================================================================
// Assuming python_tracer::PythonMemoryTracerBase is defined elsewhere
class PythonMemoryTracer final : public python_tracer::PythonMemoryTracerBase {
public:
explicit PythonMemoryTracer() = default;
~PythonMemoryTracer() override = default;
void start() override;
void stop() override;
void export_memory_history(const std::string& path) override;
};
static void toggle_memory_tracing(bool enable) {
pybind11::gil_scoped_acquire gil;
THPObjectPtr torch_cuda_memory_module(
PyImport_ImportModule("torch.cuda.memory"));
if (!torch_cuda_memory_module) {
return;
}
THPObjectPtr snapshot_func(PyObject_GetAttrString(
torch_cuda_memory_module.get(), "_record_memory_history_impl"));
if (!snapshot_func) {
return;
}
// Call the function with arguments
PyObject* args = PyTuple_New(6);
PyTuple_SetItem(args, 0, enable ? PyUnicode_FromString("all") : Py_None);
PyTuple_SetItem(args, 1, PyUnicode_FromString("all")); // context
PyTuple_SetItem(args, 2, PyUnicode_FromString("all")); // stacks
PyTuple_SetItem(args, 3, THPUtils_packInt64(100000)); // max_entries
PyTuple_SetItem(args, 4, Py_None); // device (None)
PyTuple_SetItem(args, 5, PyBool_FromLong(0)); // clear_history (False)
PyObject* result = PyObject_Call(snapshot_func.get(), args, nullptr);
Py_DECREF(args);
if (result == nullptr) {
return;
}
}
void PythonMemoryTracer::start() {
toggle_memory_tracing(true);
}
void PythonMemoryTracer::export_memory_history(const std::string& path) {
pybind11::gil_scoped_acquire gil;
THPObjectPtr torch_cuda_memory_module(
PyImport_ImportModule("torch.cuda.memory"));
if (!torch_cuda_memory_module) {
return;
}
THPObjectPtr snapshot_func(
PyObject_GetAttrString(torch_cuda_memory_module.get(), "_dump_snapshot"));
if (!snapshot_func) {
return;
}
PyObject* py_filename = PyUnicode_FromString(path.c_str());
// Call the function with arguments (e.g., a file path)
PyObject* args = PyTuple_Pack(1, py_filename);
PyObject* result = PyObject_Call(snapshot_func.get(), args, nullptr);
Py_DECREF(args);
if (result == nullptr) {
return;
}
}
void PythonMemoryTracer::stop() {
toggle_memory_tracing(false);
}
// ============================================================================
// == API =====================================================================
// ============================================================================
int PythonTracer::pyProfileFn(
PyObject* obj,
PyFrameObject* frame,
int what,
PyObject* arg) {
auto& local_results =
*reinterpret_cast<TraceContext*>(obj)->thread_local_results_;
switch (what) {
case PyTrace_CALL:
local_results.active_tracer_->recordPyCall(local_results, frame, false);
break;
case PyTrace_C_CALL:
local_results.active_tracer_->recordCCall(local_results, frame, arg);
break;
case PyTrace_RETURN:
local_results.exit_times_.emplace_back(c10::getApproximateTime());
local_results.active_frames_--;
if (local_results.active_frames_ <
local_results.remaining_start_frames_) {
local_results.remaining_start_frames_ = local_results.active_frames_;
}
break;
case PyTrace_C_EXCEPTION:
case PyTrace_C_RETURN:
if (local_results.active_frames_ >
local_results.remaining_start_frames_) {
local_results.c_exit_times_.emplace_back(c10::getApproximateTime());
local_results.active_frames_--;
}
break;
}
return 0;
}
std::unique_ptr<python_tracer::PythonTracerBase> getTracer(
torch::profiler::impl::RecordQueue* queue) {
return std::make_unique<PythonTracer>(queue);
}
std::unique_ptr<python_tracer::PythonMemoryTracerBase> getMemoryTracer() {
return std::make_unique<PythonMemoryTracer>();
}
} // namespace
} // namespace torch::profiler::impl
namespace torch::autograd::profiler::python_tracer {
void init() {
pybind11::gil_scoped_acquire gil;
TORCH_CHECK(PyType_Ready(&torch::profiler::impl::TraceContextType) == 0);
torch::profiler::impl::python_tracer::registerTracer(
&torch::profiler::impl::getTracer);
torch::profiler::impl::python_tracer::registerMemoryTracer(
&torch::profiler::impl::getMemoryTracer);
}
} // namespace torch::autograd::profiler::python_tracer