[Profiler] Make Kineto traces export ns granularity for finer timestamps (#122425) (#123650)

Summary:

Kineto traces use microsecond level granularity because of chrome tracing defaults to that precision. Fix by adding preprocessor flag to TARGETS and BUCK files. Also remove any unnecessary ns to us conversions made in the profiler itself.

This diff contains profiler changes only. Libkineto changes found in D54964435.

Test Plan:
Check JSON and chrome tracing to make sure values are as expected. Tracing with flags enabled should have ns precision. Tracings without flags should be same as master.
Zoomer: https://www.internalfb.com/intern/zoomer/?profiling_run_fbid=796886748550189
Ran key_averages() to make sure FunctionEvent code working as expected:
--  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls

                                          ProfilerStep*         0.74%       3.976ms        64.40%     346.613ms      69.323ms       0.000us         0.00%      61.710ms      12.342ms             5
                      Optimizer.zero_grad#SGD.zero_grad         0.76%       4.109ms         0.76%       4.109ms     821.743us       0.000us         0.00%       0.000us       0.000us             5
                                          ## forward ##         6.89%      37.057ms        27.19%     146.320ms      29.264ms       0.000us         0.00%      58.708ms      11.742ms             5
                                           aten::conv2d         0.22%       1.176ms         7.74%      41.658ms     157.199us       0.000us         0.00%      27.550ms     103.962us           265
                                      aten::convolution         0.79%       4.273ms         7.52%      40.482ms     152.762us       0.000us         0.00%      27.550ms     103.962us           265
                                     aten::_convolution         0.69%       3.688ms         6.73%      36.209ms     136.637us       0.000us         0.00%      27.550ms     103.962us           265
                                aten::cudnn_convolution         6.04%      32.520ms         6.04%      32.520ms     122.719us      27.550ms         8.44%      27.550ms     103.962us           265
                                             aten::add_         2.42%      13.045ms         2.42%      13.045ms      30.694us      12.700ms         3.89%      12.700ms      29.882us           425
                                       aten::batch_norm         0.19%       1.027ms         8.12%      43.717ms     164.971us       0.000us         0.00%      16.744ms      63.185us           265
                           aten::_batch_norm_impl_index         0.31%       1.646ms         7.93%      42.691ms     161.096us       0.000us         0.00%      16.744ms      63.185us           265
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------

Differential Revision: D55925068

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123650
Approved by: https://github.com/aaronenyeshi
This commit is contained in:
Shivam Raikundalia
2024-04-11 04:29:20 +00:00
committed by PyTorch MergeBot
parent ec00daf4f1
commit 3ebbeb75fd
14 changed files with 88 additions and 70 deletions

View File

@ -1720,6 +1720,7 @@ def define_buck_targets(
compiler_flags = get_pt_compiler_flags() + ["-Wno-error"],
exported_preprocessor_flags = get_pt_preprocessor_flags() + [
"-DUSE_KINETO",
"-DTMP_LIBKINETO_NANOSECOND",
# Need this otherwise USE_KINETO is undefed
# for mobile
"-DEDGE_PROFILER_USE_KINETO",
@ -1746,6 +1747,7 @@ def define_buck_targets(
compiler_flags = get_pt_compiler_flags() + ["-Wno-error"],
exported_preprocessor_flags = get_pt_preprocessor_flags() + [
"-DUSE_KINETO",
"-DTMP_LIBKINETO_NANOSECOND",
"-DEDGE_PROFILER_USE_KINETO",
],
# @lint-ignore BUCKLINT link_whole
@ -1832,6 +1834,7 @@ def define_buck_targets(
compiler_flags = get_pt_compiler_flags() + ["-Wno-error"],
exported_preprocessor_flags = get_pt_preprocessor_flags() + [
"-DUSE_KINETO",
"-DTMP_LIBKINETO_NANOSECOND",
# Need this otherwise USE_KINETO is undefed
# for mobile
"-DEDGE_PROFILER_USE_KINETO",

View File

@ -1977,7 +1977,7 @@ if(USE_KINETO)
set_property(TARGET kineto PROPERTY POSITION_INDEPENDENT_CODE ON)
endif()
list(APPEND Caffe2_DEPENDENCY_LIBS kineto)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_KINETO")
string(APPEND CMAKE_CXX_FLAGS " -DUSE_KINETO" " -DTMP_LIBKINETO_NANOSECOND")
if(LIBKINETO_NOCUPTI)
string(APPEND CMAKE_CXX_FLAGS " -DLIBKINETO_NOCUPTI")
endif()

View File

@ -27,7 +27,7 @@ add_library(backend_with_compiler SHARED
)
if(USE_KINETO)
set_target_properties(backend_with_compiler PROPERTIES COMPILE_FLAGS
"-DUSE_KINETO")
"-DUSE_KINETO -DTMP_LIBKINETO_NANOSECOND")
endif()
target_link_libraries(backend_with_compiler torch)

View File

@ -2769,11 +2769,11 @@ class MockKinetoEvent:
def name(self) -> str:
return self._name
def start_us(self) -> int:
return self._start_us
def start_ns(self) -> int:
return self._start_us * 1000
def duration_us(self) -> int:
return self._duration_us
def duration_ns(self) -> int:
return self._duration_us * 1000
def linked_correlation_id(self) -> int:
return self._linked_correlation_id
@ -2918,10 +2918,10 @@ class TestExperimentalUtils(TestCase):
kineto_events = [{
'_name':
e.name,
'_start_us':
e.start_us(),
'_duration_us':
e.duration_us(),
'_start_ns':
e.start_ns(),
'_duration_ns':
e.duration_ns(),
'_linked_correlation_id':
e.linked_correlation_id(),
'_device_type':
@ -3113,6 +3113,8 @@ aten::mm
aten::mm
aten::mm""")
# TODO: Add logic for CUDA version of test
@unittest.skipIf(torch.cuda.is_available(), "Test not working for CUDA")
def test_profiler_pattern_match_helper(self):
x = torch.ones((100, 100))
with profile() as prof:

View File

@ -247,7 +247,9 @@ class TestProfilerTree(TestCase):
else:
raise
# TODO: Add logic for CUDA version of test
@ProfilerTree.test
@unittest.skipIf(torch.cuda.is_available(), "Test not working for CUDA")
def test_profiler_experimental_tree(self):
t1, t2 = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
with torch.profiler.profile() as p:
@ -300,7 +302,9 @@ class TestProfilerTree(TestCase):
detach"""
)
# TODO: Add logic for CUDA version of test
@ProfilerTree.test
@unittest.skipIf(torch.cuda.is_available(), "Test not working for CUDA")
def test_profiler_experimental_tree_with_record_function(self):
with torch.profiler.profile() as p:
with torch.autograd.profiler.record_function("Top level Annotation"):
@ -346,7 +350,9 @@ class TestProfilerTree(TestCase):
aten::copy_"""
)
# TODO: Add logic for CUDA version of test
@ProfilerTree.test
@unittest.skipIf(torch.cuda.is_available(), "Test not working for CUDA")
def test_profiler_experimental_tree_with_memory(self):
t1, t2 = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
with torch.profiler.profile(profile_memory=True) as p:

View File

@ -53,8 +53,8 @@ class _KinetoEvent:
def name(self) -> str: ...
def device_index(self) -> int: ...
def device_resource_id(self) -> int: ...
def start_us(self) -> int: ...
def duration_us(self) -> int: ...
def start_ns(self) -> int: ...
def duration_ns(self) -> int: ...
def is_async(self) -> bool: ...
def linked_correlation_id(self) -> int: ...
def shapes(self) -> List[List[int]]: ...
@ -77,7 +77,7 @@ class _ProfilerResult:
def legacy_events(self) -> List[List[ProfilerEvent]]: ...
def save(self, path: str) -> None: ...
def experimental_event_tree(self) -> List[_ProfilerEvent]: ...
def trace_start_us(self) -> int: ...
def trace_start_ns(self) -> int: ...
class SavedTensor: ...

View File

@ -428,7 +428,7 @@ class profile:
def _parse_kineto_results(self, result: _ProfilerResult):
# result.events() has most of the events - PyTorch op-level and device-level events
trace_start_us = result.trace_start_us()
trace_start_ns = result.trace_start_ns()
mem_records = [
[evt, False] for evt in result.events() if evt.name() == MEMORY_EVENT_NAME
]
@ -466,9 +466,9 @@ class profile:
for kineto_event in result.events():
if _filter_name(kineto_event.name()):
continue
rel_start_us = kineto_event.start_us() - trace_start_us
rel_end_us = rel_start_us + kineto_event.duration_us()
abs_end_us = kineto_event.start_us() + kineto_event.duration_us()
rel_start_ns = kineto_event.start_ns() - trace_start_ns
rel_end_ns = rel_start_ns + kineto_event.duration_ns()
abs_end_ns = kineto_event.start_ns() + kineto_event.duration_ns()
cpu_memory_usage = 0
cuda_memory_usage = 0
@ -476,7 +476,7 @@ class profile:
if kineto_event.device_type() == DeviceType.CPU:
# find the corresponding memory allocation events
for mem_record in mem_records_acc.in_interval(
kineto_event.start_us(), abs_end_us
kineto_event.start_ns() / 1000, abs_end_ns / 1000
):
cpu_memory_usage += _cpu_memory_usage(mem_record[0])
cuda_memory_usage += _cuda_memory_usage(mem_record[0])
@ -492,8 +492,8 @@ class profile:
name=_rewrite_name(name=kineto_event.name(), with_wildcard=True),
trace_name=_rewrite_name(name=kineto_event.name(), with_wildcard=False),
thread=kineto_event.start_thread_id(),
start_us=rel_start_us,
end_us=rel_end_us,
start_us=rel_start_ns / 1000,
end_us=rel_end_ns / 1000,
fwd_thread=kineto_event.fwd_thread_id(),
input_shapes=kineto_event.shapes(),
concrete_inputs=kineto_event.concrete_inputs(),
@ -555,14 +555,14 @@ class profile:
f_evt.thread = fe.thread
def createFunctionEventForMemoryEvents(evt):
rel_start_us = evt.start_us() - trace_start_us
rel_start_ns = evt.start_ns() - trace_start_ns
fe = FunctionEvent(
id=max_evt_id,
name=evt.name(),
trace_name=None, # not outputting in the trace
thread=evt.start_thread_id(),
start_us=rel_start_us,
end_us=rel_start_us, # no duration
start_us=rel_start_ns / 1000,
end_us=rel_start_ns / 1000, # no duration
fwd_thread=evt.start_thread_id(),
input_shapes=[],
stack=[],

View File

@ -781,18 +781,19 @@ class MemRecordsAcc:
def __init__(self, mem_records):
self._mem_records = mem_records
self._start_uses: List[int] = []
self._start_nses: List[int] = []
self._indices: List[int] = []
if len(mem_records) > 0:
tmp = sorted([(r[0].start_us(), i) for i, r in enumerate(mem_records)])
self._start_uses, self._indices = zip(*tmp) # type: ignore[assignment]
tmp = sorted([(r[0].start_ns(), i) for i, r in enumerate(mem_records)])
self._start_nses, self._indices = zip(*tmp) # type: ignore[assignment]
def in_interval(self, start_us, end_us):
r"""
Return all records in the given interval
To maintain backward compatibility, convert us to ns in function
"""
start_idx = bisect.bisect_left(self._start_uses, start_us)
end_idx = bisect.bisect_right(self._start_uses, end_us)
start_idx = bisect.bisect_left(self._start_nses, start_us * 1000)
end_idx = bisect.bisect_right(self._start_nses, end_us * 1000)
for i in range(start_idx, end_idx):
yield self._mem_records[self._indices[i]]

View File

@ -201,10 +201,10 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
// together with fwd_thread_id, used to uniquely identify
// the forward op
.def("sequence_nr", [](const KinetoEvent& e) { return e.sequenceNr(); })
// absolute start time (since unix epoch) in us
.def("start_us", [](const KinetoEvent& e) { return e.startUs(); })
// duration in us
.def("duration_us", [](const KinetoEvent& e) { return e.durationUs(); })
// absolute start time (since unix epoch) in ns
.def("start_ns", [](const KinetoEvent& e) { return e.startNs(); })
// duration in ns
.def("duration_ns", [](const KinetoEvent& e) { return e.durationNs(); })
// used for correlation between high-level PyTorch events
// and low-level device events
.def(
@ -255,7 +255,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
m.def("_get_sequence_nr", &at::sequence_number::peek);
py::class_<ProfilerResult>(m, "_ProfilerResult")
.def("trace_start_us", &ProfilerResult::trace_start_us)
.def("trace_start_ns", &ProfilerResult::trace_start_ns)
.def("events", &ProfilerResult::events)
.def("experimental_event_tree", &ProfilerResult::event_tree)
#ifdef USE_KINETO

View File

@ -50,11 +50,11 @@ namespace autograd {
namespace profiler {
namespace {
inline int64_t getTimeUs() {
inline int64_t getTimeNs() {
#ifdef USE_KINETO
return libkineto::timeSinceEpoch(std::chrono::system_clock::now());
#else
return c10::getTime() / 1000;
return c10::getTime();
#endif // USE_KINETO
}
@ -307,7 +307,7 @@ struct KinetoThreadLocalState : public ProfilerStateBase {
const ProfilerConfig& config,
std::set<torch::profiler::impl::ActivityType> activities)
: ProfilerStateBase(config),
start_time_(getTimeUs()),
start_time_(getTimeNs()),
record_queue_(config, std::move(activities)) {}
~KinetoThreadLocalState() override = default;
@ -374,7 +374,7 @@ struct KinetoThreadLocalState : public ProfilerStateBase {
std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>
finalizeTrace() {
auto end_time = getTimeUs();
auto end_time = getTimeNs();
record_queue_.stop();
std::lock_guard<std::mutex> guard(state_mutex_);
@ -772,8 +772,8 @@ const c10::ArrayRef<std::string> KinetoEvent::moduleHierarchy() const {
return {};
}
uint64_t KinetoEvent::durationUs() const {
return (result_->endTimeNS() - result_->start_time_ns_) / 1000;
uint64_t KinetoEvent::durationNs() const {
return (result_->endTimeNS() - result_->start_time_ns_);
}
int64_t KinetoEvent::debugHandle() const {
@ -854,7 +854,7 @@ FORWARD_FROM_RESULT(endThreadId, endTID())
FORWARD_FROM_RESULT(activityType, kinetoType())
FORWARD_FROM_RESULT(name, name())
FORWARD_FROM_RESULT(deviceType, deviceType())
FORWARD_FROM_RESULT(startUs, start_time_ns_ / 1000)
FORWARD_FROM_RESULT(startNs, start_time_ns_)
FORWARD_FROM_RESULT(correlationId, correlationID())
FORWARD_FROM_RESULT(deviceResourceId, kineto_info_.resource)
#undef FORWARD_FROM_RESULT
@ -906,7 +906,7 @@ ProfilerResult::ProfilerResult(
std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>&&
trace,
std::vector<experimental_event_t>&& event_tree)
: trace_start_us_(start_time),
: trace_start_ns_(start_time),
events_(std::move(events)),
trace_(std::move(trace)),
event_tree_(std::move(event_tree)) {}

View File

@ -48,8 +48,8 @@ struct TORCH_API KinetoEvent {
c10::DeviceType deviceType() const;
int deviceIndex() const;
int64_t nBytes() const;
uint64_t startUs() const;
uint64_t durationUs() const;
uint64_t startNs() const;
uint64_t durationNs() const;
bool isAsync() const;
uint64_t correlationId() const;
uint64_t linkedCorrelationId() const;
@ -87,8 +87,8 @@ struct TORCH_API ProfilerResult {
std::vector<experimental_event_t>&& event_tree);
~ProfilerResult();
uint64_t trace_start_us() const {
return trace_start_us_;
uint64_t trace_start_ns() const {
return trace_start_ns_;
}
const std::vector<KinetoEvent>& events() const {
@ -102,7 +102,7 @@ struct TORCH_API ProfilerResult {
void save(const std::string& path);
private:
uint64_t trace_start_us_ = 0;
uint64_t trace_start_ns_ = 0;
std::vector<KinetoEvent> events_;
std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper> trace_;
std::vector<experimental_event_t> event_tree_;

View File

@ -592,7 +592,7 @@ int64_t Result::endTimeNS() const {
Vulkan, start_time_ns_ + (e.in_tree_building_ ? 0 : e.duration_ns_)),
ATTRIBUTE(Allocation, start_time_ns_),
ATTRIBUTE(OutOfMemory, start_time_ns_),
ATTRIBUTE(Kineto, start_time_ns_ + e.duration_us_ * 1000),
ATTRIBUTE(Kineto, start_time_ns_ + e.duration_ns_),
[&](const auto& e) -> int64_t { return e.end_time_ns_; }));
// In rare cases we're willing to tolerate ops which are missing an end time
@ -803,12 +803,12 @@ static constexpr const char* indexKey = "Ev Idx";
void passEventsToKineto(
const std::vector<std::shared_ptr<Result>>& results,
uint64_t start_time_us,
uint64_t end_time_us,
uint64_t start_time_ns,
uint64_t end_time_ns,
const ProfilerConfig& config) {
using namespace torch::profiler::impl::kineto;
TraceWrapper cpu_trace(
static_cast<int64_t>(start_time_us), "PyTorch Profiler");
static_cast<int64_t>(start_time_ns), "PyTorch Profiler");
// Generate Kineto events for each event recorded by the PyTorch profiler.
for (const auto i : c10::irange(results.size())) {
@ -818,8 +818,8 @@ void passEventsToKineto(
e->kinetoType(),
e->kineto_info_,
e->correlationID(),
e->start_time_ns_ / 1000,
e->endTimeNS() / 1000);
e->start_time_ns_,
e->endTimeNS());
TORCH_INTERNAL_ASSERT(activity || !kKinetoAvailable);
if (activity) {
@ -842,7 +842,7 @@ void passEventsToKineto(
}
// Kineto adds the events that it collected.
cpu_trace.transferCpuTrace(static_cast<int64_t>(end_time_us));
cpu_trace.transferCpuTrace(static_cast<int64_t>(end_time_ns));
}
#ifdef USE_KINETO
@ -1098,11 +1098,11 @@ class TransferEvents {
trace_ptr_t addKinetoEvents(
std::vector<std::shared_ptr<Result>>& results,
uint64_t start_time_us,
uint64_t end_time_us,
uint64_t start_time_ns,
uint64_t end_time_ns,
const ProfilerConfig& config) {
using namespace torch::profiler::impl::kineto;
passEventsToKineto(results, start_time_us, end_time_us, config);
passEventsToKineto(results, start_time_ns, end_time_ns, config);
// In on demand mode kineto is directly controlled by other machinery.
if (config.global()) {
@ -1353,8 +1353,8 @@ std::pair<
std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>>
RecordQueue::getRecords(
std::function<c10::time_t(c10::approx_time_t)> time_converter,
uint64_t start_time_us,
uint64_t end_time_us) {
uint64_t start_time_ns,
uint64_t end_time_ns) {
auto converter = [&](c10::approx_time_t t) {
return t == std::numeric_limits<c10::approx_time_t>::min()
? std::numeric_limits<c10::time_t>::min()
@ -1405,9 +1405,7 @@ RecordQueue::getRecords(
if (python_tracer_) {
for (const auto& i : python_tracer_->getEvents(
converter,
python_enters,
static_cast<c10::time_t>(end_time_us * 1000))) {
converter, python_enters, static_cast<c10::time_t>(end_time_ns))) {
out.push_back(i);
}
python_tracer_.reset();
@ -1427,7 +1425,7 @@ RecordQueue::getRecords(
}
}
auto trace = addKinetoEvents(out, start_time_us, end_time_us, config_);
auto trace = addKinetoEvents(out, start_time_ns, end_time_ns, config_);
std::stable_sort(out.begin(), out.end(), [](const auto& a, const auto& b) {
return a->start_time_ns_ < b->start_time_ns_;

View File

@ -340,7 +340,7 @@ struct ExtraFields<EventType::Kineto> {
};
std::string name_;
int64_t duration_us_{0};
int64_t duration_ns_{0};
uint64_t correlation_id_{0};
libkineto::ActivityType activity_type_;
Flow flow;
@ -632,8 +632,8 @@ class TORCH_API RecordQueue {
std::unique_ptr<torch::profiler::impl::kineto::ActivityTraceWrapper>>
getRecords(
std::function<c10::time_t(c10::approx_time_t)> time_converter,
uint64_t start_time_us,
uint64_t end_time_us);
uint64_t start_time_ns,
uint64_t end_time_ns);
private:
uint32_t id_;

View File

@ -147,15 +147,15 @@ class BasicEvaluation:
cuda_launch_events = sorted(
(e for e in cuda_event_list if is_cuda_launch_kernel(e)),
key=lambda x: x.start_us(),
key=lambda x: x.start_ns(),
)
cuda_kernel_events = sorted(
(e for e in cuda_event_list if is_cuda_kernel(e)),
key=lambda x: x.start_us(),
key=lambda x: x.start_ns(),
)
self.cuda_events = sorted(
cuda_launch_events + cuda_kernel_events, key=lambda x: x.start_us()
cuda_launch_events + cuda_kernel_events, key=lambda x: x.start_ns()
)
kernel_mapping: Dict[_KinetoEvent, int] = {}
@ -178,6 +178,8 @@ class BasicEvaluation:
def new_old_event_comparator(event):
if hasattr(event, "start_us"):
return event.start_us() * 1000
if hasattr(event, "start_ns"):
return event.start_ns()
if hasattr(event, "start_time_ns"):
return event.start_time_ns
raise Exception("Unknown Event Type")
@ -192,20 +194,26 @@ class BasicEvaluation:
# Find current spawned cuda kernel event
if event in kernel_mapping and kernel_mapping[event] is not None:
spawned_kernel_index = kernel_mapping[event]
if hasattr(event, "start_ns"):
start_time = event.start_ns()
end_time = event.start_ns() + event.duration_ns()
# Find current spawned cuda kernel event
if event in kernel_mapping and kernel_mapping[event] is not None:
spawned_kernel_index = kernel_mapping[event]
elif hasattr(event, "start_time_ns"):
start_time = event.start_time_ns # type: ignore[attr-defined]
end_time = event.end_time_ns # type: ignore[attr-defined]
while (
current_kernel_index < len(cuda_kernel_events)
and (cuda_kernel_events[current_kernel_index].start_us()) * 1000
and (cuda_kernel_events[current_kernel_index].start_ns())
<= start_time # type: ignore[possibly-undefined]
):
current_kernel_index += 1
current_queue_depth = spawned_kernel_index - current_kernel_index + 1
current_queue_depth = max(current_queue_depth, 0)
if hasattr(event, "start_us"):
if hasattr(event, "start_us") or hasattr(event, "start_ns"):
queue_depth_list.append(
Interval(start_time, end_time, current_queue_depth) # type: ignore[possibly-undefined]
)