mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/128301 Approved by: https://github.com/ezyang, https://github.com/r-barnes
82 lines
2.3 KiB
C++
82 lines
2.3 KiB
C++
#pragma once
|
|
#include <c10/util/ApproximateClock.h>
|
|
#include <torch/csrc/autograd/profiler.h>
|
|
|
|
namespace c10d {
|
|
constexpr int kUnsetTime = -1;
|
|
|
|
inline int64_t current_time_in_nanos() {
|
|
return c10::getTime();
|
|
}
|
|
|
|
class TORCH_API Timer {
|
|
private:
|
|
// The timestamp of forward call start time in each iteration.
|
|
int64_t forward_start_time = kUnsetTime;
|
|
// The timestamp of backward computation start and end time in each
|
|
// iteration.
|
|
int64_t backward_compute_start_time = kUnsetTime;
|
|
int64_t backward_compute_end_time = kUnsetTime;
|
|
// The timestamp of first communication call start time in each iteration.
|
|
int64_t backward_comm_start_time = kUnsetTime;
|
|
// The timestamp of last communication call end time in each iteration.
|
|
int64_t backward_comm_end_time = kUnsetTime;
|
|
|
|
public:
|
|
enum class Event : uint8_t {
|
|
kForwardStart,
|
|
kBackwardComputeStart,
|
|
kBackwardComputeEnd,
|
|
kBackwardCommStart,
|
|
kBackwardCommEnd,
|
|
};
|
|
|
|
// Record the current event, i.e., mark it as having occurred now. Default
|
|
// CPU implementation.
|
|
virtual void record(Event event) {
|
|
getTimeRef(event) = current_time_in_nanos();
|
|
}
|
|
|
|
// Return the difference between when two events occurred, in nanoseconds.
|
|
// Or nullopt if one of them hasn't been recorded.
|
|
virtual std::optional<int64_t> measureDifference(Event start, Event end) = 0;
|
|
|
|
virtual ~Timer() = default;
|
|
|
|
// Return host-side timestamp, or nullopt if it has not yet been recorded.
|
|
std::optional<int64_t> getTimestamp(Event event) {
|
|
auto time = getTimeRef(event);
|
|
if (time == kUnsetTime) {
|
|
return std::nullopt;
|
|
} else {
|
|
return time;
|
|
}
|
|
}
|
|
|
|
// Return host-side time member variable corresponding to the given event.
|
|
int64_t& getTimeRef(Event event) {
|
|
switch (event) {
|
|
case Event::kForwardStart:
|
|
return forward_start_time;
|
|
case Event::kBackwardComputeStart:
|
|
return backward_compute_start_time;
|
|
case Event::kBackwardComputeEnd:
|
|
return backward_compute_end_time;
|
|
case Event::kBackwardCommStart:
|
|
return backward_comm_start_time;
|
|
case Event::kBackwardCommEnd:
|
|
return backward_comm_end_time;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false);
|
|
}
|
|
}
|
|
};
|
|
|
|
TORCH_DECLARE_TYPED_REGISTRY(
|
|
TimerRegistry,
|
|
c10::DeviceType,
|
|
Timer,
|
|
std::unique_ptr,
|
|
c10::Device);
|
|
} // namespace c10d
|