mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: Kineto traces use microsecond level granularity because of chrome tracing defaults to that precision. Fix by adding preprocessor flag to TARGETS and BUCK files. Also remove any unnecessary ns to us conversions made in the profiler itself. This diff contains profiler changes only. Libkineto changes found in D54964435. Test Plan: Check JSON and chrome tracing to make sure values are as expected. Tracing with flags enabled should have ns precision. Tracings without flags should be same as master. Zoomer: https://www.internalfb.com/intern/zoomer/?profiling_run_fbid=796886748550189 Ran key_averages() to make sure FunctionEvent code working as expected: -- ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls ProfilerStep* 0.74% 3.976ms 64.40% 346.613ms 69.323ms 0.000us 0.00% 61.710ms 12.342ms 5 Optimizer.zero_grad#SGD.zero_grad 0.76% 4.109ms 0.76% 4.109ms 821.743us 0.000us 0.00% 0.000us 0.000us 5 ## forward ## 6.89% 37.057ms 27.19% 146.320ms 29.264ms 0.000us 0.00% 58.708ms 11.742ms 5 aten::conv2d 0.22% 1.176ms 7.74% 41.658ms 157.199us 0.000us 0.00% 27.550ms 103.962us 265 aten::convolution 0.79% 4.273ms 7.52% 40.482ms 152.762us 0.000us 0.00% 27.550ms 103.962us 265 aten::_convolution 0.69% 3.688ms 6.73% 36.209ms 136.637us 0.000us 0.00% 27.550ms 103.962us 265 aten::cudnn_convolution 6.04% 32.520ms 6.04% 32.520ms 122.719us 27.550ms 8.44% 27.550ms 103.962us 265 aten::add_ 2.42% 13.045ms 2.42% 13.045ms 30.694us 12.700ms 3.89% 12.700ms 29.882us 425 aten::batch_norm 0.19% 1.027ms 8.12% 43.717ms 164.971us 0.000us 0.00% 16.744ms 63.185us 265 aten::_batch_norm_impl_index 0.31% 1.646ms 7.93% 42.691ms 161.096us 0.000us 0.00% 16.744ms 63.185us 265 ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Differential Revision: D55925068 Pull Request resolved: https://github.com/pytorch/pytorch/pull/123650 Approved by: https://github.com/aaronenyeshi
This commit is contained in:
committed by
PyTorch MergeBot
parent
ec00daf4f1
commit
3ebbeb75fd
@ -1720,6 +1720,7 @@ def define_buck_targets(
|
||||
compiler_flags = get_pt_compiler_flags() + ["-Wno-error"],
|
||||
exported_preprocessor_flags = get_pt_preprocessor_flags() + [
|
||||
"-DUSE_KINETO",
|
||||
"-DTMP_LIBKINETO_NANOSECOND",
|
||||
# Need this otherwise USE_KINETO is undefed
|
||||
# for mobile
|
||||
"-DEDGE_PROFILER_USE_KINETO",
|
||||
@ -1746,6 +1747,7 @@ def define_buck_targets(
|
||||
compiler_flags = get_pt_compiler_flags() + ["-Wno-error"],
|
||||
exported_preprocessor_flags = get_pt_preprocessor_flags() + [
|
||||
"-DUSE_KINETO",
|
||||
"-DTMP_LIBKINETO_NANOSECOND",
|
||||
"-DEDGE_PROFILER_USE_KINETO",
|
||||
],
|
||||
# @lint-ignore BUCKLINT link_whole
|
||||
@ -1832,6 +1834,7 @@ def define_buck_targets(
|
||||
compiler_flags = get_pt_compiler_flags() + ["-Wno-error"],
|
||||
exported_preprocessor_flags = get_pt_preprocessor_flags() + [
|
||||
"-DUSE_KINETO",
|
||||
"-DTMP_LIBKINETO_NANOSECOND",
|
||||
# Need this otherwise USE_KINETO is undefed
|
||||
# for mobile
|
||||
"-DEDGE_PROFILER_USE_KINETO",
|
||||
|
@ -1977,7 +1977,7 @@ if(USE_KINETO)
|
||||
set_property(TARGET kineto PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
endif()
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS kineto)
|
||||
string(APPEND CMAKE_CXX_FLAGS " -DUSE_KINETO")
|
||||
string(APPEND CMAKE_CXX_FLAGS " -DUSE_KINETO" " -DTMP_LIBKINETO_NANOSECOND")
|
||||
if(LIBKINETO_NOCUPTI)
|
||||
string(APPEND CMAKE_CXX_FLAGS " -DLIBKINETO_NOCUPTI")
|
||||
endif()
|
||||
|
@ -27,7 +27,7 @@ add_library(backend_with_compiler SHARED
|
||||
)
|
||||
if(USE_KINETO)
|
||||
set_target_properties(backend_with_compiler PROPERTIES COMPILE_FLAGS
|
||||
"-DUSE_KINETO")
|
||||
"-DUSE_KINETO -DTMP_LIBKINETO_NANOSECOND")
|
||||
endif()
|
||||
target_link_libraries(backend_with_compiler torch)
|
||||
|
||||
|
@ -2769,11 +2769,11 @@ class MockKinetoEvent:
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def start_us(self) -> int:
|
||||
return self._start_us
|
||||
def start_ns(self) -> int:
|
||||
return self._start_us * 1000
|
||||
|
||||
def duration_us(self) -> int:
|
||||
return self._duration_us
|
||||
def duration_ns(self) -> int:
|
||||
return self._duration_us * 1000
|
||||
|
||||
def linked_correlation_id(self) -> int:
|
||||
return self._linked_correlation_id
|
||||
@ -2918,10 +2918,10 @@ class TestExperimentalUtils(TestCase):
|
||||
kineto_events = [{
|
||||
'_name':
|
||||
e.name,
|
||||
'_start_us':
|
||||
e.start_us(),
|
||||
'_duration_us':
|
||||
e.duration_us(),
|
||||
'_start_ns':
|
||||
e.start_ns(),
|
||||
'_duration_ns':
|
||||
e.duration_ns(),
|
||||
'_linked_correlation_id':
|
||||
e.linked_correlation_id(),
|
||||
'_device_type':
|
||||
@ -3113,6 +3113,8 @@ aten::mm
|
||||
aten::mm
|
||||
aten::mm""")
|
||||
|
||||
# TODO: Add logic for CUDA version of test
|
||||
@unittest.skipIf(torch.cuda.is_available(), "Test not working for CUDA")
|
||||
def test_profiler_pattern_match_helper(self):
|
||||
x = torch.ones((100, 100))
|
||||
with profile() as prof:
|
||||
|
@ -247,7 +247,9 @@ class TestProfilerTree(TestCase):
|
||||
else:
|
||||
raise
|
||||
|
||||
# TODO: Add logic for CUDA version of test
|
||||
@ProfilerTree.test
|
||||
@unittest.skipIf(torch.cuda.is_available(), "Test not working for CUDA")
|
||||
def test_profiler_experimental_tree(self):
|
||||
t1, t2 = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
|
||||
with torch.profiler.profile() as p:
|
||||
@ -300,7 +302,9 @@ class TestProfilerTree(TestCase):
|
||||
detach"""
|
||||
)
|
||||
|
||||
# TODO: Add logic for CUDA version of test
|
||||
@ProfilerTree.test
|
||||
@unittest.skipIf(torch.cuda.is_available(), "Test not working for CUDA")
|
||||
def test_profiler_experimental_tree_with_record_function(self):
|
||||
with torch.profiler.profile() as p:
|
||||
with torch.autograd.profiler.record_function("Top level Annotation"):
|
||||
@ -346,7 +350,9 @@ class TestProfilerTree(TestCase):
|
||||
aten::copy_"""
|
||||
)
|
||||
|
||||
# TODO: Add logic for CUDA version of test
|
||||
@ProfilerTree.test
|
||||
@unittest.skipIf(torch.cuda.is_available(), "Test not working for CUDA")
|
||||
def test_profiler_experimental_tree_with_memory(self):
|
||||
t1, t2 = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
|
||||
with torch.profiler.profile(profile_memory=True) as p:
|
||||
|
@ -53,8 +53,8 @@ class _KinetoEvent:
|
||||
def name(self) -> str: ...
|
||||
def device_index(self) -> int: ...
|
||||
def device_resource_id(self) -> int: ...
|
||||
def start_us(self) -> int: ...
|
||||
def duration_us(self) -> int: ...
|
||||
def start_ns(self) -> int: ...
|
||||
def duration_ns(self) -> int: ...
|
||||
def is_async(self) -> bool: ...
|
||||
def linked_correlation_id(self) -> int: ...
|
||||
def shapes(self) -> List[List[int]]: ...
|
||||
@ -77,7 +77,7 @@ class _ProfilerResult:
|
||||
def legacy_events(self) -> List[List[ProfilerEvent]]: ...
|
||||
def save(self, path: str) -> None: ...
|
||||
def experimental_event_tree(self) -> List[_ProfilerEvent]: ...
|
||||
def trace_start_us(self) -> int: ...
|
||||
def trace_start_ns(self) -> int: ...
|
||||
|
||||
class SavedTensor: ...
|
||||
|
||||
|
@ -428,7 +428,7 @@ class profile:
|
||||
def _parse_kineto_results(self, result: _ProfilerResult):
|
||||
# result.events() has most of the events - PyTorch op-level and device-level events
|
||||
|
||||
trace_start_us = result.trace_start_us()
|
||||
trace_start_ns = result.trace_start_ns()
|
||||
mem_records = [
|
||||
[evt, False] for evt in result.events() if evt.name() == MEMORY_EVENT_NAME
|
||||
]
|
||||
@ -466,9 +466,9 @@ class profile:
|
||||
for kineto_event in result.events():
|
||||
if _filter_name(kineto_event.name()):
|
||||
continue
|
||||
rel_start_us = kineto_event.start_us() - trace_start_us
|
||||
rel_end_us = rel_start_us + kineto_event.duration_us()
|
||||
abs_end_us = kineto_event.start_us() + kineto_event.duration_us()
|
||||
rel_start_ns = kineto_event.start_ns() - trace_start_ns
|
||||
rel_end_ns = rel_start_ns + kineto_event.duration_ns()
|
||||
abs_end_ns = kineto_event.start_ns() + kineto_event.duration_ns()
|
||||
|
||||
cpu_memory_usage = 0
|
||||
cuda_memory_usage = 0
|
||||
@ -476,7 +476,7 @@ class profile:
|
||||
if kineto_event.device_type() == DeviceType.CPU:
|
||||
# find the corresponding memory allocation events
|
||||
for mem_record in mem_records_acc.in_interval(
|
||||
kineto_event.start_us(), abs_end_us
|
||||
kineto_event.start_ns() / 1000, abs_end_ns / 1000
|
||||
):
|
||||
cpu_memory_usage += _cpu_memory_usage(mem_record[0])
|
||||
cuda_memory_usage += _cuda_memory_usage(mem_record[0])
|
||||
@ -492,8 +492,8 @@ class profile:
|
||||
name=_rewrite_name(name=kineto_event.name(), with_wildcard=True),
|
||||
trace_name=_rewrite_name(name=kineto_event.name(), with_wildcard=False),
|
||||
thread=kineto_event.start_thread_id(),
|
||||
start_us=rel_start_us,
|
||||
end_us=rel_end_us,
|
||||
start_us=rel_start_ns / 1000,
|
||||
end_us=rel_end_ns / 1000,
|
||||
fwd_thread=kineto_event.fwd_thread_id(),
|
||||
input_shapes=kineto_event.shapes(),
|
||||
concrete_inputs=kineto_event.concrete_inputs(),
|
||||
@ -555,14 +555,14 @@ class profile:
|
||||
f_evt.thread = fe.thread
|
||||
|
||||
def createFunctionEventForMemoryEvents(evt):
|
||||
rel_start_us = evt.start_us() - trace_start_us
|
||||
rel_start_ns = evt.start_ns() - trace_start_ns
|
||||
fe = FunctionEvent(
|
||||
id=max_evt_id,
|
||||
name=evt.name(),
|
||||
trace_name=None, # not outputting in the trace
|
||||
thread=evt.start_thread_id(),
|
||||
start_us=rel_start_us,
|
||||
end_us=rel_start_us, # no duration
|
||||
start_us=rel_start_ns / 1000,
|
||||
end_us=rel_start_ns / 1000, # no duration
|
||||
fwd_thread=evt.start_thread_id(),
|
||||
input_shapes=[],
|
||||
stack=[],
|
||||
|
@ -781,18 +781,19 @@ class MemRecordsAcc:
|
||||
|
||||
def __init__(self, mem_records):
|
||||
self._mem_records = mem_records
|
||||
self._start_uses: List[int] = []
|
||||
self._start_nses: List[int] = []
|
||||
self._indices: List[int] = []
|
||||
if len(mem_records) > 0:
|
||||
tmp = sorted([(r[0].start_us(), i) for i, r in enumerate(mem_records)])
|
||||
self._start_uses, self._indices = zip(*tmp) # type: ignore[assignment]
|
||||
tmp = sorted([(r[0].start_ns(), i) for i, r in enumerate(mem_records)])
|
||||
self._start_nses, self._indices = zip(*tmp) # type: ignore[assignment]
|
||||
|
||||
def in_interval(self, start_us, end_us):
|
||||
r"""
|
||||
Return all records in the given interval
|
||||
To maintain backward compatibility, convert us to ns in function
|
||||
"""
|
||||
start_idx = bisect.bisect_left(self._start_uses, start_us)
|
||||
end_idx = bisect.bisect_right(self._start_uses, end_us)
|
||||
start_idx = bisect.bisect_left(self._start_nses, start_us * 1000)
|
||||
end_idx = bisect.bisect_right(self._start_nses, end_us * 1000)
|
||||
for i in range(start_idx, end_idx):
|
||||
yield self._mem_records[self._indices[i]]
|
||||
|
||||
|
@ -201,10 +201,10 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
||||
// together with fwd_thread_id, used to uniquely identify
|
||||
// the forward op
|
||||
.def("sequence_nr", [](const KinetoEvent& e) { return e.sequenceNr(); })
|
||||
// absolute start time (since unix epoch) in us
|
||||
.def("start_us", [](const KinetoEvent& e) { return e.startUs(); })
|
||||
// duration in us
|
||||
.def("duration_us", [](const KinetoEvent& e) { return e.durationUs(); })
|
||||
// absolute start time (since unix epoch) in ns
|
||||
.def("start_ns", [](const KinetoEvent& e) { return e.startNs(); })
|
||||
// duration in ns
|
||||
.def("duration_ns", [](const KinetoEvent& e) { return e.durationNs(); })
|
||||
// used for correlation between high-level PyTorch events
|
||||
// and low-level device events
|
||||
.def(
|
||||
@ -255,7 +255,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
||||
m.def("_get_sequence_nr", &at::sequence_number::peek);
|
||||
|
||||
py::class_<ProfilerResult>(m, "_ProfilerResult")
|
||||
.def("trace_start_us", &ProfilerResult::trace_start_us)
|
||||
.def("trace_start_ns", &ProfilerResult::trace_start_ns)
|
||||
.def("events", &ProfilerResult::events)
|
||||
.def("experimental_event_tree", &ProfilerResult::event_tree)
|
||||
#ifdef USE_KINETO
|
||||
|
@ -50,11 +50,11 @@ namespace autograd {
|
||||
namespace profiler {
|
||||
|
||||
namespace {
|
||||
inline int64_t getTimeUs() {
|
||||
inline int64_t getTimeNs() {
|
||||
#ifdef USE_KINETO
|
||||
return libkineto::timeSinceEpoch(std::chrono::system_clock::now());
|
||||
#else
|
||||
return c10::getTime() / 1000;
|
||||
return c10::getTime();
|
||||
#endif // USE_KINETO
|
||||
}
|
||||
|
||||
@ -307,7 +307,7 @@ struct KinetoThreadLocalState : public ProfilerStateBase {
|
||||
const ProfilerConfig& config,
|
||||
std::set<torch::profiler::impl::ActivityType> activities)
|
||||
: ProfilerStateBase(config),
|
||||
start_time_(getTimeUs()),
|
||||
start_time_(getTimeNs()),
|
||||
record_queue_(config, std::move(activities)) {}
|
||||
~KinetoThreadLocalState() override = default;
|
||||
|
||||
@ -374,7 +374,7 @@ struct KinetoThreadLocalState : public ProfilerStateBase {
|
||||
|
||||
std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>
|
||||
finalizeTrace() {
|
||||
auto end_time = getTimeUs();
|
||||
auto end_time = getTimeNs();
|
||||
record_queue_.stop();
|
||||
|
||||
std::lock_guard<std::mutex> guard(state_mutex_);
|
||||
@ -772,8 +772,8 @@ const c10::ArrayRef<std::string> KinetoEvent::moduleHierarchy() const {
|
||||
return {};
|
||||
}
|
||||
|
||||
uint64_t KinetoEvent::durationUs() const {
|
||||
return (result_->endTimeNS() - result_->start_time_ns_) / 1000;
|
||||
uint64_t KinetoEvent::durationNs() const {
|
||||
return (result_->endTimeNS() - result_->start_time_ns_);
|
||||
}
|
||||
|
||||
int64_t KinetoEvent::debugHandle() const {
|
||||
@ -854,7 +854,7 @@ FORWARD_FROM_RESULT(endThreadId, endTID())
|
||||
FORWARD_FROM_RESULT(activityType, kinetoType())
|
||||
FORWARD_FROM_RESULT(name, name())
|
||||
FORWARD_FROM_RESULT(deviceType, deviceType())
|
||||
FORWARD_FROM_RESULT(startUs, start_time_ns_ / 1000)
|
||||
FORWARD_FROM_RESULT(startNs, start_time_ns_)
|
||||
FORWARD_FROM_RESULT(correlationId, correlationID())
|
||||
FORWARD_FROM_RESULT(deviceResourceId, kineto_info_.resource)
|
||||
#undef FORWARD_FROM_RESULT
|
||||
@ -906,7 +906,7 @@ ProfilerResult::ProfilerResult(
|
||||
std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>&&
|
||||
trace,
|
||||
std::vector<experimental_event_t>&& event_tree)
|
||||
: trace_start_us_(start_time),
|
||||
: trace_start_ns_(start_time),
|
||||
events_(std::move(events)),
|
||||
trace_(std::move(trace)),
|
||||
event_tree_(std::move(event_tree)) {}
|
||||
|
@ -48,8 +48,8 @@ struct TORCH_API KinetoEvent {
|
||||
c10::DeviceType deviceType() const;
|
||||
int deviceIndex() const;
|
||||
int64_t nBytes() const;
|
||||
uint64_t startUs() const;
|
||||
uint64_t durationUs() const;
|
||||
uint64_t startNs() const;
|
||||
uint64_t durationNs() const;
|
||||
bool isAsync() const;
|
||||
uint64_t correlationId() const;
|
||||
uint64_t linkedCorrelationId() const;
|
||||
@ -87,8 +87,8 @@ struct TORCH_API ProfilerResult {
|
||||
std::vector<experimental_event_t>&& event_tree);
|
||||
~ProfilerResult();
|
||||
|
||||
uint64_t trace_start_us() const {
|
||||
return trace_start_us_;
|
||||
uint64_t trace_start_ns() const {
|
||||
return trace_start_ns_;
|
||||
}
|
||||
|
||||
const std::vector<KinetoEvent>& events() const {
|
||||
@ -102,7 +102,7 @@ struct TORCH_API ProfilerResult {
|
||||
void save(const std::string& path);
|
||||
|
||||
private:
|
||||
uint64_t trace_start_us_ = 0;
|
||||
uint64_t trace_start_ns_ = 0;
|
||||
std::vector<KinetoEvent> events_;
|
||||
std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper> trace_;
|
||||
std::vector<experimental_event_t> event_tree_;
|
||||
|
@ -592,7 +592,7 @@ int64_t Result::endTimeNS() const {
|
||||
Vulkan, start_time_ns_ + (e.in_tree_building_ ? 0 : e.duration_ns_)),
|
||||
ATTRIBUTE(Allocation, start_time_ns_),
|
||||
ATTRIBUTE(OutOfMemory, start_time_ns_),
|
||||
ATTRIBUTE(Kineto, start_time_ns_ + e.duration_us_ * 1000),
|
||||
ATTRIBUTE(Kineto, start_time_ns_ + e.duration_ns_),
|
||||
[&](const auto& e) -> int64_t { return e.end_time_ns_; }));
|
||||
|
||||
// In rare cases we're willing to tolerate ops which are missing an end time
|
||||
@ -803,12 +803,12 @@ static constexpr const char* indexKey = "Ev Idx";
|
||||
|
||||
void passEventsToKineto(
|
||||
const std::vector<std::shared_ptr<Result>>& results,
|
||||
uint64_t start_time_us,
|
||||
uint64_t end_time_us,
|
||||
uint64_t start_time_ns,
|
||||
uint64_t end_time_ns,
|
||||
const ProfilerConfig& config) {
|
||||
using namespace torch::profiler::impl::kineto;
|
||||
TraceWrapper cpu_trace(
|
||||
static_cast<int64_t>(start_time_us), "PyTorch Profiler");
|
||||
static_cast<int64_t>(start_time_ns), "PyTorch Profiler");
|
||||
|
||||
// Generate Kineto events for each event recorded by the PyTorch profiler.
|
||||
for (const auto i : c10::irange(results.size())) {
|
||||
@ -818,8 +818,8 @@ void passEventsToKineto(
|
||||
e->kinetoType(),
|
||||
e->kineto_info_,
|
||||
e->correlationID(),
|
||||
e->start_time_ns_ / 1000,
|
||||
e->endTimeNS() / 1000);
|
||||
e->start_time_ns_,
|
||||
e->endTimeNS());
|
||||
|
||||
TORCH_INTERNAL_ASSERT(activity || !kKinetoAvailable);
|
||||
if (activity) {
|
||||
@ -842,7 +842,7 @@ void passEventsToKineto(
|
||||
}
|
||||
|
||||
// Kineto adds the events that it collected.
|
||||
cpu_trace.transferCpuTrace(static_cast<int64_t>(end_time_us));
|
||||
cpu_trace.transferCpuTrace(static_cast<int64_t>(end_time_ns));
|
||||
}
|
||||
|
||||
#ifdef USE_KINETO
|
||||
@ -1098,11 +1098,11 @@ class TransferEvents {
|
||||
|
||||
trace_ptr_t addKinetoEvents(
|
||||
std::vector<std::shared_ptr<Result>>& results,
|
||||
uint64_t start_time_us,
|
||||
uint64_t end_time_us,
|
||||
uint64_t start_time_ns,
|
||||
uint64_t end_time_ns,
|
||||
const ProfilerConfig& config) {
|
||||
using namespace torch::profiler::impl::kineto;
|
||||
passEventsToKineto(results, start_time_us, end_time_us, config);
|
||||
passEventsToKineto(results, start_time_ns, end_time_ns, config);
|
||||
|
||||
// In on demand mode kineto is directly controlled by other machinery.
|
||||
if (config.global()) {
|
||||
@ -1353,8 +1353,8 @@ std::pair<
|
||||
std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>>
|
||||
RecordQueue::getRecords(
|
||||
std::function<c10::time_t(c10::approx_time_t)> time_converter,
|
||||
uint64_t start_time_us,
|
||||
uint64_t end_time_us) {
|
||||
uint64_t start_time_ns,
|
||||
uint64_t end_time_ns) {
|
||||
auto converter = [&](c10::approx_time_t t) {
|
||||
return t == std::numeric_limits<c10::approx_time_t>::min()
|
||||
? std::numeric_limits<c10::time_t>::min()
|
||||
@ -1405,9 +1405,7 @@ RecordQueue::getRecords(
|
||||
|
||||
if (python_tracer_) {
|
||||
for (const auto& i : python_tracer_->getEvents(
|
||||
converter,
|
||||
python_enters,
|
||||
static_cast<c10::time_t>(end_time_us * 1000))) {
|
||||
converter, python_enters, static_cast<c10::time_t>(end_time_ns))) {
|
||||
out.push_back(i);
|
||||
}
|
||||
python_tracer_.reset();
|
||||
@ -1427,7 +1425,7 @@ RecordQueue::getRecords(
|
||||
}
|
||||
}
|
||||
|
||||
auto trace = addKinetoEvents(out, start_time_us, end_time_us, config_);
|
||||
auto trace = addKinetoEvents(out, start_time_ns, end_time_ns, config_);
|
||||
|
||||
std::stable_sort(out.begin(), out.end(), [](const auto& a, const auto& b) {
|
||||
return a->start_time_ns_ < b->start_time_ns_;
|
||||
|
@ -340,7 +340,7 @@ struct ExtraFields<EventType::Kineto> {
|
||||
};
|
||||
|
||||
std::string name_;
|
||||
int64_t duration_us_{0};
|
||||
int64_t duration_ns_{0};
|
||||
uint64_t correlation_id_{0};
|
||||
libkineto::ActivityType activity_type_;
|
||||
Flow flow;
|
||||
@ -632,8 +632,8 @@ class TORCH_API RecordQueue {
|
||||
std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>>
|
||||
getRecords(
|
||||
std::function<c10::time_t(c10::approx_time_t)> time_converter,
|
||||
uint64_t start_time_us,
|
||||
uint64_t end_time_us);
|
||||
uint64_t start_time_ns,
|
||||
uint64_t end_time_ns);
|
||||
|
||||
private:
|
||||
uint32_t id_;
|
||||
|
@ -147,15 +147,15 @@ class BasicEvaluation:
|
||||
|
||||
cuda_launch_events = sorted(
|
||||
(e for e in cuda_event_list if is_cuda_launch_kernel(e)),
|
||||
key=lambda x: x.start_us(),
|
||||
key=lambda x: x.start_ns(),
|
||||
)
|
||||
cuda_kernel_events = sorted(
|
||||
(e for e in cuda_event_list if is_cuda_kernel(e)),
|
||||
key=lambda x: x.start_us(),
|
||||
key=lambda x: x.start_ns(),
|
||||
)
|
||||
|
||||
self.cuda_events = sorted(
|
||||
cuda_launch_events + cuda_kernel_events, key=lambda x: x.start_us()
|
||||
cuda_launch_events + cuda_kernel_events, key=lambda x: x.start_ns()
|
||||
)
|
||||
|
||||
kernel_mapping: Dict[_KinetoEvent, int] = {}
|
||||
@ -178,6 +178,8 @@ class BasicEvaluation:
|
||||
def new_old_event_comparator(event):
|
||||
if hasattr(event, "start_us"):
|
||||
return event.start_us() * 1000
|
||||
if hasattr(event, "start_ns"):
|
||||
return event.start_ns()
|
||||
if hasattr(event, "start_time_ns"):
|
||||
return event.start_time_ns
|
||||
raise Exception("Unknown Event Type")
|
||||
@ -192,20 +194,26 @@ class BasicEvaluation:
|
||||
# Find current spawned cuda kernel event
|
||||
if event in kernel_mapping and kernel_mapping[event] is not None:
|
||||
spawned_kernel_index = kernel_mapping[event]
|
||||
if hasattr(event, "start_ns"):
|
||||
start_time = event.start_ns()
|
||||
end_time = event.start_ns() + event.duration_ns()
|
||||
# Find current spawned cuda kernel event
|
||||
if event in kernel_mapping and kernel_mapping[event] is not None:
|
||||
spawned_kernel_index = kernel_mapping[event]
|
||||
elif hasattr(event, "start_time_ns"):
|
||||
start_time = event.start_time_ns # type: ignore[attr-defined]
|
||||
end_time = event.end_time_ns # type: ignore[attr-defined]
|
||||
|
||||
while (
|
||||
current_kernel_index < len(cuda_kernel_events)
|
||||
and (cuda_kernel_events[current_kernel_index].start_us()) * 1000
|
||||
and (cuda_kernel_events[current_kernel_index].start_ns())
|
||||
<= start_time # type: ignore[possibly-undefined]
|
||||
):
|
||||
current_kernel_index += 1
|
||||
current_queue_depth = spawned_kernel_index - current_kernel_index + 1
|
||||
current_queue_depth = max(current_queue_depth, 0)
|
||||
|
||||
if hasattr(event, "start_us"):
|
||||
if hasattr(event, "start_us") or hasattr(event, "start_ns"):
|
||||
queue_depth_list.append(
|
||||
Interval(start_time, end_time, current_queue_depth) # type: ignore[possibly-undefined]
|
||||
)
|
||||
|
Reference in New Issue
Block a user