Files
pytorch/torch/csrc/distributed/c10d/reducer_timer.hpp

82 lines
2.3 KiB
C++

#pragma once
#include <c10/util/ApproximateClock.h>
#include <torch/csrc/autograd/profiler.h>
namespace c10d {
constexpr int kUnsetTime = -1;
inline int64_t current_time_in_nanos() {
return c10::getTime();
}
class TORCH_API Timer {
private:
// The timestamp of forward call start time in each iteration.
int64_t forward_start_time = kUnsetTime;
// The timestamp of backward computation start and end time in each
// iteration.
int64_t backward_compute_start_time = kUnsetTime;
int64_t backward_compute_end_time = kUnsetTime;
// The timestamp of first communication call start time in each iteration.
int64_t backward_comm_start_time = kUnsetTime;
// The timestamp of last communication call end time in each iteration.
int64_t backward_comm_end_time = kUnsetTime;
public:
enum class Event : uint8_t {
kForwardStart,
kBackwardComputeStart,
kBackwardComputeEnd,
kBackwardCommStart,
kBackwardCommEnd,
};
// Record the current event, i.e., mark it as having occurred now. Default
// CPU implementation.
virtual void record(Event event) {
getTimeRef(event) = current_time_in_nanos();
}
// Return the difference between when two events occurred, in nanoseconds.
// Or nullopt if one of them hasn't been recorded.
virtual std::optional<int64_t> measureDifference(Event start, Event end) = 0;
virtual ~Timer() = default;
// Return host-side timestamp, or nullopt if it has not yet been recorded.
std::optional<int64_t> getTimestamp(Event event) {
auto time = getTimeRef(event);
if (time == kUnsetTime) {
return std::nullopt;
} else {
return time;
}
}
// Return host-side time member variable corresponding to the given event.
int64_t& getTimeRef(Event event) {
switch (event) {
case Event::kForwardStart:
return forward_start_time;
case Event::kBackwardComputeStart:
return backward_compute_start_time;
case Event::kBackwardComputeEnd:
return backward_compute_end_time;
case Event::kBackwardCommStart:
return backward_comm_start_time;
case Event::kBackwardCommEnd:
return backward_comm_end_time;
default:
TORCH_INTERNAL_ASSERT(false);
}
}
};
TORCH_DECLARE_TYPED_REGISTRY(
TimerRegistry,
c10::DeviceType,
Timer,
std::unique_ptr,
c10::Device);
} // namespace c10d