mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[static runtime] Swap to out-variant compatible nodes (#44127)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44127 Test Plan: Imported from OSS Reviewed By: hlu1 Differential Revision: D23604306 Pulled By: bwasti fbshipit-source-id: 18ccfb9b466b822e28130be3d5c4fae36c76820b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
856510c96d
commit
a475613d1d
@ -16,6 +16,7 @@ class StaticRuntime:
|
||||
def __call__(self, *inps):
|
||||
return self.static_runtime.run(inps)
|
||||
|
||||
|
||||
def linear_shim(input, weight, bias=None):
|
||||
# type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
|
||||
output = input.matmul(weight.t())
|
||||
@ -23,6 +24,8 @@ def linear_shim(input, weight, bias=None):
|
||||
output += bias
|
||||
ret = output
|
||||
return ret
|
||||
|
||||
|
||||
torch.nn.functional.linear = linear_shim
|
||||
|
||||
|
||||
@ -92,6 +95,7 @@ def trivial_graph(a, b, c):
|
||||
s = torch.tensor([[3, 3], [3, 3]])
|
||||
return a + b * c + s
|
||||
|
||||
|
||||
class TestStaticRuntime(TestCase):
|
||||
def test_multihead_attention_layer(self):
|
||||
HID_DIM = 256
|
||||
@ -133,7 +137,15 @@ class TestStaticRuntime(TestCase):
|
||||
ref_top = top_l(top_inp)
|
||||
acc_top = top_l_acc(top_inp)[0]
|
||||
torch.testing.assert_allclose(acc_top, ref_top)
|
||||
|
||||
for _ in range(5):
|
||||
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
|
||||
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
|
||||
ref_bot = bot_l(bot_inp)
|
||||
acc_bot = bot_l_acc(bot_inp)[0]
|
||||
torch.testing.assert_allclose(acc_bot, ref_bot)
|
||||
ref_top = top_l(top_inp)
|
||||
acc_top = top_l_acc(top_inp)[0]
|
||||
torch.testing.assert_allclose(acc_top, ref_top)
|
||||
|
||||
# def test_trivial_graph(self):
|
||||
# s = torch.full((2, 2), 2)
|
||||
@ -143,5 +155,6 @@ class TestStaticRuntime(TestCase):
|
||||
# o_test = tg_a(s, s, s)[0]
|
||||
# torch.testing.assert_allclose(o_ref, o_test)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -104,110 +104,36 @@ StaticRuntime::StaticRuntime(const torch::jit::Module& m)
|
||||
}
|
||||
}
|
||||
|
||||
// fill constant_table_ and operator_table_
|
||||
// fill workspace_ with constants
|
||||
for (Node* node : graph_->nodes()) {
|
||||
switch (node->kind()) {
|
||||
case prim::Constant:
|
||||
CHECK(node->output()->type()->kind() != FunctionType::Kind);
|
||||
constant_table_[node->output()] = toIValue(node->output()).value();
|
||||
break;
|
||||
case prim::ListConstruct:
|
||||
nodes_.emplace_back(node, nullptr);
|
||||
break;
|
||||
case prim::TupleConstruct:
|
||||
nodes_.emplace_back(node, nullptr);
|
||||
break;
|
||||
default: {
|
||||
const Operator& op = node->getOperator();
|
||||
CHECK(op.hasOperation());
|
||||
nodes_.emplace_back(node, op.getOperation(node));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void StaticRuntime::getInputIValues(
|
||||
Node* node,
|
||||
const ConstantMap& ws,
|
||||
std::vector<IValue>& stack) const {
|
||||
const size_t size = node->inputs().size();
|
||||
stack.reserve(size);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
Value* v = node->inputs()[i];
|
||||
auto f = constant_table_.find(v);
|
||||
if (f == constant_table_.end()) {
|
||||
auto f_ws = ws.find(v);
|
||||
TORCH_CHECK(
|
||||
f_ws != ws.end(),
|
||||
"Workspace does not contain Value ",
|
||||
v->debugName());
|
||||
stack.emplace_back(f_ws->second);
|
||||
if (node->kind() == prim::Constant) {
|
||||
CHECK(node->output()->type()->kind() != FunctionType::Kind);
|
||||
workspace_[node->output()] = toIValue(node->output()).value();
|
||||
} else {
|
||||
stack.emplace_back(f->second);
|
||||
nodes_.emplace_back(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void StaticRuntime::runNodes(ConstantMap& workspace) const {
|
||||
std::vector<IValue> stack;
|
||||
for (const auto& p : nodes_) {
|
||||
Node* node = p.first;
|
||||
const Operation& op = p.second;
|
||||
getInputIValues(node, workspace, stack);
|
||||
VLOG(1) << node->kind().toDisplayString();
|
||||
|
||||
switch (node->kind()) {
|
||||
case prim::ListConstruct: {
|
||||
listConstruct(
|
||||
stack,
|
||||
node->output()->type()->expect<ListType>(),
|
||||
node->inputs().size());
|
||||
} break;
|
||||
case prim::TupleConstruct: {
|
||||
bool named =
|
||||
node->output()->type()->expect<TupleType>()->name().has_value();
|
||||
if (named) {
|
||||
namedTupleConstruct(
|
||||
stack,
|
||||
node->output()->type()->expect<TupleType>(),
|
||||
node->inputs().size());
|
||||
} else {
|
||||
tupleConstruct(stack, node->inputs().size());
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
DCHECK(op);
|
||||
op(&stack);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
DCHECK_EQ(stack.size(), node->outputs().size());
|
||||
for (auto i = 0; i < node->outputs().size(); i++) {
|
||||
workspace[node->outputs()[i]] = stack[i];
|
||||
}
|
||||
stack.clear();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> StaticRuntime::run(
|
||||
const std::vector<at::Tensor>& inps) const {
|
||||
const std::vector<at::Tensor>& inps) {
|
||||
// Container for inputs, outputs, and activations (excluding parameters)
|
||||
ConstantMap workspace_;
|
||||
|
||||
int start = 0;
|
||||
if (graph_->inputs().size() != inps.size()) {
|
||||
start = 1;
|
||||
CHECK_EQ(graph_->inputs().size(), inps.size() + 1);
|
||||
CHECK((graph_->inputs().at(0)->type()->is_module()));
|
||||
workspace_.emplace(graph_->inputs()[0], module_._ivalue());
|
||||
workspace_[graph_->inputs()[0]] = module_._ivalue();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < inps.size(); i++) {
|
||||
workspace_.emplace(graph_->inputs()[i + start], inps[i]);
|
||||
workspace_[graph_->inputs()[i + start]] = inps[i];
|
||||
}
|
||||
|
||||
runNodes(workspace_);
|
||||
for (const auto& n : nodes_) {
|
||||
n.run(workspace_);
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> out;
|
||||
for (Value* output : graph_->outputs()) {
|
||||
@ -223,5 +149,61 @@ std::vector<at::Tensor> StaticRuntime::run(
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
ProcessedNode::ProcessedNode(Node* node) : node_(node) {
|
||||
if (node->kind() != prim::ListConstruct &&
|
||||
node->kind() != prim::TupleConstruct) {
|
||||
const Operator& op = node->getOperator();
|
||||
CHECK(op.hasOperation());
|
||||
op_ = op.getOperation(node);
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessedNode::run(StaticRuntime::ConstantMap& workspace) const {
|
||||
if (use_stack_) {
|
||||
std::vector<IValue> stack;
|
||||
const size_t size = node_->inputs().size();
|
||||
stack.reserve(size);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
Value* v = node_->inputs()[i];
|
||||
auto f = workspace.find(v);
|
||||
TORCH_CHECK(
|
||||
f != workspace.end(),
|
||||
"Workspace does not contain Value ",
|
||||
v->debugName());
|
||||
stack.emplace_back(f->second);
|
||||
}
|
||||
if (op_) {
|
||||
(*op_)(&stack);
|
||||
} else {
|
||||
if (node_->kind() == prim::ListConstruct) {
|
||||
listConstruct(
|
||||
stack,
|
||||
node_->output()->type()->expect<ListType>(),
|
||||
node_->inputs().size());
|
||||
} else if (node_->kind() == prim::TupleConstruct) {
|
||||
bool named =
|
||||
node_->output()->type()->expect<TupleType>()->name().has_value();
|
||||
if (named) {
|
||||
namedTupleConstruct(
|
||||
stack,
|
||||
node_->output()->type()->expect<TupleType>(),
|
||||
node_->inputs().size());
|
||||
} else {
|
||||
tupleConstruct(stack, node_->inputs().size());
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(0, "Unhandled operation!", node_->kind().toQualString());
|
||||
}
|
||||
}
|
||||
DCHECK_EQ(stack.size(), node_->outputs().size());
|
||||
for (auto i = 0; i < node_->outputs().size(); i++) {
|
||||
workspace[node_->outputs()[i]] = stack[i];
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(0, "Non-stack execution not yet implemented");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -14,6 +14,7 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
class ProcessedNode;
|
||||
class TORCH_API StaticRuntime {
|
||||
public:
|
||||
explicit StaticRuntime(std::shared_ptr<torch::jit::Graph> g)
|
||||
@ -21,7 +22,7 @@ class TORCH_API StaticRuntime {
|
||||
|
||||
explicit StaticRuntime(const torch::jit::Module& m);
|
||||
|
||||
std::vector<at::Tensor> run(const std::vector<at::Tensor>& inps) const;
|
||||
std::vector<at::Tensor> run(const std::vector<at::Tensor>& inps);
|
||||
|
||||
#ifdef FBCODE_CAFFE2
|
||||
using ConstantMap = folly::F14FastMap<Value*, IValue>;
|
||||
@ -34,17 +35,26 @@ class TORCH_API StaticRuntime {
|
||||
std::shared_ptr<torch::jit::Graph> graph_;
|
||||
|
||||
// Static runtime states
|
||||
// Constant table (including weights)
|
||||
ConstantMap constant_table_;
|
||||
// Value table (including weights)
|
||||
ConstantMap workspace_;
|
||||
|
||||
// The nodes we need to run
|
||||
std::vector<std::pair<Node*, Operation>> nodes_;
|
||||
std::vector<ProcessedNode> nodes_;
|
||||
};
|
||||
|
||||
void getInputIValues(
|
||||
Node* node,
|
||||
const ConstantMap& ws,
|
||||
std::vector<IValue>& stack) const;
|
||||
class ProcessedNode {
|
||||
public:
|
||||
ProcessedNode(Node* n);
|
||||
void run(StaticRuntime::ConstantMap& workspace) const;
|
||||
Node* get_node() const {
|
||||
return node_;
|
||||
}
|
||||
|
||||
void runNodes(ConstantMap& ws_) const;
|
||||
private:
|
||||
Node* node_;
|
||||
c10::optional<Operation> op_;
|
||||
// if false, we have an optimized version
|
||||
bool use_stack_ = true;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
|
Reference in New Issue
Block a user