mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This is follow-up of #164912 to mark unused C++ parameters to improve code readability. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165121 Approved by: https://github.com/Skylion007
1171 lines
37 KiB
C++
1171 lines
37 KiB
C++
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/passes/onnx/function_extraction.h>
|
|
#include <torch/csrc/jit/passes/onnx/naming.h>
|
|
|
|
namespace torch::jit::onnx {
|
|
|
|
namespace {
|
|
|
|
using scope_list = std::vector<ScopePtr>;
|
|
|
|
// Annotated attributes retrieved from module by inspecting module annotations.
|
|
// These attributes are not used inside the subgraph of ONNX local function
|
|
// because they are not created by PyTorch JIT tracing, but they may be used by
|
|
// consumers to determine whether or not to replace the function with a
|
|
// particular fused kernel.
|
|
static std::unordered_map<ScopePtr, Node*> scope_attr_map_;
|
|
static std::shared_ptr<Graph> scope_attr_graph_ = std::make_shared<Graph>();
|
|
|
|
static bool HasSameAttribute(
|
|
const Node* a,
|
|
const Node* b,
|
|
const c10::Symbol& attr);
|
|
|
|
struct FunctionExtractor {
|
|
public:
|
|
FunctionExtractor(
|
|
std::shared_ptr<Graph>& graph,
|
|
const std::unordered_set<std::string>& module_names,
|
|
const std::vector<std::string>& param_names)
|
|
: graph_(graph),
|
|
module_names_(module_names.begin(), module_names.end()),
|
|
param_names_(param_names.begin(), param_names.end()) {}
|
|
NodeAttrNameMap run();
|
|
|
|
private:
|
|
struct ScopeContext {
|
|
std::unordered_set<ScopePtr> children_;
|
|
ScopePtr scope_;
|
|
node_list nlist_;
|
|
value_list inputs_;
|
|
value_list outputs_;
|
|
std::unordered_map<Value*, Value*> env_to_subgraph_;
|
|
|
|
void PopulateInputsOutputs(
|
|
const std::unordered_set<std::string>& param_names);
|
|
bool IsIdenticalFuncion(const ScopeContext& other_ctx) const;
|
|
};
|
|
|
|
using ScopeCtxPtr = ScopeContext*;
|
|
using scope_ctx_map = std::unordered_map<ScopePtr, ScopeCtxPtr>;
|
|
|
|
struct FunctionContext {
|
|
FunctionContext(
|
|
ScopePtr key,
|
|
const scope_list& scopes,
|
|
scope_ctx_map& scope_ctxs);
|
|
void DebugPrint() const;
|
|
void SetAttrName(Node* ref_n, Symbol attr, const std::string& name);
|
|
std::optional<std::string> FindAttrName(Node* ref_n, Symbol attr);
|
|
std::optional<std::string> FindAttrName(Node* ref_const_n);
|
|
|
|
ScopePtr scope_key_;
|
|
scope_ctx_map scope_ctxs_;
|
|
std::unordered_map<
|
|
Node*,
|
|
std::unordered_map<Symbol, std::unordered_set<Node*>>>
|
|
attribute_map_;
|
|
|
|
// Passed later to serialization.
|
|
NodeAttrNameMap node_attr_to_name_;
|
|
};
|
|
|
|
using FunctionCtxPtr = FunctionContext*;
|
|
using func_ctx_map = std::unordered_map<ScopePtr, FunctionCtxPtr>;
|
|
|
|
static bool IsValidScope(const ScopePtr& s);
|
|
static std::optional<ScopePtr> InferScope(Node* n);
|
|
static bool IsAncestor(const ScopePtr& parent, ScopePtr child);
|
|
static std::optional<ScopePtr> FindCommonAncestor(ScopePtr a, ScopePtr b);
|
|
static std::optional<ScopePtr> FindCommonAncestor(const scope_list& scopes);
|
|
std::shared_ptr<Graph> ConstructFuncGraph(FunctionContext& ctx);
|
|
|
|
void ConvertScopeToFunction(
|
|
const ScopePtr& scope_key,
|
|
const scope_list& scope_list,
|
|
scope_ctx_map& scope_ctxs,
|
|
const std::shared_ptr<Graph>& graph);
|
|
|
|
static void HandleNoScopeNodes(
|
|
scope_ctx_map& /*scope_ctxs*/,
|
|
const node_list& no_scope_nlist);
|
|
std::tuple<scope_ctx_map, node_list> PartitionNodesByScope(Block* b);
|
|
scope_ctx_map PartitionNodesByScope(const std::shared_ptr<Graph>& graph);
|
|
static std::unordered_map<ScopePtr, scope_list> PartitionIdenticalScopes(
|
|
scope_ctx_map& scope_ctxs);
|
|
static scope_list SortScopesByMaxDepth(
|
|
std::unordered_map<ScopePtr, scope_list>& /*identical_scope_map*/);
|
|
Node* CreateFunctionDefNode(
|
|
FunctionContext& func_ctx,
|
|
const std::shared_ptr<Graph>& graph,
|
|
const std::string& domain_name,
|
|
const std::string& func_name);
|
|
Node* CreateFunctionNode(
|
|
FunctionContext& func_ctx,
|
|
ScopeContext& scope_ctx,
|
|
const std::shared_ptr<Graph>& graph,
|
|
const std::string& domain_name,
|
|
const std::string& func_name);
|
|
|
|
static void DebugPrintScopeContexts(const scope_ctx_map& /*scope_ctxs*/);
|
|
static void DebugPrintGraphWithFunction(const std::shared_ptr<Graph>& g);
|
|
static void DebugPrintConstantDiff(const FunctionContext&);
|
|
|
|
std::shared_ptr<Graph> graph_;
|
|
std::unordered_set<std::string> module_names_;
|
|
std::unordered_set<std::string> param_names_;
|
|
// Track modules with same module name that are exported as different onnx
|
|
// local functions.
|
|
std::unordered_map<std::string, int> module_variant_count_;
|
|
func_ctx_map func_ctxs_;
|
|
};
|
|
|
|
FunctionExtractor::FunctionContext::FunctionContext(
|
|
ScopePtr key,
|
|
const scope_list& scopes,
|
|
scope_ctx_map& scope_ctxs)
|
|
: scope_key_(std::move(key)) {
|
|
GRAPH_UPDATE(
|
|
"Process function context for scope ",
|
|
scope_key_->name().toDisplayString());
|
|
TORCH_INTERNAL_ASSERT(!scopes.empty());
|
|
const auto& ref_ctx = scope_ctxs[scope_key_];
|
|
// NOTE: Function scopes must have same number and order of nodes.
|
|
GRAPH_DEBUG(
|
|
"Initialized function context for scope ",
|
|
scope_key_->name().toDisplayString());
|
|
|
|
for (const auto& scope : scopes) {
|
|
GRAPH_DEBUG(
|
|
"Process function context for scope ", scope->name().toDisplayString());
|
|
TORCH_INTERNAL_ASSERT(scope_ctxs.find(scope) != scope_ctxs.end());
|
|
scope_ctxs_[scope] = scope_ctxs[scope];
|
|
if (scope_key_ == scope) {
|
|
continue;
|
|
}
|
|
auto& scope_ctx = scope_ctxs[scope];
|
|
|
|
const auto& ns_a = ref_ctx->nlist_;
|
|
const auto& ns_b = scope_ctx->nlist_;
|
|
TORCH_INTERNAL_ASSERT(ns_a.size() == ns_b.size());
|
|
|
|
GRAPH_DEBUG("Process nodes of scope ", scope->name().toDisplayString());
|
|
for (const auto i : c10::irange(ns_a.size())) {
|
|
TORCH_INTERNAL_ASSERT(ns_a[i]->kind() == ns_b[i]->kind());
|
|
auto n_a = ns_a[i];
|
|
auto n_b = ns_b[i];
|
|
std::vector<c10::Symbol> diff_attrs;
|
|
std::vector<c10::Symbol> same_attrs;
|
|
auto n_a_attr_names = n_a->attributeNames();
|
|
auto n_b_attr_names = n_b->attributeNames();
|
|
std::sort(n_a_attr_names.begin(), n_a_attr_names.end());
|
|
std::sort(n_b_attr_names.begin(), n_b_attr_names.end());
|
|
std::set_difference(
|
|
n_a_attr_names.begin(),
|
|
n_a_attr_names.end(),
|
|
n_b_attr_names.begin(),
|
|
n_b_attr_names.end(),
|
|
std::inserter(diff_attrs, diff_attrs.begin()));
|
|
std::set_intersection(
|
|
n_a_attr_names.begin(),
|
|
n_a_attr_names.end(),
|
|
n_b_attr_names.begin(),
|
|
n_b_attr_names.end(),
|
|
std::inserter(same_attrs, same_attrs.begin()));
|
|
for (auto attr_name : diff_attrs) {
|
|
attribute_map_[n_a][attr_name].insert(n_b);
|
|
}
|
|
|
|
for (auto attr_name : same_attrs) {
|
|
if (!HasSameAttribute(n_a, n_b, attr_name)) {
|
|
attribute_map_[n_a][attr_name].insert(n_b);
|
|
}
|
|
}
|
|
}
|
|
GRAPH_DEBUG("Process scope complete. ", scope->name().toDisplayString());
|
|
}
|
|
|
|
GRAPH_DEBUG(
|
|
"Process function context complete. ",
|
|
scope_key_->name().toDisplayString());
|
|
DebugPrint();
|
|
}
|
|
|
|
void FunctionExtractor::FunctionContext::DebugPrint() const {
|
|
GRAPH_DEBUG("Scope name: ", scope_key_->name().toDisplayString());
|
|
|
|
for (const auto& it : attribute_map_) {
|
|
for (const auto& attr_it : it.second) {
|
|
GRAPH_DEBUG(
|
|
"Attribute value difference for attribute ",
|
|
attr_it.first.toDisplayString());
|
|
GRAPH_DEBUG(*it.first);
|
|
for (auto n : attr_it.second) {
|
|
GRAPH_DEBUG(*n);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void FunctionExtractor::FunctionContext::SetAttrName(
|
|
Node* ref_n,
|
|
Symbol attr,
|
|
const std::string& name) {
|
|
auto v_it =
|
|
scope_ctxs_[scope_key_]->env_to_subgraph_.find(ref_n->outputs().at(0));
|
|
TORCH_INTERNAL_ASSERT(
|
|
v_it != scope_ctxs_[scope_key_]->env_to_subgraph_.end());
|
|
auto* n_in_def = v_it->second->node();
|
|
node_attr_to_name_[n_in_def][attr.toUnqualString()] = name;
|
|
}
|
|
|
|
std::optional<std::string> FunctionExtractor::FunctionContext::FindAttrName(
|
|
Node* ref_n,
|
|
Symbol attr) {
|
|
auto v_it =
|
|
scope_ctxs_[scope_key_]->env_to_subgraph_.find(ref_n->outputs().at(0));
|
|
if (v_it == scope_ctxs_[scope_key_]->env_to_subgraph_.end()) {
|
|
return std::nullopt;
|
|
}
|
|
auto* n_in_def = v_it->second->node();
|
|
auto n_attr_it = node_attr_to_name_.find(n_in_def);
|
|
if (n_attr_it == node_attr_to_name_.end()) {
|
|
return std::nullopt;
|
|
}
|
|
auto name_it = n_attr_it->second.find(attr.toUnqualString());
|
|
if (name_it == n_attr_it->second.end()) {
|
|
return std::nullopt;
|
|
}
|
|
return name_it->second;
|
|
}
|
|
|
|
void FunctionExtractor::DebugPrintScopeContexts(
|
|
const scope_ctx_map& scope_ctxs) {
|
|
for (auto& it : scope_ctxs) {
|
|
GRAPH_UPDATE(
|
|
"Scope name: ",
|
|
it.first->namesFromRoot(),
|
|
" ",
|
|
it.first->name().toDisplayString());
|
|
GRAPH_UPDATE("Children scopes: ", [&]() {
|
|
std::stringstream ss;
|
|
for (const auto& child_scope : it.second->children_) {
|
|
ss << child_scope->name().toDisplayString() << " ";
|
|
}
|
|
return ss.str();
|
|
}());
|
|
GRAPH_UPDATE("Node types: \n", [&]() {
|
|
std::stringstream ss;
|
|
for (auto n : it.second->nlist_) {
|
|
ss << " " << *n;
|
|
}
|
|
return ss.str();
|
|
}());
|
|
GRAPH_UPDATE("Node count: ", it.second->nlist_.size());
|
|
}
|
|
}
|
|
|
|
void FunctionExtractor::DebugPrintGraphWithFunction(
|
|
const std::shared_ptr<Graph>& g) {
|
|
GRAPH_UPDATE("Local function definitions:");
|
|
for (auto* n : g->nodes()) {
|
|
if (n->kind() == Symbol::onnx("LocalFunctionDef")) {
|
|
GRAPH_UPDATE(
|
|
n->s(attr::name),
|
|
" graph: ",
|
|
n->g(Symbol::attr("graph"))->toString());
|
|
}
|
|
}
|
|
GRAPH_UPDATE("Main graph: ", g->toString());
|
|
}
|
|
|
|
bool FunctionExtractor::IsValidScope(const ScopePtr& s) {
|
|
return !s->isRoot() && !s->isBlank();
|
|
}
|
|
|
|
bool FunctionExtractor::IsAncestor(const ScopePtr& parent, ScopePtr child) {
|
|
if (!IsValidScope(parent) || !IsValidScope(child) ||
|
|
parent->getDepth() >= child->getDepth()) {
|
|
return false;
|
|
}
|
|
do {
|
|
child = child->parent();
|
|
if (parent == child) {
|
|
return true;
|
|
}
|
|
} while (IsValidScope(child));
|
|
return false;
|
|
}
|
|
|
|
std::optional<ScopePtr> FunctionExtractor::FindCommonAncestor(
|
|
ScopePtr a,
|
|
ScopePtr b) {
|
|
if (!IsValidScope(a) || !IsValidScope(b)) {
|
|
return std::nullopt;
|
|
}
|
|
|
|
auto diff =
|
|
static_cast<int64_t>(a->getDepth()) - static_cast<int64_t>(b->getDepth());
|
|
if (diff != 0) {
|
|
auto deeper_scope = diff > 0 ? a : b;
|
|
auto other_scope = diff > 0 ? b : a;
|
|
diff = std::abs(diff);
|
|
while (diff > 0) {
|
|
deeper_scope = deeper_scope->parent();
|
|
diff--;
|
|
}
|
|
a = deeper_scope;
|
|
b = other_scope;
|
|
}
|
|
|
|
while (IsValidScope(a) && IsValidScope(b)) {
|
|
if (a == b) {
|
|
return a;
|
|
} else {
|
|
a = a->parent();
|
|
b = b->parent();
|
|
}
|
|
}
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
std::optional<ScopePtr> FunctionExtractor::FindCommonAncestor(
|
|
const scope_list& scopes) {
|
|
if (scopes.empty()) {
|
|
return std::nullopt;
|
|
}
|
|
|
|
std::optional<ScopePtr> common_ancestor = scopes.at(0);
|
|
for (const auto& scope : scopes) {
|
|
common_ancestor = FindCommonAncestor(common_ancestor.value(), scope);
|
|
if (!common_ancestor.has_value()) {
|
|
return std::nullopt;
|
|
}
|
|
}
|
|
|
|
return common_ancestor;
|
|
}
|
|
|
|
std::optional<ScopePtr> FunctionExtractor::InferScope(Node* n) {
|
|
// The scope of node n is assigned based on the following rules.
|
|
// 1. If all uses of outputs of n belongs to the same scope,
|
|
// assign that scope, otherwise
|
|
// 2. If all nodes of inputs of n belongs to the same scope,
|
|
// assign that scope, otherwise
|
|
// 3. Find common ancestor of the scopes of uses of outputs of n,
|
|
// and the scopes of nodes of inputs of n.
|
|
scope_list input_scopes;
|
|
scope_list output_scopes;
|
|
for (auto input : n->inputs()) {
|
|
input_scopes.emplace_back(input->node()->scope());
|
|
}
|
|
for (auto output : n->outputs()) {
|
|
for (auto use : output->uses()) {
|
|
if (!IsValidScope(use.user->scope())) {
|
|
auto inferred_output_scope = InferScope(use.user);
|
|
if (inferred_output_scope.has_value() &&
|
|
IsValidScope(inferred_output_scope.value())) {
|
|
use.user->setScope(inferred_output_scope.value());
|
|
}
|
|
}
|
|
output_scopes.emplace_back(use.user->scope());
|
|
}
|
|
}
|
|
if (!output_scopes.empty() &&
|
|
std::all_of(
|
|
output_scopes.begin(),
|
|
output_scopes.end(),
|
|
[&output_scopes](const ScopePtr& scope) -> bool {
|
|
return IsValidScope(scope) && scope == output_scopes.at(0);
|
|
})) {
|
|
return output_scopes.at(0);
|
|
} else if (
|
|
!input_scopes.empty() &&
|
|
std::all_of(
|
|
input_scopes.begin(),
|
|
input_scopes.end(),
|
|
[&input_scopes](const ScopePtr& scope) -> bool {
|
|
return IsValidScope(scope) && scope == input_scopes.at(0);
|
|
})) {
|
|
return input_scopes.at(0);
|
|
} else {
|
|
scope_list scopes;
|
|
std::copy_if(
|
|
input_scopes.begin(),
|
|
input_scopes.end(),
|
|
std::back_inserter(scopes),
|
|
IsValidScope);
|
|
std::copy_if(
|
|
output_scopes.begin(),
|
|
output_scopes.end(),
|
|
std::back_inserter(scopes),
|
|
IsValidScope);
|
|
if (!scopes.empty()) {
|
|
auto common_ancestor = FindCommonAncestor(scopes);
|
|
if (common_ancestor.has_value() &&
|
|
IsValidScope(common_ancestor.value())) {
|
|
return common_ancestor;
|
|
}
|
|
}
|
|
}
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
std::shared_ptr<Graph> FunctionExtractor::ConstructFuncGraph(
|
|
FunctionContext& func_ctx) {
|
|
auto& ctx = *func_ctx.scope_ctxs_[func_ctx.scope_key_];
|
|
const auto& nlist = ctx.nlist_;
|
|
const auto& scope = ctx.scope_;
|
|
auto& env = ctx.env_to_subgraph_;
|
|
|
|
auto g = std::make_shared<Graph>();
|
|
GRAPH_DEBUG("Constructing graph for ", scope->namesFromRoot());
|
|
|
|
// TODO: Update input names of function to match those in Module source code
|
|
// signature.
|
|
// This requires mapping between function node inputs and Module inputs.
|
|
// Due to the lack of such mapping, currently debugName is used as input
|
|
// names.
|
|
ctx.PopulateInputsOutputs(param_names_);
|
|
for (auto* v : ctx.inputs_) {
|
|
env[v] = g->addInput()->copyMetadata(v);
|
|
GRAPH_DEBUG(
|
|
"Add input value ",
|
|
env[v]->debugName(),
|
|
" for outer scope value ",
|
|
v->debugName(),
|
|
" from ",
|
|
*v->node());
|
|
}
|
|
|
|
for (auto* n : nlist) {
|
|
auto clone_n = g->createClone(n, [&](Value* v) {
|
|
TORCH_INTERNAL_ASSERT(env.find(v) != env.end());
|
|
return env[v];
|
|
});
|
|
for (const auto i : c10::irange(clone_n->outputs().size())) {
|
|
env[n->output(i)] = clone_n->output(i);
|
|
}
|
|
g->insertNode(clone_n);
|
|
}
|
|
|
|
// If values are used outside of this graph, set as graph output.
|
|
for (auto* v : ctx.outputs_) {
|
|
TORCH_INTERNAL_ASSERT(env.find(v) != env.end());
|
|
g->registerOutput(env[v]);
|
|
}
|
|
|
|
GRAPH_DEBUG(g->toString());
|
|
return g;
|
|
}
|
|
|
|
Node* FunctionExtractor::CreateFunctionDefNode(
|
|
FunctionContext& func_ctx,
|
|
const std::shared_ptr<Graph>& graph,
|
|
const std::string& domain_name,
|
|
const std::string& func_name) {
|
|
const auto func_def_nk = Symbol::onnx("LocalFunctionDef");
|
|
const auto func_g_attr = Symbol::attr("graph");
|
|
const auto func_name_attr = attr::name;
|
|
const auto func_domain_attr = Symbol::attr("domain");
|
|
|
|
auto func_graph = ConstructFuncGraph(func_ctx);
|
|
|
|
// create and insert local function definition node
|
|
auto func_def_n = graph->create(func_def_nk, 0);
|
|
func_def_n->g_(func_g_attr, func_graph);
|
|
func_def_n->s_(func_name_attr, func_name);
|
|
func_def_n->s_(func_domain_attr, domain_name);
|
|
graph->prependNode(func_def_n);
|
|
|
|
// set constants and attributes of different values as function attributes.
|
|
std::unordered_map<std::string, int> base_attr_name_count;
|
|
std::vector<std::string> final_attr_names;
|
|
|
|
auto adjust_attr_name = [&](std::string attr_name) {
|
|
if (base_attr_name_count.find(attr_name) != base_attr_name_count.end()) {
|
|
attr_name =
|
|
attr_name + "." + std::to_string(base_attr_name_count[attr_name]++);
|
|
} else {
|
|
base_attr_name_count[attr_name] = 1;
|
|
}
|
|
return attr_name;
|
|
};
|
|
|
|
for (const auto& n_it : func_ctx.attribute_map_) {
|
|
auto* n = n_it.first;
|
|
for (const auto& attr_it : n_it.second) {
|
|
const auto& attr = attr_it.first;
|
|
// Add prefix "inferred::" to name of inferred attribute.
|
|
// This is to differentiate from annotated attributes picked up
|
|
// from python module annotation.
|
|
auto attr_name = "inferred::" + std::string(n->kind().toUnqualString()) +
|
|
'_' + attr.toUnqualString();
|
|
auto final_attr_name = adjust_attr_name(attr_name);
|
|
final_attr_names.emplace_back(final_attr_name);
|
|
func_ctx.SetAttrName(n, attr, final_attr_name);
|
|
}
|
|
}
|
|
|
|
// Set annotated attributes
|
|
std::unordered_set<Symbol> annotated_attr_names;
|
|
bool first_iteration = true;
|
|
for (const auto& it : func_ctx.scope_ctxs_) {
|
|
auto scope = it.first;
|
|
auto annotated_attr_node = scope_attr_map_.find(scope);
|
|
if (annotated_attr_node != scope_attr_map_.end()) {
|
|
auto names = annotated_attr_node->second->attributeNames();
|
|
if (first_iteration) {
|
|
std::copy(
|
|
names.begin(),
|
|
names.end(),
|
|
std::inserter(annotated_attr_names, annotated_attr_names.end()));
|
|
first_iteration = false;
|
|
} else {
|
|
auto unseen_attr_name = std::find_if(
|
|
names.begin(),
|
|
names.end(),
|
|
[&annotated_attr_names](const Symbol& name) {
|
|
return annotated_attr_names.find(name) ==
|
|
annotated_attr_names.end();
|
|
});
|
|
TORCH_CHECK(
|
|
unseen_attr_name == names.end(),
|
|
"Found outstanding annotated attribute ",
|
|
*unseen_attr_name,
|
|
" from module ",
|
|
scope->name(),
|
|
". Please ensure module instances of the same class have the same set of annotated attributes.");
|
|
}
|
|
}
|
|
}
|
|
for (auto attr_name : annotated_attr_names) {
|
|
final_attr_names.emplace_back(attr_name.toUnqualString());
|
|
}
|
|
|
|
func_def_n->ss_(Symbol::attr("attributes"), final_attr_names);
|
|
|
|
return func_def_n;
|
|
}
|
|
|
|
Node* FunctionExtractor::CreateFunctionNode(
|
|
FunctionContext& func_ctx,
|
|
ScopeContext& scope_ctx,
|
|
const std::shared_ptr<Graph>& graph,
|
|
const std::string& domain_name,
|
|
const std::string& func_name) {
|
|
const auto& func_scope = func_ctx.scope_key_;
|
|
GRAPH_DEBUG(
|
|
"Create and insert local function for scope: ",
|
|
func_scope->namesFromRoot());
|
|
scope_ctx.PopulateInputsOutputs(param_names_);
|
|
auto last_n = *scope_ctx.nlist_.rbegin();
|
|
auto func_n = graph->create(
|
|
Symbol::fromQualString(domain_name + "::" + func_name),
|
|
scope_ctx.outputs_.size());
|
|
func_n->copyMetadata(last_n);
|
|
for (auto* v : scope_ctx.inputs_) {
|
|
func_n->addInput(v);
|
|
}
|
|
for (const auto i : c10::irange(scope_ctx.outputs_.size())) {
|
|
func_n->output(i)->setType(scope_ctx.outputs_[i]->type());
|
|
scope_ctx.outputs_[i]->replaceAllUsesWith(func_n->output(i));
|
|
}
|
|
|
|
// set attributes of different values as function attributes.
|
|
auto copy_attr =
|
|
[](Node* a, Node* b, Symbol attr, const std::string& new_name) {
|
|
#define COPY_ATTR(kind) \
|
|
case AttributeKind::kind: { \
|
|
b->kind##_(Symbol::attr(new_name), a->kind(attr)); \
|
|
break; \
|
|
}
|
|
switch (a->kindOf(attr)) {
|
|
COPY_ATTR(f)
|
|
COPY_ATTR(fs)
|
|
COPY_ATTR(i)
|
|
COPY_ATTR(is)
|
|
COPY_ATTR(s)
|
|
COPY_ATTR(ss)
|
|
COPY_ATTR(t)
|
|
COPY_ATTR(ts)
|
|
#undef COPY_ATTR
|
|
case AttributeKind::ival:
|
|
case AttributeKind::g:
|
|
case AttributeKind::gs:
|
|
case AttributeKind::ty:
|
|
case AttributeKind::tys:
|
|
case AttributeKind::c:
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(
|
|
false,
|
|
"Unexpected attribute type ",
|
|
static_cast<int>(a->kindOf(attr)),
|
|
" from node ",
|
|
*a);
|
|
break;
|
|
}
|
|
};
|
|
|
|
for (const auto& it : func_ctx.attribute_map_) {
|
|
auto* ref_n = it.first;
|
|
for (const auto& attr_it : it.second) {
|
|
const auto& attr = attr_it.first;
|
|
auto attr_name = func_ctx.FindAttrName(ref_n, attr).value();
|
|
copy_attr(ref_n, func_n, attr, attr_name);
|
|
for (auto* n : scope_ctx.nlist_) {
|
|
if (attr_it.second.find(n) != attr_it.second.end()) {
|
|
copy_attr(n, func_n, attr, attr_name);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// annotated attributes
|
|
auto scope = scope_ctx.scope_;
|
|
auto annotated_attr_node = scope_attr_map_.find(scope);
|
|
if (annotated_attr_node != scope_attr_map_.end()) {
|
|
auto node = annotated_attr_node->second;
|
|
for (auto attr : node->attributeNames()) {
|
|
copy_attr(node, func_n, attr, attr.toUnqualString());
|
|
}
|
|
}
|
|
|
|
func_n->insertAfter(last_n);
|
|
return func_n;
|
|
}
|
|
|
|
void FunctionExtractor::ConvertScopeToFunction(
|
|
const ScopePtr& scope_key,
|
|
const scope_list& scope_list,
|
|
scope_ctx_map& scope_ctxs,
|
|
const std::shared_ptr<Graph>& graph) {
|
|
// This function needs to be called always on inner most scopes.
|
|
// 1. Generate function context, this identifies different constants and
|
|
// attributes.
|
|
// 2. Create function definition node, and insert to main graph.
|
|
// 3. Create function node for each call, and replace subgraph nodes in parent
|
|
// functions.
|
|
|
|
func_ctxs_.insert(std::make_pair(
|
|
scope_key, new FunctionContext(scope_key, scope_list, scope_ctxs)));
|
|
auto& func_ctx = *func_ctxs_[scope_key];
|
|
|
|
const std::string module_class_name(
|
|
ONNXScopeName::className(func_ctx.scope_key_));
|
|
auto pos = module_class_name.rfind('.');
|
|
TORCH_INTERNAL_ASSERT(pos != std::string::npos);
|
|
|
|
auto construct_unique_module_name = [&](std::string module_name) {
|
|
auto module_name_variant = module_variant_count_.find(module_name);
|
|
if (module_name_variant != module_variant_count_.end()) {
|
|
module_variant_count_[module_name]++;
|
|
module_name += ("." + std::to_string(module_name_variant->second));
|
|
} else {
|
|
module_variant_count_[module_name] = 0;
|
|
}
|
|
return module_name;
|
|
};
|
|
|
|
const auto domain_name = module_class_name.substr(0, pos);
|
|
const auto func_name =
|
|
construct_unique_module_name(module_class_name.substr(pos + 1));
|
|
|
|
CreateFunctionDefNode(func_ctx, graph, domain_name, func_name);
|
|
|
|
// create and insert local function node to graph.
|
|
for (const auto& it : func_ctx.scope_ctxs_) {
|
|
auto scope = it.first;
|
|
auto& scope_ctx = *it.second;
|
|
auto func_n =
|
|
CreateFunctionNode(func_ctx, scope_ctx, graph, domain_name, func_name);
|
|
|
|
std::unordered_set<Node*> old_nodes(
|
|
scope_ctx.nlist_.begin(), scope_ctx.nlist_.end());
|
|
|
|
auto last_n = *scope_ctx.nlist_.rbegin();
|
|
// replace function body nodes in parent scopes with local function node.
|
|
for (auto& it : scope_ctxs) {
|
|
const auto& parent_scope = it.first;
|
|
auto& parent_ctx = *it.second;
|
|
|
|
if (!IsAncestor(parent_scope, scope)) {
|
|
continue;
|
|
}
|
|
|
|
auto& ctx_nlist = parent_ctx.nlist_;
|
|
GRAPH_DEBUG(
|
|
"Replace local function node in parent scope: ",
|
|
it.first->namesFromRoot(),
|
|
" nodes to remove: ",
|
|
old_nodes.size(),
|
|
" parent total nodes: ",
|
|
ctx_nlist.size());
|
|
|
|
// insert local function node
|
|
auto last_n_it = std::find(ctx_nlist.begin(), ctx_nlist.end(), last_n);
|
|
ctx_nlist.insert(last_n_it, func_n);
|
|
|
|
// remove replaced nodes from list
|
|
ctx_nlist.erase(
|
|
std::remove_if(
|
|
ctx_nlist.begin(),
|
|
ctx_nlist.end(),
|
|
[&old_nodes](Node* n) {
|
|
return old_nodes.find(n) != old_nodes.end();
|
|
}),
|
|
ctx_nlist.end());
|
|
|
|
GRAPH_DEBUG("Parent total nodes after remove: ", ctx_nlist.size());
|
|
|
|
// refresh inputs/outputs.
|
|
parent_ctx.PopulateInputsOutputs(param_names_);
|
|
}
|
|
}
|
|
|
|
for (const auto& it : func_ctx.scope_ctxs_) {
|
|
auto& scope_ctx = *it.second;
|
|
// delete replaced nodes in graph.
|
|
for (auto it = scope_ctx.nlist_.rbegin(); it != scope_ctx.nlist_.rend();) {
|
|
auto* n = *it;
|
|
it++;
|
|
GRAPH_DEBUG("Destroying node ", *n);
|
|
n->destroy();
|
|
}
|
|
}
|
|
}
|
|
|
|
bool FunctionExtractor::ScopeContext::IsIdenticalFuncion(
|
|
const ScopeContext& other_ctx) const {
|
|
// Differentiate same function under different inputs.
|
|
// When constants are passed in place of inputs, it leads to different
|
|
// input count and node count. Likewise, due to different uses, output
|
|
// count can be different as well.
|
|
// For now export them as different functions.
|
|
// Covered by `test_local_function_overloads` in
|
|
// `test/onnx/test_utility_funs.py`.
|
|
if (&other_ctx == this) {
|
|
return true;
|
|
}
|
|
if (ONNXScopeName::className(this->scope_) !=
|
|
ONNXScopeName::className(other_ctx.scope_)) {
|
|
return false;
|
|
}
|
|
if (this->inputs_.size() != other_ctx.inputs_.size() ||
|
|
this->outputs_.size() != other_ctx.outputs_.size()) {
|
|
return false;
|
|
}
|
|
const auto& ns_a = this->nlist_;
|
|
const auto& ns_b = other_ctx.nlist_;
|
|
if (ns_a.size() != ns_b.size()) {
|
|
return false;
|
|
}
|
|
for (const auto i : c10::irange(ns_a.size())) {
|
|
if (ns_a[i]->kind() != ns_b[i]->kind()) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
void FunctionExtractor::ScopeContext::PopulateInputsOutputs(
|
|
const std::unordered_set<std::string>& param_names) {
|
|
inputs_.clear();
|
|
outputs_.clear();
|
|
const auto& nlist = this->nlist_;
|
|
std::unordered_set<Value*> v_set;
|
|
std::unordered_set<Node*> n_set;
|
|
|
|
value_list input_list;
|
|
value_list initializer_list;
|
|
|
|
// Add initializers after inputs.
|
|
for (auto* n : nlist) {
|
|
for (auto* v : n->inputs()) {
|
|
if (v_set.find(v) == v_set.end()) {
|
|
if (param_names.find(v->debugName()) != param_names.end()) {
|
|
initializer_list.emplace_back(v);
|
|
} else {
|
|
input_list.emplace_back(v);
|
|
}
|
|
v_set.insert(v);
|
|
}
|
|
}
|
|
for (auto* v : n->outputs()) {
|
|
v_set.insert(v);
|
|
}
|
|
n_set.insert(n);
|
|
}
|
|
for (auto* v : input_list) {
|
|
inputs_.emplace_back(v);
|
|
}
|
|
for (auto* v : initializer_list) {
|
|
inputs_.emplace_back(v);
|
|
}
|
|
|
|
for (auto* n : nlist) {
|
|
for (auto* v : n->outputs()) {
|
|
bool used_outside = false;
|
|
for (auto use : v->uses()) {
|
|
used_outside |= (n_set.find(use.user) == n_set.end());
|
|
}
|
|
if (used_outside) {
|
|
outputs_.emplace_back(v);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void FunctionExtractor::HandleNoScopeNodes(
|
|
scope_ctx_map& scope_ctxs,
|
|
const node_list& no_scope_nlist) {
|
|
GRAPH_UPDATE("No scope node count: ", no_scope_nlist.size());
|
|
for (auto n : no_scope_nlist) {
|
|
TORCH_WARN(
|
|
"ONNX function extraction cannot determine the scope for node: ", *n);
|
|
}
|
|
TORCH_INTERNAL_ASSERT(
|
|
no_scope_nlist.empty(),
|
|
"ONNX function extraction cannot determine the scope for the above nodes.");
|
|
}
|
|
|
|
std::tuple<FunctionExtractor::scope_ctx_map, node_list> FunctionExtractor::
|
|
PartitionNodesByScope(Block* b) {
|
|
scope_ctx_map scope_ctxs = {};
|
|
node_list no_scope_nlist;
|
|
|
|
auto find_or_create_scope_ctx = [](scope_ctx_map& scope_ctxs,
|
|
const ScopePtr& scope) {
|
|
if (scope_ctxs.find(scope) == scope_ctxs.end()) {
|
|
scope_ctxs.insert(std::make_pair(scope, new ScopeContext()));
|
|
}
|
|
return scope_ctxs[scope];
|
|
};
|
|
|
|
auto record_node_scope = [&scope_ctxs, &find_or_create_scope_ctx](Node* n) {
|
|
const auto& scope = n->scope();
|
|
find_or_create_scope_ctx(scope_ctxs, scope)->scope_ = scope;
|
|
auto tmp_scope = scope;
|
|
while (IsValidScope(tmp_scope)) {
|
|
find_or_create_scope_ctx(scope_ctxs, tmp_scope)->nlist_.emplace_back(n);
|
|
if (IsValidScope(tmp_scope->parent())) {
|
|
find_or_create_scope_ctx(scope_ctxs, tmp_scope->parent())
|
|
->children_.insert(tmp_scope);
|
|
}
|
|
tmp_scope = tmp_scope->parent();
|
|
}
|
|
};
|
|
|
|
for (auto* n : b->nodes()) {
|
|
auto scope = n->scope();
|
|
if (scope && IsValidScope(scope)) {
|
|
record_node_scope(n);
|
|
} else {
|
|
auto inferred_scope = InferScope(n);
|
|
|
|
if (inferred_scope.has_value() && IsValidScope(inferred_scope.value())) {
|
|
n->setScope(inferred_scope.value());
|
|
record_node_scope(n);
|
|
} else {
|
|
GRAPH_UPDATE("Cannot infer proper scope for node: ", *n);
|
|
no_scope_nlist.emplace_back(n);
|
|
}
|
|
}
|
|
|
|
for (auto* sub_b : n->blocks()) {
|
|
auto [subblock_scope_ctxs, subblock_no_scope_nlist] =
|
|
PartitionNodesByScope(sub_b);
|
|
|
|
for (auto& it : subblock_scope_ctxs) {
|
|
if (scope_ctxs.find(it.first) == scope_ctxs.end()) {
|
|
scope_ctxs.insert(std::make_pair(it.first, it.second));
|
|
} else {
|
|
for (auto* s_n : it.second->nlist_) {
|
|
scope_ctxs[it.first]->nlist_.emplace_back(s_n);
|
|
}
|
|
for (const auto& s_child_scope : it.second->children_) {
|
|
scope_ctxs[it.first]->children_.insert(s_child_scope);
|
|
}
|
|
}
|
|
}
|
|
|
|
no_scope_nlist.insert(
|
|
no_scope_nlist.end(),
|
|
subblock_no_scope_nlist.begin(),
|
|
subblock_no_scope_nlist.end());
|
|
}
|
|
}
|
|
|
|
for (auto& it : scope_ctxs) {
|
|
it.second->scope_ = it.first;
|
|
it.second->PopulateInputsOutputs(param_names_);
|
|
}
|
|
|
|
return std::tie(scope_ctxs, no_scope_nlist);
|
|
}
|
|
|
|
FunctionExtractor::scope_ctx_map FunctionExtractor::PartitionNodesByScope(
|
|
const std::shared_ptr<Graph>& graph) {
|
|
scope_ctx_map scope_ctxs;
|
|
node_list no_scope_nlist;
|
|
std::tie(scope_ctxs, no_scope_nlist) = PartitionNodesByScope(graph->block());
|
|
|
|
HandleNoScopeNodes(scope_ctxs, no_scope_nlist);
|
|
|
|
return scope_ctxs;
|
|
}
|
|
|
|
std::unordered_map<ScopePtr, scope_list> FunctionExtractor::
|
|
PartitionIdenticalScopes(FunctionExtractor::scope_ctx_map& scope_ctxs) {
|
|
std::unordered_map<ScopePtr, scope_list> identical_scope_map;
|
|
|
|
for (auto& it : scope_ctxs) {
|
|
auto scope = it.first;
|
|
const auto& scope_ctx = it.second;
|
|
bool unique = true;
|
|
for (auto& kv_it : identical_scope_map) {
|
|
auto key_scope = kv_it.first;
|
|
const auto& key_scope_ctx = scope_ctxs[key_scope];
|
|
auto& key_scope_vec = kv_it.second;
|
|
if (key_scope_ctx->IsIdenticalFuncion(*scope_ctx)) {
|
|
key_scope_vec.emplace_back(scope);
|
|
unique = false;
|
|
break;
|
|
}
|
|
}
|
|
if (unique) {
|
|
identical_scope_map[scope].emplace_back(scope);
|
|
}
|
|
}
|
|
|
|
return identical_scope_map;
|
|
}
|
|
|
|
static bool HasSameAttribute(
|
|
const Node* a,
|
|
const Node* b,
|
|
const c10::Symbol& attr) {
|
|
if (!a->hasAttribute(attr) && !b->hasAttribute(attr)) {
|
|
return true;
|
|
}
|
|
if (!a->hasAttribute(attr) || !b->hasAttribute(attr)) {
|
|
return false;
|
|
}
|
|
auto a_kind = a->kindOf(attr);
|
|
auto b_kind = b->kindOf(attr);
|
|
if (a_kind != b_kind) {
|
|
return false;
|
|
}
|
|
|
|
#define COMP_ATTR(kind) \
|
|
case AttributeKind::kind: { \
|
|
const auto& a_v = a->kind(attr); \
|
|
const auto& b_v = b->kind(attr); \
|
|
return a_v == b_v; \
|
|
}
|
|
|
|
switch (a_kind) {
|
|
COMP_ATTR(f)
|
|
COMP_ATTR(fs)
|
|
COMP_ATTR(i)
|
|
COMP_ATTR(is)
|
|
COMP_ATTR(s)
|
|
COMP_ATTR(ss)
|
|
#undef COMP_ATTR
|
|
case AttributeKind::t: {
|
|
const auto& a_v = a->t(attr);
|
|
const auto& b_v = b->t(attr);
|
|
return a_v.equal(b_v);
|
|
}
|
|
case AttributeKind::ts: {
|
|
const auto& a_v = a->ts(attr);
|
|
const auto& b_v = b->ts(attr);
|
|
return std::equal(
|
|
a_v.begin(),
|
|
a_v.end(),
|
|
b_v.begin(),
|
|
b_v.end(),
|
|
[](const at::Tensor& a_t, const at::Tensor& b_t) {
|
|
return a_t.equal(b_t);
|
|
});
|
|
}
|
|
case AttributeKind::ival:
|
|
case AttributeKind::g:
|
|
case AttributeKind::gs:
|
|
case AttributeKind::ty:
|
|
case AttributeKind::tys:
|
|
case AttributeKind::c:
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(
|
|
false,
|
|
"Unexpected attribute type ",
|
|
static_cast<int>(a_kind),
|
|
" from node ",
|
|
*a);
|
|
break;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
scope_list FunctionExtractor::SortScopesByMaxDepth(
|
|
std::unordered_map<ScopePtr, scope_list>& identical_scope_map) {
|
|
std::unordered_map<ScopePtr, size_t> scope_max_depth;
|
|
for (const auto& it : identical_scope_map) {
|
|
const auto& scopes = it.second;
|
|
size_t max_depth = 0;
|
|
for (const auto& scope : scopes) {
|
|
if (scope->getDepth() > max_depth) {
|
|
max_depth = scope->getDepth();
|
|
}
|
|
}
|
|
scope_max_depth[it.first] = max_depth;
|
|
}
|
|
|
|
scope_list sorted_scopes;
|
|
sorted_scopes.reserve(scope_max_depth.size());
|
|
for (const auto& it : scope_max_depth) {
|
|
sorted_scopes.emplace_back(it.first);
|
|
}
|
|
std::sort(
|
|
sorted_scopes.begin(),
|
|
sorted_scopes.end(),
|
|
[&scope_max_depth](const ScopePtr& a, const ScopePtr& b) -> bool {
|
|
return scope_max_depth[a] >= scope_max_depth[b];
|
|
});
|
|
return sorted_scopes;
|
|
}
|
|
|
|
NodeAttrNameMap FunctionExtractor::run() {
|
|
auto scope_ctxs = PartitionNodesByScope(graph_);
|
|
DebugPrintScopeContexts(scope_ctxs);
|
|
auto identical_scope_map = PartitionIdenticalScopes(scope_ctxs);
|
|
// Deepest scope comes first, guaranteeing no other scope can be its child.
|
|
auto sorted_scope_keys = SortScopesByMaxDepth(identical_scope_map);
|
|
for (const auto& scope_key : sorted_scope_keys) {
|
|
if (module_names_.find(ONNXScopeName::className(scope_key)) !=
|
|
module_names_.end()) {
|
|
ConvertScopeToFunction(
|
|
scope_key, identical_scope_map[scope_key], scope_ctxs, graph_);
|
|
}
|
|
GRAPH_DEBUG("Main graph afterwards: ", graph_->toString());
|
|
}
|
|
DebugPrintGraphWithFunction(graph_);
|
|
|
|
// Construct return mappings
|
|
NodeAttrNameMap node_attr_to_name;
|
|
|
|
for (const auto& it : func_ctxs_) {
|
|
auto func_ref_map = it.second->node_attr_to_name_;
|
|
node_attr_to_name.insert(func_ref_map.begin(), func_ref_map.end());
|
|
}
|
|
|
|
// Clear
|
|
for (auto& it : scope_ctxs) {
|
|
delete it.second;
|
|
}
|
|
scope_ctxs.clear();
|
|
for (auto& it : func_ctxs_) {
|
|
delete it.second;
|
|
}
|
|
func_ctxs_.clear();
|
|
|
|
return node_attr_to_name;
|
|
}
|
|
|
|
// Retrieves the node representing the most recent
|
|
// ScopePtr. This function should only be invoked from module forward hook. At
|
|
// this point, module forward call is completed, and the most recent ScopePtr
|
|
// is popped from TracingState.
|
|
// This function inspects the node, and its subblock, to find
|
|
// the node associated with the most recent ScopePtr.
|
|
Node* NodeOfMostRecentScope(Node* forward_node) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
forward_node->kind() == prim::TracedModuleForward,
|
|
"forward_node got kind: ",
|
|
forward_node->kind().toDisplayString());
|
|
auto* block = forward_node->blocks()[0];
|
|
for (auto* node : block->nodes().reverse()) {
|
|
if (node->kind() == prim::TracedModuleForward) {
|
|
Node* target_node = NodeOfMostRecentScope(node);
|
|
if (scope_attr_map_.find(node->scope()) == scope_attr_map_.end()) {
|
|
return target_node;
|
|
}
|
|
}
|
|
}
|
|
return forward_node;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
// FunctionExtractor runs in the following steps. Updates are made inplace to
|
|
// the graph argument.
|
|
// 1. Partition nodes into groups based on their scope information.
|
|
// Each scope represents an individual nn.Module call. A ScopeContext object
|
|
// is created for each group.
|
|
// 2. Compare and find groups with the same subgraph pattern from step 1.
|
|
// 3. Scopes are nested. Starting from the deepest scope, extract the
|
|
// subgraph pattern, and define as local function node. Replace subgraph
|
|
// pattern with a single node of the new local function node type. A
|
|
// FunctionContext object is created for each function.
|
|
// 4. Construct NodeAttrNameMap tracking mapping from attribute name of
|
|
// IR Node inside function subgraph, to function attribute name.
|
|
NodeAttrNameMap ONNXFunctionExtraction(
|
|
std::shared_ptr<Graph>& graph,
|
|
const std::unordered_set<std::string>& module_names,
|
|
const std::vector<std::string>& param_names) {
|
|
GRAPH_UPDATE(
|
|
"Export these module forward calls as functions: ",
|
|
std::vector<std::string>{module_names.begin(), module_names.end()});
|
|
FunctionExtractor fe(graph, module_names, param_names);
|
|
return fe.run();
|
|
}
|
|
|
|
void ONNXClearScopeRecords() {
|
|
scope_attr_map_.clear();
|
|
scope_attr_graph_ = std::make_shared<Graph>();
|
|
}
|
|
|
|
void ONNXTrackScopeAttributes(
|
|
std::shared_ptr<Graph>& graph,
|
|
std::map<std::string, IValue>& attributes) {
|
|
// Skip the "real" last node which is `return_node`.
|
|
auto* last_node = graph->nodes().back()->prev();
|
|
auto* scope_node = NodeOfMostRecentScope(last_node);
|
|
auto* attr_node = scope_attr_graph_->create(prim::TracedModuleForward);
|
|
attr_node->setScope(scope_node->scope());
|
|
TORCH_INTERNAL_ASSERT(
|
|
scope_attr_map_.find(scope_node->scope()) == scope_attr_map_.end());
|
|
scope_attr_map_[scope_node->scope()] = attr_node;
|
|
|
|
for (const auto& it : attributes) {
|
|
auto k = Symbol::attr(it.first);
|
|
auto v = it.second;
|
|
if (v.isTensor()) {
|
|
attr_node->t_(k, v.toTensor());
|
|
} else if (v.isInt()) {
|
|
attr_node->i_(k, v.toInt());
|
|
} else if (v.isDouble()) {
|
|
attr_node->f_(k, v.toDouble());
|
|
} else if (v.isBool()) {
|
|
attr_node->i_(k, v.toBool());
|
|
} else if (v.isString()) {
|
|
attr_node->s_(k, v.toStringRef());
|
|
} else if (v.isIntList()) {
|
|
attr_node->is_(k, v.toIntList().vec());
|
|
} else if (v.isBoolList()) {
|
|
auto bool_list = v.toBoolList();
|
|
attr_node->is_(
|
|
k, std::vector<int64_t>(bool_list.begin(), bool_list.end()));
|
|
} else if (v.isDoubleList()) {
|
|
attr_node->fs_(k, v.toDoubleList().vec());
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace torch::jit::onnx
|