Files
pytorch/torch/csrc/jit/passes/onnx/naming.cpp
Yuanyuan Chen 36871622f1 [2/N] Mark unused parameters in C++ code (#165121)
This is follow-up of #164912 to mark unused C++ parameters to improve code readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165121
Approved by: https://github.com/Skylion007
2025-10-15 03:04:39 +00:00

206 lines
5.9 KiB
C++

#include <torch/csrc/jit/passes/onnx/naming.h>
#include <torch/csrc/onnx/onnx.h>
#include <utility>
namespace torch::jit::onnx {
namespace ONNXScopeName {
using NameFunc = std::string (*)(const torch::jit::ScopePtr& scope);
const std::string name_separator = "::";
namespace {
std::string nameFromRoot(
const torch::jit::ScopePtr& scope,
const std::string& layer_separator,
NameFunc name_func) {
std::string out = (*name_func)(scope);
if (scope->isRoot()) {
return out;
}
auto parent = scope->parent();
while (isCompatibleScope(parent)) {
out = std::string((*name_func)(parent)).append(layer_separator).append(out);
parent = parent->parent();
}
return out;
}
std::pair<std::string, std::string> parseNameFromScope(
const torch::jit::ScopePtr& scope) {
std::string full_name = scope->name().toUnqualString();
auto pos = full_name.find(name_separator);
TORCH_CHECK(
pos != std::string::npos,
"Scope name (" + full_name + ") does not contain '" + name_separator +
"'");
return std::make_pair(full_name.substr(0, pos), full_name.substr(pos + 2));
}
} // namespace
std::string createFullScopeName(
const std::string& class_name,
const std::string& variable_name) {
return std::string(class_name).append(name_separator).append(variable_name);
}
std::string variableName(const torch::jit::ScopePtr& scope) {
return parseNameFromScope(scope).second;
}
std::string variableNameFromRoot(
const torch::jit::ScopePtr& scope,
const std::string& layer_separator) {
return nameFromRoot(scope, layer_separator, &variableName);
}
std::string className(const torch::jit::ScopePtr& scope) {
return parseNameFromScope(scope).first;
}
std::string classNameFromRoot(
const torch::jit::ScopePtr& scope,
const std::string& layer_separator) {
return nameFromRoot(scope, layer_separator, &className);
}
bool isCompatibleScope(const torch::jit::ScopePtr& scope) {
return !scope->isRoot() && !scope->isBlank() &&
(std::string(scope->name().toUnqualString()).find(name_separator) !=
std::string::npos);
}
} // namespace ONNXScopeName
namespace {
class NodeNameGenerator {
public:
NodeNameGenerator(std::shared_ptr<Graph> g) : graph_(std::move(g)) {}
virtual ~NodeNameGenerator() = 0;
void PopulateNodeNames();
protected:
virtual void CreateNodeName(Node* n) = 0;
void PopulateNodeNames(Block* /*b*/);
void UpdateOutputsNames(Node* n);
bool IsGraphOutput(const Value* v, const std::shared_ptr<Graph>& graph) const;
protected:
std::string CreateUniqueName(
std::unordered_map<std::string, size_t>& base_name_count,
std::string base_name);
std::unordered_map<const Node*, std::string> node_names_;
std::unordered_map<std::string, size_t> base_node_name_counts_;
std::shared_ptr<Graph> graph_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::string layer_separator_ = "/";
};
NodeNameGenerator::~NodeNameGenerator() = default;
class ScopedNodeNameGenerator : public NodeNameGenerator {
public:
ScopedNodeNameGenerator(std::shared_ptr<Graph> g)
: NodeNameGenerator(std::move(g)) {}
protected:
void CreateNodeName(Node* n) override;
private:
std::string GetFullScopeName(const ScopePtr& scope);
std::unordered_map<ScopePtr, std::string> full_scope_names_;
std::unordered_map<std::string, size_t> base_scope_name_counts_;
};
std::string NodeNameGenerator::CreateUniqueName(
std::unordered_map<std::string, size_t>& base_name_count,
std::string base_name) {
if (base_name_count.find(base_name) == base_name_count.end()) {
base_name_count[base_name] = 0;
} else {
auto count = ++base_name_count[base_name];
base_name += "_";
base_name += std::to_string(count);
}
return base_name;
}
bool NodeNameGenerator::IsGraphOutput(
const Value* v,
const std::shared_ptr<Graph>& graph) const {
for (const auto* graph_output : graph->outputs()) {
if (v == graph_output) {
return true;
}
}
return false;
}
void NodeNameGenerator::UpdateOutputsNames(Node* n) {
if (node_names_.find(n) != node_names_.end()) {
auto node_name = node_names_[n];
for (auto i : c10::irange(n->outputs().size())) {
auto output = n->output(i);
if (!IsGraphOutput(output, graph_)) {
auto output_name = node_name;
output_name.append("_output_").append(std::to_string(i));
output->setDebugName(output_name);
}
}
}
}
void NodeNameGenerator::PopulateNodeNames() {
PopulateNodeNames(graph_->block());
}
void NodeNameGenerator::PopulateNodeNames(Block* b) {
for (auto* n : b->nodes()) {
for (auto* sub_block : n->blocks()) {
PopulateNodeNames(sub_block);
}
CreateNodeName(n);
UpdateOutputsNames(n);
}
}
void ScopedNodeNameGenerator::CreateNodeName(Node* n) {
if (node_names_.find(n) == node_names_.end()) {
if (!ONNXScopeName::isCompatibleScope(n->scope())) {
return;
}
if (n->mustBeNone()) {
// JIT IR does not allow attribute for None node.
return;
}
auto name = GetFullScopeName(n->scope());
name += layer_separator_;
name += n->kind().toUnqualString();
node_names_[n] = CreateUniqueName(base_node_name_counts_, name);
}
n->s_(Symbol::attr(::torch::onnx::kOnnxNodeNameAttribute), node_names_[n]);
}
std::string ScopedNodeNameGenerator::GetFullScopeName(const ScopePtr& scope) {
if (full_scope_names_.find(scope) == full_scope_names_.end()) {
auto full_scope_name =
ONNXScopeName::variableNameFromRoot(scope, layer_separator_);
full_scope_names_[scope] =
CreateUniqueName(base_scope_name_counts_, full_scope_name);
}
return full_scope_names_[scope];
}
} // namespace
void AssignScopedNamesForNodeAndValue(std::shared_ptr<Graph>& graph) {
auto node_name_generator = std::make_unique<ScopedNodeNameGenerator>(graph);
node_name_generator->PopulateNodeNames();
}
} // namespace torch::jit::onnx