Files
pytorch/torch/csrc/jit/passes/lower_graph.cpp

156 lines
5.1 KiB
C++

#include <torch/csrc/jit/passes/lower_graph.h>
#include <torch/csrc/jit/api/object.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/custom_class.h>
#include <unordered_map>
namespace torch::jit {
struct Slot {
c10::intrusive_ptr<c10::ivalue::Object> obj;
size_t offset;
bool operator==(const Slot& other) const {
return (this->obj == other.obj && this->offset == other.offset);
}
};
// remove the first module argument, replacing any access of its
// parameters/attributes with extra_ivalue input Slots that hold what value to
// pass into the graph. Used for ONNX export to remove first-class modules
// so it can deal purely with parameters and inputs
static std::pair<std::shared_ptr<Graph>, std::vector<Slot>> lower_graph(
const ModulePtr& self,
Graph& g_,
size_t self_offset = 0) {
std::shared_ptr<Graph> g = g_.copy();
// Inline to remove method/function calls
Inline(*g);
std::vector<Slot> extra_ivalues;
struct SlotHash {
std::size_t operator()(const Slot& slot) const {
auto obj_hash = std::hash<c10::ivalue::Object*>{}(slot.obj.get());
auto offset_hash = std::hash<size_t>{}(slot.offset);
return c10::hash_combine(obj_hash, offset_hash);
}
};
std::unordered_map<Slot, size_t, SlotHash> slot_to_offset;
struct ToScan {
ModulePtr mod;
Node* n;
size_t offset;
};
std::vector<ToScan> to_scan;
std::vector<Node*> to_clean; // nodes that should be dead at the end
auto getOrAddSlot = [&](const Slot& slot) -> Value* {
auto it = slot_to_offset.find(slot);
if (it != slot_to_offset.end()) {
size_t ivalues_start = g->inputs().size() - extra_ivalues.size();
return g->inputs().at(ivalues_start + it->second);
}
extra_ivalues.emplace_back(slot);
slot_to_offset[slot] = extra_ivalues.size() - 1;
return g->addInput()->setType(slot.obj->getSlot(slot.offset).type());
};
auto self_value = g->inputs().at(self_offset);
for (Use use : self_value->uses()) {
to_scan.emplace_back(ToScan{self, use.user, use.offset});
}
while (!to_scan.empty()) {
auto e = to_scan.back();
to_scan.pop_back();
// when we lambda lift forks, first-class modules may be passed across
// forks. This code recursively lowers the module in the fork call.
if (e.n->kind() == prim::fork) {
auto subgraph = e.n->g(attr::Subgraph);
std::vector<Slot> new_slots;
std::tie(subgraph, new_slots) = lower_graph(e.mod, *subgraph, e.offset);
e.n->g_(attr::Subgraph, subgraph);
for (const Slot& slot : new_slots) {
e.n->addInput(getOrAddSlot(slot));
}
e.n->removeInput(e.offset);
continue;
}
if (e.n->kind() == prim::PythonOp) {
throw ErrorReport(e.n->sourceRange()) << "Couldn't export Python method.";
}
if (e.n->kind() != prim::GetAttr) {
throw ErrorReport(e.n->sourceRange())
<< "temporary: the only valid use of a module is looking up an "
"attribute but found "
<< *e.n;
}
size_t slot_idx = e.mod->type()->getAttributeSlot(e.n->s(attr::name));
auto iv = e.mod->getSlot(slot_idx);
if (ClassTypePtr c = e.n->output()->type()->cast<ClassType>()) {
if (c->is_module()) {
for (Use use : e.n->output()->uses()) {
to_scan.emplace_back(ToScan{iv.toObject(), use.user, use.offset});
}
to_clean.emplace_back(e.n);
continue;
}
}
e.n->output()->replaceAllUsesWith(getOrAddSlot({e.mod, slot_idx}));
e.n->destroy();
}
while (!to_clean.empty()) {
Node* n = to_clean.back();
AT_ASSERT(!n->hasUses());
n->destroy();
to_clean.pop_back();
}
AT_ASSERT(!self_value->hasUses());
g->eraseInput(self_offset);
return std::make_pair(std::move(g), std::move(extra_ivalues));
}
static std::vector<IValue> loadTensors(const std::vector<Slot>& slots) {
std::vector<IValue> result;
result.reserve(slots.size());
for (const Slot& slot : slots) {
auto obj = slot.obj->getSlot(slot.offset);
if (obj.isTensor()) {
result.emplace_back(obj.toTensor());
} else {
// Unpack quantization packed tensor
auto type = obj.type();
TORCH_CHECK(
(type ==
getCustomClass(
"__torch__.torch.classes.quantized.Conv2dPackedParamsBase")) ||
(type ==
getCustomClass(
"__torch__.torch.classes.quantized.Conv3dPackedParamsBase")) ||
(type ==
getCustomClass(
"__torch__.torch.classes.quantized.LinearPackedParamsBase")),
"Unknown type ",
type->repr_str(),
" encountered in graph lowering. This type is not supported in ONNX export.");
result.emplace_back(
script::Object(obj.toObject()).run_method("__getstate__"));
}
}
return result;
}
std::pair<std::shared_ptr<Graph>, std::vector<IValue>> LowerGraph(
Graph& graph,
const ModulePtr& self) {
auto result = lower_graph(self, graph);
return std::make_pair(result.first, loadTensors(result.second));
}
} // namespace torch::jit