mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit da94023b0205bf98c3da366f2f86e0a443f4db17. Reverted https://github.com/pytorch/pytorch/pull/155446 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. @sraikund16 can you please help validate the fix? (See D78845227 for details). You can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/155446#issuecomment-3115072504))
1290 lines
44 KiB
C++
1290 lines
44 KiB
C++
#include <torch/csrc/autograd/profiler_python.h>
|
|
|
|
#include <atomic>
|
|
#include <cstdint>
|
|
#include <deque>
|
|
#include <limits>
|
|
#include <memory>
|
|
#include <queue>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include <Python.h>
|
|
#include <frameobject.h>
|
|
|
|
#include <ATen/core/TensorBase.h>
|
|
#include <c10/macros/Macros.h>
|
|
#include <c10/util/ApproximateClock.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/Logging.h>
|
|
#include <c10/util/flat_hash_map.h>
|
|
#include <c10/util/irange.h>
|
|
#include <torch/csrc/autograd/python_variable.h>
|
|
#include <torch/csrc/profiler/collection.h>
|
|
#include <torch/csrc/profiler/containers.h>
|
|
#include <torch/csrc/profiler/orchestration/python_tracer.h>
|
|
#include <torch/csrc/profiler/util.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
#include <torch/csrc/utils/python_compat.h>
|
|
#include <torch/csrc/utils/python_numbers.h>
|
|
#include <torch/csrc/utils/python_strings.h>
|
|
#include <optional>
|
|
|
|
namespace py = pybind11;
|
|
|
|
namespace torch::profiler::impl {
|
|
namespace {
|
|
enum CallType { PyCall = 0, PyModuleCall, PyCCall, PyOptimizerCall };
|
|
static constexpr size_t CallTypeSize = 4;
|
|
using no_ephemeral_t = std::tuple<>;
|
|
static constexpr uint64_t NoTID = std::numeric_limits<uint64_t>::max();
|
|
|
|
// ============================================================================
|
|
// == Miscellaneous structs and utils =========================================
|
|
// ============================================================================
|
|
struct CodeLocation {
|
|
CodeLocation() = default;
|
|
explicit CodeLocation(PyFrameObject* frame)
|
|
: line_number_{PyFrame_GetLineNumber(frame)} {
|
|
auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
|
|
filename_ = THPUtils_unpackStringView(code->co_filename).data();
|
|
name_ = THPUtils_unpackStringView(code->co_name).data();
|
|
}
|
|
|
|
bool operator==(const CodeLocation& other) const {
|
|
return filename_ == other.filename_ && name_ == other.name_ &&
|
|
line_number_ == other.line_number_;
|
|
}
|
|
|
|
const char* filename_{nullptr};
|
|
const char* name_{nullptr};
|
|
int line_number_{0};
|
|
};
|
|
|
|
template <CallType C>
|
|
PyCodeObject* getCode();
|
|
|
|
template <>
|
|
PyCodeObject* getCode<CallType::PyModuleCall>() {
|
|
static auto module_call_code = []() {
|
|
pybind11::gil_scoped_acquire gil;
|
|
auto res = py::module::import("torch.nn")
|
|
.attr("Module")
|
|
.attr("__call__")
|
|
.attr("__code__")
|
|
.ptr();
|
|
TORCH_INTERNAL_ASSERT(PyCode_Check(res));
|
|
return (PyCodeObject*)res;
|
|
}();
|
|
return module_call_code;
|
|
}
|
|
|
|
template <>
|
|
PyCodeObject* getCode<CallType::PyOptimizerCall>() {
|
|
static auto optimizer_step_code = []() {
|
|
pybind11::gil_scoped_acquire gil;
|
|
auto res = py::module::import("torch.optim")
|
|
.attr("Optimizer")
|
|
.attr("_optimizer_step_code")
|
|
.attr("__code__")
|
|
.ptr();
|
|
TORCH_INTERNAL_ASSERT(PyCode_Check(res));
|
|
return (PyCodeObject*)res;
|
|
}();
|
|
return optimizer_step_code;
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace torch::profiler::impl
|
|
|
|
template <>
|
|
struct std::hash<torch::profiler::impl::CodeLocation> {
|
|
size_t operator()(const torch::profiler::impl::CodeLocation& x) {
|
|
return c10::get_hash(x.filename_, x.name_, x.line_number_);
|
|
}
|
|
};
|
|
|
|
namespace torch::profiler::impl {
|
|
namespace {
|
|
// ============================================================================
|
|
// == CallTypeHelper: Tools for generic programming on specializations. =======
|
|
// ============================================================================
|
|
template <template <CallType> class ClassT>
|
|
class CallTypeHelper final {
|
|
private:
|
|
static_assert(
|
|
CallType::PyCall == 0,
|
|
"CallTypeHelper uses integer math which depends on a zero start.");
|
|
static constexpr size_t End = CallTypeSize;
|
|
|
|
template <size_t... I>
|
|
static constexpr std::tuple<ClassT<(CallType)I>...> make_tuple_impl(
|
|
std::index_sequence<I...>);
|
|
|
|
template <size_t C, typename T, typename FunctorT, typename... Args>
|
|
static void map(T& t, FunctorT& f, Args&&... args) {
|
|
f(std::get<C>(t), args...);
|
|
if constexpr (C + 1 < End) {
|
|
map<C + 1>(t, f, std::forward<Args>(args)...);
|
|
}
|
|
}
|
|
|
|
public:
|
|
using tuple_type = decltype(make_tuple_impl(std::make_index_sequence<End>{}));
|
|
|
|
template <typename FunctorT, typename... Args>
|
|
static void map(tuple_type& t, FunctorT& f, Args&&... args) {
|
|
map<0>(t, f, std::forward<Args>(args)...);
|
|
}
|
|
};
|
|
|
|
// ============================================================================
|
|
// == Event type definitions. =================================================
|
|
// ============================================================================
|
|
// When we are tracing a Python program, the general procedure is to record
|
|
// every time we enter or exit a function and later replay these events during
|
|
// post processing. Thus, during the profiling phase we want to do the MINIMAL
|
|
// amount of work to capture all of the information that we need; otherwise we
|
|
// will distort the profile. (While we don't wish to be terribly inefficient
|
|
// during post processing, we are willing to do extra fixup work in post if it
|
|
// reduces overhead in the profiling phase.)
|
|
//
|
|
// When the tracer first enters a frame, it constructs a CallKey for that
|
|
// location. The contents of the key vary by context. For a python function
|
|
// the key is the (PyCodeObject*, int) pair that defines the bytecode of the
|
|
// function. For an `nn.Module` the key is a (non-owning) pointer to `self`.
|
|
// For a bound C function it is a (non-owning) pointer to the bound function.
|
|
// A CallKey should be small, inexpensive, and POD.
|
|
//
|
|
// We then collect a CallKey<CallType::PyCall> for the calling frame for better
|
|
// source tracking. This pair is a `Callsite`, and serves as a first level key
|
|
// during tracing. We lookup the Callsite in a thread local cache which maps
|
|
// Callsite to a unique integer `TraceKey`. On a cache hit, we simply store the
|
|
// TraceKey and return. On a cache miss, we use a global value cache to store
|
|
// whatever fields we need from the two CallKeys, generate a new TraceKey, and
|
|
// update the local cache.
|
|
//
|
|
// During post processing we:
|
|
// 1) Determine the type represented by a TraceKey by checking which
|
|
// sub-cache it appears in in the thread local cache.
|
|
// 2) Look up the pair of CallKeys from the thread local cache.
|
|
// 3) Look up the expanded values of each CallKey from the global value cache.
|
|
//
|
|
// To add a new event type to the cache:
|
|
// 1) Add an entry to the `CallType` enum.
|
|
// 2) Add a specialization of Config which defined key_t, ephemeral_t and
|
|
// cache_t.
|
|
// 3) Add a specialization of ValueCache::store and ValueCache::load.
|
|
//
|
|
// -------------------------
|
|
// -- Ephemeral arguments --
|
|
// -------------------------
|
|
// The value cache mechanism assumes that `key_t` is enough to specify the
|
|
// correct value. However it may not be possible to materialize a value using
|
|
// only an instance of `key_t`. As a result, the cache also accepts "ephemeral"
|
|
// inputs which can be used to populate the value cache. Ephemeral inputs come
|
|
// with two caveats:
|
|
// 1) They are NOT safe to save, and cannot be used after `ValueCache::store`.
|
|
// 2) They should be used to access data that is not expect to change from
|
|
// call to call, such as the name of a function.
|
|
|
|
template <CallType>
|
|
struct Config;
|
|
|
|
template <>
|
|
struct Config<CallType::PyCall> {
|
|
using key_t = CodeLocation;
|
|
using ephemeral_t = no_ephemeral_t;
|
|
using cache_t = ska::flat_hash_map<key_t, PyFrameState>;
|
|
static constexpr EventType event_type = EventType::PyCall;
|
|
};
|
|
|
|
template <typename Key, typename Cls, typename ParameterInfo>
|
|
struct ExtendedPyCallConfig {
|
|
using key_t = Key;
|
|
using cls_t = Cls;
|
|
using ephemeral_t = PyFrameObject*;
|
|
|
|
struct ClsAndParameters {
|
|
cls_t cls_;
|
|
std::vector<ParameterInfo> parameters_;
|
|
};
|
|
|
|
struct Cache {
|
|
// `nn.Module.forward` or `optim.Optimizer._optimizer_step_code`
|
|
std::optional<CodeLocation> location_;
|
|
ska::flat_hash_map<key_t, ClsAndParameters> cls_and_parameters_;
|
|
ska::flat_hash_map<cls_t, at::StringView> cls_names_;
|
|
};
|
|
using cache_t = Cache;
|
|
|
|
static constexpr EventType event_type = EventType::PyCall;
|
|
};
|
|
|
|
template <>
|
|
struct Config<CallType::PyModuleCall> : ExtendedPyCallConfig<
|
|
PyModuleSelf,
|
|
PyModuleCls,
|
|
NNModuleInfo::ParameterInfo> {};
|
|
|
|
template <>
|
|
struct Config<CallType::PyOptimizerCall> : ExtendedPyCallConfig<
|
|
PyOptimizerSelf,
|
|
PyOptimizerCls,
|
|
OptimizerInfo::ParameterInfo> {};
|
|
|
|
template <>
|
|
struct Config<CallType::PyCCall> {
|
|
using key_t = PyMethod;
|
|
using ephemeral_t = PyObject*;
|
|
using cache_t = ska::flat_hash_map<key_t, at::StringView>;
|
|
static constexpr EventType event_type = EventType::PyCCall;
|
|
};
|
|
|
|
// ============================================================================
|
|
// == Callsite & ValueCache: Storage during profiling =========================
|
|
// ============================================================================
|
|
template <CallType C>
|
|
class Callsite {
|
|
public:
|
|
static constexpr CallType call_type = C;
|
|
using key_t = typename Config<C>::key_t;
|
|
|
|
static_assert(
|
|
std::is_trivially_copyable_v<key_t>,
|
|
"Key should be trivial, as it is passed by value.");
|
|
|
|
template <typename U>
|
|
Callsite(U value, PyFrameObject* f_back) : value_(value), caller_(f_back) {}
|
|
|
|
bool operator==(const Callsite<C>& other) const {
|
|
return value_ == other.value_ && caller_ == other.caller_;
|
|
}
|
|
|
|
key_t value_;
|
|
Config<CallType::PyCall>::key_t caller_;
|
|
};
|
|
|
|
// ============================================================================
|
|
// == Type specific store and load implementations. ===========================
|
|
// ============================================================================
|
|
using PyCallKey = Config<CallType::PyCall>::key_t;
|
|
using PyModuleCallKey = Config<CallType::PyModuleCall>::key_t;
|
|
using PyCCallKey = Config<CallType::PyCCall>::key_t;
|
|
using PyOptimizerCallKey = Config<CallType::PyOptimizerCall>::key_t;
|
|
|
|
class ValueCache {
|
|
public:
|
|
ValueCache() = default;
|
|
ValueCache(const ValueCache&) = delete;
|
|
ValueCache& operator==(const ValueCache&) = delete;
|
|
ValueCache(ValueCache&&) = default;
|
|
ValueCache& operator==(ValueCache&&) = delete;
|
|
~ValueCache() = default;
|
|
|
|
template <CallType C>
|
|
void store(const typename Config<C>::key_t&, typename Config<C>::ephemeral_t);
|
|
|
|
template <CallType C>
|
|
auto load(const Callsite<C>& callsite, size_t python_tid) const {
|
|
auto caller = load<CallType::PyCall>(callsite.caller_);
|
|
TORCH_INTERNAL_ASSERT(!caller.module_info_.has_value());
|
|
return ExtraFields<Config<C>::event_type>{
|
|
/*end_time_ns=*/std::numeric_limits<c10::time_t>::min(),
|
|
python_tid,
|
|
caller.frame_state_,
|
|
load<C>(callsite.value_)};
|
|
}
|
|
|
|
std::optional<TensorMetadata> recordIfTensor(py::handle p);
|
|
std::vector<std::pair<std::string, TensorMetadata>> unpackTensorMap(
|
|
const py::dict& tensor_map);
|
|
void trimPrefixes();
|
|
|
|
private:
|
|
template <CallType C>
|
|
typename ExtraFields<Config<C>::event_type>::args_t load(
|
|
const typename Config<C>::key_t&) const;
|
|
|
|
template <CallType C>
|
|
using State = typename Config<C>::cache_t;
|
|
|
|
CallTypeHelper<State>::tuple_type state_;
|
|
};
|
|
|
|
template <CallType C>
|
|
typename Config<C>::cls_t set_class(
|
|
ValueCache* value_cache,
|
|
typename Config<C>::cache_t& cache,
|
|
const typename Config<C>::key_t& key,
|
|
const typename Config<C>::ephemeral_t& frame) {
|
|
if (C10_UNLIKELY(!cache.location_.has_value())) {
|
|
auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
|
|
TORCH_INTERNAL_ASSERT(code.get() == getCode<C>());
|
|
cache.location_ = PyCallKey(frame);
|
|
value_cache->store<CallType::PyCall>(*cache.location_, no_ephemeral_t());
|
|
}
|
|
|
|
auto cls_handle = py::handle((PyObject*)key).attr("__class__");
|
|
auto cls = typename Config<C>::cls_t(cls_handle.ptr());
|
|
if (cache.cls_names_.find(cls) == cache.cls_names_.end()) {
|
|
cache.cls_names_[cls] =
|
|
at::StringView(py::str(cls_handle.attr("__name__")));
|
|
}
|
|
return cls;
|
|
}
|
|
|
|
TensorMetadata toTensorMetadata(PyObject* self) {
|
|
TORCH_INTERNAL_ASSERT(THPVariable_CheckExact(self));
|
|
const auto& t = THPVariable_Unpack(self);
|
|
RawTensorMetadata m{t};
|
|
return TensorMetadata{
|
|
m,
|
|
t.sizes().vec(),
|
|
m.layout_ == at::kStrided ? t.strides().vec() : std::vector<int64_t>()};
|
|
}
|
|
|
|
std::optional<TensorMetadata> ValueCache::recordIfTensor(py::handle p) {
|
|
return THPVariable_CheckExact(p.ptr())
|
|
? std::optional<TensorMetadata>{toTensorMetadata(p.ptr())}
|
|
: std::nullopt;
|
|
}
|
|
|
|
std::vector<std::pair<std::string, TensorMetadata>> ValueCache::unpackTensorMap(
|
|
const py::dict& tensor_map) {
|
|
std::vector<std::pair<std::string, TensorMetadata>> out;
|
|
for (auto& it : tensor_map) {
|
|
auto* value = it.second.ptr();
|
|
if (py::isinstance<py::str>(it.first) && THPVariable_CheckExact(value)) {
|
|
out.emplace_back(
|
|
py::cast<std::string>(it.first), toTensorMetadata(value));
|
|
}
|
|
}
|
|
return out;
|
|
}
|
|
|
|
template <>
|
|
void ValueCache::store<CallType::PyCall>(const PyCallKey& key, no_ephemeral_t) {
|
|
auto& locations = std::get<CallType::PyCall>(state_);
|
|
if (C10_UNLIKELY(locations.find(key) == locations.end())) {
|
|
locations[key] = {
|
|
key.line_number_,
|
|
at::StringView(key.filename_),
|
|
at::StringView(key.name_)};
|
|
}
|
|
}
|
|
|
|
template <>
|
|
ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyCall>(
|
|
const PyCallKey& key) const {
|
|
return {std::get<CallType::PyCall>(state_).at(key), std::nullopt};
|
|
}
|
|
|
|
template <>
|
|
void ValueCache::store<CallType::PyModuleCall>(
|
|
const PyModuleCallKey& key,
|
|
Config<CallType::PyModuleCall>::ephemeral_t frame) {
|
|
auto& cache = std::get<CallType::PyModuleCall>(state_);
|
|
if (C10_UNLIKELY(
|
|
cache.cls_and_parameters_.find(key) ==
|
|
cache.cls_and_parameters_.end())) {
|
|
auto cls = set_class<CallType::PyModuleCall>(this, cache, key, frame);
|
|
|
|
py::dict params = py::handle((PyObject*)key).attr("_parameters");
|
|
std::vector<NNModuleInfo::ParameterInfo> params_;
|
|
for (auto& it : params) {
|
|
auto* p = it.second.ptr();
|
|
if (py::isinstance<py::str>(it.first) && THPVariable_CheckExact(p)) {
|
|
params_.push_back(
|
|
{it.first.cast<std::string>(),
|
|
toTensorMetadata(p),
|
|
recordIfTensor(py::getattr(it.second, "grad", py::none()))});
|
|
}
|
|
}
|
|
cache.cls_and_parameters_[key] = {cls, std::move(params_)};
|
|
}
|
|
}
|
|
|
|
template <>
|
|
ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyModuleCall>(
|
|
const PyModuleCallKey& key) const {
|
|
auto& cache = std::get<CallType::PyModuleCall>(state_);
|
|
TORCH_INTERNAL_ASSERT(cache.location_.has_value());
|
|
const auto& cls_and_parameters = cache.cls_and_parameters_.at(key);
|
|
const auto& cls = cls_and_parameters.cls_;
|
|
NNModuleInfo info{
|
|
key, cls, cache.cls_names_.at(cls), cls_and_parameters.parameters_};
|
|
return {
|
|
/*frame_state_=*/std::get<CallType::PyCall>(state_).at(*cache.location_),
|
|
/*module_info_=*/std::move(info),
|
|
/*optimizer_info_=*/std::nullopt};
|
|
}
|
|
|
|
template <>
|
|
void ValueCache::store<CallType::PyOptimizerCall>(
|
|
const PyOptimizerCallKey& key,
|
|
Config<CallType::PyOptimizerCall>::ephemeral_t frame) {
|
|
auto& cache = std::get<CallType::PyOptimizerCall>(state_);
|
|
if (C10_UNLIKELY(
|
|
cache.cls_and_parameters_.find(key) ==
|
|
cache.cls_and_parameters_.end())) {
|
|
auto cls = set_class<CallType::PyOptimizerCall>(this, cache, key, frame);
|
|
const py::handle self{(PyObject*)key};
|
|
std::vector<OptimizerInfo::ParameterInfo> params;
|
|
|
|
for (const auto& i : (py::list)self.attr("param_groups")) {
|
|
for (auto& param : py::cast<py::dict>(i).attr("get")("params")) {
|
|
if (THPVariable_CheckExact(param.ptr())) {
|
|
// While `self.state` is permitted to store data in an arbitrary way,
|
|
// all generic optimizers (SGD, Adam, etc) use param as the key since
|
|
// the state in question is tied to particular parameters. We can
|
|
// relax this assumption if the need arises.
|
|
params.push_back(
|
|
{toTensorMetadata(param.ptr()),
|
|
recordIfTensor(py::getattr(param, "grad", py::none())),
|
|
unpackTensorMap(py::cast<py::dict>(self.attr("state"))
|
|
.attr("get")(param, py::dict()))});
|
|
}
|
|
}
|
|
}
|
|
|
|
cache.cls_and_parameters_[key] = {cls, std::move(params)};
|
|
}
|
|
}
|
|
|
|
template <>
|
|
ExtraFields<EventType::PyCall>::args_t ValueCache::load<
|
|
CallType::PyOptimizerCall>(const PyOptimizerCallKey& key) const {
|
|
auto& cache = std::get<CallType::PyOptimizerCall>(state_);
|
|
const auto& cls_and_parameters = cache.cls_and_parameters_.at(key);
|
|
auto cls = cls_and_parameters.cls_;
|
|
OptimizerInfo info{
|
|
key, cls, cache.cls_names_.at(cls), cls_and_parameters.parameters_};
|
|
return {
|
|
/*frame_state_=*/std::get<CallType::PyCall>(state_).at(
|
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
|
cache.location_.value()),
|
|
/*module_info_=*/std::nullopt,
|
|
/*optimizer_info_=*/std::move(info)};
|
|
}
|
|
|
|
template <>
|
|
void ValueCache::store<CallType::PyCCall>(
|
|
const PyCCallKey& key,
|
|
Config<CallType::PyCCall>::ephemeral_t arg) {
|
|
auto& names = std::get<CallType::PyCCall>(state_);
|
|
if (C10_UNLIKELY(names.find(key) == names.end())) {
|
|
names[key] = at::StringView(py::repr(arg));
|
|
}
|
|
}
|
|
|
|
template <>
|
|
ExtraFields<EventType::PyCCall>::args_t ValueCache::load<CallType::PyCCall>(
|
|
const PyCCallKey& key) const {
|
|
return std::get<CallType::PyCCall>(state_).at(key);
|
|
}
|
|
|
|
// TODO: Use re2.
|
|
void ValueCache::trimPrefixes() {
|
|
static const auto prefixes = []() {
|
|
pybind11::gil_scoped_acquire gil;
|
|
return py::module::import("torch.profiler.python_tracer")
|
|
.attr("_prefix_regex")()
|
|
.cast<std::vector<std::string>>();
|
|
}();
|
|
|
|
for (auto& it : std::get<CallType::PyCall>(state_)) {
|
|
std::string filename = it.second.filename_.str();
|
|
for (const auto& p : prefixes) {
|
|
if (filename.compare(0, p.size(), p) == 0) {
|
|
filename.erase(0, p.size());
|
|
it.second.filename_ = at::StringView(filename);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// == TraceKey cache ==========================================================
|
|
// ============================================================================
|
|
using python_tracer::TraceKey;
|
|
|
|
TraceKey nextKey() {
|
|
static std::atomic<uint64_t> key{0};
|
|
return TraceKey{++key};
|
|
}
|
|
|
|
template <CallType C>
|
|
struct TraceKeyCacheState {
|
|
struct Hash {
|
|
size_t operator()(const Callsite<C>& key) {
|
|
return c10::get_hash(key.value_, key.caller_);
|
|
}
|
|
};
|
|
|
|
TraceKey intern(
|
|
Callsite<C> callsite,
|
|
typename Config<C>::ephemeral_t ephemeral,
|
|
ValueCache& value_cache) {
|
|
auto it = state_.find(callsite);
|
|
if (C10_UNLIKELY(it == state_.end())) {
|
|
value_cache.store<C>(callsite.value_, ephemeral);
|
|
value_cache.store<CallType::PyCall>(callsite.caller_, no_ephemeral_t());
|
|
it = state_.insert({callsite, nextKey()}).first;
|
|
}
|
|
return it->second;
|
|
}
|
|
|
|
auto lookup(Callsite<C>& callsite, ValueCache& value_cache) const {
|
|
return std::make_pair(
|
|
value_cache.load<C>(callsite.value_),
|
|
value_cache.load<CallType::PyCall>(callsite.caller_));
|
|
}
|
|
|
|
ska::flat_hash_map<Callsite<C>, TraceKey, Hash> state_;
|
|
};
|
|
|
|
// ============================================================================
|
|
// == Core CPython data types =================================================
|
|
// ============================================================================
|
|
// PyObject that allows different threads to record events without colliding.
|
|
// It is passed as the second argument when enabling tracing via
|
|
// `PyEval_SetProfile`.
|
|
struct ThreadLocalResults;
|
|
struct TraceContext {
|
|
PyObject_HEAD
|
|
ThreadLocalResults* thread_local_results_;
|
|
};
|
|
|
|
// CPython boilerplate to define `TraceContext` as a proper python object.
|
|
static PyTypeObject TraceContextType = {
|
|
PyVarObject_HEAD_INIT(nullptr, 0)
|
|
"TraceContext", /* tp_name */
|
|
sizeof(TraceContext), /* tp_basicsize */
|
|
0, /* tp_itemsize */
|
|
nullptr, /* tp_dealloc */
|
|
0,
|
|
/* tp_vectorcall_offset */
|
|
nullptr, /* tp_getattr */
|
|
nullptr, /* tp_setattr */
|
|
nullptr, /* tp_reserved */
|
|
nullptr, /* tp_repr */
|
|
nullptr, /* tp_as_number */
|
|
nullptr, /* tp_as_sequence */
|
|
nullptr, /* tp_as_mapping */
|
|
nullptr, /* tp_hash */
|
|
nullptr, /* tp_call */
|
|
nullptr, /* tp_str */
|
|
nullptr, /* tp_getattro */
|
|
nullptr, /* tp_setattro */
|
|
nullptr, /* tp_as_buffer */
|
|
Py_TPFLAGS_DEFAULT, /* tp_flags */
|
|
"Python tracer TLS", /* tp_doc */
|
|
nullptr, /* tp_traverse */
|
|
nullptr, /* tp_clear */
|
|
nullptr, /* tp_richcompare */
|
|
0, /* tp_weaklistoffset */
|
|
nullptr, /* tp_iter */
|
|
nullptr, /* tp_iternext */
|
|
nullptr, /* tp_methods */
|
|
nullptr, /* tp_members */
|
|
nullptr, /* tp_getset */
|
|
nullptr, /* tp_base */
|
|
nullptr, /* tp_dict */
|
|
nullptr, /* tp_descr_get */
|
|
nullptr, /* tp_descr_set */
|
|
0, /* tp_dictoffset */
|
|
nullptr, /* tp_init */
|
|
nullptr, /* tp_alloc */
|
|
PyType_GenericNew, /* tp_new */
|
|
nullptr /* tp_free */
|
|
};
|
|
|
|
class gil_and_restore_thread {
|
|
public:
|
|
gil_and_restore_thread() : initial_thread_state_{PyThreadState_Get()} {}
|
|
~gil_and_restore_thread() {
|
|
PyThreadState_Swap(initial_thread_state_);
|
|
|
|
// `gil_scoped_acquire` is a bit fragile in on-demand mode:
|
|
// https://github.com/pytorch/pytorch/pull/91684#issuecomment-1413154458
|
|
if (!Py_IsInitialized()) {
|
|
gil_.disarm();
|
|
}
|
|
}
|
|
|
|
PyThreadState* initial_thread_state() const {
|
|
return initial_thread_state_;
|
|
}
|
|
|
|
private:
|
|
pybind11::gil_scoped_acquire gil_;
|
|
PyThreadState* initial_thread_state_;
|
|
};
|
|
|
|
// ============================================================================
|
|
// == Thread local cache ======================================================
|
|
// ============================================================================
|
|
class PythonTracer;
|
|
struct ThreadLocalResults {
|
|
ThreadLocalResults(
|
|
PyThreadState* thread_state,
|
|
ValueCache* value_cache,
|
|
PythonTracer* active_tracer)
|
|
: thread_state_{thread_state},
|
|
ctx_{(TraceContext*)TraceContextType.tp_alloc(&TraceContextType, 0)},
|
|
value_cache_{value_cache},
|
|
active_tracer_{active_tracer} {
|
|
ctx_->thread_local_results_ = this;
|
|
}
|
|
|
|
ThreadLocalResults() = delete;
|
|
ThreadLocalResults(const ThreadLocalResults&) = delete;
|
|
ThreadLocalResults(ThreadLocalResults&&) = delete;
|
|
ThreadLocalResults& operator=(const ThreadLocalResults&) = delete;
|
|
ThreadLocalResults& operator=(const ThreadLocalResults&&) = delete;
|
|
|
|
~ThreadLocalResults() {
|
|
// Currently, there is a bug in Profiler when using Python 3.12 that causes
|
|
// a segfault when decrementing the refcount of a TraceContext during
|
|
// on-demand. We are purposefully allowing for a small leak in this
|
|
// situation to avoid the segfault. This should be fixed in the future.
|
|
#if PY_MAJOR_VERSION < 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 12)
|
|
Py_DECREF((PyObject*)ctx_);
|
|
#endif
|
|
}
|
|
|
|
template <CallType C, EventType E, typename Ephemeral, typename... Args>
|
|
TraceKey intern(Ephemeral ephemeral, Args... args) {
|
|
static_assert(
|
|
Config<C>::event_type == E,
|
|
"ThreadLocalResults.intern called from the wrong typed context.");
|
|
auto callsite = Callsite<C>(std::forward<Args>(args)...);
|
|
return std::get<C>(trace_keys_).intern(callsite, ephemeral, *value_cache_);
|
|
}
|
|
|
|
static constexpr size_t BLOCK_SIZE = 1024;
|
|
|
|
PyThreadState* thread_state_;
|
|
TraceContext* ctx_;
|
|
ValueCache* value_cache_;
|
|
PythonTracer* active_tracer_;
|
|
CallTypeHelper<TraceKeyCacheState>::tuple_type trace_keys_;
|
|
AppendOnlyList<c10::approx_time_t, BLOCK_SIZE> exit_times_;
|
|
AppendOnlyList<c10::approx_time_t, BLOCK_SIZE> c_exit_times_;
|
|
};
|
|
|
|
// ============================================================================
|
|
// == Tracing implementation ==================================================
|
|
// ============================================================================
|
|
class PythonTracer final : public python_tracer::PythonTracerBase {
|
|
public:
|
|
PythonTracer(torch::profiler::impl::RecordQueue* queue);
|
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
|
~PythonTracer() override;
|
|
|
|
static int pyProfileFn(
|
|
PyObject* obj,
|
|
PyFrameObject* frame,
|
|
int what,
|
|
PyObject* arg);
|
|
|
|
void stop() override;
|
|
void restart() override;
|
|
std::vector<std::shared_ptr<Result>> getEvents(
|
|
std::function<c10::time_t(c10::approx_time_t)> time_converter,
|
|
std::vector<python_tracer::CompressedEvent>& enters,
|
|
c10::time_t end_time_ns) override;
|
|
|
|
struct StartFrame {
|
|
TraceKey trace_key_;
|
|
c10::approx_time_t start_time{};
|
|
};
|
|
|
|
private:
|
|
void recordPyCall(
|
|
ThreadLocalResults& tls,
|
|
PyFrameObject* frame,
|
|
bool is_startup_frame);
|
|
|
|
void recordCCall(
|
|
ThreadLocalResults& tls,
|
|
PyFrameObject* frame,
|
|
PyObject* arg,
|
|
bool start_frame = false);
|
|
|
|
const std::vector<PyThreadState*> interpreterThreads() const;
|
|
|
|
PyObject* get_callable_from_frame(PyFrameObject* frame);
|
|
|
|
std::atomic<bool> active_lock_{false};
|
|
bool active_{false};
|
|
|
|
torch::profiler::impl::RecordQueue* queue_;
|
|
PyInterpreterState* interpreter_{nullptr};
|
|
PyCodeObject* module_call_code_;
|
|
PyCodeObject* optimizer_hook_;
|
|
|
|
std::vector<StartFrame> start_frames_;
|
|
std::deque<ThreadLocalResults> thread_local_results_;
|
|
ValueCache value_cache_;
|
|
};
|
|
|
|
const std::vector<PyThreadState*> PythonTracer::interpreterThreads() const {
|
|
pybind11::gil_scoped_acquire gil;
|
|
std::vector<PyThreadState*> out;
|
|
if (SOFT_ASSERT(interpreter_)) {
|
|
auto* thread_state = PyInterpreterState_ThreadHead(interpreter_);
|
|
while (thread_state != nullptr) {
|
|
out.push_back(thread_state);
|
|
thread_state = PyThreadState_Next(thread_state);
|
|
}
|
|
}
|
|
return out;
|
|
}
|
|
|
|
PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
|
|
: queue_(queue),
|
|
|
|
module_call_code_(getCode<CallType::PyModuleCall>()),
|
|
optimizer_hook_(getCode<CallType::PyOptimizerCall>()) {
|
|
TORCH_CHECK(queue_ != nullptr);
|
|
|
|
bool expected{false};
|
|
active_ = active_lock_.compare_exchange_strong(expected, true);
|
|
if (!active_) {
|
|
TORCH_WARN(
|
|
"There is already an active Python tracer. "
|
|
"Refusing to register profile functions.");
|
|
return;
|
|
}
|
|
|
|
gil_and_restore_thread gil;
|
|
interpreter_ = PyInterpreterState_Get();
|
|
|
|
if (!gil.initial_thread_state()) {
|
|
TORCH_WARN("PyThreadState_Get returned NULL");
|
|
return;
|
|
}
|
|
|
|
// Register the tracer in each thread.
|
|
for (const auto thread_state : interpreterThreads()) {
|
|
PyThreadState_Swap(thread_state);
|
|
|
|
thread_local_results_.emplace_back(thread_state, &value_cache_, this);
|
|
auto* ctx = thread_local_results_.back().ctx_;
|
|
|
|
// When we begin profiling there are already frames on the Python
|
|
// interpreter stack. To ensure a complete trace, we must push calls
|
|
// to all the prior frames onto our event stack. (We stop at depth=128)
|
|
|
|
std::vector<THPFrameObjectPtr> current_stack;
|
|
auto frame = PyEval_GetFrame();
|
|
Py_XINCREF(frame);
|
|
|
|
size_t depth = 0; // Make sure we can't infinite loop.
|
|
while (frame != nullptr) {
|
|
current_stack.emplace_back(frame);
|
|
if (++depth == 128) {
|
|
break;
|
|
}
|
|
|
|
// NB: `PyFrame_GetBack` returns a strong reference.
|
|
frame = PyFrame_GetBack(frame);
|
|
}
|
|
|
|
for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) {
|
|
recordPyCall(thread_local_results_.back(), it->get(), true);
|
|
PyFrameObject* frame = it->get();
|
|
PyObject* callable = get_callable_from_frame(frame);
|
|
if (callable) {
|
|
// If the frame has a callable, record it as a C call since
|
|
// PyEval_GetFrame only gets the python frame. We need to record this C
|
|
// call so that when exiting the profiler we don't have a mismatched C
|
|
// call.
|
|
recordCCall(thread_local_results_.back(), it->get(), callable, true);
|
|
}
|
|
|
|
auto frame_refcount = Py_REFCNT(it->get());
|
|
|
|
// We hold one reference in `current_stack`, and the interpreter holds
|
|
// another.
|
|
TORCH_INTERNAL_ASSERT(frame_refcount >= 2, frame_refcount);
|
|
}
|
|
|
|
// Note:
|
|
// This profile will not compose with other CPython profilers, and
|
|
// cannot be round tripped via `sys.settrace(sys.gettrace())`
|
|
PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx);
|
|
}
|
|
}
|
|
|
|
void PythonTracer::stop() {
|
|
gil_and_restore_thread gil;
|
|
if (active_) {
|
|
for (const auto thread_state : interpreterThreads()) {
|
|
if (thread_state->c_profilefunc == &PythonTracer::pyProfileFn) {
|
|
PyThreadState_Swap(thread_state);
|
|
PyEval_SetProfile(nullptr, nullptr);
|
|
}
|
|
}
|
|
|
|
auto lock_returned = active_lock_.compare_exchange_strong(active_, false);
|
|
active_ = false;
|
|
SOFT_ASSERT(lock_returned, "Failed to return python tracer lock.");
|
|
}
|
|
}
|
|
|
|
void PythonTracer::restart() {
|
|
gil_and_restore_thread gil;
|
|
active_ = active_lock_.compare_exchange_strong(active_, true);
|
|
if (!active_) {
|
|
TORCH_WARN(
|
|
"There is already an active Python tracer. "
|
|
"Refusing to register profile functions.");
|
|
return;
|
|
}
|
|
int cur_thread = 0;
|
|
for (const auto thread_state : interpreterThreads()) {
|
|
if (thread_state->c_profilefunc == nullptr) {
|
|
auto* ctx = thread_local_results_[cur_thread].ctx_;
|
|
PyThreadState_Swap(thread_state);
|
|
PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx);
|
|
}
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
|
PythonTracer::~PythonTracer() {
|
|
if (active_) {
|
|
TORCH_WARN("`PythonTracer::stop()` was not called.");
|
|
stop();
|
|
}
|
|
}
|
|
|
|
void PythonTracer::recordPyCall(
|
|
ThreadLocalResults& tls,
|
|
PyFrameObject* frame,
|
|
bool is_startup_frame) {
|
|
static constexpr auto E = EventType::PyCall;
|
|
const auto key = [&]() -> TraceKey {
|
|
auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
|
|
if (code.get() == module_call_code_) {
|
|
// By default, CPython stores locals in a "fast" format, with an array
|
|
// of names and an array of values. Consequently, frame->f_locals is
|
|
// NULL since the interpreter has no need to populate it.
|
|
//
|
|
// If these arrays were part of the public API then we could very
|
|
// quickly access `self`. Unfortunately they are not, and moreover are
|
|
// not stable across versions. As a result, we are forced to call
|
|
// `PyFrame_FastToLocals` which forces the interpreter to materialize
|
|
// the full dict of locals.
|
|
auto locals = THPObjectPtr(PyFrame_GetLocals(frame));
|
|
|
|
#if PY_MAJOR_VERSION < 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 13)
|
|
auto self = THPObjectPtr(PyDict_GetItemString(locals, "self"));
|
|
#else
|
|
// In Python-3.13+ `PyFrame_GetLocals()` returns instance of
|
|
// PyFrameLocalsProxy_Type See PEP 667 for more info
|
|
auto self = THPObjectPtr(PyMapping_GetItemString(locals, "self"));
|
|
#endif
|
|
Py_INCREF(self.get());
|
|
auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
|
|
TORCH_INTERNAL_ASSERT(back != nullptr);
|
|
return tls.intern<CallType::PyModuleCall, E>(
|
|
frame, self.get(), back.get());
|
|
} else if (code.get() == optimizer_hook_) {
|
|
auto locals = THPObjectPtr(PyFrame_GetLocals(frame));
|
|
#if PY_MAJOR_VERSION < 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 13)
|
|
auto self = THPObjectPtr(PyDict_GetItemString(locals, "self"));
|
|
#else
|
|
auto self = THPObjectPtr(PyMapping_GetItemString(locals, "self"));
|
|
#endif
|
|
Py_INCREF(self.get());
|
|
auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
|
|
TORCH_INTERNAL_ASSERT(back != nullptr);
|
|
return tls.intern<CallType::PyOptimizerCall, E>(
|
|
frame, self.get(), back.get());
|
|
} else {
|
|
auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
|
|
auto f_back = (back.get() != nullptr) ? back.get() : frame;
|
|
return tls.intern<CallType::PyCall, E>(no_ephemeral_t(), frame, f_back);
|
|
}
|
|
}();
|
|
const auto time = c10::getApproximateTime();
|
|
is_startup_frame ? start_frames_.push_back({key, time})
|
|
: queue_->getSubqueue()->emplace_py_call(key, time);
|
|
}
|
|
|
|
void PythonTracer::recordCCall(
|
|
ThreadLocalResults& tls,
|
|
PyFrameObject* frame,
|
|
PyObject* arg,
|
|
bool start_frame) {
|
|
// for starting frames we duplicate callable python functions to avoid having
|
|
// empty C frames in trace when exiting
|
|
if (!start_frame) {
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(PyCFunction_Check(arg));
|
|
}
|
|
auto fn = reinterpret_cast<PyCFunctionObject*>(arg);
|
|
|
|
// NB: For C calls a new frame is not created, so we use `frame` rather than
|
|
// `frame->f_back`.
|
|
auto key = tls.intern<CallType::PyCCall, EventType::PyCCall>(
|
|
arg, (void*)(fn->m_ml), frame);
|
|
queue_->getSubqueue()->emplace_py_call(key, c10::getApproximateTime());
|
|
}
|
|
|
|
PyObject* PythonTracer::get_callable_from_frame(PyFrameObject* frame) {
|
|
if (frame == nullptr) {
|
|
return nullptr;
|
|
}
|
|
// Get the code object associated with the frame
|
|
auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
|
|
if (code == nullptr) {
|
|
return nullptr;
|
|
}
|
|
// Get the function name (if needed)
|
|
auto name = THPUtils_unpackStringView(code->co_name).data();
|
|
// To get the function object, you will need to look in the globals or the
|
|
// frame's f_globals
|
|
PyObject* func = PyDict_GetItemString(PyFrame_GetGlobals(frame), name);
|
|
if (func) {
|
|
Py_INCREF(func); // Make sure the returned function has a reference
|
|
}
|
|
return func; // Returns a PyObject* (the function)
|
|
}
|
|
|
|
// ============================================================================
|
|
// == Post processing =========================================================
|
|
// ============================================================================
|
|
struct Exit {
|
|
bool operator>(const Exit& other) const {
|
|
return t_ > other.t_;
|
|
}
|
|
|
|
c10::time_t t_;
|
|
size_t python_tid_;
|
|
};
|
|
|
|
class PostProcess {
|
|
public:
|
|
PostProcess(
|
|
std::function<c10::time_t(c10::approx_time_t)> time_converter,
|
|
std::deque<ThreadLocalResults>& tls,
|
|
const ValueCache& value_cache,
|
|
c10::time_t end_time_ns)
|
|
: end_time_{end_time_ns}, time_converter_{std::move(time_converter)} {
|
|
for (size_t python_tid : c10::irange(tls.size())) {
|
|
CallTypeHelper<TraceKeyCacheState>::map(
|
|
tls[python_tid].trace_keys_, *this, value_cache, python_tid);
|
|
|
|
addExits<EventType::PyCall>(tls[python_tid].exit_times_, python_tid);
|
|
addExits<EventType::PyCCall>(tls[python_tid].c_exit_times_, python_tid);
|
|
}
|
|
}
|
|
|
|
void set_start_frames(
|
|
const std::vector<PythonTracer::StartFrame>& start_frames,
|
|
std::vector<python_tracer::CompressedEvent>& enters) {
|
|
for (const auto& frame : start_frames) {
|
|
enters.push_back(
|
|
{frame.trace_key_,
|
|
NoTID, // Allows us to detect unhandled start frames
|
|
{},
|
|
time_converter_(frame.start_time)});
|
|
}
|
|
}
|
|
|
|
template <CallType C>
|
|
void operator()(
|
|
const TraceKeyCacheState<C>& trace_cache,
|
|
const ValueCache& value_cache,
|
|
size_t python_tid) {
|
|
for (const auto& it : trace_cache.state_) {
|
|
const auto inserted = get_state<Config<C>::event_type>().fields_.insert(
|
|
{it.second, value_cache.load(it.first, python_tid)});
|
|
TORCH_INTERNAL_ASSERT(inserted.second, "Duplicate key: ", it.second);
|
|
}
|
|
}
|
|
|
|
template <EventType E, size_t N>
|
|
void addExits(
|
|
AppendOnlyList<c10::approx_time_t, N>& exits,
|
|
size_t python_tid) {
|
|
for (const auto i : exits) {
|
|
get_state<E>().exits_.push({time_converter_(i), python_tid});
|
|
}
|
|
}
|
|
|
|
std::vector<std::shared_ptr<Result>> run(
|
|
std::vector<python_tracer::CompressedEvent>& enters) {
|
|
std::stable_sort(
|
|
enters.begin(), enters.end(), [](const auto a, const auto b) {
|
|
return a.enter_t_ < b.enter_t_;
|
|
});
|
|
std::vector<std::shared_ptr<Result>> out;
|
|
populate<EventType::PyCall>(enters, out);
|
|
populate<EventType::PyCCall>(enters, out);
|
|
return out;
|
|
}
|
|
|
|
private:
|
|
template <EventType E>
|
|
void populate(
|
|
std::vector<python_tracer::CompressedEvent>& enters,
|
|
std::vector<std::shared_ptr<Result>>& out) {
|
|
using stack_t = std::vector<std::shared_ptr<Result>>;
|
|
const auto initial_size = out.size();
|
|
auto pop = [](stack_t& stack, c10::time_t t) {
|
|
if (!stack.empty()) {
|
|
std::get<ExtraFields<E>>(stack.back()->extra_fields_).end_time_ns_ = t;
|
|
stack.pop_back();
|
|
} else {
|
|
TORCH_WARN_ONCE(
|
|
"Python replay stack is empty during pop operation! May result in incorrect stack tracing.");
|
|
}
|
|
};
|
|
|
|
ska::flat_hash_map<size_t, stack_t> stacks;
|
|
auto& state = get_state<E>();
|
|
// We already own the GIL at this point
|
|
for (const auto& enter : enters) {
|
|
auto fields_it = state.fields_.find(enter.key_);
|
|
if (fields_it != state.fields_.end()) {
|
|
while (!state.exits_.empty() &&
|
|
state.exits_.top().t_ < enter.enter_t_) {
|
|
auto& exit = state.exits_.top();
|
|
auto& tstack = stacks[exit.python_tid_];
|
|
if (!tstack.empty()) {
|
|
pop(tstack, exit.t_);
|
|
}
|
|
state.exits_.pop();
|
|
}
|
|
out.push_back(Result::create(
|
|
enter.enter_t_,
|
|
enter.system_tid_,
|
|
enter.kineto_info_,
|
|
fields_it->second));
|
|
|
|
stacks[fields_it->second.python_tid_].push_back(out.back());
|
|
}
|
|
}
|
|
|
|
// Handle events which were still running when profiling ended.
|
|
for (auto& i : stacks) {
|
|
while (!i.second.empty()) {
|
|
pop(i.second, end_time_);
|
|
}
|
|
}
|
|
|
|
// Assign system TIDs to start events based on the system TID of the next
|
|
// observed event with the same Python TID.
|
|
ska::flat_hash_map<size_t, std::pair<size_t, kineto::DeviceAndResource>>
|
|
tid_map;
|
|
auto it = out.rbegin();
|
|
for ([[maybe_unused]] auto _ : c10::irange(initial_size, out.size())) {
|
|
const auto python_tid =
|
|
std::get<ExtraFields<E>>((*it)->extra_fields_).python_tid_;
|
|
if ((*it)->start_tid_ == NoTID && SOFT_ASSERT(E == EventType::PyCall)) {
|
|
const auto& tid_info =
|
|
tid_map.insert({python_tid, {NoTID, kineto::DeviceAndResource()}})
|
|
.first->second;
|
|
(*it)->start_tid_ = tid_info.first;
|
|
(*it)->kineto_info_ = tid_info.second;
|
|
}
|
|
tid_map[python_tid] = {(*it)->start_tid_, (*it)->kineto_info_};
|
|
++it;
|
|
}
|
|
}
|
|
|
|
template <EventType E>
|
|
struct State {
|
|
ska::flat_hash_map<TraceKey, ExtraFields<E>> fields_;
|
|
std::priority_queue<Exit, std::vector<Exit>, std::greater<>> exits_;
|
|
};
|
|
|
|
template <EventType E>
|
|
auto& get_state() {
|
|
return std::get < E == EventType::PyCall ? 0 : 1 > (state_);
|
|
}
|
|
|
|
c10::time_t end_time_;
|
|
std::function<c10::time_t(c10::approx_time_t)> time_converter_;
|
|
std::tuple<State<EventType::PyCall>, State<EventType::PyCCall>> state_;
|
|
};
|
|
|
|
struct PythonIDVisitor {
|
|
void operator()(ExtraFields<EventType::PyCall>& py_call) {
|
|
py_call.id_ = ++current_python_id_;
|
|
if (py_call.module_.has_value()) {
|
|
auto& m = py_call.module_;
|
|
auto& module_ids = module_ids_[m->cls_];
|
|
m->id_ = module_ids.insert({m->self_, module_ids.size()}).first->second;
|
|
}
|
|
}
|
|
|
|
void operator()(ExtraFields<EventType::PyCCall>& py_call) {
|
|
py_call.id_ = ++current_python_id_;
|
|
}
|
|
|
|
template <typename T>
|
|
void operator()(T&) {}
|
|
|
|
size_t current_python_id_{0};
|
|
ska::flat_hash_map<PyModuleCls, ska::flat_hash_map<PyModuleSelf, size_t>>
|
|
module_ids_;
|
|
};
|
|
|
|
std::vector<std::shared_ptr<Result>> PythonTracer::getEvents(
|
|
std::function<c10::time_t(c10::approx_time_t)> time_converter,
|
|
std::vector<python_tracer::CompressedEvent>& enters,
|
|
c10::time_t end_time_ns) {
|
|
value_cache_.trimPrefixes();
|
|
PostProcess post_process(
|
|
std::move(time_converter),
|
|
thread_local_results_,
|
|
value_cache_,
|
|
end_time_ns);
|
|
post_process.set_start_frames(start_frames_, enters);
|
|
auto out = post_process.run(enters);
|
|
|
|
std::stable_sort(out.begin(), out.end(), [](const auto& a, const auto& b) {
|
|
return a->start_time_ns_ < b->start_time_ns_;
|
|
});
|
|
|
|
PythonIDVisitor id_visitor;
|
|
for (auto& i : out) {
|
|
std::visit(id_visitor, i->extra_fields_);
|
|
}
|
|
|
|
return out;
|
|
}
|
|
// ============================================================================
|
|
// == Memory Tracer ======================================================
|
|
// ============================================================================
|
|
|
|
// Assuming python_tracer::PythonMemoryTracerBase is defined elsewhere
|
|
class PythonMemoryTracer final : public python_tracer::PythonMemoryTracerBase {
|
|
public:
|
|
explicit PythonMemoryTracer() = default;
|
|
~PythonMemoryTracer() override = default;
|
|
void start() override;
|
|
void stop() override;
|
|
void export_memory_history(const std::string& path) override;
|
|
};
|
|
|
|
static void toggle_memory_tracing(bool enable) {
|
|
pybind11::gil_scoped_acquire gil;
|
|
THPObjectPtr torch_cuda_memory_module(
|
|
PyImport_ImportModule("torch.cuda.memory"));
|
|
if (!torch_cuda_memory_module) {
|
|
return;
|
|
}
|
|
THPObjectPtr snapshot_func(PyObject_GetAttrString(
|
|
torch_cuda_memory_module.get(), "_record_memory_history_impl"));
|
|
if (!snapshot_func) {
|
|
return;
|
|
}
|
|
// Call the function with arguments
|
|
PyObject* args = PyTuple_New(6);
|
|
PyTuple_SetItem(args, 0, enable ? PyUnicode_FromString("all") : Py_None);
|
|
PyTuple_SetItem(args, 1, PyUnicode_FromString("all")); // context
|
|
PyTuple_SetItem(args, 2, PyUnicode_FromString("all")); // stacks
|
|
PyTuple_SetItem(args, 3, THPUtils_packInt64(100000)); // max_entries
|
|
PyTuple_SetItem(args, 4, Py_None); // device (None)
|
|
PyTuple_SetItem(args, 5, PyBool_FromLong(0)); // clear_history (False)
|
|
PyObject* result = PyObject_Call(snapshot_func.get(), args, nullptr);
|
|
Py_DECREF(args);
|
|
if (result == nullptr) {
|
|
return;
|
|
}
|
|
}
|
|
|
|
void PythonMemoryTracer::start() {
|
|
toggle_memory_tracing(true);
|
|
}
|
|
|
|
void PythonMemoryTracer::export_memory_history(const std::string& path) {
|
|
pybind11::gil_scoped_acquire gil;
|
|
THPObjectPtr torch_cuda_memory_module(
|
|
PyImport_ImportModule("torch.cuda.memory"));
|
|
if (!torch_cuda_memory_module) {
|
|
return;
|
|
}
|
|
THPObjectPtr snapshot_func(
|
|
PyObject_GetAttrString(torch_cuda_memory_module.get(), "_dump_snapshot"));
|
|
if (!snapshot_func) {
|
|
return;
|
|
}
|
|
PyObject* py_filename = PyUnicode_FromString(path.c_str());
|
|
// Call the function with arguments (e.g., a file path)
|
|
PyObject* args = PyTuple_Pack(1, py_filename);
|
|
PyObject* result = PyObject_Call(snapshot_func.get(), args, nullptr);
|
|
Py_DECREF(args);
|
|
if (result == nullptr) {
|
|
return;
|
|
}
|
|
}
|
|
|
|
void PythonMemoryTracer::stop() {
|
|
toggle_memory_tracing(false);
|
|
}
|
|
|
|
// ============================================================================
|
|
// == API =====================================================================
|
|
// ============================================================================
|
|
int PythonTracer::pyProfileFn(
|
|
PyObject* obj,
|
|
PyFrameObject* frame,
|
|
int what,
|
|
PyObject* arg) {
|
|
auto& local_results =
|
|
*reinterpret_cast<TraceContext*>(obj)->thread_local_results_;
|
|
switch (what) {
|
|
case PyTrace_CALL:
|
|
local_results.active_tracer_->recordPyCall(local_results, frame, false);
|
|
break;
|
|
|
|
case PyTrace_C_CALL:
|
|
local_results.active_tracer_->recordCCall(local_results, frame, arg);
|
|
break;
|
|
|
|
case PyTrace_RETURN:
|
|
local_results.exit_times_.emplace_back(c10::getApproximateTime());
|
|
break;
|
|
|
|
case PyTrace_C_EXCEPTION:
|
|
case PyTrace_C_RETURN:
|
|
local_results.c_exit_times_.emplace_back(c10::getApproximateTime());
|
|
break;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
std::unique_ptr<python_tracer::PythonTracerBase> getTracer(
|
|
torch::profiler::impl::RecordQueue* queue) {
|
|
return std::make_unique<PythonTracer>(queue);
|
|
}
|
|
|
|
std::unique_ptr<python_tracer::PythonMemoryTracerBase> getMemoryTracer() {
|
|
return std::make_unique<PythonMemoryTracer>();
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace torch::profiler::impl
|
|
|
|
namespace torch::autograd::profiler::python_tracer {
|
|
|
|
void init() {
|
|
pybind11::gil_scoped_acquire gil;
|
|
TORCH_CHECK(PyType_Ready(&torch::profiler::impl::TraceContextType) == 0);
|
|
torch::profiler::impl::python_tracer::registerTracer(
|
|
&torch::profiler::impl::getTracer);
|
|
torch::profiler::impl::python_tracer::registerMemoryTracer(
|
|
&torch::profiler::impl::getMemoryTracer);
|
|
}
|
|
} // namespace torch::autograd::profiler::python_tracer
|