mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/138974 Approved by: https://github.com/ezyang
206 lines
5.8 KiB
C++
206 lines
5.8 KiB
C++
#include <torch/csrc/jit/passes/onnx/naming.h>
|
|
#include <torch/csrc/onnx/onnx.h>
|
|
|
|
#include <utility>
|
|
|
|
namespace torch::jit::onnx {
|
|
|
|
namespace ONNXScopeName {
|
|
|
|
using NameFunc = std::string (*)(const torch::jit::ScopePtr& scope);
|
|
|
|
const std::string name_separator = "::";
|
|
|
|
namespace {
|
|
|
|
std::string nameFromRoot(
|
|
const torch::jit::ScopePtr& scope,
|
|
const std::string& layer_separator,
|
|
NameFunc name_func) {
|
|
std::string out = (*name_func)(scope);
|
|
if (scope->isRoot()) {
|
|
return out;
|
|
}
|
|
auto parent = scope->parent();
|
|
while (isCompatibleScope(parent)) {
|
|
out = std::string((*name_func)(parent)).append(layer_separator).append(out);
|
|
parent = parent->parent();
|
|
}
|
|
return out;
|
|
}
|
|
|
|
std::pair<std::string, std::string> parseNameFromScope(
|
|
const torch::jit::ScopePtr& scope) {
|
|
std::string full_name = scope->name().toUnqualString();
|
|
auto pos = full_name.find(name_separator);
|
|
TORCH_CHECK(
|
|
pos != std::string::npos,
|
|
"Scope name (" + full_name + ") does not contain '" + name_separator +
|
|
"'");
|
|
return std::make_pair(full_name.substr(0, pos), full_name.substr(pos + 2));
|
|
}
|
|
|
|
} // namespace
|
|
|
|
std::string createFullScopeName(
|
|
const std::string& class_name,
|
|
const std::string& variable_name) {
|
|
return std::string(class_name).append(name_separator).append(variable_name);
|
|
}
|
|
|
|
std::string variableName(const torch::jit::ScopePtr& scope) {
|
|
return parseNameFromScope(scope).second;
|
|
}
|
|
|
|
std::string variableNameFromRoot(
|
|
const torch::jit::ScopePtr& scope,
|
|
const std::string& layer_separator) {
|
|
return nameFromRoot(scope, layer_separator, &variableName);
|
|
}
|
|
|
|
std::string className(const torch::jit::ScopePtr& scope) {
|
|
return parseNameFromScope(scope).first;
|
|
}
|
|
|
|
std::string classNameFromRoot(
|
|
const torch::jit::ScopePtr& scope,
|
|
const std::string& layer_separator) {
|
|
return nameFromRoot(scope, layer_separator, &className);
|
|
}
|
|
|
|
bool isCompatibleScope(const torch::jit::ScopePtr& scope) {
|
|
return !scope->isRoot() && !scope->isBlank() &&
|
|
(std::string(scope->name().toUnqualString()).find(name_separator) !=
|
|
std::string::npos);
|
|
}
|
|
} // namespace ONNXScopeName
|
|
|
|
namespace {
|
|
|
|
class NodeNameGenerator {
|
|
public:
|
|
NodeNameGenerator(std::shared_ptr<Graph> g) : graph_(std::move(g)){};
|
|
virtual ~NodeNameGenerator() = 0;
|
|
void PopulateNodeNames();
|
|
|
|
protected:
|
|
virtual void CreateNodeName(Node* n) = 0;
|
|
void PopulateNodeNames(Block*);
|
|
void UpdateOutputsNames(Node* n);
|
|
bool IsGraphOutput(const Value* v, const std::shared_ptr<Graph>& graph) const;
|
|
|
|
protected:
|
|
std::string CreateUniqueName(
|
|
std::unordered_map<std::string, size_t>& base_name_count,
|
|
std::string base_name);
|
|
|
|
std::unordered_map<const Node*, std::string> node_names_;
|
|
std::unordered_map<std::string, size_t> base_node_name_counts_;
|
|
std::shared_ptr<Graph> graph_;
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
|
const std::string layer_separator_ = "/";
|
|
};
|
|
NodeNameGenerator::~NodeNameGenerator() = default;
|
|
|
|
class ScopedNodeNameGenerator : public NodeNameGenerator {
|
|
public:
|
|
ScopedNodeNameGenerator(std::shared_ptr<Graph> g)
|
|
: NodeNameGenerator(std::move(g)){};
|
|
|
|
protected:
|
|
void CreateNodeName(Node* n) override;
|
|
|
|
private:
|
|
std::string GetFullScopeName(const ScopePtr& scope);
|
|
std::unordered_map<ScopePtr, std::string> full_scope_names_;
|
|
std::unordered_map<std::string, size_t> base_scope_name_counts_;
|
|
};
|
|
|
|
std::string NodeNameGenerator::CreateUniqueName(
|
|
std::unordered_map<std::string, size_t>& base_name_count,
|
|
std::string base_name) {
|
|
if (base_name_count.find(base_name) == base_name_count.end()) {
|
|
base_name_count[base_name] = 0;
|
|
} else {
|
|
auto count = ++base_name_count[base_name];
|
|
base_name += "_";
|
|
base_name += std::to_string(count);
|
|
}
|
|
return base_name;
|
|
}
|
|
|
|
bool NodeNameGenerator::IsGraphOutput(
|
|
const Value* v,
|
|
const std::shared_ptr<Graph>& graph) const {
|
|
for (const auto* graph_output : graph->outputs()) {
|
|
if (v == graph_output) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
void NodeNameGenerator::UpdateOutputsNames(Node* n) {
|
|
if (node_names_.find(n) != node_names_.end()) {
|
|
auto node_name = node_names_[n];
|
|
for (auto i : c10::irange(n->outputs().size())) {
|
|
auto output = n->output(i);
|
|
if (!IsGraphOutput(output, graph_)) {
|
|
auto output_name = node_name;
|
|
output_name.append("_output_").append(std::to_string(i));
|
|
output->setDebugName(output_name);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void NodeNameGenerator::PopulateNodeNames() {
|
|
PopulateNodeNames(graph_->block());
|
|
}
|
|
|
|
void NodeNameGenerator::PopulateNodeNames(Block* b) {
|
|
for (auto* n : b->nodes()) {
|
|
for (auto* sub_block : n->blocks()) {
|
|
PopulateNodeNames(sub_block);
|
|
}
|
|
CreateNodeName(n);
|
|
UpdateOutputsNames(n);
|
|
}
|
|
}
|
|
|
|
void ScopedNodeNameGenerator::CreateNodeName(Node* n) {
|
|
if (node_names_.find(n) == node_names_.end()) {
|
|
if (!ONNXScopeName::isCompatibleScope(n->scope())) {
|
|
return;
|
|
}
|
|
if (n->mustBeNone()) {
|
|
// JIT IR does not allow attribute for None node.
|
|
return;
|
|
}
|
|
auto name = GetFullScopeName(n->scope());
|
|
name += layer_separator_;
|
|
name += n->kind().toUnqualString();
|
|
node_names_[n] = CreateUniqueName(base_node_name_counts_, name);
|
|
}
|
|
n->s_(Symbol::attr(::torch::onnx::kOnnxNodeNameAttribute), node_names_[n]);
|
|
}
|
|
|
|
std::string ScopedNodeNameGenerator::GetFullScopeName(const ScopePtr& scope) {
|
|
if (full_scope_names_.find(scope) == full_scope_names_.end()) {
|
|
auto full_scope_name =
|
|
ONNXScopeName::variableNameFromRoot(scope, layer_separator_);
|
|
full_scope_names_[scope] =
|
|
CreateUniqueName(base_scope_name_counts_, full_scope_name);
|
|
}
|
|
return full_scope_names_[scope];
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void AssignScopedNamesForNodeAndValue(std::shared_ptr<Graph>& graph) {
|
|
auto node_name_generator = std::make_unique<ScopedNodeNameGenerator>(graph);
|
|
node_name_generator->PopulateNodeNames();
|
|
}
|
|
|
|
} // namespace torch::jit::onnx
|