mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: still influx. Pull Request resolved: https://github.com/pytorch/pytorch/pull/12144 Reviewed By: smessmer Differential Revision: D10140176 Pulled By: Yangqing fbshipit-source-id: 1a313abed022039333e3925d19f8b3ef2d95306c
419 lines
12 KiB
C++
419 lines
12 KiB
C++
/**
|
|
* Copyright (c) 2016-present, Facebook, Inc.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "caffe2/core/net_async_tracing.h"
|
|
|
|
#include "caffe2/utils/proto_utils.h"
|
|
#include "caffe2/utils/string_utils.h"
|
|
|
|
C10_DEFINE_string(
|
|
caffe2_net_async_tracing_filepath,
|
|
"/tmp",
|
|
"Path to save tracing information");
|
|
|
|
C10_DEFINE_string(
|
|
caffe2_net_async_names_to_trace,
|
|
"",
|
|
"Comma-separated list of net names to trace");
|
|
|
|
C10_DEFINE_int(caffe2_net_async_tracing_nth, 100, "Trace every Nth batch");
|
|
|
|
// For every Nth iterations, we will dump the tracing results to a json file
|
|
// The file is appended with the iteration number.
|
|
C10_DEFINE_int(
|
|
caffe2_net_async_tracing_dumping_nth,
|
|
10000,
|
|
"Dump profiling result file every Nth batch");
|
|
|
|
namespace caffe2 {
|
|
namespace tracing {
|
|
|
|
int getCounterForNetName(const std::string& net_name) {
|
|
// Append a unique number suffix because there could be multiple instances
|
|
// of the same net and we want to uniquely associate each instance with
|
|
// a profiling trace.
|
|
static std::unordered_map<std::string, int> net_name_to_counter;
|
|
static std::mutex map_mutex;
|
|
std::unique_lock<std::mutex> map_lock(map_mutex);
|
|
int counter = net_name_to_counter[net_name] + 1;
|
|
net_name_to_counter[net_name] = counter;
|
|
return counter;
|
|
}
|
|
|
|
Tracer::Tracer(const NetBase* net, const std::string& net_name)
|
|
: net_(net), filename_(net_name), iter_(0) {
|
|
std::replace(filename_.begin(), filename_.end(), '/', '_');
|
|
filename_ = c10::FLAGS_caffe2_net_async_tracing_filepath + "/" + filename_ +
|
|
+"_id_" + caffe2::to_string(getCounterForNetName(net_name));
|
|
timer_.Start();
|
|
}
|
|
|
|
void Tracer::recordEvent(const TracerEvent& event) {
|
|
std::lock_guard<std::mutex> lock(tracer_mutex_);
|
|
events_.push_back(event);
|
|
}
|
|
|
|
// Forward
|
|
int getUniqueShardId(const OperatorDef& op_def);
|
|
|
|
// Special handling of shard blob annotations
|
|
std::string Tracer::opTraceName(const OperatorBase* op) {
|
|
int unique_shard_id =
|
|
op->has_debug_def() ? getUniqueShardId(op->debug_def()) : -1;
|
|
if (unique_shard_id != -1) {
|
|
return op->type() + ":" + caffe2::to_string(unique_shard_id);
|
|
} else {
|
|
return op->type();
|
|
}
|
|
}
|
|
|
|
std::string Tracer::opBlobsInfo(const OperatorBase& op) {
|
|
std::string blobs_info;
|
|
if (op.has_debug_def()) {
|
|
blobs_info += "I: ";
|
|
const auto& op_def = op.debug_def();
|
|
for (const auto& input : op_def.input()) {
|
|
blobs_info += input + "; ";
|
|
}
|
|
blobs_info += "O: ";
|
|
for (const auto& output : op_def.output()) {
|
|
blobs_info += output + "; ";
|
|
}
|
|
}
|
|
return blobs_info;
|
|
}
|
|
|
|
std::string Tracer::serializeEvent(const TracerEvent& event) {
|
|
std::stringstream serialized_event;
|
|
serialized_event << std::fixed;
|
|
serialized_event << "{\n";
|
|
serialized_event << " \"ts\": " << event.timestamp_ << ",\n";
|
|
serialized_event << " \"pid\": 0,\n"; // not using pid field
|
|
if (event.thread_label_ >= 0) {
|
|
serialized_event << " \"tid\": " << event.thread_label_ << ",\n";
|
|
} else {
|
|
serialized_event << " \"tid\": " << event.tid_ << ",\n";
|
|
}
|
|
|
|
if (event.is_beginning_) {
|
|
std::unordered_map<std::string, int> int_args;
|
|
std::unordered_map<std::string, std::string> string_args;
|
|
if (event.name_) {
|
|
serialized_event << " \"name\": \"" << event.name_ << "\",\n";
|
|
} else if (event.op_id_ >= 0) {
|
|
auto* op = net_->GetOperators().at(event.op_id_);
|
|
serialized_event << " \"name\": \"" << opTraceName(op) << "\",\n";
|
|
} else {
|
|
serialized_event << " \"name\": \"n/a\",\n";
|
|
}
|
|
|
|
if (event.category_) {
|
|
serialized_event << " \"cat\": \"" << event.category_ << "\",\n";
|
|
} else {
|
|
serialized_event << " \"cat\": \"net\",\n";
|
|
}
|
|
|
|
if (event.op_id_ >= 0) {
|
|
auto* op = net_->GetOperators().at(event.op_id_);
|
|
int_args["op_id"] = event.op_id_;
|
|
int_args["device_type"] = op->device_option().device_type();
|
|
int_args["device_id"] = DeviceId(op->device_option());
|
|
string_args["blobs"] = opBlobsInfo(*op);
|
|
}
|
|
|
|
if (event.task_id_ >= 0) {
|
|
int_args["task_id"] = event.task_id_;
|
|
}
|
|
|
|
if (event.stream_id_ >= 0) {
|
|
int_args["stream_id"] = event.stream_id_;
|
|
}
|
|
|
|
serialized_event << " \"ph\": \"B\"";
|
|
if (!int_args.empty() || !string_args.empty()) {
|
|
serialized_event << ",\n \"args\": {\n";
|
|
auto left_to_output = int_args.size() + string_args.size();
|
|
for (const auto& kv : int_args) {
|
|
serialized_event << " \"" << kv.first << "\": " << kv.second;
|
|
--left_to_output;
|
|
if (left_to_output > 0) {
|
|
serialized_event << ",\n";
|
|
}
|
|
}
|
|
for (const auto& kv : string_args) {
|
|
serialized_event << " \"" << kv.first << "\": \"" << kv.second << "\"";
|
|
--left_to_output;
|
|
if (left_to_output > 0) {
|
|
serialized_event << ",\n";
|
|
}
|
|
}
|
|
serialized_event << "\n }";
|
|
}
|
|
} else {
|
|
serialized_event << " \"ph\": \"E\"\n";
|
|
}
|
|
serialized_event << "\n}";
|
|
|
|
return serialized_event.str();
|
|
}
|
|
|
|
// fix occasional cases with zero duration events
|
|
void Tracer::linearizeEvents() {
|
|
std::unordered_map<long, long> time_offsets;
|
|
std::unordered_map<long, long> last_times;
|
|
std::hash<std::thread::id> hasher;
|
|
const long time_eps = 1; // us
|
|
for (auto& event : events_) {
|
|
long tid =
|
|
(event.thread_label_ >= 0) ? event.thread_label_ : hasher(event.tid_);
|
|
auto event_ts = event.timestamp_;
|
|
if (last_times.count(tid)) {
|
|
event_ts += time_offsets[tid];
|
|
CAFFE_ENFORCE(event_ts >= last_times[tid]);
|
|
if (event_ts <= last_times[tid] + time_eps) {
|
|
event_ts += time_eps;
|
|
time_offsets[tid] += time_eps;
|
|
} else if (event_ts > last_times[tid] + 2 * time_eps) {
|
|
long eps_len = (event_ts - last_times[tid]) / time_eps;
|
|
if (time_offsets[tid] >= time_eps * (eps_len - 1)) {
|
|
time_offsets[tid] -= time_eps * (eps_len - 1);
|
|
event_ts -= time_eps * (eps_len - 1);
|
|
} else {
|
|
event_ts -= time_offsets[tid];
|
|
time_offsets[tid] = 0;
|
|
}
|
|
}
|
|
event.timestamp_ = event_ts;
|
|
last_times[tid] = event_ts;
|
|
} else {
|
|
last_times[tid] = event_ts;
|
|
time_offsets[tid] = 0;
|
|
}
|
|
}
|
|
}
|
|
|
|
void Tracer::renameThreads() {
|
|
std::unordered_map<long, int> tids;
|
|
std::unordered_map<int, int> numa_counters;
|
|
std::unordered_map<long, int> tid_to_numa;
|
|
std::hash<std::thread::id> hasher;
|
|
const long numa_multiplier = 1000000000;
|
|
for (auto& event : events_) {
|
|
if (event.thread_label_ >= 0 || event.op_id_ < 0) {
|
|
continue;
|
|
}
|
|
auto* op = net_->GetOperators().at(event.op_id_);
|
|
if (!op->device_option().has_numa_node_id()) {
|
|
continue;
|
|
}
|
|
int numa_node_id = op->device_option().numa_node_id();
|
|
CAFFE_ENFORCE_GE(numa_node_id, 0, "Invalid NUMA node id: ", numa_node_id);
|
|
long tid = hasher(event.tid_);
|
|
|
|
if (!tid_to_numa.count(tid)) {
|
|
tid_to_numa[tid] = numa_node_id;
|
|
} else {
|
|
CAFFE_ENFORCE_EQ(tid_to_numa[tid], numa_node_id);
|
|
}
|
|
|
|
if (!numa_counters.count(numa_node_id)) {
|
|
numa_counters[numa_node_id] = 1;
|
|
}
|
|
if (!tids.count(tid)) {
|
|
tids[tid] = numa_counters[numa_node_id]++;
|
|
}
|
|
event.thread_label_ = numa_multiplier * (numa_node_id + 1) + tids[tid];
|
|
}
|
|
}
|
|
|
|
void Tracer::setEnabled(bool enabled) {
|
|
enabled_ = enabled;
|
|
}
|
|
|
|
bool Tracer::isEnabled() const {
|
|
return enabled_;
|
|
}
|
|
|
|
int Tracer::bumpIter() {
|
|
return iter_++;
|
|
}
|
|
|
|
void Tracer::dumpTracingResultAndClearEvents(const std::string& file_suffix) {
|
|
if (events_.empty() || filename_.empty()) {
|
|
return;
|
|
}
|
|
linearizeEvents();
|
|
renameThreads();
|
|
std::stringstream serialized;
|
|
serialized << "[\n";
|
|
for (size_t idx = 0; idx < events_.size(); ++idx) {
|
|
serialized << serializeEvent(events_[idx]);
|
|
if (idx != events_.size() - 1) {
|
|
serialized << ",\n";
|
|
}
|
|
}
|
|
serialized << "\n]\n";
|
|
|
|
auto output_file_name = filename_ + "_iter_" + file_suffix + ".json";
|
|
LOG(INFO) << "Dumping profiling result file to " << output_file_name;
|
|
WriteStringToFile(serialized.str(), output_file_name.c_str());
|
|
events_.clear();
|
|
}
|
|
|
|
Tracer::~Tracer() {
|
|
dumpTracingResultAndClearEvents("final_batch");
|
|
}
|
|
|
|
void TracerGuard::init(Tracer* tracer) {
|
|
enabled_ = true;
|
|
tracer_ = tracer;
|
|
}
|
|
|
|
void TracerGuard::addArgument() {}
|
|
|
|
void TracerGuard::addArgument(TracingField field, const char* value) {
|
|
switch (field) {
|
|
case TRACE_NAME: {
|
|
event_.name_ = value;
|
|
break;
|
|
}
|
|
case TRACE_CATEGORY: {
|
|
event_.category_ = value;
|
|
break;
|
|
}
|
|
default: { CAFFE_THROW("Unexpected tracing string field ", field); }
|
|
}
|
|
}
|
|
|
|
void TracerGuard::addArgument(TracingField field, int value) {
|
|
switch (field) {
|
|
case TRACE_OP: {
|
|
event_.op_id_ = value;
|
|
break;
|
|
}
|
|
case TRACE_TASK: {
|
|
event_.task_id_ = value;
|
|
break;
|
|
}
|
|
case TRACE_STREAM: {
|
|
event_.stream_id_ = value;
|
|
break;
|
|
}
|
|
case TRACE_THREAD: {
|
|
event_.thread_label_ = value;
|
|
break;
|
|
}
|
|
default: { CAFFE_THROW("Unexpected tracing int field ", field); }
|
|
}
|
|
}
|
|
|
|
void TracerGuard::recordEventStart() {
|
|
if (enabled_) {
|
|
if (event_.thread_label_ < 0) {
|
|
event_.tid_ = std::this_thread::get_id();
|
|
}
|
|
event_.is_beginning_ = true;
|
|
event_.timestamp_ = (long)caffe2::round(tracer_->timer_.MicroSeconds());
|
|
tracer_->recordEvent(event_);
|
|
}
|
|
}
|
|
|
|
TracerGuard::~TracerGuard() {
|
|
if (enabled_) {
|
|
event_.is_beginning_ = false;
|
|
event_.timestamp_ = (long)caffe2::round(tracer_->timer_.MicroSeconds());
|
|
tracer_->recordEvent(event_);
|
|
}
|
|
}
|
|
|
|
int extractShardId(const std::string& name) {
|
|
const std::string kShard = "shard:";
|
|
// We sometimes have multiple shards, but actually need the last one, hence
|
|
// using rfind here. Hacky but it works till we pass shard id in graph
|
|
// metadata.
|
|
auto pos = name.rfind(kShard);
|
|
if (pos != std::string::npos) {
|
|
int left_pos = pos + kShard.length();
|
|
int right_pos = left_pos;
|
|
while (right_pos < name.length() && isdigit(name[right_pos])) {
|
|
right_pos++;
|
|
}
|
|
return caffe2::stoi(name.substr(left_pos, right_pos - left_pos));
|
|
} else {
|
|
return -1;
|
|
}
|
|
}
|
|
|
|
// Return unique shard id, or -1 if it is not unique.
|
|
int getUniqueShardId(const OperatorDef& op_def) {
|
|
int unique_shard_id = -1;
|
|
for (const auto& names : {op_def.input(), op_def.output()}) {
|
|
for (const auto& name : names) {
|
|
int shard_id = extractShardId(name);
|
|
if (shard_id != -1) {
|
|
if (unique_shard_id != -1) {
|
|
return -1;
|
|
}
|
|
unique_shard_id = shard_id;
|
|
}
|
|
}
|
|
}
|
|
return unique_shard_id;
|
|
}
|
|
|
|
bool isTraceableNetName(const std::string& net_name) {
|
|
auto tracing_nets =
|
|
caffe2::split(',', c10::FLAGS_caffe2_net_async_names_to_trace);
|
|
return !net_name.empty() &&
|
|
std::find(tracing_nets.begin(), tracing_nets.end(), net_name) !=
|
|
tracing_nets.end();
|
|
}
|
|
|
|
bool hasEnableTracingFlag(const NetBase* net) {
|
|
if (!net->has_debug_def()) {
|
|
return false;
|
|
}
|
|
return GetFlagArgument(net->debug_def(), "enable_tracing", false);
|
|
}
|
|
|
|
std::shared_ptr<Tracer> create(
|
|
const NetBase* net,
|
|
const std::string& net_name) {
|
|
// Enable the tracer if the net has the "enable_tracing" argument set OR
|
|
// if the command line option includes the net name option in the list of
|
|
// tracable nets.
|
|
bool trace_net = hasEnableTracingFlag(net) || isTraceableNetName(net_name);
|
|
return trace_net ? std::make_shared<Tracer>(net, net_name) : nullptr;
|
|
}
|
|
|
|
bool startIter(const std::shared_ptr<Tracer>& tracer) {
|
|
if (!tracer) {
|
|
return false;
|
|
}
|
|
auto iter = tracer->bumpIter();
|
|
auto is_enabled = iter % c10::FLAGS_caffe2_net_async_tracing_nth == 0;
|
|
tracer->setEnabled(is_enabled);
|
|
if (iter % c10::FLAGS_caffe2_net_async_tracing_dumping_nth == 0) {
|
|
int dumping_iter = iter / c10::FLAGS_caffe2_net_async_tracing_dumping_nth;
|
|
tracer->dumpTracingResultAndClearEvents(caffe2::to_string(dumping_iter));
|
|
}
|
|
return is_enabled;
|
|
}
|
|
|
|
} // namespace tracing
|
|
|
|
} // namespace caffe2
|