[jit] Scaffold a static runtime (#42753)

Summary:
The premise of this approach is that a small subset of neural networks are well represented by a data flow graph.  The README contains more information.

The name is subject to change, but I thought it was a cute reference to fire.

suo let me know if you'd prefer this in a different spot.  Since it lowers a JIT'd module directly I assumed the JIT folder would be appropriate.  There is no exposed Python interface yet (but is mocked up in `test_accelerant.py`)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/42753

Reviewed By: zou3519

Differential Revision: D23043771

Pulled By: bwasti

fbshipit-source-id: 5353731e3aae31c08b5b49820815da98113eb551
This commit is contained in:
Bram Wasti
2020-08-12 13:02:29 -07:00
committed by Facebook GitHub Bot
parent 59f8692350
commit ada8404f2d
8 changed files with 274 additions and 0 deletions

133
test/test_static_runtime.py Normal file
View File

@ -0,0 +1,133 @@
import torch
from torch import nn
import numpy as np
class StaticRuntime:
def __init__(self, scripted):
# this is an nn.Module
if hasattr(scripted, "_c"):
scripted._c = torch._C._freeze_module(scripted._c)
self.static_runtime = torch._C._jit_to_static_runtime(
scripted._c, scripted._c._get_method("forward").graph
)
else:
self.static_runtime = torch._C._jit_to_static_runtime(scripted.graph)
def __call__(self, *inps):
return self.static_runtime.run(inps)
class MultiHeadAttentionLayer(nn.Module):
def __init__(self, hid_dim, n_heads, dropout, device):
super().__init__()
assert hid_dim % n_heads == 0
self.hid_dim = hid_dim
self.n_heads = n_heads
self.head_dim = hid_dim // n_heads
self.fc_q = nn.Linear(hid_dim, hid_dim)
self.fc_k = nn.Linear(hid_dim, hid_dim)
self.fc_v = nn.Linear(hid_dim, hid_dim)
self.fc_o = nn.Linear(hid_dim, hid_dim)
# self.dropout = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
def forward(self, query, key, value, mask):
batch_size = query.shape[0]
Q = self.fc_q(query)
K = self.fc_k(key)
V = self.fc_v(value)
Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
# energy = energy.masked_fill(mask == 0, -1e10)
attention = torch.softmax(energy, dim=-1)
# x = torch.matmul(self.dropout(attention), V)
x = torch.matmul(attention, V)
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(batch_size, -1, self.hid_dim)
x = self.fc_o(x)
return x, attention
# Taken from https://github.com/facebookresearch/dlrm/blob/master/dlrm_s_pytorch.py
def create_mlp(ln, sigmoid_layer):
layers = nn.ModuleList()
for i in range(0, len(ln) - 1):
n = ln[i]
m = ln[i + 1]
LL = nn.Linear(int(n), int(m), bias=True)
mean = 0.0 # std_dev = np.sqrt(variance)
std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n)
W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1))
bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
LL.weight.data = torch.tensor(W, requires_grad=True)
LL.bias.data = torch.tensor(bt, requires_grad=True)
layers.append(LL)
if i == sigmoid_layer:
layers.append(nn.Sigmoid())
else:
layers.append(nn.ReLU())
with torch.no_grad():
s = torch.jit.script(torch.nn.Sequential(*layers))
s.eval()
return s
def trivial_graph(a, b, c):
s = torch.tensor([[3, 3], [3, 3]])
return a + b * c + s
if __name__ == "__main__":
HID_DIM = 256
QUERY_LEN = 8
BATCH_SIZE = 128
LAYERS = 3
HEADS = 8
DROPOUT = 0.1
device = torch.device("cpu")
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
attention.eval()
attention = torch.jit.script(attention)
attention.eval()
o_ref = attention(src, src, src, src_mask)
attention_a = StaticRuntime(attention)
o_test = attention_a(src, src, src, src_mask)
for a, b in zip(o_ref, o_test):
torch.testing.assert_allclose(a, b)
s = torch.full((2, 2), 2)
tg = torch.jit.script(trivial_graph)
o_ref = tg(s, s, s)
tg_a = StaticRuntime(tg)
o_test = tg_a(s, s, s)[0]
torch.testing.assert_allclose(o_ref, o_test)
# Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh
ln_bot = [512, 512, 64]
sigmoid_bot = -1
ln_top = [100, 1024, 1024, 1024, 1]
sigmoid_top = 3
bot_l = create_mlp(ln_bot, sigmoid_bot)
bot_l_acc = StaticRuntime(bot_l)
top_l = create_mlp(ln_top, sigmoid_top)
top_l_acc = StaticRuntime(top_l)
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
ref_bot = bot_l(bot_inp)
acc_bot = bot_l_acc(bot_inp)[0]
torch.testing.assert_allclose(acc_bot, ref_bot)
ref_top = top_l(top_inp)
acc_top = top_l_acc(top_inp)[0]
torch.testing.assert_allclose(acc_top, ref_top)

View File

@ -219,6 +219,7 @@ core_sources_full = [
"torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp",
"torch/csrc/jit/runtime/profiling_record.cpp",
"torch/csrc/jit/runtime/symbolic_script.cpp",
"torch/csrc/jit/runtime/static/impl.cpp",
"torch/csrc/jit/serialization/import.cpp",
"torch/csrc/jit/serialization/import_export_helpers.cpp",
"torch/csrc/jit/serialization/import_source.cpp",
@ -501,6 +502,7 @@ libtorch_python_core_sources = [
"torch/csrc/jit/frontend/concrete_module_type.cpp",
"torch/csrc/jit/python/python_sugared_value.cpp",
"torch/csrc/jit/python/python_tree_views.cpp",
"torch/csrc/jit/runtime/static/init.cpp",
"torch/csrc/multiprocessing/init.cpp",
"torch/csrc/onnx/init.cpp",
"torch/csrc/serialization.cpp",

View File

@ -73,6 +73,7 @@
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/print_handler.h>
#include <torch/csrc/jit/runtime/static/init.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/tensorexpr/execution_counter.h>
@ -1091,6 +1092,7 @@ void initJITBindings(PyObject* module) {
initTreeViewBindings(module);
initJitScriptBindings(module);
initJitBackendBindings(module);
initStaticRuntimeBindings(module);
setPrintHandler([](const std::string& str) {
py::gil_scoped_acquire acquire;

View File

@ -0,0 +1,31 @@
> :warning: **This is an experimental feature**
# Static Runtime
The premise of this approach is that a small subset of neural networks are well represented by a
completely flattened dataflow graph.
TorchScript supports a far more feature programming paradigm,
so many models will not work out of the box.
## Assumptions
This is a list of current assumptions for use with
this feature.
- Inference only execution
- Single CPU device
After `torch.jit.freeze` and inlining/constant propagation is run on the model:
- No control flow
- No submodule invocations
- No references to `self`
- Inlined weights (i.e. no calls to `GetAttr`)
## Planned features
- Memory planning
- Operator dispatch inlining
- Operator subsitution
- Weight layout transformations (pre-packing)
- Lowering to `torch.jit.tensorexpr`

View File

@ -0,0 +1,45 @@
#include <torch/csrc/jit/runtime/static/impl.h>
namespace torch {
namespace jit {
StaticRuntime::StaticRuntime(
const torch::jit::Module& m,
std::shared_ptr<torch::jit::Graph> g)
: graph_(std::move(g)), module_(m.deepcopy()) {
Inline(*graph_);
ConstantPropagation(graph_);
for (auto n : graph_->nodes()) {
if (n->kind() == c10::Symbol::fromQualString("prim::GetAttr")) {
throw std::runtime_error("Cannot accelerate unfrozen graphs");
}
}
}
std::vector<at::Tensor> StaticRuntime::run(
const std::vector<at::Tensor>& inps) const {
std::vector<torch::jit::IValue> stack;
if (graph_->inputs().at(0)->type()->is_module()) {
stack.emplace_back(module_._ivalue());
}
for (const auto& inp : inps) {
stack.emplace_back(inp);
}
torch::jit::Code code(graph_, "");
torch::jit::InterpreterState interp(code);
interp.run(stack);
std::vector<at::Tensor> out;
for (const auto& v : stack) {
if (v.isTuple()) {
auto t = v.toTuple();
for (const auto& el : t->elements()) {
out.emplace_back(el.toTensor());
}
} else {
out.emplace_back(v.toTensor());
}
}
return out;
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,28 @@
#pragma once
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/inliner.h>
namespace torch {
namespace jit {
class TORCH_API StaticRuntime {
public:
StaticRuntime(std::shared_ptr<torch::jit::Graph> g) : graph_(std::move(g)) {}
StaticRuntime(
const torch::jit::Module& m,
std::shared_ptr<torch::jit::Graph> g);
std::vector<at::Tensor> run(const std::vector<at::Tensor>& inps) const;
private:
std::shared_ptr<torch::jit::Graph> graph_;
torch::jit::Module module_;
};
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,24 @@
#include <torch/csrc/jit/runtime/static/impl.h>
#include <torch/csrc/jit/runtime/static/init.h>
namespace torch {
namespace jit {
void initStaticRuntimeBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
py::class_<StaticRuntime>(m, "StaticRuntime").def("run", &StaticRuntime::run);
m.def(
"_jit_to_static_runtime",
[](const std::shared_ptr<torch::jit::Graph>& g) {
return StaticRuntime(g);
})
.def(
"_jit_to_static_runtime",
[](const torch::jit::Module& m,
const std::shared_ptr<torch::jit::Graph>& g) {
return StaticRuntime(m, g);
});
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,9 @@
#include <torch/csrc/jit/python/pybind_utils.h>
namespace torch {
namespace jit {
void initStaticRuntimeBindings(PyObject* module);
} // namespace jit
} // namespace torch