mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This is an automatic change generated by the following script: ``` #!/usr/bin/env python3 from subprocess import check_output, check_call import os def get_compiled_files_list(): import json with open("build/compile_commands.json") as f: data = json.load(f) files = [os.path.relpath(node['file']) for node in data] for idx, fname in enumerate(files): if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'): files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')] return files def run_clang_tidy(fname): check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"]) changes = check_output(["git", "ls-files", "-m"]) if len(changes) == 0: return check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"]) def main(): git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n") compiled_files = get_compiled_files_list() for idx, fname in enumerate(git_files): if fname not in compiled_files: continue if fname.startswith("caffe2/contrib/aten/"): continue print(f"[{idx}/{len(git_files)}] Processing {fname}") run_clang_tidy(fname) if __name__ == "__main__": main() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892 Reviewed By: H-Huang Differential Revision: D27991944 Pulled By: malfet fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
288 lines
8.7 KiB
C++
288 lines
8.7 KiB
C++
#include "caffe2/core/graph.h"
|
|
|
|
#include "caffe2/core/common.h"
|
|
#include "caffe2/core/logging.h"
|
|
#include "caffe2/core/net.h"
|
|
#include "caffe2/proto/caffe2_pb.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
namespace transform {
|
|
|
|
Graph::Graph(const NetDef& net) : netdef_(net) {
|
|
nodes_.clear();
|
|
nodes_.resize(net.op_size());
|
|
|
|
// Copy over operators
|
|
for (int x = 0; x < net.op_size(); x++) {
|
|
node(x).op = net.op(x);
|
|
}
|
|
|
|
// For any blob, which operator was the last to write to it?
|
|
// In python, this is known as "versions".
|
|
std::unordered_map<string, int> edge_parent;
|
|
|
|
for (int i = 0; i < (int)nodes_.size(); i++) {
|
|
for (const string& blob : node(i).op.input()) {
|
|
auto it = edge_parent.find(blob);
|
|
if (it != edge_parent.end()) {
|
|
int j = it->second;
|
|
node(i).parents[j].push_back(blob);
|
|
node(j).children[i].push_back(blob);
|
|
} else {
|
|
external_input_.insert(blob);
|
|
}
|
|
}
|
|
for (const string& blob : node(i).op.output()) {
|
|
edge_parent[blob] = i;
|
|
}
|
|
}
|
|
|
|
// Traverse opposite direction to find external outputs
|
|
|
|
// For any blob, which operator was the last to read to from it?
|
|
std::unordered_map<string, int> edge_child;
|
|
|
|
for (int i = (int)nodes_.size() - 1; i >= 0; i--) {
|
|
for (const string& blob : node(i).op.output()) {
|
|
auto it = edge_child.find(blob);
|
|
if (it == edge_child.end()) {
|
|
external_output_.insert(blob);
|
|
}
|
|
}
|
|
for (const string& blob : node(i).op.input()) {
|
|
edge_child[blob] = i;
|
|
}
|
|
}
|
|
}
|
|
|
|
const std::vector<std::pair<string, int>> Graph::GetSubgraphInput(
|
|
const std::vector<int>& match) {
|
|
return GetSubgraphPerimeterHelper(true, match);
|
|
}
|
|
|
|
const std::vector<std::pair<string, int>> Graph::GetSubgraphOutput(
|
|
const std::vector<int>& match) {
|
|
return GetSubgraphPerimeterHelper(false, match);
|
|
}
|
|
|
|
// This helper function will either get:
|
|
// 1) a list for the blobs that write INTO a subgraph
|
|
// 2) a list of for the blobs that are written FROM a subgraph.
|
|
//
|
|
// The "from_children" flag determines if it is case 1 (true) or case 2 (false).
|
|
const std::vector<std::pair<string, int>> Graph::GetSubgraphPerimeterHelper(
|
|
bool from_children,
|
|
const std::vector<int>& match) {
|
|
std::vector<std::pair<string, int>> edge_list;
|
|
std::unordered_set<int> match_set(match.begin(), match.end());
|
|
for (int x = 0; x < (int)nodes_.size(); x++) {
|
|
if (!is_node_active(x)) {
|
|
continue;
|
|
}
|
|
if (!match_set.count(x)) { // x is not in subgraph
|
|
const auto& list = from_children ? node(x).children : node(x).parents;
|
|
for (const auto& edge : list) {
|
|
int parent = edge.first;
|
|
const auto& blobs = edge.second;
|
|
if (match_set.count(parent)) { // but has a parent that is in subgraph
|
|
for (const string& blob : blobs) {
|
|
// NOLINTNEXTLINE(modernize-use-emplace)
|
|
edge_list.push_back({blob, x});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// return the list in sorted order, to allow binary searching
|
|
std::sort(edge_list.begin(), edge_list.end());
|
|
return edge_list;
|
|
}
|
|
|
|
NetDef Graph::GetNetDef() {
|
|
std::vector<bool> visited(nodes_.size(), false);
|
|
|
|
// Copy over all the properties of the netdef we're based on
|
|
NetDef netdef = netdef_;
|
|
|
|
// But we're going to put in our own operators.
|
|
netdef.clear_op();
|
|
|
|
// Keeps track of the number of parents yet to be processed.
|
|
std::vector<int> unchecked_parent_count;
|
|
|
|
// We will perform a topological traversal on the nodes, but we will prefer
|
|
// nodes that come earlier in the execution order.
|
|
|
|
// This is a min-heap, which stores its elements in ascending order.
|
|
// This stores the nodes in the order we process them to be in.
|
|
// This guarantees the lowest lexicographical topological ordering.
|
|
|
|
// This also means the original nodes will be kept in their execution order.
|
|
// NOLINTNEXTLINE(modernize-use-transparent-functors)
|
|
std::priority_queue<int, std::vector<int>, std::greater<int>> q;
|
|
|
|
// In our graph, G, the nodes don't have a strict ordering. But in the netdef,
|
|
// they must (since nets are operators executed in some order).
|
|
// How do we make sure that the order of operators in our generated netdef
|
|
// is valid?
|
|
// 1) The ordering of the netdef must be topologically sorted, respect to G.
|
|
// If A -> B is an edge in the graph G, then A must come before B in the
|
|
// netdef's ordering.
|
|
// 2) No blob conflicts: If A -> B is an edge in the graph G, and A writes to
|
|
// blob X and B reads from blob X, then there cannot be an op that writes
|
|
// to blob X between A and B in the ordering.
|
|
//
|
|
// Perform a Topological Sort, to find an order for the Operators to be in.
|
|
// We will keep track of the number of parents each node has.
|
|
// We begin with an empty queue, and push in all nodes that do not have any
|
|
// parents. Then, we keep track of all unprocessed parents for each node.
|
|
// When a node has no more unprocessed parents, we can push it into the queue
|
|
// to be processed. This guarantees condition 1 is satisfied.
|
|
|
|
// TODO(benz): Currently, condition 2 is not guaranteed to be satisified.
|
|
// However, giving each blob unique names via SSA will satisfy this condition.
|
|
// Then, the resulting graph can be optimized with memonger.
|
|
|
|
for (int i = 0; i < (int)nodes_.size(); i++) {
|
|
unchecked_parent_count.push_back(node(i).parents.size());
|
|
if (node(i).parents.size() == 0 && is_node_active(i)) {
|
|
q.push(i);
|
|
visited[i] = true;
|
|
}
|
|
}
|
|
|
|
while (!q.empty()) {
|
|
int idx = q.top();
|
|
q.pop();
|
|
if (!is_node_active(idx)) {
|
|
continue;
|
|
}
|
|
// Creates a new OperatorDef in NetDef
|
|
auto& op = *(netdef.add_op());
|
|
// Sets it equal to the OperatorDef at node(idx)
|
|
op = node(idx).op;
|
|
for (const auto& edge : node(idx).children) {
|
|
int child = edge.first;
|
|
if (!visited[child] && is_node_active(child)) {
|
|
unchecked_parent_count[child]--;
|
|
if (unchecked_parent_count[child] == 0) {
|
|
q.push(child);
|
|
visited[child] = true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return netdef;
|
|
}
|
|
|
|
void Graph::DeactivateSubgraph(std::vector<int> subgraph) {
|
|
for (int idx : subgraph) {
|
|
// remove all edges connected to inactive node
|
|
for (const auto& edge : node(idx).parents) {
|
|
int parent = edge.first;
|
|
node(parent).children.erase(idx);
|
|
}
|
|
for (const auto& edge : node(idx).children) {
|
|
int child = edge.first;
|
|
node(child).parents.erase(idx);
|
|
}
|
|
// actually mark flags as false
|
|
node(idx).active = false;
|
|
}
|
|
}
|
|
|
|
} // namespace transform
|
|
|
|
OperatorDef* AddOp(
|
|
NetDef* netdef_ptr,
|
|
string op_type,
|
|
std::vector<string> inputs,
|
|
std::vector<string> outputs) {
|
|
CHECK(netdef_ptr);
|
|
auto& netdef = *netdef_ptr;
|
|
auto op_ptr = netdef.add_op();
|
|
auto& op = *op_ptr;
|
|
op.set_type(op_type);
|
|
for (const string& inp : inputs) {
|
|
op.add_input(inp);
|
|
}
|
|
for (const string& outp : outputs) {
|
|
op.add_output(outp);
|
|
}
|
|
return op_ptr;
|
|
}
|
|
|
|
bool MatchStrings(string p, string s) {
|
|
if (p == "*") { // star accepts anything
|
|
return true;
|
|
}
|
|
// TODO(benz): memoize this. (high constant factor boost in performance)
|
|
vector<string> choices = split('|', p);
|
|
for (const string& candidate : choices) {
|
|
if (candidate == s) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool MatchArguments(const OperatorDef& p_op, const OperatorDef& g_op) {
|
|
for (const auto& p_arg : p_op.arg()) {
|
|
if (!p_arg.has_name()) {
|
|
continue;
|
|
}
|
|
bool found = false;
|
|
for (const auto& g_arg : g_op.arg()) {
|
|
if (p_arg.name() == g_arg.name()) {
|
|
found = true;
|
|
if (p_arg.has_f()) {
|
|
if (!g_arg.has_f() || p_arg.f() != g_arg.f()) {
|
|
return false;
|
|
}
|
|
}
|
|
if (p_arg.has_i()) {
|
|
if (!g_arg.has_i() || p_arg.i() != g_arg.i()) {
|
|
return false;
|
|
}
|
|
}
|
|
if (p_arg.has_s()) {
|
|
if (!g_arg.has_s() || !MatchStrings(p_arg.s(), g_arg.s())) {
|
|
return false;
|
|
}
|
|
}
|
|
if (p_arg.floats_size() != g_arg.floats_size()) {
|
|
return false;
|
|
}
|
|
for (int i = 0; i < p_arg.floats_size(); i++) {
|
|
if (p_arg.floats(i) != g_arg.floats(i)) {
|
|
return false;
|
|
}
|
|
}
|
|
if (p_arg.ints_size() != g_arg.ints_size()) {
|
|
return false;
|
|
}
|
|
for (int i = 0; i < p_arg.ints_size(); i++) {
|
|
if (p_arg.ints(i) != g_arg.ints(i)) {
|
|
return false;
|
|
}
|
|
}
|
|
if (p_arg.strings_size() != g_arg.strings_size()) {
|
|
return false;
|
|
}
|
|
for (int i = 0; i < p_arg.strings_size(); i++) {
|
|
if (!MatchStrings(p_arg.strings(i), g_arg.strings(i))) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if (!found) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
} // namespace caffe2
|