mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adds make_unique in places Pull Request resolved: https://github.com/pytorch/pytorch/pull/126966 Approved by: https://github.com/Skylion007
2321 lines
65 KiB
C++
2321 lines
65 KiB
C++
#include <torch/csrc/jit/ir/ir.h>
|
|
|
|
#include <ATen/core/builtin_function.h>
|
|
#include <ATen/core/function.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/StringUtil.h>
|
|
#include <c10/util/irange.h>
|
|
#include <torch/csrc/jit/api/function_impl.h>
|
|
#include <torch/csrc/jit/frontend/error_report.h>
|
|
#include <torch/csrc/jit/frontend/schema_matching.h>
|
|
#include <torch/csrc/jit/ir/constants.h>
|
|
#include <torch/csrc/jit/runtime/operator.h>
|
|
#include <torch/csrc/jit/serialization/python_print.h>
|
|
|
|
#include <algorithm>
|
|
#include <iostream>
|
|
#include <locale>
|
|
#include <memory>
|
|
#include <set>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include <utility>
|
|
|
|
namespace torch::jit {
|
|
|
|
namespace utils {
|
|
std::string getNodesModuleHierarchy(const Node& n) {
|
|
if (!n.callstack().has_value()) {
|
|
return std::string();
|
|
}
|
|
InlinedCallStackPtr callstack_ptr = n.callstack().value();
|
|
std::string module_hierarchy;
|
|
for (auto& entry : callstack_ptr->vec()) {
|
|
const auto& opt_module_info = std::get<kModuleInstanceInfo>(entry);
|
|
if (opt_module_info.has_value()) {
|
|
const auto& module_instance_info = opt_module_info.value();
|
|
if (!module_hierarchy.empty()) {
|
|
module_hierarchy.append(".");
|
|
}
|
|
module_hierarchy.append(utils::get_module_info(module_instance_info));
|
|
} else {
|
|
module_hierarchy += ".UNKNOWN_INSTANCE(UNKNOWN_TYPE)";
|
|
}
|
|
}
|
|
return module_hierarchy;
|
|
}
|
|
} // namespace utils
|
|
|
|
namespace {
|
|
|
|
// Constants relating to maintaining the topological index of nodes.
|
|
//
|
|
// Lower and upper bounds of the index. Inclusive range.
|
|
constexpr topo_position_t kLowerBound = INT64_MIN;
|
|
constexpr topo_position_t kUpperBound = INT64_MAX;
|
|
constexpr topo_position_t kMidPoint = 0;
|
|
|
|
// How far away to space nodes that are appended to the graph.
|
|
// should be 2^n, where:
|
|
// - n is the maximum number of repeated insertions without a re-index
|
|
// - 2^(64-n) is the maximum number of appends to the end without reindex
|
|
constexpr topo_position_t kAppendInterval = 1099511627776ULL /* 2^40 */;
|
|
|
|
void printValueRef(std::ostream& out, const Value* n) {
|
|
out << "%" << n->debugName();
|
|
}
|
|
|
|
bool isNumber(c10::string_view str) {
|
|
return str.find_first_not_of("0123456789") == std::string::npos;
|
|
}
|
|
|
|
std::string normalizeAttrName(c10::string_view field) {
|
|
if (isNumber(field)) {
|
|
return "_" + std::string{field};
|
|
}
|
|
return std::string{field};
|
|
}
|
|
|
|
void findAllNodes(
|
|
Block& block,
|
|
Symbol kind,
|
|
bool recurse,
|
|
std::vector<Node*>& ret) {
|
|
for (Node* n : block.nodes()) {
|
|
if (n->kind() == kind) {
|
|
ret.push_back(n);
|
|
}
|
|
if (recurse) {
|
|
for (auto b : n->blocks()) {
|
|
findAllNodes(*b, kind, recurse, ret);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
// NB: This overload will become ambiguous with the one Caffe2 provides in its
|
|
// logging, if they ever intersect.
|
|
template <typename T>
|
|
std::ostream& operator<<(std::ostream& out, const std::vector<T>& nodes) {
|
|
out << at::ArrayRef<T>{nodes};
|
|
return out;
|
|
}
|
|
|
|
template <typename T>
|
|
static std::ostream& printValueRefs(
|
|
std::ostream& out,
|
|
const at::ArrayRef<T> nodes) {
|
|
size_t i = 0;
|
|
for (auto n : nodes) {
|
|
if (i++ > 0) {
|
|
out << ", ";
|
|
}
|
|
printValueRef(out, n);
|
|
}
|
|
return out;
|
|
}
|
|
|
|
// Can't make these two overloads directly a template, it'll be ambiguous with
|
|
// the global printer for operator<<.
|
|
|
|
static std::ostream& operator<<(
|
|
std::ostream& out,
|
|
const at::ArrayRef<const Value*> nodes) {
|
|
return printValueRefs(out, nodes);
|
|
}
|
|
|
|
static std::ostream& operator<<(
|
|
std::ostream& out,
|
|
const at::ArrayRef<Value*> nodes) {
|
|
return printValueRefs(out, nodes);
|
|
}
|
|
|
|
struct const_value_list_with_types {
|
|
const ArrayRef<const Value*> values;
|
|
std::string delim;
|
|
const_value_list_with_types(
|
|
ArrayRef<const Value*> values,
|
|
std::string delim_ = ", ")
|
|
: values(values), delim(std::move(delim_)) {}
|
|
};
|
|
|
|
static std::ostream& operator<<(
|
|
std::ostream& out,
|
|
const const_value_list_with_types& l) {
|
|
size_t i = 0;
|
|
for (auto n : l.values) {
|
|
if (i++ > 0) {
|
|
out << l.delim;
|
|
}
|
|
printValueRef(out, n);
|
|
if (c10::type_verbosity() >= c10::TypeVerbosity::Type) {
|
|
out << " : ";
|
|
out << *n->type();
|
|
}
|
|
}
|
|
return out;
|
|
}
|
|
|
|
static void printAttribute(std::ostream& out, const at::Tensor& tensor) {
|
|
// 1-elem tensors are usually boxed scalars, so print them like it
|
|
if (tensor.numel() == 1) {
|
|
auto scalar_tensor = tensor.view(std::vector<int64_t>{}).item();
|
|
out << "{";
|
|
if (scalar_tensor.isFloatingPoint()) {
|
|
out << scalar_tensor.toDouble();
|
|
} else if (scalar_tensor.isComplex()) {
|
|
out << scalar_tensor.toComplexDouble();
|
|
} else {
|
|
out << scalar_tensor.toLong();
|
|
}
|
|
out << "}";
|
|
} else if (tensor.numel() <= max_tensor_display_size) {
|
|
// TODO: This is awful code. Also it doesn't work on Windows.
|
|
std::ostringstream tensor_ss;
|
|
tensor_ss << tensor;
|
|
std::string tensor_s{tensor_ss.str()};
|
|
// Remove newlines
|
|
std::replace(tensor_s.begin(), tensor_s.end(), '\n', ' ');
|
|
out << tensor_s;
|
|
} else {
|
|
out << "<Tensor>";
|
|
}
|
|
}
|
|
|
|
static void printAttribute(std::ostream& out, const IValue& ival) {
|
|
const auto customFormatter = [](std::ostream& ss, const IValue& input) {
|
|
if (input.isTensor()) {
|
|
printAttribute(ss, input.toTensor());
|
|
return true;
|
|
} else if (input.isTensorList()) {
|
|
ss << "[<Tensors>]";
|
|
return true;
|
|
} else if (input.isObject() && !input.type()->is_module()) {
|
|
ss << "object(" << &input.toObjectRef() << ")";
|
|
return true;
|
|
}
|
|
return false;
|
|
};
|
|
ival.repr(out, customFormatter);
|
|
}
|
|
|
|
static void printTypeList(
|
|
std::ostream& out,
|
|
const std::vector<TypePtr>& items) {
|
|
out << "[";
|
|
int i = 0;
|
|
for (auto& item : items) {
|
|
if (i++ > 0)
|
|
out << ", ";
|
|
out << *item;
|
|
}
|
|
out << "]";
|
|
}
|
|
|
|
void Node::printAttrValue(std::ostream& out, const Symbol& name) const {
|
|
switch (kindOf(name)) {
|
|
case AttributeKind::c:
|
|
printAttribute(out, c(name));
|
|
break;
|
|
case AttributeKind::cs:
|
|
// TODO(@anjali411): fix this
|
|
AT_ASSERT(false);
|
|
break;
|
|
case AttributeKind::f:
|
|
printAttribute(out, f(name));
|
|
break;
|
|
case AttributeKind::fs:
|
|
printAttribute(out, fs(name));
|
|
break;
|
|
case AttributeKind::i:
|
|
printAttribute(out, i(name));
|
|
break;
|
|
case AttributeKind::is:
|
|
printAttribute(out, is(name));
|
|
break;
|
|
case AttributeKind::s:
|
|
printAttribute(out, s(name));
|
|
break;
|
|
case AttributeKind::ss:
|
|
printAttribute(out, ss(name));
|
|
break;
|
|
case AttributeKind::t:
|
|
printAttribute(out, t(name));
|
|
break;
|
|
case AttributeKind::ts:
|
|
out << "[<Tensors>]";
|
|
break;
|
|
case AttributeKind::ival:
|
|
printAttribute(out, ival(name));
|
|
break;
|
|
case AttributeKind::g:
|
|
out << "<Graph>";
|
|
break;
|
|
case AttributeKind::gs:
|
|
out << "[<Graphs>]";
|
|
break;
|
|
case AttributeKind::ty:
|
|
out << *ty(name);
|
|
break;
|
|
case AttributeKind::tys:
|
|
printTypeList(out, tys(name));
|
|
break;
|
|
}
|
|
}
|
|
|
|
void Node::printAttributes(std::ostream& out, bool ignore_subgraph = false)
|
|
const {
|
|
out << "[";
|
|
auto names = attributeNames();
|
|
int i = 0;
|
|
for (auto name : names) {
|
|
if (ignore_subgraph && name == attr::Subgraph) {
|
|
continue;
|
|
}
|
|
if (i++ > 0) {
|
|
out << ", ";
|
|
}
|
|
// TODO: debugging mode to see the qualifier. We definitely
|
|
// don't want to print the qualifier since it should always
|
|
// be attribute, but you might be able to track down a weird
|
|
// bug by printing it out.
|
|
out << name.toUnqualString() << "=";
|
|
|
|
printAttrValue(out, name);
|
|
}
|
|
out << "]";
|
|
}
|
|
|
|
SourceRange Node::sourceRange() const {
|
|
if (source_range_) {
|
|
return *source_range_;
|
|
}
|
|
return SourceRange();
|
|
}
|
|
|
|
static std::ostream& indent(std::ostream& out, size_t level) {
|
|
for (const auto i : c10::irange(level)) {
|
|
(void)i; // Suppress unused variable warning
|
|
out << " ";
|
|
}
|
|
return out;
|
|
}
|
|
|
|
std::ostream& Node::print(
|
|
std::ostream& out,
|
|
size_t level,
|
|
std::vector<const Node*>* groups,
|
|
bool print_source_locations,
|
|
bool print_attributes,
|
|
bool print_scopes,
|
|
bool print_body) const {
|
|
auto outs = outputs();
|
|
indent(out, level) << const_value_list_with_types(outs);
|
|
out << " = ";
|
|
if (kind() == prim::PythonOp) {
|
|
auto* pyOp = static_cast<const ::torch::jit::PythonOp*>(this);
|
|
out << "^" << pyOp->name();
|
|
printAttributes(out, /*ignore_subgraph=*/false);
|
|
pyOp->writeScalars(out);
|
|
} else if (hasAttribute(attr::Subgraph) && groups) {
|
|
out << kind().toQualString() << "_" << groups->size();
|
|
if (print_attributes && numAttributes() > 1 &&
|
|
kind() != prim::DifferentiableGraph) {
|
|
printAttributes(out, /*ignore_subgraph=*/true);
|
|
}
|
|
|
|
groups->push_back(this);
|
|
} else {
|
|
out << kind().toQualString();
|
|
if (print_attributes && hasAttributes()) {
|
|
printAttributes(out);
|
|
}
|
|
}
|
|
out << "(" << inputs() << ")";
|
|
|
|
if (print_scopes) {
|
|
std::string scName = scopeName();
|
|
if (!scName.empty()) {
|
|
out << ", ";
|
|
out << "scope: " << scName;
|
|
}
|
|
}
|
|
|
|
// In debug print, append file:line:col as a comment after each node
|
|
if (print_source_locations) {
|
|
SourceRange r = sourceRange();
|
|
if (sourceRange().source()) {
|
|
if (auto orig = sourceRange().source()->findSourceRangeThatGenerated(r)) {
|
|
r = *orig;
|
|
}
|
|
}
|
|
if (auto file_line_col = r.file_line_col()) {
|
|
auto [filename, line, col] = *file_line_col;
|
|
out << " # " << filename << ":" << line << ":" << col;
|
|
}
|
|
}
|
|
|
|
if (!print_body) {
|
|
return out;
|
|
}
|
|
|
|
out << "\n";
|
|
|
|
for (const auto i : c10::irange(blocks().size())) {
|
|
auto b = blocks()[i];
|
|
indent(out, level + 1) << "block" << i << "("
|
|
<< const_value_list_with_types(b->inputs())
|
|
<< "):\n";
|
|
for (auto nested : b->nodes()) {
|
|
nested->print(out, level + 2, groups);
|
|
}
|
|
indent(out, level + 2) << "-> (" << b->outputs() << ")\n";
|
|
}
|
|
|
|
return out;
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& out, const Node& n) {
|
|
return n.print(out, 0, nullptr);
|
|
}
|
|
|
|
std::ostream& Graph::print(std::ostream& out, bool print_source_locations)
|
|
const {
|
|
out << "graph(" << const_value_list_with_types(inputs(), ",\n ")
|
|
<< "):\n";
|
|
std::vector<const Node*> groups;
|
|
for (auto n : nodes()) {
|
|
n->print(out, 1, &groups, print_source_locations);
|
|
}
|
|
out << " return (" << outputs() << ")\n";
|
|
size_t i = 0;
|
|
for (auto fg : groups) {
|
|
out << "with " << fg->kind().toQualString() << "_" << i++ << " = "
|
|
<< *fg->g(attr::Subgraph);
|
|
}
|
|
out.flush();
|
|
|
|
/*
|
|
// Uncomment this to debug all_nodes issues
|
|
{
|
|
out << "\n";
|
|
out << "all_nodes:\n";
|
|
for (auto& n : all_nodes) {
|
|
printNode(out, const_cast<Node*>(n), nullptr);
|
|
}
|
|
}
|
|
*/
|
|
return out;
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& out, const Graph& g) {
|
|
return g.print(out, true);
|
|
}
|
|
|
|
static void checkSameDevice(const Node* node) {
|
|
bool has_device = false;
|
|
std::optional<at::Device> device = c10::nullopt;
|
|
auto checkValue = [&](const Value* v) {
|
|
if (TensorTypePtr type = v->type()->cast<TensorType>()) {
|
|
if (type->device() && !has_device) {
|
|
has_device = true;
|
|
device = *type->device();
|
|
} else {
|
|
AT_ASSERT(device == type->device());
|
|
}
|
|
}
|
|
};
|
|
for (auto input : node->inputs()) {
|
|
checkValue(input);
|
|
}
|
|
for (auto output : node->outputs()) {
|
|
checkValue(output);
|
|
}
|
|
}
|
|
|
|
using node_set = std::set<const Node*>;
|
|
#define ALL_OF(container) container.begin(), container.end()
|
|
|
|
// These functions purposely operate on the internal members directly, to force
|
|
// you to think about how the invariants change if you change the data
|
|
// representation (even if the external API does not change.)
|
|
|
|
// NB: This assert is written to assume you don't have any unattached
|
|
// nodes. Unattached nodes can occur while manipulations to the
|
|
// graph are occurring.
|
|
void Node::lint() const {
|
|
// Node invariants
|
|
// - if node should live in list, nodes_iter is consistent
|
|
// - Inputs are all marked as a use by the nodes they refer to
|
|
// - Owning graph is non-null and consistent
|
|
// - The "Select" invariant, when the node is MultiReturn
|
|
//
|
|
// The handle invariant:
|
|
// If a node takes a handle as an input, it is always the
|
|
// LAST input of the node. There is at most one handle input.
|
|
|
|
{
|
|
size_t i = 0;
|
|
for (auto input : inputs_) {
|
|
// WARNING: O(n^2)
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
|
AT_ASSERT(
|
|
std::find(ALL_OF(input->uses_), Use(const_cast<Node*>(this), i)) !=
|
|
input->uses_.end());
|
|
AT_ASSERT(graph_->all_nodes.count(this) == 1);
|
|
i++;
|
|
}
|
|
}
|
|
|
|
for (auto o : outputs()) {
|
|
for (auto use : o->uses()) {
|
|
// Use invariants
|
|
// - Use is consistent with inputs
|
|
// - Every user node is live (checked in Graph)
|
|
AT_ASSERT(use.user->inputs_[use.offset] == o);
|
|
}
|
|
}
|
|
|
|
// Node subclass invariants
|
|
switch (kind()) {
|
|
case prim::Constant:
|
|
AT_ASSERT(inputs_.empty());
|
|
break;
|
|
case prim::Return:
|
|
// Return uses is zero
|
|
AT_ASSERT(outputs().empty());
|
|
break;
|
|
case prim::Param:
|
|
// Param inputs is zero
|
|
AT_ASSERT(inputs_.empty());
|
|
break;
|
|
case prim::PythonOp: {
|
|
// Python operator cconv is correct
|
|
auto* value = static_cast<const PythonOp*>(this);
|
|
value->lint_python();
|
|
break;
|
|
}
|
|
case prim::Eval:
|
|
// TODO: add invariants
|
|
// TODO: It's not good for these ops to be top-level, it makes cases
|
|
// longer.
|
|
break;
|
|
case prim::FusionGroup:
|
|
case prim::CudaFusionGroup:
|
|
case prim::oneDNNFusionGroup:
|
|
checkSameDevice(this);
|
|
// TODO: Typecheck the parameters
|
|
g(attr::Subgraph)->lint();
|
|
break;
|
|
}
|
|
}
|
|
|
|
// TODO: When lint fails, give better indication about which
|
|
// instruction triggered the failure.
|
|
void Graph::lint() const {
|
|
// Graph invariants
|
|
|
|
// Uncomment the following to see the graph
|
|
// std::cout << *const_cast<Graph*>(this);
|
|
|
|
// nodes
|
|
// - nodes_ is a valid topological ordering for inputs
|
|
// - No repeated nodes
|
|
// - Params and return do NOT occur in nodes
|
|
// - next_unique_ is greater than all uniques in graph
|
|
// - uniques in all_nodes are unique
|
|
// - every use will occur later in the toposort
|
|
|
|
struct LintScope {
|
|
LintScope() = default;
|
|
LintScope(std::unique_ptr<LintScope> parent) : parent(std::move(parent)) {}
|
|
bool contains(const Value* v) {
|
|
return values.count(v) > 0 || (parent && parent->contains(v));
|
|
}
|
|
bool contains(const Node* n) {
|
|
return nodes.count(n) > 0 || (parent && parent->contains(n));
|
|
}
|
|
void insert(const Value* v) {
|
|
AT_ASSERT(!contains(v));
|
|
values.insert(v);
|
|
}
|
|
void insert(const Node* n) {
|
|
AT_ASSERT(!contains(n));
|
|
nodes.insert(n);
|
|
}
|
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
|
std::unique_ptr<LintScope> parent;
|
|
|
|
private:
|
|
std::unordered_set<const Value*> values;
|
|
std::unordered_set<const Node*> nodes;
|
|
};
|
|
// Struct enables mutual recursion in linting methods.
|
|
// Putting it inside Graph::lint enables access to private Graph members
|
|
struct LintImpl {
|
|
LintImpl(const Graph& g)
|
|
: g(g),
|
|
scope(new LintScope()),
|
|
all_nodes_set(ALL_OF(g.all_nodes)) {} // NB: all_nodes is *unordered*
|
|
const Graph& g;
|
|
std::unique_ptr<LintScope> scope;
|
|
std::unordered_set<size_t> seen_uniques;
|
|
std::unordered_map<const Node*, int64_t> anticipated_uses;
|
|
node_set all_nodes_set;
|
|
node_set sum_set;
|
|
|
|
void check_value(const Value* v) {
|
|
scope->insert(v);
|
|
auto b2 = seen_uniques.insert(v->unique());
|
|
AT_ASSERT(b2.second); // insertion took place
|
|
AT_ASSERT(v->unique() < g.next_unique_);
|
|
|
|
for (auto use : v->uses()) {
|
|
AT_ASSERT(!scope->contains(use.user));
|
|
AT_ASSERT(g.all_nodes.count(use.user) == 1);
|
|
anticipated_uses[use.user]++; // int default constructs to 0
|
|
}
|
|
}
|
|
void check_node(const Node* n) {
|
|
for (auto input : n->inputs_) {
|
|
if (!scope->contains(input)) {
|
|
AT_ASSERTM(0, input->unique(), " not in scope");
|
|
}
|
|
}
|
|
AT_ASSERT(anticipated_uses[n] == static_cast<int64_t>(n->inputs_.size()));
|
|
anticipated_uses[n] = -1; // we saw the anticipated user!
|
|
scope->insert(n);
|
|
for (auto block : n->blocks()) {
|
|
scope = std::make_unique<LintScope>(std::move(scope));
|
|
check_block(block);
|
|
scope = std::move(scope->parent);
|
|
}
|
|
size_t i = 0;
|
|
for (auto o : n->outputs()) {
|
|
AT_ASSERT(o->node() == n);
|
|
AT_ASSERT(i++ == o->offset_);
|
|
check_value(o);
|
|
}
|
|
n->lint();
|
|
}
|
|
void check_block(const Block* b) {
|
|
// Check topological ordering
|
|
AT_ASSERT(b->param_node()->isBefore(*b->nodes().begin()));
|
|
auto curNode = *b->nodes().begin();
|
|
while (curNode != b->return_node()) {
|
|
AT_ASSERT(curNode->isBefore(curNode->next()));
|
|
curNode = curNode->next();
|
|
}
|
|
|
|
for (auto input : b->inputs()) {
|
|
check_value(input);
|
|
AT_ASSERT(input->node()->kind_ == prim::Param);
|
|
}
|
|
|
|
for (auto n : b->nodes()) {
|
|
AT_ASSERT(n->kind_ != prim::Param);
|
|
AT_ASSERT(n->kind_ != prim::Return);
|
|
check_node(n);
|
|
}
|
|
|
|
AT_ASSERT(b->output_->kind() == prim::Return);
|
|
check_node(b->output_);
|
|
|
|
// all_nodes
|
|
// - inputs_, output_ and nodes_ are all included in all_nodes
|
|
// - all_nodes does not contain dead nodes??? (likely to be temporarily
|
|
// suspended). Weaker: all_nodes contains all inputs and returns
|
|
// - only one return node???
|
|
|
|
node_set nodes_set(ALL_OF(b->nodes()));
|
|
node_set inputs_set{b->input_};
|
|
node_set output_set{b->output_};
|
|
// TODO: Make a more type safe std::includes wrapper which disallows use
|
|
// on non-ordered containers
|
|
AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(nodes_set)));
|
|
AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(inputs_set)));
|
|
AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(output_set)));
|
|
|
|
sum_set.insert(ALL_OF(nodes_set));
|
|
sum_set.insert(ALL_OF(inputs_set));
|
|
sum_set.insert(ALL_OF(output_set));
|
|
}
|
|
void check_graph() {
|
|
node_set all_nodes_set(
|
|
ALL_OF(g.all_nodes)); // NB: all_nodes is *unordered*
|
|
|
|
check_block(g.block_);
|
|
for (auto kv : anticipated_uses) {
|
|
AT_ASSERT(kv.second == -1);
|
|
}
|
|
AT_ASSERT(std::includes(ALL_OF(sum_set), ALL_OF(all_nodes_set)));
|
|
}
|
|
};
|
|
LintImpl(*this).check_graph();
|
|
}
|
|
|
|
void Graph::dump() const {
|
|
std::cout << *this << "\n";
|
|
}
|
|
|
|
void Graph::push_scope(const std::string& scope_name) {
|
|
current_scope_ = current_scope_->push(Symbol::scope(scope_name));
|
|
Node* block_node = insertNode(create(prim::TracedModuleForward, 0));
|
|
block_node->s_(attr::scope, scope_name);
|
|
Block* b = block_node->addBlock();
|
|
setInsertPoint(b);
|
|
}
|
|
void Graph::pop_scope() {
|
|
current_scope_ = current_scope_->parent();
|
|
if (insertPoint()->owningBlock()->owningNode()->kind() ==
|
|
prim::TracedModuleForward) {
|
|
setInsertPoint(insertPoint()->owningBlock()->owningNode()->next());
|
|
}
|
|
}
|
|
|
|
void LintGraph(const std::shared_ptr<Graph>& graph) {
|
|
graph->lint();
|
|
}
|
|
|
|
Block::Block(Graph* graph_, Node* node_)
|
|
: graph_(graph_),
|
|
output_(graph_->create(prim::Return, 0)),
|
|
input_(graph_->create(prim::Param, 0)),
|
|
owning_node_(node_) {
|
|
input_->next() = output_;
|
|
input_->prev() = output_;
|
|
output_->next() = input_;
|
|
output_->prev() = input_;
|
|
|
|
graph_->all_blocks.emplace(this);
|
|
output_->owning_block_ = this;
|
|
output_->topo_position_ = kUpperBound;
|
|
input_->owning_block_ = this;
|
|
input_->topo_position_ = kLowerBound;
|
|
}
|
|
|
|
void Block::reIndexTopology() {
|
|
auto curPos = kLowerBound;
|
|
for (auto node : nodes()) {
|
|
AT_ASSERT(curPos <= (kUpperBound - kAppendInterval));
|
|
curPos += kAppendInterval;
|
|
node->topo_position_ = curPos;
|
|
}
|
|
}
|
|
|
|
void Block::cloneFrom(Block* src, std::function<Value*(Value*)> value_map) {
|
|
std::unordered_map<Value*, Value*> local_map;
|
|
auto env = [&](Value* v) {
|
|
auto it = local_map.find(v);
|
|
if (it != local_map.end()) {
|
|
return it->second;
|
|
}
|
|
return value_map(v);
|
|
};
|
|
|
|
auto graph = owningGraph();
|
|
for (auto input : src->inputs()) {
|
|
local_map[input] = this->addInput()->copyMetadata(input);
|
|
}
|
|
|
|
for (auto node : src->nodes()) {
|
|
auto new_node = this->appendNode(graph->createClone(node, env));
|
|
for (size_t i = 0; i < node->outputs().size(); ++i) {
|
|
auto oo = node->outputs()[i];
|
|
auto no = new_node->outputs()[i];
|
|
local_map[oo] = no;
|
|
no->copyMetadata(oo);
|
|
}
|
|
}
|
|
for (auto output : src->outputs()) {
|
|
this->registerOutput(env(output));
|
|
}
|
|
}
|
|
|
|
void Block::destroy() {
|
|
// we cannot destroy the output because it is used as the sentinel
|
|
// for the nodes() list and has to remain valid for the loop
|
|
output_->removeAllInputs();
|
|
for (auto it = this->nodes().reverse().begin(),
|
|
end = this->nodes().reverse().end();
|
|
it != end;
|
|
++it) {
|
|
it.destroyCurrent();
|
|
}
|
|
output_->destroy();
|
|
input_->destroy();
|
|
graph_->freeBlock(this);
|
|
}
|
|
|
|
void Graph::cloneFrom(Graph& src) {
|
|
auto env = [](Value* v) -> Value* {
|
|
AT_ERROR(
|
|
"Graph::copy() encountered a use of a value " + v->debugName() +
|
|
" not in scope. Run lint!");
|
|
};
|
|
block()->cloneFrom(src.block(), env);
|
|
}
|
|
|
|
std::shared_ptr<Graph> Graph::copy() {
|
|
auto new_g = std::make_shared<Graph>();
|
|
new_g->cloneFrom(*this);
|
|
return new_g;
|
|
}
|
|
|
|
std::unique_ptr<Graph> Graph::copyUnique() {
|
|
auto new_g = std::make_unique<Graph>();
|
|
new_g->cloneFrom(*this);
|
|
return new_g;
|
|
}
|
|
|
|
void Block::remapTypes(const std::function<TypePtr(TypePtr)>& type_map) {
|
|
for (Value* input : inputs()) {
|
|
input->setType(type_map(input->type()));
|
|
}
|
|
for (Node* node : nodes()) {
|
|
for (Value* output : node->outputs()) {
|
|
output->setType(type_map(output->type()));
|
|
}
|
|
for (Block* sub_block : node->blocks()) {
|
|
sub_block->remapTypes(type_map);
|
|
}
|
|
for (Symbol name : node->attributeNames()) {
|
|
if (node->kindOf(name) == AttributeKind::g) {
|
|
node->g(name)->remapTypes(type_map);
|
|
} else if (node->kindOf(name) == AttributeKind::gs) {
|
|
for (const auto& g : node->gs(name)) {
|
|
g->remapTypes(type_map);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void Graph::remapTypes(const std::function<TypePtr(TypePtr)>& type_map) {
|
|
block()->remapTypes(type_map);
|
|
}
|
|
|
|
void Value::inferTypeFrom(const at::Tensor& output) {
|
|
setType(TensorType::create(output));
|
|
}
|
|
|
|
void Value::inferTypeFrom(
|
|
const c10::intrusive_ptr<c10::ivalue::Object>& output) {
|
|
setType(output->type());
|
|
}
|
|
|
|
bool Value::mustBeNone() const {
|
|
return type()->cast<NoneType>() || node_->mustBeNone();
|
|
}
|
|
bool Value::mustNotBeNone() const {
|
|
return node_->kind() != prim::AutogradAdd && type() != NoneType::get() &&
|
|
!type()->cast<OptionalType>() &&
|
|
!(type()->cast<UnionType>() &&
|
|
type()->expect<UnionType>()->canHoldType(*NoneType::get()));
|
|
}
|
|
|
|
std::string Value::debugNameBase() const {
|
|
std::string name = debugName();
|
|
std::string name_base = name;
|
|
auto last_dot_pos = name.find_last_of('.');
|
|
if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) {
|
|
if (name.find_first_not_of("0123456789", last_dot_pos + 1) ==
|
|
std::string::npos) {
|
|
name_base = name.substr(0, last_dot_pos);
|
|
}
|
|
}
|
|
return name_base;
|
|
}
|
|
|
|
bool Value::isValidName(const std::string& name) {
|
|
// Empty strings are legal
|
|
if (name.empty()) {
|
|
return true;
|
|
}
|
|
|
|
// Numbers are not legal
|
|
if (isNumber(name)) {
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
Value* Value::setDebugName(const std::string& name) {
|
|
if (!isValidName(name)) {
|
|
throw std::runtime_error("Invalid name: '" + name + "'");
|
|
}
|
|
|
|
auto& names = node()->owningGraph()->unique_names_;
|
|
|
|
// clear any old name from the map
|
|
if (hasDebugName()) {
|
|
names.erase(unique_name_);
|
|
unique_name_ = "";
|
|
}
|
|
|
|
// allow "" to clear the uniquename
|
|
if (name.empty()) {
|
|
return this;
|
|
}
|
|
|
|
// if someone else has this name, then rename the other value
|
|
auto old_owner_of_name = names.find(name);
|
|
if (old_owner_of_name != names.end()) {
|
|
size_t suffix = 1;
|
|
std::string name_base = name;
|
|
auto last_dot_pos = name.find_last_of('.');
|
|
if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) {
|
|
if (name.find_first_not_of("0123456789", last_dot_pos + 1) ==
|
|
std::string::npos) {
|
|
suffix = std::stoll(name.substr(last_dot_pos + 1));
|
|
name_base = name.substr(0, last_dot_pos);
|
|
}
|
|
}
|
|
|
|
auto& names_suffixes = node()->owningGraph()->name_base_suffix_;
|
|
auto it = names_suffixes.find(name_base);
|
|
if (it != names_suffixes.end()) {
|
|
suffix = std::max(suffix, it->second + 1);
|
|
}
|
|
|
|
// Verify that new name is not used and find next usable name in case
|
|
// suffix is used.
|
|
std::string replacement_name;
|
|
do {
|
|
std::stringstream ss;
|
|
#ifndef _WIN32
|
|
// Protect 12345 integer from becoming "1,2345" if some other process sets
|
|
// global locale For more details see
|
|
// https://github.com/pytorch/pytorch/issues/79583#issuecomment-1161260061
|
|
static std::locale c_locale("C");
|
|
ss.imbue(c_locale);
|
|
#endif
|
|
ss << name_base << "." << suffix++;
|
|
replacement_name = ss.str();
|
|
} while (names.count(replacement_name) > 0);
|
|
|
|
names_suffixes[name_base] = suffix;
|
|
|
|
old_owner_of_name->second->setDebugName(replacement_name);
|
|
}
|
|
|
|
names[name] = this;
|
|
unique_name_ = name;
|
|
return this;
|
|
}
|
|
|
|
Value* Value::copyMetadata(Value* from) {
|
|
setType(from->type());
|
|
if (from->hasDebugName()) {
|
|
setDebugName(from->debugName());
|
|
}
|
|
return this;
|
|
}
|
|
|
|
void Value::replaceFirstUseWith(Value* newValue) {
|
|
AT_ASSERT(owningGraph() == newValue->owningGraph());
|
|
auto u = uses()[0];
|
|
u.user->inputs_[u.offset] = newValue;
|
|
newValue->uses_.push_back(u);
|
|
uses_.erase(uses_.begin());
|
|
}
|
|
|
|
void Value::replaceAllUsesWith(Value* newValue) {
|
|
while (!uses().empty()) {
|
|
replaceFirstUseWith(newValue);
|
|
}
|
|
}
|
|
|
|
void Value::replaceAllUsesAfterNodeWith(const Node* node, Value* newValue) {
|
|
std::for_each(uses_.begin(), uses_.end(), [&node, newValue](Use& u) {
|
|
if (u.user->isAfter(node)) {
|
|
u.user->inputs_[u.offset] = newValue;
|
|
newValue->uses_.push_back(u);
|
|
}
|
|
});
|
|
|
|
uses_.erase(
|
|
std::remove_if(
|
|
uses_.begin(),
|
|
uses_.end(),
|
|
[&node](const Use& u) { return u.user->isAfter(node); }),
|
|
uses_.end());
|
|
}
|
|
|
|
void Value::replaceAllUsesDominatedByNodeWith(
|
|
const Node* node,
|
|
Value* newValue) {
|
|
std::for_each(uses_.begin(), uses_.end(), [&node, newValue](Use& u) {
|
|
if (u.user->isDominatedBy(node)) {
|
|
u.user->inputs_[u.offset] = newValue;
|
|
newValue->uses_.push_back(u);
|
|
}
|
|
});
|
|
|
|
uses_.erase(
|
|
std::remove_if(
|
|
uses_.begin(),
|
|
uses_.end(),
|
|
[&node](const Use& u) { return u.user->isDominatedBy(node); }),
|
|
uses_.end());
|
|
}
|
|
|
|
static size_t findArgument(
|
|
const FunctionSchema& the_schema,
|
|
const std::string& unqualName) {
|
|
for (const auto i : c10::irange(the_schema.arguments().size())) {
|
|
const Argument* arg = &the_schema.arguments()[i];
|
|
if (arg->name() == unqualName) {
|
|
return i;
|
|
}
|
|
}
|
|
throw std::runtime_error(
|
|
std::string("Couldn't find an argument called ") + unqualName);
|
|
}
|
|
|
|
static size_t findArgument(const FunctionSchema& the_schema, Symbol name) {
|
|
const auto unqualName = name.toUnqualString();
|
|
return findArgument(the_schema, unqualName);
|
|
}
|
|
|
|
std::optional<IValue> Node::get(Symbol name) const {
|
|
return toIValue(namedInput(name));
|
|
}
|
|
|
|
bool Node::hasNamedInput(const std::string& name) const {
|
|
for (const auto& argument : schema().arguments()) {
|
|
if (argument.name() == name) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
Value* Node::namedInput(const std::string& unqualName) const {
|
|
return input(findArgument(schema(), unqualName));
|
|
}
|
|
Value* Node::namedInput(Symbol name) const {
|
|
return input(findArgument(schema(), name));
|
|
}
|
|
|
|
bool Node::matches(const FunctionSchema& schema) const {
|
|
if (isBlockListedSchema(schema)) {
|
|
return false;
|
|
}
|
|
// wrong name
|
|
if (kind().toQualString() != schema.name()) {
|
|
return false;
|
|
}
|
|
at::ArrayRef<const Value*> actuals = inputs();
|
|
const auto& formals = schema.arguments();
|
|
|
|
// not enough inputs
|
|
if (actuals.size() < formals.size()) {
|
|
return false;
|
|
}
|
|
|
|
TypeEnv type_env;
|
|
for (const auto i : c10::irange(formals.size())) {
|
|
auto formal = formals[i].type();
|
|
const MatchTypeReturn matched_type =
|
|
matchTypeVariables(formal, actuals[i]->type(), type_env);
|
|
if (!matched_type.success()) {
|
|
return false;
|
|
}
|
|
|
|
TypePtr resolved = tryEvalTypeVariables(formal, type_env);
|
|
if (resolved) {
|
|
formal = resolved;
|
|
}
|
|
// note: it is possible at this point that type variable matching has
|
|
// not resolved all type variables, e.g. if None was matched to Optional[T]
|
|
// we will not succeed at matching T. However None <: Optional[T] so this
|
|
// check can still succeed.
|
|
|
|
if (!actuals[i]->type()->isSubtypeOf(*formal)) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// too many inputs
|
|
if (!schema.is_vararg() && actuals.size() != formals.size()) {
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool Node::matches(
|
|
const char* signature_literal,
|
|
at::ArrayRef<Symbol> const_inputs) const {
|
|
if (!matches(getOperatorForLiteral(signature_literal)->schema())) {
|
|
return false;
|
|
}
|
|
for (Symbol s : const_inputs) {
|
|
if (!is_constant(s)) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool Node::mustBeNone() const {
|
|
// We can statically deduce this Node has returning None if:
|
|
return
|
|
// It's an AutogradZero node, or ...
|
|
kind_ == prim::AutogradZero ||
|
|
// It has only one output and that output is NoneType, or ...
|
|
(outputs().size() == 1 && output()->type() == NoneType::get()) ||
|
|
// It's a constant optional with no value in the attributes.
|
|
(kind_ == prim::Constant && !this->hasAttributes() &&
|
|
output()->type()->cast<OptionalType>());
|
|
}
|
|
|
|
void Node::dump() const {
|
|
std::cout << *this << "\n";
|
|
}
|
|
|
|
const FunctionSchema& Node::schema() const {
|
|
if (op_) {
|
|
return op_->schema();
|
|
}
|
|
return getOperator().schema();
|
|
}
|
|
|
|
const FunctionSchema* Node::maybeSchema() const {
|
|
if (auto op = maybeOperator()) {
|
|
return &op->schema();
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
const Operator* Node::maybeOperator() const {
|
|
if (!op_) {
|
|
const auto& candidates = getAllOperatorsFor(kind());
|
|
for (const auto& candidate : candidates) {
|
|
if (matches(candidate->schema())) {
|
|
op_ = candidate.get();
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
return op_;
|
|
}
|
|
|
|
const Operator& Node::getOperator() const {
|
|
const Operator* maybe = maybeOperator();
|
|
if (maybe)
|
|
return *maybe;
|
|
|
|
auto er = ErrorReport(sourceRange());
|
|
er << "Schema not found for node. File a bug report.\n";
|
|
er << "Node: " << *this << "\n";
|
|
er << "Input types:";
|
|
for (const auto i : c10::irange(inputs().size())) {
|
|
if (i > 0)
|
|
er << ", ";
|
|
er << *inputs()[i]->type();
|
|
}
|
|
const auto& candidates = getAllOperatorsFor(kind());
|
|
if (!candidates.empty()) {
|
|
er << "\ncandidates were:\n";
|
|
for (auto& candidate : candidates) {
|
|
er << " " << candidate->schema() << "\n";
|
|
}
|
|
} else {
|
|
er << "\nno candidates found\n";
|
|
}
|
|
er << "within the graph:\n";
|
|
er << *owningGraph() << "\n";
|
|
throw er;
|
|
}
|
|
|
|
Operation Node::getOperation() const {
|
|
// note: some operators require the node to produce a runnable operation,
|
|
// which is why 'this' is passed here. getOperator() ensures that 'this'
|
|
// matches the schema of the returned operator.
|
|
return getOperator().getOperation(this);
|
|
}
|
|
|
|
bool Node::isNondeterministic() const {
|
|
const auto schema = maybeSchema();
|
|
if (!kind().is_aten()) {
|
|
return false;
|
|
}
|
|
// All aten ops are expecte to have a schema. However this is left as a
|
|
// warning instead of an assert to ensure that previous use cases do not
|
|
// break.
|
|
if (!schema) {
|
|
TORCH_WARN("aten Schema not found.");
|
|
return false;
|
|
}
|
|
torch::utils::SchemaInfo schema_info(*schema);
|
|
if (hasNamedInput("train")) {
|
|
auto value = constant_as<bool>(namedInput("train"));
|
|
if (value.has_value()) {
|
|
schema_info.addArgumentValue("train", *value);
|
|
}
|
|
}
|
|
return schema_info.is_nondeterministic();
|
|
}
|
|
|
|
bool Node::hasSideEffects() const {
|
|
switch (kind_) {
|
|
case prim::PythonOp:
|
|
case prim::IgnoredPythonOp:
|
|
case prim::Print:
|
|
case prim::RaiseException:
|
|
case aten::warn:
|
|
case aten::save:
|
|
case aten::manual_seed:
|
|
case prim::AddStatValue:
|
|
case prim::TimePoint:
|
|
case prim::CallFunction:
|
|
case prim::CallMethod:
|
|
case prim::BailoutTemplate:
|
|
case prim::BailOut:
|
|
case prim::rpc_async: // It represents RPC message sent.
|
|
case prim::rpc_sync: // It represents RPC message sent.
|
|
case prim::rpc_remote: // It represents RPC message sent.
|
|
case aten::wait: // It can represent RPC message received.
|
|
#if !defined(USE_ROCM)
|
|
case cuda::set_stream:
|
|
case cuda::_set_device:
|
|
case cuda::_current_device:
|
|
case cuda::synchronize:
|
|
#endif
|
|
case prim::Enter:
|
|
case prim::Exit:
|
|
return true;
|
|
}
|
|
|
|
auto op = maybeOperator();
|
|
if (!op) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
kind_.is_prim(),
|
|
"Only prim ops are allowed to not have a registered operator but ",
|
|
kind_.toDisplayString(),
|
|
" doesn't have one either. We don't know if this op has side effects.");
|
|
return false;
|
|
}
|
|
|
|
if (kind_.is_prim() || kind_.is_aten() || kind_.is_cuda()) {
|
|
// TODO There is nothing in the system that relies on aten:: and prim::
|
|
// ops using AliasAnalysisKind::FROM_SCHEMA,
|
|
// AliasAnalysisKind::INTERNAL_SPECIAL_CASE, or
|
|
// AliasAnalysisKind::CONSERVATIVE but this is the intended behavior for all
|
|
// current ops and a good error check. We can consider lifting this
|
|
// constraint later if we have a use case for it.
|
|
TORCH_INTERNAL_ASSERT(
|
|
op->aliasAnalysisKind() == AliasAnalysisKind::INTERNAL_SPECIAL_CASE ||
|
|
op->aliasAnalysisKind() == AliasAnalysisKind::FROM_SCHEMA ||
|
|
op->aliasAnalysisKind() == AliasAnalysisKind::CONSERVATIVE,
|
|
"aten:: and prim:: ops should have AliasAnalysisKind::INTERNAL_SPECIAL_CASE"
|
|
", AliasAnalysisKind::FROM_SCHEMA or AliasAnalysisKind::CONSERVATIVE but ",
|
|
kind_.toDisplayString(),
|
|
" has ",
|
|
toString(op->aliasAnalysisKind()));
|
|
}
|
|
|
|
switch (op->aliasAnalysisKind()) {
|
|
case AliasAnalysisKind::PURE_FUNCTION:
|
|
case AliasAnalysisKind::FROM_SCHEMA:
|
|
case AliasAnalysisKind::INTERNAL_SPECIAL_CASE:
|
|
return false;
|
|
case AliasAnalysisKind::CONSERVATIVE:
|
|
return true;
|
|
}
|
|
TORCH_INTERNAL_ASSERT(false, "Unhandled AliasAnalysisKind case");
|
|
return false; // silence compiler warning
|
|
}
|
|
|
|
// Assign this node a topological position, to facilitate fast isBefore() and
|
|
// isAfter() queries. Must be called right after a node is inserted into the
|
|
// node list.
|
|
//
|
|
// The basic scheme is: assign every node a position (uint64_t). The common
|
|
// case (appending to the end of the graph) is made more efficient by advancing
|
|
// a fixed interval past the previous node and placing `this` there. Otherwise,
|
|
// assign `this` a position at the midpoint between its prev() and next()
|
|
// nodes.
|
|
//
|
|
// If we ever run out of space (by, e.g. inserting too much in place), we
|
|
// reindex by spreading out all the nodes again.
|
|
void Node::assignTopoPosition() {
|
|
bool is_first = prev() == owningBlock()->param_node();
|
|
bool is_last = next() == owningBlock()->return_node();
|
|
|
|
const auto prevPos = prev()->topo_position_;
|
|
const auto nextPos = next()->topo_position_;
|
|
|
|
// Append to the end of the graph
|
|
if (is_last) {
|
|
if (is_first) {
|
|
// the node list is empty, assign the first position
|
|
topo_position_ = kMidPoint;
|
|
return;
|
|
}
|
|
|
|
if (prevPos >= (kUpperBound - kAppendInterval)) {
|
|
// we're running off the edge
|
|
owningBlock()->reIndexTopology();
|
|
return;
|
|
}
|
|
|
|
topo_position_ = prevPos + kAppendInterval;
|
|
|
|
// Prepend to the graph
|
|
} else if (is_first) {
|
|
// next() is the first element in the block list
|
|
if (nextPos <= (kLowerBound + kAppendInterval)) {
|
|
// we're running off the edge
|
|
owningBlock()->reIndexTopology();
|
|
return;
|
|
}
|
|
topo_position_ = nextPos - kAppendInterval;
|
|
|
|
// insert between two existing nodes
|
|
} else {
|
|
int64_t remaining = nextPos - prevPos;
|
|
AT_ASSERT(remaining > 0);
|
|
if (remaining == 1) {
|
|
// There was no room
|
|
owningBlock()->reIndexTopology();
|
|
return;
|
|
}
|
|
int64_t predicted_future_insertions = 0;
|
|
if (next() == graph_->insertPoint()) {
|
|
predicted_future_insertions = graph_->predicted_insert_count_++;
|
|
}
|
|
topo_position_ = prevPos +
|
|
std::max(int64_t(1), remaining / (2 + predicted_future_insertions));
|
|
AT_ASSERT(prevPos < topo_position_ && topo_position_ < nextPos);
|
|
}
|
|
}
|
|
|
|
Node::Node(Graph* graph_, NodeKind kind_)
|
|
: kind_(kind_),
|
|
graph_(graph_),
|
|
owning_block_(nullptr),
|
|
scope_(graph_->current_scope_),
|
|
callstack_(c10::nullopt),
|
|
op_(nullptr),
|
|
topo_position_(0) {
|
|
graph_->all_nodes.emplace(this);
|
|
}
|
|
|
|
void Node::eraseOutput(size_t i) {
|
|
AT_ASSERT(i < outputs_.size());
|
|
AT_ASSERT(outputs_[i]->uses().empty());
|
|
op_ = nullptr;
|
|
Value* n = outputs_[i];
|
|
outputs_.erase(outputs_.begin() + i);
|
|
owningGraph()->freeValue(n);
|
|
for (const auto j : c10::irange(i, outputs_.size())) {
|
|
outputs_[j]->offset_--;
|
|
}
|
|
}
|
|
|
|
Block* Node::addBlock() {
|
|
op_ = nullptr;
|
|
blocks_.push_back(new Block(owningGraph(), this));
|
|
return blocks_.back();
|
|
}
|
|
|
|
void Node::eraseBlock(size_t i) {
|
|
AT_ASSERT(i < blocks_.size());
|
|
op_ = nullptr;
|
|
Block* n = blocks_[i];
|
|
blocks_.erase(blocks_.begin() + i);
|
|
n->destroy();
|
|
}
|
|
|
|
void Node::destroy() {
|
|
while (!outputs().empty()) {
|
|
eraseOutput(outputs().size() - 1);
|
|
}
|
|
while (!blocks().empty()) {
|
|
eraseBlock(blocks().size() - 1);
|
|
}
|
|
removeAllInputs();
|
|
if (inBlockList()) {
|
|
removeFromList();
|
|
}
|
|
graph_->freeNode(this);
|
|
}
|
|
|
|
void Node::cloneFrom(Node* s) {
|
|
source_range_ = s->source_range_;
|
|
if (s->scope_ && !s->scope_->isBlank()) {
|
|
scope_ = s->scope_;
|
|
}
|
|
copyAttributes(*s);
|
|
callstack_ = s->callstack_;
|
|
}
|
|
|
|
void Node::replaceAllUsesWith(Node* n) {
|
|
AT_ASSERT(outputs().size() == n->outputs().size());
|
|
size_t nOutputs = outputs().size();
|
|
for (const auto i : c10::irange(nOutputs)) {
|
|
outputs()[i]->replaceAllUsesWith(n->outputs()[i]);
|
|
}
|
|
}
|
|
|
|
Node* Node::replaceWithNewSymbol(Symbol new_symbol) {
|
|
WithInsertPoint insert_guard{this};
|
|
bool had_operator = maybeOperator() != nullptr;
|
|
auto graph = owningGraph();
|
|
auto replace_node = graph->insertNode(graph->create(new_symbol, 0));
|
|
for (Value* v : inputs()) {
|
|
replace_node->addInput(v);
|
|
}
|
|
for (Value* v : outputs()) {
|
|
auto new_out = replace_node->addOutput()->copyMetadata(v);
|
|
v->replaceAllUsesWith(new_out);
|
|
}
|
|
replace_node->copyMetadata(this);
|
|
replace_node->copyAttributes(*this);
|
|
TORCH_INTERNAL_ASSERT(
|
|
(replace_node->maybeOperator() != nullptr) == had_operator,
|
|
"invalid symbol replacement:",
|
|
new_symbol,
|
|
kind());
|
|
return replace_node;
|
|
}
|
|
|
|
bool Node::isDominatedBy(const Node* dominator) const {
|
|
const Node* node = this;
|
|
while (node) {
|
|
if (node->owningBlock() == dominator->owningBlock()) {
|
|
return dominator->isBefore(node);
|
|
}
|
|
node = node->owningBlock()->owningNode();
|
|
}
|
|
return false;
|
|
}
|
|
|
|
Value* Node::insertInput(size_t i, Value* value) {
|
|
AT_ASSERT(graph_ == value->owningGraph());
|
|
op_ = nullptr;
|
|
// First we update the offsets for all existing inputs that will reside
|
|
// after the one we're inserting. Concretely, these are the inputs at
|
|
// indices [i, # input). Since we're inserting one input before all of
|
|
// these inputs, increment their use offsets for this value by 1
|
|
for (const auto use_itr : c10::irange(i, inputs_.size())) {
|
|
// See Note [User node does not uniquely identify use]
|
|
auto use = findUseForInput(use_itr);
|
|
use->offset += 1;
|
|
}
|
|
// Insert the actual input at the specified index
|
|
inputs_.insert(inputs_.begin() + i, value);
|
|
// Register the new use of the value we're inserted as an input.
|
|
value->uses_.emplace_back(this, i);
|
|
return value;
|
|
}
|
|
|
|
Value* Node::addInput(Value* value) {
|
|
AT_ASSERT(graph_ == value->owningGraph());
|
|
op_ = nullptr;
|
|
value->uses_.emplace_back(this, inputs_.size());
|
|
inputs_.push_back(value);
|
|
return value;
|
|
}
|
|
|
|
Value* Node::replaceInput(size_t i, Value* newValue) {
|
|
AT_ASSERT(newValue->owningGraph() == graph_);
|
|
op_ = nullptr;
|
|
Value* old = dropInput(i);
|
|
inputs_[i] = newValue;
|
|
newValue->uses_.emplace_back(this, i);
|
|
return old;
|
|
}
|
|
|
|
void Node::replaceInputWith(Value* from, Value* to) {
|
|
AT_ASSERT(from->owningGraph() == graph_);
|
|
AT_ASSERT(to->owningGraph() == graph_);
|
|
op_ = nullptr;
|
|
size_t i = 0;
|
|
for (auto input : inputs()) {
|
|
if (input == from) {
|
|
replaceInput(i, to);
|
|
}
|
|
i++;
|
|
}
|
|
}
|
|
|
|
Value* Node::addOutput() {
|
|
outputs_.push_back(new Value(this, outputs_.size()));
|
|
op_ = nullptr;
|
|
return outputs_.back();
|
|
}
|
|
|
|
Value* Node::insertOutput(size_t i) {
|
|
op_ = nullptr;
|
|
outputs_.insert(outputs_.begin() + i, new Value(this, i));
|
|
for (size_t itr = i + 1; itr < outputs_.size(); ++itr) {
|
|
outputs_[itr]->setOffset(outputs_[itr]->offset() + 1);
|
|
}
|
|
return outputs_.at(i);
|
|
}
|
|
|
|
bool Node::isBeforeOrAfter(const Node* n, MoveSide moveSide) const {
|
|
if (this->owningBlock() == n->owningBlock()) {
|
|
if (moveSide == MoveSide::BEFORE) {
|
|
return this->topo_position_ < n->topo_position_;
|
|
}
|
|
|
|
if (moveSide == MoveSide::AFTER) {
|
|
return this->topo_position_ > n->topo_position_;
|
|
}
|
|
|
|
AT_ASSERT(this == n);
|
|
return false;
|
|
}
|
|
|
|
// These nodes don't share a common block. Traverse the blockchains upward
|
|
// until we find the first common block.
|
|
auto lhs = this;
|
|
while (lhs) {
|
|
AT_ASSERT(lhs->owningBlock());
|
|
|
|
auto rhs = n;
|
|
while (rhs) {
|
|
if (!rhs->owningBlock()) {
|
|
break;
|
|
}
|
|
|
|
if (lhs->owningBlock() == rhs->owningBlock()) {
|
|
return lhs->isBeforeOrAfter(rhs, moveSide);
|
|
}
|
|
rhs = rhs->owningBlock()->owningNode();
|
|
}
|
|
|
|
lhs = lhs->owningBlock()->owningNode();
|
|
}
|
|
// should never reach here, since both nodes are ultimately in the same graph
|
|
AT_ASSERT(false);
|
|
}
|
|
|
|
bool Node::isBefore(const Node* n) const {
|
|
return isBeforeOrAfter(n, MoveSide::BEFORE);
|
|
}
|
|
|
|
bool Node::isAfter(const Node* n) const {
|
|
return isBeforeOrAfter(n, MoveSide::AFTER);
|
|
}
|
|
|
|
Node* Node::insertBefore(Node* n) {
|
|
AT_ASSERT(n->inBlockList());
|
|
insertAfter(n->prev());
|
|
return this;
|
|
}
|
|
|
|
Node* Node::insertAfter(Node* n) {
|
|
AT_ASSERT(!inBlockList() && n->inBlockList());
|
|
AT_ASSERT(n->owningBlock());
|
|
AT_ASSERTM(
|
|
n->kind() != prim::Return,
|
|
"Attempting to insert a Node after the Return node or before the Param node. Tried to insert",
|
|
*this,
|
|
" after ",
|
|
*n,
|
|
".");
|
|
this->owning_block_ = n->owningBlock();
|
|
Node* next = n->next();
|
|
n->next() = this;
|
|
this->prev() = n;
|
|
this->next() = next;
|
|
next->prev() = this;
|
|
assignTopoPosition();
|
|
return this;
|
|
}
|
|
|
|
void Node::moveAfter(Node* n) {
|
|
removeFromList();
|
|
insertAfter(n);
|
|
}
|
|
|
|
void Node::moveBefore(Node* n) {
|
|
removeFromList();
|
|
insertBefore(n);
|
|
}
|
|
|
|
void Node::removeInput(size_t i) {
|
|
op_ = nullptr;
|
|
dropInput(i);
|
|
// everything after this input shifts left,
|
|
// so we need to update their use offsets to match
|
|
for (size_t j = i + 1; j < inputs_.size(); j++) {
|
|
auto it = findUseForInput(j);
|
|
it->offset--;
|
|
}
|
|
inputs_.erase(inputs_.begin() + i);
|
|
}
|
|
|
|
void Node::removeAllInputs() {
|
|
op_ = nullptr;
|
|
for (const auto i : c10::irange(inputs().size())) {
|
|
dropInput(i);
|
|
}
|
|
inputs_.clear();
|
|
}
|
|
|
|
void Node::removeAllOutputs() {
|
|
op_ = nullptr;
|
|
size_t init_osize = outputs_.size();
|
|
for (auto i : c10::irange(init_osize)) {
|
|
eraseOutput(init_osize - i - 1);
|
|
}
|
|
}
|
|
|
|
void Node::permuteInputs(const std::vector<size_t>& new_order) {
|
|
op_ = nullptr;
|
|
AT_ASSERT(new_order.size() == inputs_.size());
|
|
std::vector<Value*> new_inputs;
|
|
new_inputs.reserve(new_order.size());
|
|
for (const auto i : c10::irange(new_order.size())) {
|
|
AT_ASSERTM(inputs_.at(new_order[i]) != nullptr, "Repeated index");
|
|
new_inputs.push_back(inputs_.at(new_order[i]));
|
|
auto it = findUseForInput(new_order[i]);
|
|
it->offset = i;
|
|
inputs_.at(new_order[i]) = nullptr;
|
|
}
|
|
inputs_ = std::move(new_inputs);
|
|
}
|
|
|
|
void Node::permuteOutputs(const std::vector<size_t>& new_order) {
|
|
op_ = nullptr;
|
|
AT_ASSERT(new_order.size() == outputs_.size());
|
|
std::vector<Value*> new_outputs;
|
|
new_outputs.reserve(new_order.size());
|
|
for (const auto i : c10::irange(new_order.size())) {
|
|
AT_ASSERTM(outputs_.at(new_order[i]) != nullptr, "Repeated index");
|
|
new_outputs.push_back(outputs_.at(new_order[i]));
|
|
outputs_.at(new_order[i])->setOffset(i);
|
|
outputs_.at(new_order[i]) = nullptr;
|
|
}
|
|
outputs_ = std::move(new_outputs);
|
|
}
|
|
|
|
use_list::iterator Node::findUseForInput(size_t i) {
|
|
auto& input_uses = inputs_[i]->uses_;
|
|
// O(N) on the use list, but unless we get nodes with +100 uses
|
|
// vector traversal still is probably faster than linked list
|
|
auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i));
|
|
AT_ASSERT(use_it != input_uses.end());
|
|
return use_it;
|
|
}
|
|
|
|
Value* Node::dropInput(size_t i) {
|
|
AT_ASSERT(i < inputs_.size());
|
|
auto input_node = inputs_[i];
|
|
auto use_it = findUseForInput(i);
|
|
input_node->uses_.erase(use_it);
|
|
inputs_[i] = nullptr;
|
|
return input_node;
|
|
}
|
|
|
|
void Node::removeFromList() {
|
|
AT_ASSERT(inBlockList());
|
|
this->owning_block_ = nullptr;
|
|
Node* next = this->next();
|
|
Node* prev = this->prev();
|
|
prev->next() = next;
|
|
next->prev() = prev;
|
|
this->next() = nullptr;
|
|
this->prev() = nullptr;
|
|
}
|
|
|
|
Block* Node::findCommonAncestorBlockWith(Node* n) {
|
|
if (n->owningBlock() == owningBlock()) {
|
|
return owningBlock();
|
|
}
|
|
|
|
Node* n1 = this;
|
|
Node* n2 = n;
|
|
|
|
size_t d_1 = n1->blocksFromGraphBlock();
|
|
size_t d_2 = n2->blocksFromGraphBlock();
|
|
|
|
for (; d_1 > d_2; --d_1) {
|
|
n1 = n1->owningBlock()->owningNode();
|
|
// n2 contains n1
|
|
}
|
|
|
|
for (; d_2 > d_1; --d_2) {
|
|
n2 = n2->owningBlock()->owningNode();
|
|
}
|
|
|
|
// Now they are the same numer of blocks from the graph block,
|
|
// recurse upwards, checking if they are on the same block
|
|
while (true) {
|
|
if (n1->owningBlock() == n2->owningBlock()) {
|
|
return n1->owningBlock();
|
|
}
|
|
|
|
n1 = n1->owningBlock()->owningNode();
|
|
n2 = n2->owningBlock()->owningNode();
|
|
|
|
AT_ASSERT(n1 != nullptr);
|
|
AT_ASSERT(n2 != nullptr);
|
|
}
|
|
}
|
|
|
|
size_t Node::blocksFromGraphBlock() {
|
|
Node* n = this;
|
|
size_t dist = 0;
|
|
while (n->owningBlock()->owningNode()) {
|
|
n = n->owningBlock()->owningNode();
|
|
++dist;
|
|
}
|
|
return dist;
|
|
}
|
|
|
|
inline const SourceRange& fakeRange() {
|
|
static SourceRange range(std::make_shared<Source>(std::string("")), 0, 1);
|
|
return range;
|
|
}
|
|
|
|
Value* Graph::insert(
|
|
Symbol opname,
|
|
at::ArrayRef<NamedValue> args,
|
|
at::ArrayRef<NamedValue> kwargs,
|
|
const std::optional<SourceRange>& range) {
|
|
return emitBuiltinCall(
|
|
range.value_or(fakeRange()), *this, opname, args, kwargs);
|
|
}
|
|
|
|
Node* Graph::create(NodeKind kind, size_t num_outputs) {
|
|
// NB: Node constructor adds node to all_nodes
|
|
auto n = new Node(this, kind);
|
|
for (const auto i : c10::irange(num_outputs)) {
|
|
(void)i;
|
|
n->addOutput();
|
|
}
|
|
return n;
|
|
}
|
|
|
|
Node* Graph::create(
|
|
NodeKind kind,
|
|
ArrayRef<Value*> inputs,
|
|
size_t num_outputs) {
|
|
auto n = create(kind, num_outputs);
|
|
for (auto i : inputs) {
|
|
n->addInput(i);
|
|
}
|
|
return n;
|
|
}
|
|
|
|
Node* Graph::createAutogradZero() {
|
|
return create(prim::AutogradZero);
|
|
}
|
|
|
|
Node* Graph::createNone() {
|
|
Node* n = create(prim::Constant);
|
|
n->output()->setType(NoneType::get());
|
|
return n;
|
|
}
|
|
|
|
Node* Graph::createUninitialized(TypePtr typ) {
|
|
Node* n = create(prim::Uninitialized);
|
|
n->output()->setType(std::move(typ));
|
|
return n;
|
|
}
|
|
|
|
Node* Graph::createWithSubgraph(Symbol kind) {
|
|
auto n = create(kind, 0);
|
|
n->g_(attr::Subgraph, std::make_shared<Graph>(current_scope()));
|
|
return n;
|
|
}
|
|
|
|
Node* Graph::createTuple(at::ArrayRef<Value*> values, TupleTypePtr tuple_type) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
!tuple_type || tuple_type->schema(),
|
|
"only pass tuple_type when creating a named tuple");
|
|
if (!tuple_type) {
|
|
auto types = fmap(values, [](Value* v) { return v->type(); });
|
|
tuple_type = TupleType::create(std::move(types));
|
|
}
|
|
auto n = create(prim::TupleConstruct, values);
|
|
|
|
n->output()->setType(tuple_type);
|
|
return n;
|
|
}
|
|
|
|
Node* Graph::createTupleUnpack(Value* v) {
|
|
TupleTypePtr tt = v->type()->expect<TupleType>();
|
|
auto n = create(prim::TupleUnpack, {v}, 0);
|
|
for (auto& element : tt->elements()) {
|
|
n->addOutput()->setType(element);
|
|
}
|
|
return n;
|
|
}
|
|
|
|
Node* Graph::createTupleIndex(
|
|
Value* tup,
|
|
Value* idx,
|
|
const TypePtr& output_type) {
|
|
auto n = create(prim::TupleIndex, {tup, idx});
|
|
n->output()->setType(output_type);
|
|
return n;
|
|
}
|
|
|
|
Node* Graph::createTupleSlice(
|
|
Value* tup,
|
|
int64_t beg,
|
|
int64_t step_size,
|
|
int64_t num_values) {
|
|
std::vector<Value*> new_vals;
|
|
TupleTypePtr tt = tup->type()->expect<TupleType>();
|
|
new_vals.reserve(num_values);
|
|
|
|
int64_t i = beg;
|
|
for (const auto j : c10::irange(num_values)) {
|
|
(void)j; // Suppress unused variable warning
|
|
auto idx = insertConstant(IValue(static_cast<int64_t>(i)));
|
|
auto tupleIndex = insertNode(createTupleIndex(tup, idx, tt->elements()[i]));
|
|
|
|
new_vals.push_back(tupleIndex->output());
|
|
i += step_size;
|
|
}
|
|
|
|
auto n = createTuple(new_vals);
|
|
return n;
|
|
}
|
|
|
|
Node* Graph::createEnumName(Value* e) {
|
|
e->type()->expect<EnumType>();
|
|
assert(e->type()->cast<EnumType>());
|
|
auto n = create(prim::EnumName, {e});
|
|
n->output()->setType(StringType::get());
|
|
return n;
|
|
}
|
|
|
|
Node* Graph::createEnumValue(Value* e) {
|
|
auto enum_type = e->type()->expect<EnumType>();
|
|
auto n = create(prim::EnumValue, {e});
|
|
n->output()->setType(enum_type->getValueType());
|
|
return n;
|
|
}
|
|
|
|
Node* Graph::createList(
|
|
const TypePtr& contained_type,
|
|
at::ArrayRef<Value*> values) {
|
|
auto n = create(prim::ListConstruct, values);
|
|
for (const auto& v : values) {
|
|
TORCH_CHECK(
|
|
v->type()->isSubtypeOf(*contained_type),
|
|
"Expected a list element that subtypes '",
|
|
contained_type->repr_str(),
|
|
"' but got an element of type '",
|
|
v->type()->repr_str(),
|
|
"'");
|
|
}
|
|
n->output()->setType(ListType::create(contained_type));
|
|
return n;
|
|
}
|
|
|
|
Node* Graph::createListUnpack(Value* v, size_t size) {
|
|
ListTypePtr list_type = v->type()->expect<ListType>();
|
|
TypePtr elem_type = list_type->getElementType();
|
|
auto n = create(prim::ListUnpack, {v}, 0);
|
|
for (const auto i : c10::irange(size)) {
|
|
(void)i; // Suppress unused variable warning
|
|
n->addOutput()->setType(elem_type);
|
|
}
|
|
return n;
|
|
}
|
|
|
|
Node* Graph::createDict(
|
|
const TypePtr& key_type,
|
|
const TypePtr& value_type,
|
|
at::ArrayRef<Value*> keys,
|
|
at::ArrayRef<Value*> values) {
|
|
AT_ASSERT(keys.size() == values.size());
|
|
auto n = create(prim::DictConstruct, 1);
|
|
for (const auto i : c10::irange(keys.size())) {
|
|
AT_ASSERT(keys[i]->type()->isSubtypeOf(*key_type));
|
|
AT_ASSERT(values[i]->type()->isSubtypeOf(*value_type));
|
|
|
|
n->addInput(keys[i]);
|
|
n->addInput(values[i]);
|
|
}
|
|
n->output()->setType(DictType::create(key_type, value_type));
|
|
return n;
|
|
}
|
|
|
|
Node* Graph::createNumToTensor(Value* value) {
|
|
Node* result = create(prim::NumToTensor, {value});
|
|
result->output()->setType(TensorType::fromNumberType(*value->type()));
|
|
return result;
|
|
}
|
|
|
|
Node* Graph::createObject(const ClassTypePtr& type) {
|
|
auto result = create(prim::CreateObject);
|
|
result->output()->setType(type);
|
|
return result;
|
|
}
|
|
|
|
Node* Graph::createSetAttr(
|
|
Value* obj,
|
|
const std::string& field,
|
|
Value* newValue) {
|
|
auto n = create(prim::SetAttr, {obj, newValue}, /*num_outputs=*/0);
|
|
n->s_(attr::name, field);
|
|
return n;
|
|
}
|
|
|
|
Node* Graph::createGetAttr(Value* obj, const std::string& field) {
|
|
const auto classType = obj->type()->expect<ClassType>();
|
|
|
|
auto n = create(prim::GetAttr, {obj}, /*num_outputs=*/1);
|
|
n->s_(attr::name, field);
|
|
|
|
const auto outputType = classType->getAttribute(field);
|
|
n->output()->setType(outputType);
|
|
n->output()->setDebugName(normalizeAttrName(field));
|
|
return n;
|
|
}
|
|
|
|
Node* Graph::createStore(const std::string& name, Value* v) {
|
|
auto n = create(prim::Store, {v}, /*num_outputs*/ 0);
|
|
n->s_(attr::name, name);
|
|
return n;
|
|
}
|
|
|
|
Node* Graph::createLoad(const std::string& name, const TypePtr& type) {
|
|
auto n = create(prim::Load, {}, /*num_outputs*/ 1);
|
|
n->s_(attr::name, name);
|
|
n->output()->setType(type);
|
|
return n;
|
|
}
|
|
|
|
Node* Graph::createIsInstance(Value* v, at::ArrayRef<TypePtr> types) {
|
|
auto n = create(prim::isinstance, {v}, /*num_outputs*/ 1);
|
|
n->tys_(attr::types, types.vec());
|
|
n->output()->setType(BoolType::get());
|
|
return n;
|
|
}
|
|
Value* Graph::insertUncheckedCast(Value* v, TypePtr type) {
|
|
Node* n = insertNode(create(prim::unchecked_cast, {v}));
|
|
n->output()->setType(std::move(type));
|
|
return n->output();
|
|
}
|
|
|
|
Value* Graph::insertToList(Value* v, TypePtr type) {
|
|
int dim = 0;
|
|
TypePtr ptr = type;
|
|
|
|
// Unwrap the type to determine the number of dimensions.
|
|
while (auto list_type = ptr->cast<ListType>()) {
|
|
ptr = list_type->getElementType();
|
|
++dim;
|
|
}
|
|
|
|
// Encode the base element type as an integer.
|
|
int elem_ty = 0;
|
|
if (ptr == IntType::get()) {
|
|
elem_ty = 0;
|
|
} else if (ptr == FloatType::get()) {
|
|
elem_ty = 1;
|
|
} else if (ptr == BoolType::get()) {
|
|
elem_ty = 2;
|
|
} else if (ptr == ComplexType::get()) {
|
|
elem_ty = 3;
|
|
} else {
|
|
TORCH_CHECK(
|
|
false,
|
|
ptr->repr_str(),
|
|
" is not one of the supported element types for tolist: int, float, complex, bool");
|
|
}
|
|
|
|
// Pass in the number of dimensions and base element type as arguments
|
|
// to the op.
|
|
Value* dim_val = insertConstant(IValue(dim));
|
|
Value* elem_ty_val = insertConstant(IValue(elem_ty));
|
|
Node* n = insertNode(create(prim::tolist, {v, dim_val, elem_ty_val}));
|
|
n->output()->setType(std::move(type));
|
|
return n->output();
|
|
}
|
|
|
|
Value* Graph::insertFunctionCall(
|
|
Function* callee,
|
|
const MatchedSchema& matched) {
|
|
std::string func_name = callee->name();
|
|
Value* fn_constant = insertNode(create(prim::Constant))
|
|
->s_(attr::name, func_name)
|
|
->output()
|
|
->setType(FunctionType::create(callee));
|
|
std::vector<Value*> inputs = {fn_constant};
|
|
inputs.insert(inputs.end(), matched.inputs.begin(), matched.inputs.end());
|
|
Value* result = insertNode(create(prim::CallFunction, inputs))
|
|
->output()
|
|
->setType(matched.return_types.at(0));
|
|
return result;
|
|
}
|
|
|
|
Value* Graph::insertMethodCall(
|
|
std::string method_name,
|
|
const MatchedSchema& matched) {
|
|
Value* result = insertNode(create(prim::CallMethod, matched.inputs))
|
|
->s_(attr::name, std::move(method_name))
|
|
->output()
|
|
->setType(matched.return_types.at(0));
|
|
return result;
|
|
}
|
|
|
|
Node* Graph::createClone(
|
|
Node* n,
|
|
const std::function<Value*(Value*)>& value_map,
|
|
bool copy_blocks) {
|
|
// n can be from a different graph
|
|
Node* r = n->allocNewInstance(this);
|
|
for (auto o : n->outputs()) {
|
|
r->addOutput()->copyMetadata(o);
|
|
}
|
|
r->cloneFrom(n);
|
|
for (auto i : n->inputs()) {
|
|
r->addInput(value_map(i));
|
|
}
|
|
if (copy_blocks) {
|
|
for (auto b : n->blocks()) {
|
|
r->addBlock()->cloneFrom(b, value_map);
|
|
}
|
|
}
|
|
return r;
|
|
}
|
|
|
|
Value* Graph::insertConstant(
|
|
const IValue& val,
|
|
std::optional<SourceRange> loc,
|
|
std::optional<ScopePtr> scope) {
|
|
return jit::insertConstant(*this, val, std::move(loc), std::move(scope));
|
|
}
|
|
|
|
std::string Graph::toString(bool print_source_locations) const {
|
|
std::ostringstream oss;
|
|
print(oss, print_source_locations);
|
|
return oss.str();
|
|
}
|
|
|
|
Graph::~Graph() {
|
|
for (const Node* n : all_nodes) {
|
|
delete n;
|
|
}
|
|
for (const Value* v : all_values) {
|
|
delete v;
|
|
}
|
|
for (const Block* b : all_blocks) {
|
|
delete b;
|
|
}
|
|
}
|
|
|
|
void Graph::freeNode(Node* n) {
|
|
auto it = all_nodes.find(n);
|
|
AT_ASSERT(it != all_nodes.end());
|
|
delete *it;
|
|
all_nodes.erase(it);
|
|
}
|
|
void Graph::freeValue(Value* v) {
|
|
v->setDebugName("");
|
|
auto it = all_values.find(v);
|
|
AT_ASSERT(it != all_values.end());
|
|
delete *it;
|
|
all_values.erase(it);
|
|
}
|
|
void Graph::freeBlock(Block* b) {
|
|
auto it = all_blocks.find(b);
|
|
AT_ASSERT(it != all_blocks.end());
|
|
delete *it;
|
|
all_blocks.erase(it);
|
|
}
|
|
|
|
at::ArrayRef<Value*> createTupleUnpack(Value* v) {
|
|
// small peephole optimization to ensure IntArrayRef attributes can still turn
|
|
// into constants e.g. in x.expand([3, 4])
|
|
if (v->node()->kind() == prim::TupleConstruct) {
|
|
return v->node()->inputs();
|
|
}
|
|
auto& g = *v->owningGraph();
|
|
return g.insertNode(g.createTupleUnpack(v))->outputs();
|
|
}
|
|
|
|
void inlineCallStackOfNode(
|
|
Node* n,
|
|
std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries,
|
|
Function* callee,
|
|
Node* to_replace,
|
|
std::optional<ModuleInstanceInfo> m_info);
|
|
|
|
static void inlineCallStackOfBlock(
|
|
Block* b,
|
|
std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries,
|
|
Function* callee,
|
|
Node* to_replace,
|
|
std::optional<ModuleInstanceInfo> m_info) {
|
|
for (auto n : b->nodes()) {
|
|
inlineCallStackOfNode(n, new_cs_entries, callee, to_replace, m_info);
|
|
}
|
|
}
|
|
|
|
void inlineCallStackOfNode(
|
|
Node* new_node,
|
|
std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries,
|
|
Function* callee,
|
|
Node* to_replace,
|
|
std::optional<ModuleInstanceInfo> m_info) {
|
|
auto new_node_cs = new_node->callstack();
|
|
|
|
InlinedCallStack* raw_callstack_ptr =
|
|
new_node_cs ? new_node_cs->get() : nullptr;
|
|
|
|
if (!new_cs_entries.count(raw_callstack_ptr)) {
|
|
if (new_node_cs) {
|
|
new_cs_entries[raw_callstack_ptr] = c10::make_intrusive<InlinedCallStack>(
|
|
*new_node_cs, callee, to_replace->sourceRange(), m_info);
|
|
} else {
|
|
new_cs_entries[raw_callstack_ptr] = c10::make_intrusive<InlinedCallStack>(
|
|
callee, to_replace->sourceRange(), m_info);
|
|
}
|
|
}
|
|
new_node->setCallStack(new_cs_entries.at(raw_callstack_ptr));
|
|
// We updated the inlined callstack of new_node.
|
|
// Same must be done for the nodes of the blocks of new_node.
|
|
// For example If node's block otherwise is not annotated appropriately.
|
|
for (auto block : new_node->blocks()) {
|
|
inlineCallStackOfBlock(block, new_cs_entries, callee, to_replace, m_info);
|
|
}
|
|
}
|
|
|
|
std::vector<Value*> inlineCallTo(
|
|
Node* to_replace,
|
|
GraphFunction* callee,
|
|
Graph* callee_graph) {
|
|
WithInsertPoint guard(to_replace);
|
|
std::unordered_map<Value*, Value*> value_map;
|
|
std::vector<torch::jit::Value*> new_outputs = insertGraph(
|
|
*to_replace->owningGraph(),
|
|
*callee_graph,
|
|
to_replace->inputs(),
|
|
value_map);
|
|
|
|
std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>
|
|
new_callstack_entries;
|
|
|
|
std::optional<ModuleInstanceInfo> module_instance_info = c10::nullopt;
|
|
if (to_replace->kind() == prim::CallMethod) {
|
|
auto class_type_ptr = to_replace->input(0)->type()->cast<c10::ClassType>();
|
|
if (to_replace->input(0)->node()->kind() == prim::GetAttr) {
|
|
module_instance_info = c10::make_optional(ModuleInstanceInfo(
|
|
class_type_ptr, to_replace->input(0)->node()->s(attr::name)));
|
|
} else if (
|
|
!to_replace->owningGraph()->inputs().empty() &&
|
|
to_replace->input(0) == to_replace->owningGraph()->inputs()[0]) {
|
|
// This CallMethod must correspond to method of the same object
|
|
// to which this graph belongs.
|
|
module_instance_info =
|
|
c10::make_optional(ModuleInstanceInfo(class_type_ptr, "SELF"));
|
|
} else {
|
|
// Not sure if it is possible to come here ever.
|
|
// TODO: Remove this else. Or add assert
|
|
module_instance_info = c10::make_optional(
|
|
ModuleInstanceInfo(class_type_ptr, "INSTANCE_NAME_UNKNOWN"));
|
|
}
|
|
}
|
|
|
|
// TODO: We might need to use nodes_map instead of value_map. Otherwise, we
|
|
// are missing nodes without outputs (e.g. prim::Print).
|
|
std::unordered_set<Node*> updated_nodes;
|
|
for (const auto& kv : value_map) {
|
|
/* Skip the old value if it is the graph input.
|
|
* The reason is that, value_map contains values not all for the nodes of
|
|
* the graph but primary inputs as well, and it will create duplicates when
|
|
* the first inlined graph is input to the next one. To avoid this issue,
|
|
* skip the old value when it is one of the
|
|
* callee->optimized_graph()->inputs() or callee->graph()->inputs(), depends
|
|
* on if it is inlined_optimized_graph
|
|
*/
|
|
auto is_graph_input = std::find(
|
|
callee_graph->inputs().begin(), callee_graph->inputs().end(), kv.first);
|
|
if (is_graph_input != callee_graph->inputs().end()) {
|
|
continue;
|
|
}
|
|
|
|
Node* new_node = kv.second->node();
|
|
if (!updated_nodes.insert(new_node).second) {
|
|
continue;
|
|
}
|
|
|
|
inlineCallStackOfNode(
|
|
new_node,
|
|
new_callstack_entries,
|
|
callee,
|
|
to_replace,
|
|
module_instance_info);
|
|
}
|
|
const auto& old_outputs = to_replace->outputs();
|
|
|
|
AT_ASSERT(new_outputs.size() == old_outputs.size());
|
|
for (const auto i : c10::irange(old_outputs.size())) {
|
|
if (old_outputs[i]->hasDebugName()) {
|
|
new_outputs[i]->setDebugName(old_outputs[i]->debugName());
|
|
}
|
|
old_outputs[i]->replaceAllUsesWith(new_outputs[i]);
|
|
}
|
|
to_replace->destroy();
|
|
|
|
return new_outputs;
|
|
}
|
|
|
|
// inline_optimized_graph argument is used in substitute function call for
|
|
// ONNX conversion
|
|
std::vector<Value*> inlineCallTo(
|
|
Node* to_replace,
|
|
GraphFunction* callee,
|
|
bool inline_optimized_graph /*=true*/) {
|
|
auto graph =
|
|
inline_optimized_graph ? callee->optimized_graph() : callee->graph();
|
|
return inlineCallTo(to_replace, callee, graph.get());
|
|
}
|
|
|
|
std::vector<Value*> unpackOutputs(const std::vector<Value*>& outputs) {
|
|
std::vector<Value*> new_outputs;
|
|
if (outputs.size() != 1 || outputs.at(0)->type()->kind() != TupleType::Kind) {
|
|
return outputs;
|
|
}
|
|
|
|
auto tup = outputs[0];
|
|
for (Value* v : createTupleUnpack(tup)) {
|
|
new_outputs.emplace_back(v);
|
|
}
|
|
// if this was a peephole tuple unpack we can just get rid of
|
|
// the tuple construct here and prevent needing DCE
|
|
if (tup->node()->kind() == prim::TupleConstruct && !tup->node()->hasUses()) {
|
|
tup->node()->destroy();
|
|
}
|
|
return new_outputs;
|
|
}
|
|
|
|
std::vector<Node*> findAllNodes(
|
|
at::ArrayRef<Block*> array,
|
|
Symbol kind,
|
|
bool recurse) {
|
|
std::vector<Node*> ret;
|
|
for (auto block : array) {
|
|
findAllNodes(*block, kind, recurse, ret);
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
std::vector<Node*> findAllNodes(Block& block, Symbol kind, bool recurse) {
|
|
return findAllNodes({&block}, kind, recurse);
|
|
}
|
|
|
|
std::vector<Node*> findAllNodes(Graph& g, Symbol kind, bool recurse) {
|
|
return findAllNodes(*g.block(), kind, recurse);
|
|
}
|
|
|
|
std::vector<Value*> insertGraph(
|
|
Graph& g,
|
|
Graph& callee,
|
|
ArrayRef<Value*> inputs,
|
|
std::unordered_map<Value*, Value*>& value_map) {
|
|
auto value_map_func = [&](Value* v) { return value_map.at(v); };
|
|
AT_ASSERT(callee.inputs().size() == inputs.size());
|
|
for (const auto i : c10::irange(inputs.size())) {
|
|
value_map[callee.inputs()[i]] = inputs[i];
|
|
}
|
|
for (auto* node : callee.nodes()) {
|
|
auto* new_node = g.insertNode(g.createClone(node, value_map_func));
|
|
for (size_t i = 0; i < node->outputs().size(); ++i) {
|
|
value_map[node->outputs()[i]] = new_node->outputs()[i];
|
|
}
|
|
}
|
|
|
|
std::vector<Value*> outputs;
|
|
for (auto* output : callee.outputs()) {
|
|
outputs.push_back(value_map_func(output));
|
|
}
|
|
|
|
return outputs;
|
|
}
|
|
|
|
std::vector<Value*> insertGraph(
|
|
Graph& g,
|
|
Graph& callee,
|
|
ArrayRef<Value*> inputs) {
|
|
std::unordered_map<Value*, Value*> value_map;
|
|
return insertGraph(g, callee, inputs, value_map);
|
|
}
|
|
|
|
void ProfileOp::cloneFrom(Node* other_) {
|
|
Node::cloneFrom(other_);
|
|
auto other = other_->cast<ProfileOp>();
|
|
this->callback_ = other->getCallback();
|
|
}
|
|
|
|
Node* ProfileOp::allocNewInstance(Graph* g) {
|
|
return new ProfileOp(g, {nullptr});
|
|
}
|
|
|
|
void ProfileIValueOp::cloneFrom(Node* other_) {
|
|
Node::cloneFrom(other_);
|
|
auto other = other_->cast<ProfileIValueOp>();
|
|
this->callback_ = other->getCallback();
|
|
}
|
|
|
|
Node* ProfileIValueOp::allocNewInstance(Graph* g) {
|
|
return new ProfileIValueOp(g, {nullptr});
|
|
}
|
|
|
|
TypePtr NamedValue::type() const {
|
|
if (value_) {
|
|
return value_->type();
|
|
} else {
|
|
return ivalue_.type();
|
|
}
|
|
}
|
|
|
|
const Symbol ProfileOp::Kind = ::c10::prim::profile;
|
|
const Symbol ProfileIValueOp::Kind = ::c10::prim::profile_ivalue;
|
|
|
|
OperatorSet::OperatorSet(std::initializer_list<const char*> sig_literals) {
|
|
insert(sig_literals);
|
|
}
|
|
|
|
std::vector<std::shared_ptr<Operator>> OperatorSet::getOps() const {
|
|
std::vector<std::shared_ptr<Operator>> result;
|
|
for (const auto& kv : ops) {
|
|
auto ops_for_symbol = kv.second;
|
|
result.insert(result.end(), ops_for_symbol.begin(), ops_for_symbol.end());
|
|
}
|
|
return result;
|
|
}
|
|
|
|
void OperatorSet::insert(std::initializer_list<const char*> sig_literals) {
|
|
for (const char* sig : sig_literals) {
|
|
auto op = getOperatorForLiteral(sig);
|
|
ops[Symbol::fromQualString(op->schema().name())].push_back(op);
|
|
}
|
|
}
|
|
|
|
bool Node::isMemberOf(const OperatorSet& os) const {
|
|
auto it = os.ops.find(kind());
|
|
if (it == os.ops.end()) {
|
|
return false;
|
|
}
|
|
for (auto& op : it->second) {
|
|
if (matches(op->schema())) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
} // namespace torch::jit
|