[static runtime] Swap to out-variant compatible nodes (#44127)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44127

Test Plan: Imported from OSS

Reviewed By: hlu1

Differential Revision: D23604306

Pulled By: bwasti

fbshipit-source-id: 18ccfb9b466b822e28130be3d5c4fae36c76820b
This commit is contained in:
Bram Wasti
2020-09-14 12:33:02 -07:00
committed by Facebook GitHub Bot
parent 856510c96d
commit a475613d1d
3 changed files with 100 additions and 95 deletions

View File

@ -16,6 +16,7 @@ class StaticRuntime:
def __call__(self, *inps):
return self.static_runtime.run(inps)
def linear_shim(input, weight, bias=None):
# type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
output = input.matmul(weight.t())
@ -23,6 +24,8 @@ def linear_shim(input, weight, bias=None):
output += bias
ret = output
return ret
torch.nn.functional.linear = linear_shim
@ -92,6 +95,7 @@ def trivial_graph(a, b, c):
s = torch.tensor([[3, 3], [3, 3]])
return a + b * c + s
class TestStaticRuntime(TestCase):
def test_multihead_attention_layer(self):
HID_DIM = 256
@ -133,7 +137,15 @@ class TestStaticRuntime(TestCase):
ref_top = top_l(top_inp)
acc_top = top_l_acc(top_inp)[0]
torch.testing.assert_allclose(acc_top, ref_top)
for _ in range(5):
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
ref_bot = bot_l(bot_inp)
acc_bot = bot_l_acc(bot_inp)[0]
torch.testing.assert_allclose(acc_bot, ref_bot)
ref_top = top_l(top_inp)
acc_top = top_l_acc(top_inp)[0]
torch.testing.assert_allclose(acc_top, ref_top)
# def test_trivial_graph(self):
# s = torch.full((2, 2), 2)
@ -143,5 +155,6 @@ class TestStaticRuntime(TestCase):
# o_test = tg_a(s, s, s)[0]
# torch.testing.assert_allclose(o_ref, o_test)
if __name__ == "__main__":
run_tests()

View File

@ -104,110 +104,36 @@ StaticRuntime::StaticRuntime(const torch::jit::Module& m)
}
}
// fill constant_table_ and operator_table_
// fill workspace_ with constants
for (Node* node : graph_->nodes()) {
switch (node->kind()) {
case prim::Constant:
CHECK(node->output()->type()->kind() != FunctionType::Kind);
constant_table_[node->output()] = toIValue(node->output()).value();
break;
case prim::ListConstruct:
nodes_.emplace_back(node, nullptr);
break;
case prim::TupleConstruct:
nodes_.emplace_back(node, nullptr);
break;
default: {
const Operator& op = node->getOperator();
CHECK(op.hasOperation());
nodes_.emplace_back(node, op.getOperation(node));
}
}
}
}
void StaticRuntime::getInputIValues(
Node* node,
const ConstantMap& ws,
std::vector<IValue>& stack) const {
const size_t size = node->inputs().size();
stack.reserve(size);
for (size_t i = 0; i < size; i++) {
Value* v = node->inputs()[i];
auto f = constant_table_.find(v);
if (f == constant_table_.end()) {
auto f_ws = ws.find(v);
TORCH_CHECK(
f_ws != ws.end(),
"Workspace does not contain Value ",
v->debugName());
stack.emplace_back(f_ws->second);
if (node->kind() == prim::Constant) {
CHECK(node->output()->type()->kind() != FunctionType::Kind);
workspace_[node->output()] = toIValue(node->output()).value();
} else {
stack.emplace_back(f->second);
nodes_.emplace_back(node);
}
}
}
void StaticRuntime::runNodes(ConstantMap& workspace) const {
std::vector<IValue> stack;
for (const auto& p : nodes_) {
Node* node = p.first;
const Operation& op = p.second;
getInputIValues(node, workspace, stack);
VLOG(1) << node->kind().toDisplayString();
switch (node->kind()) {
case prim::ListConstruct: {
listConstruct(
stack,
node->output()->type()->expect<ListType>(),
node->inputs().size());
} break;
case prim::TupleConstruct: {
bool named =
node->output()->type()->expect<TupleType>()->name().has_value();
if (named) {
namedTupleConstruct(
stack,
node->output()->type()->expect<TupleType>(),
node->inputs().size());
} else {
tupleConstruct(stack, node->inputs().size());
}
} break;
default: {
DCHECK(op);
op(&stack);
break;
}
}
DCHECK_EQ(stack.size(), node->outputs().size());
for (auto i = 0; i < node->outputs().size(); i++) {
workspace[node->outputs()[i]] = stack[i];
}
stack.clear();
}
}
std::vector<at::Tensor> StaticRuntime::run(
const std::vector<at::Tensor>& inps) const {
const std::vector<at::Tensor>& inps) {
// Container for inputs, outputs, and activations (excluding parameters)
ConstantMap workspace_;
int start = 0;
if (graph_->inputs().size() != inps.size()) {
start = 1;
CHECK_EQ(graph_->inputs().size(), inps.size() + 1);
CHECK((graph_->inputs().at(0)->type()->is_module()));
workspace_.emplace(graph_->inputs()[0], module_._ivalue());
workspace_[graph_->inputs()[0]] = module_._ivalue();
}
for (size_t i = 0; i < inps.size(); i++) {
workspace_.emplace(graph_->inputs()[i + start], inps[i]);
workspace_[graph_->inputs()[i + start]] = inps[i];
}
runNodes(workspace_);
for (const auto& n : nodes_) {
n.run(workspace_);
}
std::vector<at::Tensor> out;
for (Value* output : graph_->outputs()) {
@ -223,5 +149,61 @@ std::vector<at::Tensor> StaticRuntime::run(
}
return out;
}
ProcessedNode::ProcessedNode(Node* node) : node_(node) {
if (node->kind() != prim::ListConstruct &&
node->kind() != prim::TupleConstruct) {
const Operator& op = node->getOperator();
CHECK(op.hasOperation());
op_ = op.getOperation(node);
}
}
void ProcessedNode::run(StaticRuntime::ConstantMap& workspace) const {
if (use_stack_) {
std::vector<IValue> stack;
const size_t size = node_->inputs().size();
stack.reserve(size);
for (size_t i = 0; i < size; i++) {
Value* v = node_->inputs()[i];
auto f = workspace.find(v);
TORCH_CHECK(
f != workspace.end(),
"Workspace does not contain Value ",
v->debugName());
stack.emplace_back(f->second);
}
if (op_) {
(*op_)(&stack);
} else {
if (node_->kind() == prim::ListConstruct) {
listConstruct(
stack,
node_->output()->type()->expect<ListType>(),
node_->inputs().size());
} else if (node_->kind() == prim::TupleConstruct) {
bool named =
node_->output()->type()->expect<TupleType>()->name().has_value();
if (named) {
namedTupleConstruct(
stack,
node_->output()->type()->expect<TupleType>(),
node_->inputs().size());
} else {
tupleConstruct(stack, node_->inputs().size());
}
} else {
TORCH_CHECK(0, "Unhandled operation!", node_->kind().toQualString());
}
}
DCHECK_EQ(stack.size(), node_->outputs().size());
for (auto i = 0; i < node_->outputs().size(); i++) {
workspace[node_->outputs()[i]] = stack[i];
}
} else {
TORCH_CHECK(0, "Non-stack execution not yet implemented");
}
}
} // namespace jit
} // namespace torch

View File

@ -14,6 +14,7 @@
namespace torch {
namespace jit {
class ProcessedNode;
class TORCH_API StaticRuntime {
public:
explicit StaticRuntime(std::shared_ptr<torch::jit::Graph> g)
@ -21,7 +22,7 @@ class TORCH_API StaticRuntime {
explicit StaticRuntime(const torch::jit::Module& m);
std::vector<at::Tensor> run(const std::vector<at::Tensor>& inps) const;
std::vector<at::Tensor> run(const std::vector<at::Tensor>& inps);
#ifdef FBCODE_CAFFE2
using ConstantMap = folly::F14FastMap<Value*, IValue>;
@ -34,17 +35,26 @@ class TORCH_API StaticRuntime {
std::shared_ptr<torch::jit::Graph> graph_;
// Static runtime states
// Constant table (including weights)
ConstantMap constant_table_;
// Value table (including weights)
ConstantMap workspace_;
// The nodes we need to run
std::vector<std::pair<Node*, Operation>> nodes_;
std::vector<ProcessedNode> nodes_;
};
void getInputIValues(
Node* node,
const ConstantMap& ws,
std::vector<IValue>& stack) const;
class ProcessedNode {
public:
ProcessedNode(Node* n);
void run(StaticRuntime::ConstantMap& workspace) const;
Node* get_node() const {
return node_;
}
void runNodes(ConstantMap& ws_) const;
private:
Node* node_;
c10::optional<Operation> op_;
// if false, we have an optimized version
bool use_stack_ = true;
};
} // namespace jit