mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-07 01:50:04 +08:00
Summary: Relanding D78841818 with fixes Test Plan: Tested all failing tests buck build --config fbcode.use_link_groups=true --flagfile fbcode//mode/dev-nosan fbcode//sigmoid/core/executor/memory/test:layout_planner_tests buck test 'fbcode//mode/opt' fbcode//sigmoid/inference/test:test_passes Rollback Plan: Reviewed By: hl475 Differential Revision: D79038615 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159216 Approved by: https://github.com/dolpm
1579 lines
44 KiB
C++
1579 lines
44 KiB
C++
#include <torch/nativert/graph/Graph.h>
|
|
|
|
#include <fmt/ostream.h>
|
|
#include <fmt/ranges.h>
|
|
#include <limits>
|
|
#include <queue>
|
|
|
|
#include <c10/util/Enumerate.h>
|
|
#include <c10/util/FbcodeMaps.h>
|
|
#include <c10/util/StringUtil.h>
|
|
#include <torch/nativert/executor/Placement.h>
|
|
#include <torch/nativert/graph/TensorMeta.h>
|
|
|
|
namespace torch::nativert {
|
|
|
|
namespace {
|
|
|
|
// Workaround for MSVC bug: "std" ambiguous symbol.
|
|
template <typename T, typename U>
|
|
constexpr bool is_same_v = std::is_same_v<T, U>;
|
|
|
|
bool isBlank(char n) {
|
|
return std::isspace(n);
|
|
}
|
|
|
|
size_t consumeWhitespaceImpl(std::string_view source, size_t curPos) {
|
|
while (isBlank(source.at(curPos))) {
|
|
curPos++;
|
|
}
|
|
return curPos;
|
|
}
|
|
|
|
size_t expectImpl(
|
|
std::string_view source,
|
|
std::string_view expected,
|
|
size_t curPos) {
|
|
curPos = consumeWhitespaceImpl(source, curPos);
|
|
const auto actual = source.substr(curPos, expected.size());
|
|
TORCH_CHECK(
|
|
expected == actual,
|
|
fmt::format(
|
|
"Parser error: expected '{}' at position {}, but found '{}'.",
|
|
expected,
|
|
curPos,
|
|
actual));
|
|
curPos += expected.size();
|
|
return curPos;
|
|
}
|
|
|
|
size_t expectImpl(std::string_view source, char expected, size_t curPos) {
|
|
curPos = consumeWhitespaceImpl(source, curPos);
|
|
while (isBlank(source.at(curPos))) {
|
|
curPos++;
|
|
}
|
|
TORCH_CHECK(
|
|
expected == source[curPos],
|
|
"Parser error: expected '{}' at position {}, but found '{}'.",
|
|
expected,
|
|
curPos,
|
|
source[curPos]);
|
|
curPos++;
|
|
return curPos;
|
|
}
|
|
} // namespace
|
|
|
|
bool operator==(const Type& left, const Type& right) {
|
|
if (left.kind() != right.kind()) {
|
|
return false;
|
|
}
|
|
if (std::holds_alternative<Type::CustomObjData>(left.kind_) &&
|
|
std::holds_alternative<Type::CustomObjData>(right.kind_)) {
|
|
return std::get<Type::CustomObjData>(left.kind_).classFqn ==
|
|
std::get<Type::CustomObjData>(right.kind_).classFqn;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
Graph::Graph()
|
|
: insertBefore_(nodes_.end()),
|
|
inputNode_(insertNode("prim.Input", {})),
|
|
outputNode_(insertNode("prim.Output", {})) {
|
|
// Set the insertion point to append to the graph
|
|
insertBefore_ = nodes_.iterator_to(*outputNode_);
|
|
}
|
|
|
|
std::string Graph::getUniqueValueName() {
|
|
auto name = fmt::format("v{}", uniqueValueName_);
|
|
while (values_.find(name) != values_.end()) {
|
|
name = fmt::format("v{}", uniqueValueName_++);
|
|
}
|
|
return name;
|
|
}
|
|
|
|
// If `name` is null, create a unique value name
|
|
Value* Graph::addValue(
|
|
const std::optional<std::string>& name,
|
|
const Type& type,
|
|
Node* node) {
|
|
const auto valueName = name.value_or(getUniqueValueName());
|
|
ValueId valueId = getNextValueId();
|
|
const auto [it, success] = values_.insert(
|
|
{valueName, std::make_unique<Value>(valueId, valueName, type, node)});
|
|
TORCH_CHECK(
|
|
success,
|
|
fmt::format(
|
|
"Tried to create Value with name: '{}', but it already existed",
|
|
valueName));
|
|
return it->second.get();
|
|
}
|
|
|
|
Value* Graph::addInput(std::string_view name, const Type& type) {
|
|
return inputNode_->addOutput(name, type);
|
|
}
|
|
|
|
void Graph::addInput() {
|
|
inputNode_->addOutput();
|
|
}
|
|
|
|
Value* Graph::addOutput(Value* v) {
|
|
outputNode_->addInput({std::string(v->name()), v});
|
|
return v;
|
|
}
|
|
|
|
void Graph::addConstantOutput(Constant c) {
|
|
constantOutputs_.push_back(std::move(c));
|
|
}
|
|
|
|
// Create a node without inserting it into the execution graph.
|
|
Node* Graph::createNode(
|
|
std::string target,
|
|
std::vector<NamedArgument> inputs,
|
|
std::unordered_map<std::string, std::string> metadata) {
|
|
auto& node = nodesOwner_.emplace_back(std::make_unique<Node>(
|
|
this, std::move(target), std::move(inputs), std::move(metadata)));
|
|
return node.get();
|
|
}
|
|
|
|
Node* Graph::insertBefore(Node* toInsert, Node* insertionPoint) {
|
|
TORCH_CHECK(insertionPoint != inputNode_, "can't insert before prim.Input");
|
|
TORCH_CHECK(
|
|
!toInsert->is_linked(), "expected node to be unlinked: ", *toInsert);
|
|
TORCH_CHECK(
|
|
insertionPoint->is_linked(),
|
|
"expected node to be linked: ",
|
|
*insertionPoint);
|
|
auto it = nodes_.insert(nodes_.iterator_to(*insertionPoint), *toInsert);
|
|
return &*it;
|
|
}
|
|
|
|
Node* Graph::insert(Node* toInsert) {
|
|
TORCH_CHECK(
|
|
!toInsert->is_linked(), "expected node to be unlinked: ", *toInsert);
|
|
nodes_.insert(insertBefore_, *toInsert);
|
|
return toInsert;
|
|
}
|
|
|
|
Node* Graph::insertAfter(Node* toInsert, Node* insertionPoint) {
|
|
TORCH_CHECK(insertionPoint != outputNode_, "can't insert after prim.Output");
|
|
TORCH_CHECK(
|
|
!toInsert->is_linked(), "expected node to be unlinked: ", *toInsert);
|
|
TORCH_CHECK(
|
|
insertionPoint->is_linked(),
|
|
"expected node to be linked: ",
|
|
*insertionPoint);
|
|
|
|
auto insertIt = nodes_.iterator_to(*insertionPoint);
|
|
// Increment once because we want to insert after the insertion point
|
|
++insertIt;
|
|
auto it = nodes_.insert(insertIt, *toInsert);
|
|
return &*it;
|
|
}
|
|
|
|
Node* Graph::insertNode(
|
|
std::string target,
|
|
std::vector<NamedArgument> inputs,
|
|
std::unordered_map<std::string, std::string> metadata) {
|
|
auto node =
|
|
createNode(std::move(target), std::move(inputs), std::move(metadata));
|
|
nodes_.insert(insertBefore_, *node);
|
|
return node;
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& out, const Type& ty) {
|
|
std::visit(
|
|
[&out](auto&& arg) {
|
|
using T = std::decay_t<decltype(arg)>;
|
|
if constexpr (is_same_v<T, Type::Kind>) {
|
|
switch (arg) {
|
|
case Type::Kind::None:
|
|
out << "None";
|
|
break;
|
|
case Type::Kind::Tensor:
|
|
out << "Tensor";
|
|
break;
|
|
case Type::Kind::TensorList:
|
|
out << "TensorList";
|
|
break;
|
|
case Type::Kind::OptionalTensorList:
|
|
out << "OptionalTensorList";
|
|
break;
|
|
case Type::Kind::SymInt:
|
|
out << "SymInt";
|
|
break;
|
|
case Type::Kind::SymFloat:
|
|
out << "SymFloat";
|
|
break;
|
|
case Type::Kind::SymIntList:
|
|
out << "SymIntList";
|
|
break;
|
|
case Type::Kind::CustomObj:
|
|
out << "CustomObj";
|
|
break;
|
|
default:
|
|
TORCH_CHECK(false, "Unhandled type");
|
|
}
|
|
} else if constexpr (is_same_v<T, Type::CustomObjData>) {
|
|
out << "CustomObj: " << arg.classFqn;
|
|
}
|
|
},
|
|
ty.kind_);
|
|
return out;
|
|
}
|
|
|
|
const NamedArgument* Node::tryGetInput(std::string_view name) const {
|
|
// Just do a scan over the inputs. We expect there to always be a very small
|
|
// number of elements, so it shouldn't be slow. This allows us to avoid a
|
|
// second datastructure for lookups.
|
|
// Drop a debug check here, just to make sure :)
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inputs_.size() < 1000);
|
|
for (const auto& input : inputs_) {
|
|
if (input.name == name) {
|
|
return &input;
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
const NamedArgument& Node::getInput(std::string_view name) const {
|
|
const auto ret = tryGetInput(name);
|
|
if (ret == nullptr) {
|
|
TORCH_CHECK(
|
|
false,
|
|
fmt::format(
|
|
"Expected input '{}' on node: '{}' to exist, but it does not.",
|
|
name,
|
|
fmt::streamed(*this)));
|
|
}
|
|
return *ret;
|
|
}
|
|
|
|
const Attribute* Node::tryGetAttribute(std::string_view name) const {
|
|
// Just do a scan over the inputs. We expect there to always be a very small
|
|
// number of elements, so it shouldn't be slow. This allows us to avoid a
|
|
// second datastructure for lookups.
|
|
// Drop a debug check here, just to make sure :)
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(attributes_.size() < 1000);
|
|
for (const auto& attribute : attributes_) {
|
|
if (attribute.name == name) {
|
|
return &attribute;
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
const Attribute& Node::getAttribute(std::string_view name) const {
|
|
const auto ret = tryGetAttribute(name);
|
|
if (ret == nullptr) {
|
|
TORCH_CHECK(
|
|
false,
|
|
fmt::format(
|
|
"Expected attribute '{}' on node: '{}' to exist, but it does not.",
|
|
name,
|
|
fmt::streamed(*this)));
|
|
}
|
|
return *ret;
|
|
}
|
|
|
|
void Node::applyDevicePlacement(const Placement& placement) {
|
|
for (auto& attribute : attributes_) {
|
|
if (std::holds_alternative<c10::Device>(attribute.value)) {
|
|
auto device = std::get<c10::Device>(attribute.value);
|
|
auto targetDevice =
|
|
placement.getMappedDevice(std::get<c10::Device>(attribute.value));
|
|
if (!isSameDevice(targetDevice, device)) {
|
|
LOG(INFO) << "Overriding " << device.str() << " to "
|
|
<< targetDevice.str() << " for node " << *this;
|
|
attribute.value = targetDevice;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
Node* Node::next() {
|
|
return owningGraph()->nodeAfter(this);
|
|
}
|
|
|
|
const Node* Node::next() const {
|
|
return owningGraph()->nodeAfter(this);
|
|
}
|
|
|
|
Node* Node::prev() {
|
|
return owningGraph()->nodeBefore(this);
|
|
}
|
|
|
|
const Node* Node::prev() const {
|
|
return owningGraph()->nodeBefore(this);
|
|
}
|
|
|
|
bool Node::isBefore(const Node* n) const {
|
|
if (this == n) {
|
|
return false;
|
|
}
|
|
|
|
for (const Node* cursor = this->next(); cursor != nullptr;
|
|
cursor = cursor->next()) {
|
|
if (cursor == n) {
|
|
return true;
|
|
}
|
|
}
|
|
// Reached the end without finding n
|
|
return false;
|
|
}
|
|
|
|
std::vector<Node*> Node::producers() const {
|
|
std::vector<Node*> ret;
|
|
|
|
if (this->prev() == nullptr /* prim.Input */) {
|
|
return ret;
|
|
}
|
|
|
|
if (this->next() == nullptr /* prim.Output */) {
|
|
for (auto& node : owningGraph_->nodes()) {
|
|
if (node.next() == nullptr /* prim.Output */ ||
|
|
node.prev() == nullptr /* prim.Input */) {
|
|
continue;
|
|
}
|
|
for (auto* dep : node.users()) {
|
|
if (dep == this /* prim.Output */) {
|
|
ret.push_back(&node);
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
std::unordered_set<const Node*> seen;
|
|
|
|
for (const auto& input : inputs()) {
|
|
auto* n = input.value->producer();
|
|
if (n == nullptr) {
|
|
continue;
|
|
}
|
|
if (const auto [_, inserted] = seen.insert(n); inserted) {
|
|
ret.push_back(n);
|
|
}
|
|
}
|
|
|
|
if (ret.empty()) {
|
|
ret.push_back(owningGraph_->inputNode());
|
|
}
|
|
}
|
|
|
|
return ret;
|
|
}
|
|
|
|
std::vector<Node*> Node::users() const {
|
|
std::vector<Node*> ret;
|
|
|
|
if (this->next() == nullptr /* prim.Output */) {
|
|
return ret;
|
|
}
|
|
|
|
if (this->prev() == nullptr /* prim.Input */) {
|
|
for (auto& node : owningGraph_->nodes()) {
|
|
if (node.prev() == nullptr /* prim.Input */ ||
|
|
node.next() == nullptr /* prim.Output */) {
|
|
continue;
|
|
}
|
|
for (auto* dep : node.producers()) {
|
|
if (dep == this /* prim.Input */) {
|
|
ret.push_back(&node);
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
std::unordered_set<const Node*> seen;
|
|
|
|
for (const auto* output : outputs()) {
|
|
for (auto* n : output->users()) {
|
|
if (const auto [_, inserted] = seen.insert(n); inserted) {
|
|
ret.push_back(n);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (ret.empty()) {
|
|
ret.push_back(owningGraph_->outputNode());
|
|
}
|
|
}
|
|
|
|
return ret;
|
|
}
|
|
|
|
Node* Graph::createListPack(std::vector<Value*> inputs, const Type& inputType) {
|
|
std::vector<NamedArgument> nodeInputs;
|
|
nodeInputs.reserve(inputs.size());
|
|
for (auto [i, input] : c10::enumerate(inputs)) {
|
|
nodeInputs.push_back({fmt::format("l{}", i), input});
|
|
}
|
|
// Create a new named value for this
|
|
auto name = getUniqueValueName();
|
|
auto node = createNode("prim.ListPack", std::move(nodeInputs));
|
|
|
|
// Make sure all inputs are the same type
|
|
for (auto& input : inputs) {
|
|
TORCH_CHECK(input->type() == inputType);
|
|
}
|
|
|
|
if (inputType == Type::Kind::Tensor) {
|
|
node->addOutput(name, Type::Kind::TensorList);
|
|
} else if (inputType == Type::Kind::SymInt) {
|
|
node->addOutput(name, Type::Kind::SymIntList);
|
|
}
|
|
|
|
return node;
|
|
}
|
|
|
|
Node* Graph::createOptionalListPack(std::vector<Value*> inputs) {
|
|
std::vector<NamedArgument> nodeInputs;
|
|
nodeInputs.reserve(inputs.size());
|
|
for (auto [i, input] : c10::enumerate(inputs)) {
|
|
nodeInputs.push_back({fmt::format("l{}", i), input});
|
|
}
|
|
// Create a new named value for this
|
|
auto name = getUniqueValueName();
|
|
auto node = createNode("prim.ListPack", std::move(nodeInputs));
|
|
// Make sure all inputs are either None or Tensor
|
|
for (auto& input : inputs) {
|
|
TORCH_CHECK(
|
|
input->type() == Type::Kind::None ||
|
|
input->type() == Type::Kind::Tensor);
|
|
}
|
|
node->addOutput(name, Type::Kind::OptionalTensorList);
|
|
|
|
return node;
|
|
}
|
|
|
|
Value* Graph::createConstantSymIntValue(int value) {
|
|
auto valueName = getUniqueValueName();
|
|
ValueId valueId = getNextValueId();
|
|
const auto [it, success] = values_.insert(
|
|
{valueName,
|
|
std::make_unique<Value>(
|
|
valueId, valueName, Type::Kind::SymInt, nullptr)});
|
|
TORCH_CHECK(
|
|
success,
|
|
fmt::format(
|
|
"Tried to create constant SymInt Value with name: '{}', but it already existed",
|
|
valueName));
|
|
constantSymIntValues_[valueId] = value;
|
|
return it->second.get();
|
|
}
|
|
|
|
Value* Graph::getValue(std::string_view name) const {
|
|
// TODO: can eliminate this string copy by enabling heterogeneous lookup for
|
|
// the container
|
|
return values_.at(std::string(name)).get();
|
|
}
|
|
|
|
Value* Graph::tryGetValue(std::string_view name) const {
|
|
// TODO: can eliminate this string copy by enabling heterogeneous lookup for
|
|
// the container
|
|
const auto key = std::string(name);
|
|
if (values_.find(key) != values_.end()) {
|
|
return values_.at(key).get();
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
void Graph::renumberValues() {
|
|
std::vector<Value*> currentValues;
|
|
currentValues.reserve(values_.size());
|
|
for (auto& kv : values_) {
|
|
currentValues.push_back(kv.second.get());
|
|
}
|
|
|
|
// Sort values in creation order (by value ids)
|
|
std::sort(currentValues.begin(), currentValues.end(), [](Value* a, Value* b) {
|
|
return a->id() < b->id();
|
|
});
|
|
|
|
// Build a new id map with all ids < values_.size()
|
|
std::unordered_map<ValueId, ValueId> oldToNew;
|
|
oldToNew.reserve(currentValues.size());
|
|
ValueId newId = 0;
|
|
for (Value* v : currentValues) {
|
|
oldToNew[v->id()] = newId;
|
|
v->setId(newId);
|
|
newId++;
|
|
}
|
|
|
|
std::unordered_map<ValueId, int> newSymIntMap;
|
|
for (auto& [oldId, symIntVal] : constantSymIntValues_) {
|
|
auto it = oldToNew.find(oldId);
|
|
if (it != oldToNew.end()) {
|
|
ValueId updatedId = it->second;
|
|
newSymIntMap[updatedId] = symIntVal;
|
|
}
|
|
}
|
|
constantSymIntValues_ = std::move(newSymIntMap);
|
|
uniqueValueId_ = newId;
|
|
}
|
|
|
|
bool Graph::cleanupDeadNodes() {
|
|
std::unordered_set<const Node*> visited;
|
|
std::vector<const Node*> visitStack;
|
|
|
|
// Mark reachable nodes from output
|
|
visitStack.push_back(outputNode_);
|
|
visited.insert(outputNode_);
|
|
|
|
while (!visitStack.empty()) {
|
|
const Node* current = visitStack.back();
|
|
visitStack.pop_back();
|
|
|
|
for (auto& namedArg : current->inputs()) {
|
|
Value* val = namedArg.value;
|
|
Node* producer = val->producer();
|
|
|
|
if (!producer) {
|
|
continue;
|
|
}
|
|
if (!visited.count(producer)) {
|
|
visited.insert(producer);
|
|
visitStack.push_back(producer);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Remove all nodes not in visited (other than input/outputs)
|
|
std::vector<Node*> toRemove;
|
|
for (auto& n : nodes()) {
|
|
if (n.target() == "prim.Input" || n.target() == "prim.Output" ||
|
|
visited.count(&n)) {
|
|
continue;
|
|
}
|
|
toRemove.push_back(&n);
|
|
}
|
|
|
|
const bool mutated = !toRemove.empty();
|
|
|
|
// Remove nodes in reverse order to handle input/output dependencies
|
|
for (auto it = toRemove.rbegin(); it != toRemove.rend(); ++it) {
|
|
removeNode(*it);
|
|
}
|
|
|
|
renumberValues();
|
|
lint();
|
|
|
|
return mutated;
|
|
}
|
|
|
|
void Graph::lint() const {
|
|
// Check that every value has a producer marked.
|
|
for (const auto& [name, value] : values_) {
|
|
// Some constant symint and None don't have producer nodes
|
|
if (value->type().kind() != Type::Kind::SymInt &&
|
|
value->type().kind() != Type::Kind::None) {
|
|
TORCH_CHECK(value->isFolded() || value->producer() != nullptr);
|
|
}
|
|
}
|
|
for (const auto& node : nodes()) {
|
|
TORCH_CHECK(node.owningGraph() == this);
|
|
}
|
|
// Check that every list type is either produced by a prim.ListPack or
|
|
// immediately consumed by a prim.ListUnpack. We make use of this invariant
|
|
// to retrieve list elements in `getListElements`.
|
|
for (const auto& [_, value] : values_) {
|
|
if (value->type().kind() != Type::Kind::TensorList) {
|
|
continue;
|
|
}
|
|
const bool producedByListPack =
|
|
value->producer(/* resolve_folded = */ true)->target() ==
|
|
"prim.ListPack";
|
|
const bool consumedByListUnpack = value->users().size() == 1 &&
|
|
value->users()[0]->target() == "prim.ListUnpack";
|
|
TORCH_CHECK(producedByListPack || consumedByListUnpack);
|
|
}
|
|
|
|
auto getNames = [](const auto& values) {
|
|
c10::FastSet<std::string> names;
|
|
for (const auto* value : values) {
|
|
if (value) {
|
|
names.emplace(value->name());
|
|
}
|
|
}
|
|
return names;
|
|
};
|
|
signature_.lint(getNames(inputs()), getNames(outputs()));
|
|
}
|
|
|
|
void Graph::finalize() {
|
|
// build userOutputs_ view
|
|
userOutputs_.clear();
|
|
size_t constantIndex = 0;
|
|
for (auto& outputName : signature_.userOutputs()) {
|
|
if (outputName.has_value()) {
|
|
userOutputs_.emplace_back(getValue(*outputName));
|
|
} else {
|
|
if (constantIndex < constantOutputs_.size()) {
|
|
userOutputs_.emplace_back(std::move(constantOutputs_[constantIndex]));
|
|
constantIndex++;
|
|
} else {
|
|
TORCH_CHECK(false, "No more constant outputs available");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
// Scan through a node's inputs, replacing ALL instances of `old` with
|
|
// `replacement`. Returns true if a replacement occurred, otherwise false.
|
|
bool replace(Node* node, Value* old, Value* replacement) {
|
|
bool replacementOccurred = false;
|
|
for (auto& input : node->inputs()) {
|
|
if (input.value == old) {
|
|
input.value = replacement;
|
|
replacementOccurred = true;
|
|
}
|
|
}
|
|
return replacementOccurred;
|
|
}
|
|
} // namespace
|
|
|
|
void Graph::replaceAllUses(Value* old, Value* replacement) {
|
|
for (auto user : old->users()) {
|
|
// Find this use in the input list and replace it
|
|
auto replaced = replace(user, old, replacement);
|
|
TORCH_CHECK(replaced);
|
|
replacement->addUser(user);
|
|
}
|
|
old->eraseAllUsers();
|
|
signature_.replaceAllUses(old->name(), replacement->name());
|
|
}
|
|
|
|
void Graph::replaceAllUsesAfterNode(
|
|
Value* old,
|
|
Value* replacement,
|
|
Node* afterThis) {
|
|
auto it = nodes_.iterator_to(*afterThis);
|
|
// Don't search `afterThis`
|
|
++it;
|
|
// Scan through all node inputs linearly and replace uses
|
|
for (; it != nodes_.end(); ++it) {
|
|
Node* node = &*it;
|
|
const bool replaced = replace(node, old, replacement);
|
|
if (replaced) {
|
|
old->eraseUser(node);
|
|
replacement->addUser(node);
|
|
}
|
|
}
|
|
signature_.replaceAllUses(old->name(), replacement->name());
|
|
}
|
|
|
|
void Graph::applyDevicePlacement(const Placement& placement) {
|
|
TORCH_CHECK(
|
|
!placementApplied_,
|
|
"placement has been applied to the graph! placement must be applied once and once only.");
|
|
|
|
placementApplied_ = true;
|
|
|
|
// inplace override node's device-typed attributes according to placement
|
|
for (auto& node : nodes_) {
|
|
node.applyDevicePlacement(placement);
|
|
}
|
|
|
|
// inplace override weightMeta_'s device according to placement
|
|
for (auto& [_, weightMeta] : weightsMeta_) {
|
|
weightMeta.applyDevicePlacement(placement);
|
|
}
|
|
|
|
// inplace override tensorValuesMeta_'s device according to placement
|
|
for (auto& [_, tensorMeta] : tensorValuesMeta_) {
|
|
tensorMeta.applyDevicePlacement(placement);
|
|
}
|
|
}
|
|
|
|
Node* Graph::nodeAfter(Node* n) {
|
|
TORCH_CHECK(n->owningGraph() == this);
|
|
if (n == outputNode_) {
|
|
return nullptr;
|
|
}
|
|
auto it = nodes_.iterator_to(*n);
|
|
return &*(++it);
|
|
}
|
|
|
|
const Node* Graph::nodeAfter(const Node* n) const {
|
|
TORCH_CHECK(n->owningGraph() == this);
|
|
if (n == outputNode_) {
|
|
return nullptr;
|
|
}
|
|
auto it = nodes_.iterator_to(*n);
|
|
return &*(++it);
|
|
}
|
|
|
|
Node* Graph::nodeBefore(Node* n) {
|
|
TORCH_CHECK(n->owningGraph() == this);
|
|
if (n == inputNode_) {
|
|
return nullptr;
|
|
}
|
|
auto it = nodes_.iterator_to(*n);
|
|
return &*(--it);
|
|
}
|
|
|
|
const Node* Graph::nodeBefore(const Node* n) const {
|
|
TORCH_CHECK(n->owningGraph() == this);
|
|
if (n == inputNode_) {
|
|
return nullptr;
|
|
}
|
|
auto it = nodes_.iterator_to(*n);
|
|
return &*(--it);
|
|
}
|
|
|
|
void Graph::removeNode(Node* n) {
|
|
TORCH_CHECK(n->owningGraph() == this, "Node does not belong to this graph!");
|
|
|
|
for (auto* outputVal : n->outputs()) {
|
|
TORCH_CHECK(
|
|
outputVal->users().empty(),
|
|
"Trying to erase a node that still has users: ",
|
|
outputVal->name());
|
|
outputVal->eraseAllUsers();
|
|
removeValue(outputVal);
|
|
}
|
|
|
|
for (const auto& input : n->inputs()) {
|
|
input.value->eraseUser(n);
|
|
}
|
|
|
|
TORCH_CHECK(n->is_linked(), "Node is not linked to the graph!");
|
|
n->unlink();
|
|
|
|
auto it = std::find_if(
|
|
nodesOwner_.begin(),
|
|
nodesOwner_.end(),
|
|
[n](const std::unique_ptr<Node>& ptr) { return ptr.get() == n; });
|
|
|
|
TORCH_CHECK(it != nodesOwner_.end(), "Node not found in nodesOwner_!");
|
|
nodesOwner_.erase(it);
|
|
}
|
|
|
|
void Graph::removeValue(Value* value) {
|
|
// TODO: assuming not removing from constantSymIntValues_
|
|
TORCH_CHECK(value->users().empty(), "Cannot erase a value with users.");
|
|
auto it = values_.find(std::string(value->name()));
|
|
TORCH_CHECK(
|
|
it != values_.end(),
|
|
"Attempted to erase a value not in graph ",
|
|
value->name());
|
|
values_.erase(it);
|
|
}
|
|
|
|
std::vector<Value*> Graph::insertGraph(
|
|
const Graph& subgraph,
|
|
std::vector<Value*> inputs,
|
|
std::unordered_map<const Value*, Value*>& valueMap) {
|
|
TORCH_CHECK(subgraph.inputs().size() == inputs.size(), "Input size mismatch");
|
|
for (auto i : c10::irange(subgraph.inputs().size())) {
|
|
valueMap[subgraph.inputs()[i]] = inputs[i];
|
|
}
|
|
|
|
// Clone each node from subgraph
|
|
for (const auto& n : subgraph.nodes()) {
|
|
if (n.target() == "prim.Input" || n.target() == "prim.Output") {
|
|
continue;
|
|
}
|
|
|
|
std::vector<NamedArgument> clonedInputs;
|
|
auto inputs = n.inputs();
|
|
clonedInputs.reserve(inputs.size());
|
|
for (auto& inp : inputs) {
|
|
auto it = valueMap.find(inp.value);
|
|
TORCH_CHECK(it != valueMap.end(), "Missing input value in subgraph");
|
|
clonedInputs.push_back({inp.name, it->second});
|
|
}
|
|
|
|
Node* newNode = insertNode(
|
|
std::string(n.target()), std::move(clonedInputs), n.metadata());
|
|
|
|
for (const auto& attr : n.attributes()) {
|
|
Attribute newAttr;
|
|
newAttr.name = attr.name;
|
|
|
|
std::visit(
|
|
[&](auto&& val) -> void {
|
|
// Workaround for MSVC bug: "std" ambiguous symbol.
|
|
using std::unique_ptr;
|
|
using std::move;
|
|
using T = std::decay_t<decltype(val)>;
|
|
if constexpr (is_same_v<T, unique_ptr<Graph>>) {
|
|
LOG(ERROR)
|
|
<< "Graph attributes are not supported yet. Skipping attribute: "
|
|
<< attr.name;
|
|
} else {
|
|
newAttr.value = val;
|
|
#ifdef __clang__
|
|
#pragma GCC diagnostic push
|
|
#pragma GCC diagnostic ignored "-Wunknown-warning-option"
|
|
#pragma GCC diagnostic ignored "-Wunqualified-std-cast-call"
|
|
#endif
|
|
newNode->addAttribute(move(newAttr));
|
|
#ifdef __clang__
|
|
#pragma GCC diagnostic pop
|
|
#endif
|
|
}
|
|
},
|
|
attr.value);
|
|
}
|
|
|
|
for (const auto* outVal : n.outputs()) {
|
|
const auto& uniqueName = getUniqueValueName();
|
|
Value* newOut = newNode->addOutput(uniqueName, outVal->type());
|
|
valueMap[outVal] = newOut;
|
|
}
|
|
}
|
|
|
|
auto subgraphOutputs = subgraph.outputs();
|
|
std::vector<Value*> outputValues;
|
|
outputValues.reserve(subgraphOutputs.size());
|
|
for (auto* outputValue : subgraphOutputs) {
|
|
outputValues.emplace_back(valueMap[outputValue]);
|
|
}
|
|
lint();
|
|
return outputValues;
|
|
}
|
|
|
|
Node::Node(
|
|
Graph* owningGraph,
|
|
std::string target,
|
|
std::vector<NamedArgument> inputs,
|
|
std::unordered_map<std::string, std::string> metadata)
|
|
: owningGraph_(owningGraph),
|
|
target_(std::move(target)),
|
|
inputs_(std::move(inputs)),
|
|
metadata_(std::move(metadata)) {
|
|
for (const auto& input : inputs_) {
|
|
input.value->addUser(this);
|
|
}
|
|
}
|
|
|
|
Value* Node::addInput(NamedArgument input) {
|
|
inputs_.push_back(std::move(input));
|
|
auto val = inputs_.back().value;
|
|
val->addUser(this);
|
|
return val;
|
|
}
|
|
|
|
void Node::addInputs(const std::vector<NamedArgument>& inputs) {
|
|
for (const auto& input : inputs) {
|
|
addInput(input);
|
|
}
|
|
}
|
|
|
|
void Node::addAttribute(Attribute attr) {
|
|
attributes_.push_back(std::move(attr));
|
|
}
|
|
|
|
void Node::addOutput() {
|
|
outputs_.push_back(nullptr);
|
|
}
|
|
|
|
Value* Node::addOutput(const Type& type) {
|
|
TORCH_CHECK(type == Type::Kind::None);
|
|
Value* v = owningGraph_->addValue(std::nullopt, type, this);
|
|
outputs_.push_back(v);
|
|
return v;
|
|
}
|
|
|
|
Value* Node::addOutput(std::string_view name, const Type& type) {
|
|
Value* v = owningGraph_->addValue(std::string(name), type, this);
|
|
outputs_.push_back(v);
|
|
return v;
|
|
}
|
|
|
|
void Node::destroy() {
|
|
owningGraph_->removeNode(this);
|
|
}
|
|
|
|
void Value::addUser(Node* node) {
|
|
for (const auto* user : users_) {
|
|
if (user == node) {
|
|
return;
|
|
}
|
|
}
|
|
users_.push_back(node);
|
|
}
|
|
|
|
void Value::eraseUser(Node* node) {
|
|
users_.erase(
|
|
std::remove_if(
|
|
users_.begin(), users_.end(), [&](Node* el) { return el == node; }),
|
|
users_.end());
|
|
}
|
|
|
|
std::vector<const Value*> Value::getListElements() const {
|
|
std::vector<const Value*> ret;
|
|
if (auto p = producer(); p && p->target() == "prim.ListPack") {
|
|
for (const auto& tv : p->inputs()) {
|
|
ret.push_back(tv.value);
|
|
}
|
|
} else {
|
|
TORCH_CHECK(users().size() == 1);
|
|
const auto listUnpack = users()[0];
|
|
TORCH_CHECK(listUnpack->target() == "prim.ListUnpack");
|
|
for (const auto v : listUnpack->outputs()) {
|
|
ret.push_back(v);
|
|
}
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
template <class>
|
|
[[maybe_unused]] inline constexpr bool AlwaysFalse = false;
|
|
|
|
c10::IValue constantToIValue(const Constant& constant) {
|
|
// Workaround for MSVC bug: "std" ambiguous symbol.
|
|
using std::string;
|
|
using std::unique_ptr;
|
|
using std::vector;
|
|
return std::visit(
|
|
[](auto&& arg) -> c10::IValue {
|
|
using T = std::decay_t<decltype(arg)>;
|
|
if constexpr (is_same_v<T, None>) {
|
|
return c10::IValue();
|
|
} else if constexpr (std::is_convertible_v<T, c10::IValue>) {
|
|
return arg;
|
|
} else if constexpr (is_same_v<T, unique_ptr<Graph>>) {
|
|
TORCH_CHECK(
|
|
false, "subgraph arguments cannot be turned into ivalues!");
|
|
} else {
|
|
static_assert(AlwaysFalse<T>, "non-exhaustive visitor!");
|
|
}
|
|
},
|
|
constant);
|
|
}
|
|
|
|
namespace {
|
|
|
|
template <class>
|
|
[[maybe_unused]] inline constexpr bool always_false_v = false;
|
|
|
|
void printDouble(std::ostream& out, double arg) {
|
|
fmt::print(out, "{}", arg);
|
|
}
|
|
|
|
template <typename T, typename F>
|
|
std::ostream& printList(
|
|
std::ostream& out,
|
|
bool encloseInSquareBrackets,
|
|
const T& list,
|
|
F formatter) {
|
|
if (encloseInSquareBrackets) {
|
|
out << '[';
|
|
}
|
|
for (const auto& [idx, el] : c10::enumerate(list)) {
|
|
if (idx > 0) {
|
|
out << ", ";
|
|
}
|
|
formatter(out, el);
|
|
}
|
|
if (encloseInSquareBrackets) {
|
|
out << ']';
|
|
}
|
|
return out;
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& out, const Constant& constant) {
|
|
// Workaround for MSVC bug: "std" ambiguous symbol.
|
|
using std::quoted;
|
|
using std::string;
|
|
using std::unique_ptr;
|
|
using std::vector;
|
|
std::visit(
|
|
[&](auto&& arg) {
|
|
using T = std::decay_t<decltype(arg)>;
|
|
if constexpr (is_same_v<T, None>) {
|
|
out << "None";
|
|
} else if constexpr (is_same_v<T, int64_t> || is_same_v<T, bool>) {
|
|
out << arg;
|
|
} else if constexpr (
|
|
is_same_v<T, vector<int64_t>> || is_same_v<T, vector<bool>>) {
|
|
out << fmt::format("{}", fmt::streamed(arg));
|
|
} else if constexpr (is_same_v<T, double>) {
|
|
printDouble(out, arg);
|
|
} else if constexpr (is_same_v<T, vector<double>>) {
|
|
printList(out, true, arg, printDouble);
|
|
} else if constexpr (is_same_v<T, string>) {
|
|
out << quoted(arg);
|
|
} else if constexpr (is_same_v<T, c10::ScalarType>) {
|
|
out << kScalarTypePrefix << arg;
|
|
} else if constexpr (is_same_v<T, c10::MemoryFormat>) {
|
|
out << kMemoryFormatPrefix << arg;
|
|
} else if constexpr (is_same_v<T, c10::Layout>) {
|
|
out << kLayoutPrefix << arg;
|
|
} else if constexpr (is_same_v<T, c10::Device>) {
|
|
out << kDevicePrefix << "{" << arg << "}";
|
|
} else if constexpr (is_same_v<T, vector<string>>) {
|
|
out << fmt::format("[{}]", fmt::join(arg, ","));
|
|
} else if constexpr (is_same_v<T, unique_ptr<Graph>>) {
|
|
out << fmt::format("<subgraph>");
|
|
VLOG(0) << "Subgraph pretty print is not implemented";
|
|
} else {
|
|
static_assert(always_false_v<T>, "non-exhaustive visitor!");
|
|
}
|
|
},
|
|
constant);
|
|
return out;
|
|
}
|
|
|
|
void printValue(std::ostream& out, const Value* v) {
|
|
if (!v) {
|
|
out << "<Constant>";
|
|
return;
|
|
}
|
|
out << *v;
|
|
}
|
|
|
|
void printNamedArgument(std::ostream& out, const NamedArgument& nv) {
|
|
out << nv.name << "=" << *nv.value;
|
|
}
|
|
|
|
void printAttribute(std::ostream& out, const Attribute& nv) {
|
|
out << nv.name << "=" << nv.value;
|
|
}
|
|
} // namespace
|
|
|
|
std::ostream& operator<<(std::ostream& out, const Value& v) {
|
|
out << "%" << v.name();
|
|
// If a list, distinguish it by adding a []
|
|
// Looks like %my_list[]
|
|
if (v.type() == Type::Kind::TensorList) {
|
|
out << "[]";
|
|
}
|
|
return out;
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& out, const Node& node) {
|
|
// special casing for inputs and outputs
|
|
if (node.target() == "prim.Input") {
|
|
out << "graph(";
|
|
printList(out, false, node.outputs(), printValue);
|
|
out << "):";
|
|
return out;
|
|
}
|
|
if (node.target() == "prim.Output") {
|
|
out << "return(";
|
|
printList(out, false, node.inputs(), [](std::ostream& out, const auto& nv) {
|
|
out << *nv.value;
|
|
});
|
|
out << ")";
|
|
return out;
|
|
}
|
|
|
|
printList(out, false, node.outputs_, printValue);
|
|
|
|
out << " = ";
|
|
out << node.target_ << "(";
|
|
printList(out, false, node.inputs_, printNamedArgument);
|
|
if (!node.inputs_.empty() && !node.attributes_.empty()) {
|
|
// Emit a connective ',' between inputs and attributes.
|
|
out << ", ";
|
|
}
|
|
|
|
printList(out, false, node.attributes_, printAttribute);
|
|
out << ")";
|
|
return out;
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& out, const Graph& graph) {
|
|
for (const auto& node : graph.nodes_) {
|
|
out << node << "\n";
|
|
}
|
|
return out;
|
|
}
|
|
|
|
c10::Device convertDevice(std::string_view symbol) {
|
|
// Symbol looks like `Device{cuda:1}`
|
|
const auto typeStart = symbol.find('{') + 1;
|
|
TORCH_CHECK(typeStart < symbol.size());
|
|
|
|
const auto typeEnd = symbol.find(':');
|
|
TORCH_CHECK(typeEnd != std::string_view::npos);
|
|
|
|
const auto type = symbol.substr(typeStart, typeEnd - typeStart);
|
|
const auto indexStart = typeEnd + 1;
|
|
TORCH_CHECK(indexStart < symbol.size());
|
|
|
|
const auto indexEnd = symbol.find('}');
|
|
TORCH_CHECK(indexEnd != std::string_view::npos);
|
|
|
|
const auto index = symbol.substr(indexStart, indexEnd - indexStart);
|
|
|
|
c10::Device device((std::string(type)));
|
|
auto indexValue = c10::tryToNumber<int64_t>(std::string{index});
|
|
TORCH_CHECK(indexValue.has_value(), "Invalid device index format");
|
|
int64_t deviceIndex = indexValue.value();
|
|
TORCH_CHECK(
|
|
deviceIndex >= std::numeric_limits<c10::DeviceIndex>::min() &&
|
|
deviceIndex <= std::numeric_limits<c10::DeviceIndex>::max(),
|
|
"Device index out of range for int8_t");
|
|
device.set_index(static_cast<c10::DeviceIndex>(deviceIndex));
|
|
return device;
|
|
}
|
|
|
|
Constant convertAtomicConstant(std::string_view symbol) {
|
|
if (c10::starts_with(symbol, "\"")) {
|
|
// chop off the outer quotes and return the string
|
|
TORCH_CHECK(symbol.size() >= 2);
|
|
symbol.remove_prefix(1);
|
|
symbol.remove_suffix(1);
|
|
return std::string(symbol);
|
|
} else if (symbol == "None") {
|
|
return None();
|
|
} else if (symbol == "true") {
|
|
return true;
|
|
} else if (symbol == "false") {
|
|
return false;
|
|
} else if (c10::starts_with(symbol, kMemoryFormatPrefix)) {
|
|
torch::_export::MemoryFormat value = torch::_export::MemoryFormat::Unknown;
|
|
symbol.remove_prefix(kMemoryFormatPrefix.length());
|
|
torch::_export::parseEnum(symbol, value);
|
|
return convertJsonMemoryFormat(value);
|
|
} else if (c10::starts_with(symbol, kLayoutPrefix)) {
|
|
torch::_export::Layout value = torch::_export::Layout::Unknown;
|
|
symbol.remove_prefix(kLayoutPrefix.length());
|
|
torch::_export::parseEnum(symbol, value);
|
|
return convertJsonLayout(value);
|
|
} else if (c10::starts_with(symbol, kDevicePrefix)) {
|
|
return convertDevice(symbol);
|
|
} else if (c10::starts_with(symbol, kScalarTypePrefix)) {
|
|
torch::_export::ScalarType value = torch::_export::ScalarType::UNKNOWN;
|
|
symbol.remove_prefix(kScalarTypePrefix.length());
|
|
torch::_export::parseEnum(symbol, value);
|
|
return convertJsonScalarType(value);
|
|
}
|
|
|
|
// match number
|
|
// We need to disambiguate between int and float constants
|
|
const auto maybeInt = c10::tryToNumber<int64_t>(std::string{symbol});
|
|
|
|
// Libraries may happily convert "5.0" to an int 5, but we want that to
|
|
// become a float. So add an extra check for whether a '.' is in the string
|
|
// to guard against that.
|
|
bool hasDecimalSeparator = symbol.find('.') != std::string_view::npos;
|
|
if (maybeInt.has_value() && !hasDecimalSeparator) {
|
|
return maybeInt.value();
|
|
}
|
|
|
|
const auto maybeDouble = c10::tryToNumber<double>(std::string{symbol});
|
|
if (maybeDouble.has_value()) {
|
|
return maybeDouble.value();
|
|
}
|
|
|
|
TORCH_CHECK(false, "unhandled symbol: ", symbol);
|
|
}
|
|
|
|
Constant convertListConstant(std::string_view source) {
|
|
std::vector<Constant> values;
|
|
size_t curPos = 0;
|
|
Constant type = None();
|
|
|
|
// This basically the same as parseValueList, it's probably better to refactor
|
|
curPos = expectImpl(source, '[', curPos);
|
|
while (true) {
|
|
curPos = consumeWhitespaceImpl(source, curPos);
|
|
|
|
size_t start = curPos;
|
|
while (source.at(curPos) != ',' && source.at(curPos) != ']') {
|
|
curPos++;
|
|
}
|
|
auto symbol = source.substr(start, curPos - start);
|
|
auto val = convertAtomicConstant(symbol);
|
|
if (std::holds_alternative<None>(type)) {
|
|
// First time around; initialize our type sentinel with the first value.
|
|
// We will use this on subsequent iterations to check that all types are
|
|
// the same.
|
|
if (auto intPtr = std::get_if<int64_t>(&val)) {
|
|
type = *intPtr;
|
|
} else if (auto doublePtr = std::get_if<double>(&val)) {
|
|
type = *doublePtr;
|
|
} else if (auto boolPtr = std::get_if<bool>(&val)) {
|
|
type = *boolPtr;
|
|
} else {
|
|
TORCH_CHECK(false, "constant lists only support int, float, bool");
|
|
}
|
|
} else {
|
|
TORCH_CHECK(
|
|
type.index() == val.index(), "lists must have all the same type");
|
|
}
|
|
values.push_back(std::move(val));
|
|
if (source.at(curPos) == ']') {
|
|
break;
|
|
}
|
|
curPos = expectImpl(source, ',', curPos);
|
|
}
|
|
expectImpl(source, ']', curPos);
|
|
|
|
// Some annoying unwrapping
|
|
// std::vector<Constant<T>> -->
|
|
// Constant<std::vector<T>>
|
|
// Do it the dumb way.
|
|
if (std::holds_alternative<int64_t>(type)) {
|
|
std::vector<int64_t> inner;
|
|
inner.reserve(values.size());
|
|
for (const auto& el : values) {
|
|
inner.push_back(std::get<int64_t>(el));
|
|
}
|
|
return inner;
|
|
} else if (std::holds_alternative<double>(type)) {
|
|
std::vector<double> inner;
|
|
inner.reserve(values.size());
|
|
for (const auto& el : values) {
|
|
inner.push_back(std::get<double>(el));
|
|
}
|
|
return inner;
|
|
} else if (std::holds_alternative<bool>(type)) {
|
|
std::vector<bool> inner;
|
|
inner.reserve(values.size());
|
|
for (const auto& el : values) {
|
|
inner.push_back(std::get<bool>(el));
|
|
}
|
|
return inner;
|
|
}
|
|
TORCH_CHECK(false, "constant lists only support int, float, bool");
|
|
}
|
|
|
|
namespace {
|
|
|
|
/**
|
|
* Deserialization for graphs: parse the output produced by operator<<(Graph).
|
|
* This parser really only expects the exact output generated by well-formed
|
|
* Graph objects, so it is not very permissive and does not give good error
|
|
* messages.
|
|
*/
|
|
class Parser {
|
|
public:
|
|
explicit Parser(std::string_view source)
|
|
: source_(source), graph_(Graph::createGraph()) {}
|
|
std::unique_ptr<Graph> parse();
|
|
|
|
private:
|
|
template <typename T>
|
|
std::vector<T> parseList(
|
|
char open,
|
|
char close,
|
|
const std::function<T()>& parseFn);
|
|
|
|
std::string_view parseUntil(
|
|
const std::function<bool()>& fn,
|
|
bool includeEnd = false);
|
|
|
|
void expect(std::string_view expected);
|
|
void expect(char expected);
|
|
bool nextEquals(std::string_view expected) const;
|
|
bool nextIf(std::string_view expected);
|
|
bool nextIf(char expected);
|
|
void consumeWhitespace();
|
|
bool validIdent(char n);
|
|
char cur();
|
|
|
|
void parseReturn();
|
|
void parseNode();
|
|
std::pair<std::string_view, Type> parseOutput();
|
|
void parseGraphInputs();
|
|
std::string_view parseString();
|
|
std::variant<Value*, Constant> parseArgument();
|
|
std::variant<NamedArgument, Attribute> parseNamedArgument();
|
|
Value* parseSymbolicArgument();
|
|
// Symbols look like %v109, with the same valid ident rules as Python
|
|
// This returns the symbol *without* the % at the front.
|
|
std::string_view parseAtomicSymbol();
|
|
|
|
size_t curPos_ = 0;
|
|
std::string_view source_;
|
|
std::unique_ptr<Graph> graph_;
|
|
torch::_export::GraphSignature signature_;
|
|
};
|
|
|
|
std::unique_ptr<Graph> Parser::parse() {
|
|
parseGraphInputs();
|
|
while (true) {
|
|
consumeWhitespace();
|
|
if (nextEquals("return")) {
|
|
parseReturn();
|
|
break;
|
|
}
|
|
parseNode();
|
|
}
|
|
// For graph textual format, it should be safe to assume all
|
|
// inputs/outputs are from users.
|
|
graph_->setSignature(GraphSignature{signature_});
|
|
graph_->finalize();
|
|
graph_->lint();
|
|
// TODO: Might have some source left over, should check it if so.
|
|
return std::move(graph_);
|
|
}
|
|
|
|
bool Parser::nextIf(std::string_view expected) {
|
|
if (nextEquals(expected)) {
|
|
curPos_ += expected.size();
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool Parser::nextIf(char expected) {
|
|
if (cur() == expected) {
|
|
curPos_++;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
void Parser::parseGraphInputs() {
|
|
TORCH_CHECK(curPos_ == 0);
|
|
expect("graph");
|
|
const auto inputs = parseList<std::string_view>(
|
|
'(', ')', [&]() { return parseAtomicSymbol(); });
|
|
std::vector<torch::_export::InputSpec> inputSpecs;
|
|
inputSpecs.reserve(inputs.size());
|
|
for (const auto& input : inputs) {
|
|
graph_->addInput(input, Type::Kind::Tensor);
|
|
|
|
torch::_export::TensorArgument inputTensorArg;
|
|
inputTensorArg.set_name(std::string{input});
|
|
torch::_export::Argument inputArg;
|
|
inputArg.set_as_tensor(std::move(inputTensorArg));
|
|
torch::_export::UserInputSpec userInput;
|
|
userInput.set_arg(std::move(inputArg));
|
|
torch::_export::InputSpec inputSpec;
|
|
inputSpec.set_user_input(std::move(userInput));
|
|
inputSpecs.push_back(std::move(inputSpec));
|
|
}
|
|
signature_.set_input_specs(std::move(inputSpecs));
|
|
// TODO populate graphinputs
|
|
expect(":");
|
|
}
|
|
|
|
template <typename T>
|
|
std::vector<T> Parser::parseList(
|
|
char open,
|
|
char close,
|
|
const std::function<T()>& parseFn) {
|
|
std::vector<T> ret;
|
|
expect(open);
|
|
|
|
// Handle empty list
|
|
if (nextIf(close)) {
|
|
return ret;
|
|
}
|
|
while (true) {
|
|
ret.push_back(parseFn());
|
|
if (cur() == close) {
|
|
break;
|
|
}
|
|
expect(',');
|
|
}
|
|
expect(close);
|
|
return ret;
|
|
}
|
|
|
|
// Parse until `fn` returns true, returning the segment of the source that was
|
|
// consumed. If `includeEnd` is true, the returned segment will also include
|
|
// final character, which caused `fn` to return true.
|
|
std::string_view Parser::parseUntil(
|
|
const std::function<bool()>& fn,
|
|
bool includeEnd) {
|
|
size_t start = curPos_;
|
|
while (!fn()) {
|
|
curPos_++;
|
|
}
|
|
if (includeEnd) {
|
|
curPos_++;
|
|
}
|
|
return source_.substr(start, curPos_ - start);
|
|
}
|
|
|
|
// Parse a string, including the outer quotes
|
|
std::string_view Parser::parseString() {
|
|
size_t start = curPos_;
|
|
expect('"');
|
|
while (cur() != '"') {
|
|
// Handle escaped characters by skipping the next char when we see a
|
|
// backslash
|
|
if (cur() == '\\') {
|
|
curPos_++;
|
|
}
|
|
curPos_++;
|
|
}
|
|
|
|
// Consume final quote
|
|
curPos_++;
|
|
auto ret = source_.substr(start, curPos_ - start);
|
|
return ret;
|
|
}
|
|
|
|
bool Parser::validIdent(char n) {
|
|
return std::isalpha(n) || n == '_' || std::isdigit(n);
|
|
}
|
|
|
|
// Symbols look like %v109, with the same valid ident rules as Python
|
|
// This returns the symbol *without* the % at the front.
|
|
std::string_view Parser::parseAtomicSymbol() {
|
|
expect("%");
|
|
return parseUntil([&]() { return !validIdent(cur()); });
|
|
}
|
|
|
|
char Parser::cur() {
|
|
return source_.at(curPos_);
|
|
}
|
|
|
|
void Parser::consumeWhitespace() {
|
|
while (isBlank(cur())) {
|
|
curPos_++;
|
|
}
|
|
}
|
|
|
|
void Parser::expect(std::string_view expected) {
|
|
curPos_ = expectImpl(source_, expected, curPos_);
|
|
}
|
|
|
|
void Parser::expect(char expected) {
|
|
curPos_ = expectImpl(source_, expected, curPos_);
|
|
}
|
|
|
|
bool Parser::nextEquals(std::string_view expected) const {
|
|
const auto actual = source_.substr(curPos_, expected.size());
|
|
return expected == actual;
|
|
}
|
|
|
|
// %a, %b = aten.foo.default(input=%foo, foo=[7616], blah=%lol)
|
|
void Parser::parseNode() {
|
|
std::vector<std::pair<std::string_view, Type>> outputs;
|
|
|
|
outputs.push_back(parseOutput());
|
|
while (nextIf(",")) {
|
|
outputs.push_back(parseOutput());
|
|
}
|
|
expect("=");
|
|
consumeWhitespace();
|
|
|
|
// parse target name
|
|
const auto target = parseUntil([&]() { return cur() == '('; });
|
|
|
|
Node* node = graph_->insertNode(std::string(target));
|
|
for (auto& [name, var] : outputs) {
|
|
node->addOutput(name, var);
|
|
}
|
|
|
|
auto arguments = parseList<std::variant<NamedArgument, Attribute>>(
|
|
'(', ')', [&]() { return parseNamedArgument(); });
|
|
|
|
// Split the arguments into symbolic inputs and constant attributes
|
|
for (auto& arg : arguments) {
|
|
if (std::holds_alternative<NamedArgument>(arg)) {
|
|
node->addInput(std::get<NamedArgument>(arg));
|
|
} else {
|
|
node->addAttribute(std::get<Attribute>(std::move(arg)));
|
|
}
|
|
}
|
|
}
|
|
|
|
void Parser::parseReturn() {
|
|
expect("return");
|
|
const auto returns =
|
|
parseList<Value*>('(', ')', [&]() { return parseSymbolicArgument(); });
|
|
std::vector<torch::_export::OutputSpec> outputSpecs;
|
|
outputSpecs.reserve(returns.size());
|
|
for (const auto ret : returns) {
|
|
graph_->addOutput(ret);
|
|
|
|
torch::_export::TensorArgument retTensorArg;
|
|
retTensorArg.set_name(std::string{ret->name()});
|
|
torch::_export::Argument retArg;
|
|
retArg.set_as_tensor(std::move(retTensorArg));
|
|
torch::_export::UserOutputSpec userOutput;
|
|
userOutput.set_arg(std::move(retArg));
|
|
torch::_export::OutputSpec outputSpec;
|
|
outputSpec.set_user_output(std::move(userOutput));
|
|
outputSpecs.push_back(std::move(outputSpec));
|
|
}
|
|
signature_.set_output_specs(std::move(outputSpecs));
|
|
}
|
|
|
|
std::variant<NamedArgument, Attribute> Parser::parseNamedArgument() {
|
|
consumeWhitespace();
|
|
// Parse name
|
|
const auto symbol = parseUntil([&]() { return cur() == '='; });
|
|
expect('=');
|
|
|
|
// Parse value
|
|
auto value = parseArgument();
|
|
if (std::holds_alternative<Value*>(value)) {
|
|
return NamedArgument{std::string(symbol), std::get<Value*>(value)};
|
|
} else {
|
|
return Attribute{std::string(symbol), std::get<Constant>(std::move(value))};
|
|
}
|
|
}
|
|
|
|
std::pair<std::string_view, Type> Parser::parseOutput() {
|
|
consumeWhitespace();
|
|
TORCH_CHECK(cur() == '%', fmt::format("expected % but got {}", cur()));
|
|
|
|
auto symbol = parseAtomicSymbol();
|
|
if (nextIf('[')) {
|
|
expect(']');
|
|
return {symbol, Type::Kind::TensorList};
|
|
} else {
|
|
return {symbol, Type::Kind::Tensor};
|
|
}
|
|
}
|
|
|
|
Value* Parser::parseSymbolicArgument() {
|
|
consumeWhitespace();
|
|
TORCH_CHECK(cur() == '%', fmt::format("expected % but got {}", cur()));
|
|
|
|
auto symbol = parseAtomicSymbol();
|
|
std::vector<Value*> listElements;
|
|
if (cur() == '[') {
|
|
listElements = parseList<Value*>(
|
|
'[', ']', [&]() { return graph_->getValue(parseAtomicSymbol()); });
|
|
}
|
|
return graph_->getValue(symbol);
|
|
}
|
|
|
|
std::variant<Value*, Constant> Parser::parseArgument() {
|
|
consumeWhitespace();
|
|
|
|
// match symbol
|
|
if (cur() == '%') {
|
|
return parseSymbolicArgument();
|
|
}
|
|
|
|
// match list
|
|
if (cur() == '[') {
|
|
const auto symbol =
|
|
parseUntil([&]() { return cur() == ']'; }, /*includeEnd=*/true);
|
|
return convertListConstant(symbol);
|
|
}
|
|
|
|
// match string
|
|
if (cur() == '"') {
|
|
return convertAtomicConstant(parseString());
|
|
}
|
|
|
|
// otherwise parse this as a value
|
|
const auto symbol =
|
|
parseUntil([&]() { return cur() == ',' || cur() == ')'; });
|
|
return convertAtomicConstant(symbol);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
std::unique_ptr<Graph> stringToGraph(std::string_view source) {
|
|
return Parser(source).parse();
|
|
}
|
|
|
|
std::string graphToString(const Graph& g, bool include_signature) {
|
|
std::stringstream ss;
|
|
ss << g;
|
|
|
|
if (include_signature) {
|
|
ss << "\nGraphSignature\n";
|
|
ss << g.signature();
|
|
}
|
|
|
|
return ss.str();
|
|
}
|
|
|
|
} // namespace torch::nativert
|