mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[jit] Scaffold a static runtime (#42753)
Summary: The premise of this approach is that a small subset of neural networks are well represented by a data flow graph. The README contains more information. The name is subject to change, but I thought it was a cute reference to fire. suo let me know if you'd prefer this in a different spot. Since it lowers a JIT'd module directly I assumed the JIT folder would be appropriate. There is no exposed Python interface yet (but is mocked up in `test_accelerant.py`) Pull Request resolved: https://github.com/pytorch/pytorch/pull/42753 Reviewed By: zou3519 Differential Revision: D23043771 Pulled By: bwasti fbshipit-source-id: 5353731e3aae31c08b5b49820815da98113eb551
This commit is contained in:
committed by
Facebook GitHub Bot
parent
59f8692350
commit
ada8404f2d
133
test/test_static_runtime.py
Normal file
133
test/test_static_runtime.py
Normal file
@ -0,0 +1,133 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
class StaticRuntime:
|
||||
def __init__(self, scripted):
|
||||
# this is an nn.Module
|
||||
if hasattr(scripted, "_c"):
|
||||
scripted._c = torch._C._freeze_module(scripted._c)
|
||||
self.static_runtime = torch._C._jit_to_static_runtime(
|
||||
scripted._c, scripted._c._get_method("forward").graph
|
||||
)
|
||||
else:
|
||||
self.static_runtime = torch._C._jit_to_static_runtime(scripted.graph)
|
||||
|
||||
def __call__(self, *inps):
|
||||
return self.static_runtime.run(inps)
|
||||
|
||||
|
||||
class MultiHeadAttentionLayer(nn.Module):
|
||||
def __init__(self, hid_dim, n_heads, dropout, device):
|
||||
super().__init__()
|
||||
assert hid_dim % n_heads == 0
|
||||
self.hid_dim = hid_dim
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = hid_dim // n_heads
|
||||
self.fc_q = nn.Linear(hid_dim, hid_dim)
|
||||
self.fc_k = nn.Linear(hid_dim, hid_dim)
|
||||
self.fc_v = nn.Linear(hid_dim, hid_dim)
|
||||
self.fc_o = nn.Linear(hid_dim, hid_dim)
|
||||
# self.dropout = nn.Dropout(dropout)
|
||||
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
|
||||
|
||||
def forward(self, query, key, value, mask):
|
||||
batch_size = query.shape[0]
|
||||
Q = self.fc_q(query)
|
||||
K = self.fc_k(key)
|
||||
V = self.fc_v(value)
|
||||
Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
|
||||
K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
|
||||
V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
|
||||
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
|
||||
# energy = energy.masked_fill(mask == 0, -1e10)
|
||||
attention = torch.softmax(energy, dim=-1)
|
||||
# x = torch.matmul(self.dropout(attention), V)
|
||||
x = torch.matmul(attention, V)
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
x = x.view(batch_size, -1, self.hid_dim)
|
||||
x = self.fc_o(x)
|
||||
return x, attention
|
||||
|
||||
|
||||
# Taken from https://github.com/facebookresearch/dlrm/blob/master/dlrm_s_pytorch.py
|
||||
def create_mlp(ln, sigmoid_layer):
|
||||
layers = nn.ModuleList()
|
||||
for i in range(0, len(ln) - 1):
|
||||
n = ln[i]
|
||||
m = ln[i + 1]
|
||||
|
||||
LL = nn.Linear(int(n), int(m), bias=True)
|
||||
|
||||
mean = 0.0 # std_dev = np.sqrt(variance)
|
||||
std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n)
|
||||
W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
|
||||
std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1))
|
||||
bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
|
||||
LL.weight.data = torch.tensor(W, requires_grad=True)
|
||||
LL.bias.data = torch.tensor(bt, requires_grad=True)
|
||||
layers.append(LL)
|
||||
|
||||
if i == sigmoid_layer:
|
||||
layers.append(nn.Sigmoid())
|
||||
else:
|
||||
layers.append(nn.ReLU())
|
||||
|
||||
with torch.no_grad():
|
||||
s = torch.jit.script(torch.nn.Sequential(*layers))
|
||||
s.eval()
|
||||
return s
|
||||
|
||||
|
||||
def trivial_graph(a, b, c):
|
||||
s = torch.tensor([[3, 3], [3, 3]])
|
||||
return a + b * c + s
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
HID_DIM = 256
|
||||
QUERY_LEN = 8
|
||||
BATCH_SIZE = 128
|
||||
LAYERS = 3
|
||||
HEADS = 8
|
||||
DROPOUT = 0.1
|
||||
device = torch.device("cpu")
|
||||
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
|
||||
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
|
||||
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
|
||||
|
||||
attention.eval()
|
||||
attention = torch.jit.script(attention)
|
||||
attention.eval()
|
||||
o_ref = attention(src, src, src, src_mask)
|
||||
|
||||
attention_a = StaticRuntime(attention)
|
||||
o_test = attention_a(src, src, src, src_mask)
|
||||
for a, b in zip(o_ref, o_test):
|
||||
torch.testing.assert_allclose(a, b)
|
||||
|
||||
s = torch.full((2, 2), 2)
|
||||
tg = torch.jit.script(trivial_graph)
|
||||
o_ref = tg(s, s, s)
|
||||
tg_a = StaticRuntime(tg)
|
||||
o_test = tg_a(s, s, s)[0]
|
||||
torch.testing.assert_allclose(o_ref, o_test)
|
||||
|
||||
# Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh
|
||||
ln_bot = [512, 512, 64]
|
||||
sigmoid_bot = -1
|
||||
ln_top = [100, 1024, 1024, 1024, 1]
|
||||
sigmoid_top = 3
|
||||
bot_l = create_mlp(ln_bot, sigmoid_bot)
|
||||
bot_l_acc = StaticRuntime(bot_l)
|
||||
top_l = create_mlp(ln_top, sigmoid_top)
|
||||
top_l_acc = StaticRuntime(top_l)
|
||||
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
|
||||
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
|
||||
ref_bot = bot_l(bot_inp)
|
||||
acc_bot = bot_l_acc(bot_inp)[0]
|
||||
torch.testing.assert_allclose(acc_bot, ref_bot)
|
||||
ref_top = top_l(top_inp)
|
||||
acc_top = top_l_acc(top_inp)[0]
|
||||
torch.testing.assert_allclose(acc_top, ref_top)
|
@ -219,6 +219,7 @@ core_sources_full = [
|
||||
"torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp",
|
||||
"torch/csrc/jit/runtime/profiling_record.cpp",
|
||||
"torch/csrc/jit/runtime/symbolic_script.cpp",
|
||||
"torch/csrc/jit/runtime/static/impl.cpp",
|
||||
"torch/csrc/jit/serialization/import.cpp",
|
||||
"torch/csrc/jit/serialization/import_export_helpers.cpp",
|
||||
"torch/csrc/jit/serialization/import_source.cpp",
|
||||
@ -501,6 +502,7 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/jit/frontend/concrete_module_type.cpp",
|
||||
"torch/csrc/jit/python/python_sugared_value.cpp",
|
||||
"torch/csrc/jit/python/python_tree_views.cpp",
|
||||
"torch/csrc/jit/runtime/static/init.cpp",
|
||||
"torch/csrc/multiprocessing/init.cpp",
|
||||
"torch/csrc/onnx/init.cpp",
|
||||
"torch/csrc/serialization.cpp",
|
||||
|
@ -73,6 +73,7 @@
|
||||
#include <torch/csrc/jit/runtime/jit_exception.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/runtime/print_handler.h>
|
||||
#include <torch/csrc/jit/runtime/static/init.h>
|
||||
#include <torch/csrc/jit/serialization/export.h>
|
||||
#include <torch/csrc/jit/serialization/import.h>
|
||||
#include <torch/csrc/jit/tensorexpr/execution_counter.h>
|
||||
@ -1091,6 +1092,7 @@ void initJITBindings(PyObject* module) {
|
||||
initTreeViewBindings(module);
|
||||
initJitScriptBindings(module);
|
||||
initJitBackendBindings(module);
|
||||
initStaticRuntimeBindings(module);
|
||||
|
||||
setPrintHandler([](const std::string& str) {
|
||||
py::gil_scoped_acquire acquire;
|
||||
|
31
torch/csrc/jit/runtime/static/README.md
Normal file
31
torch/csrc/jit/runtime/static/README.md
Normal file
@ -0,0 +1,31 @@
|
||||
> :warning: **This is an experimental feature**
|
||||
|
||||
# Static Runtime
|
||||
|
||||
The premise of this approach is that a small subset of neural networks are well represented by a
|
||||
completely flattened dataflow graph.
|
||||
TorchScript supports a far more feature programming paradigm,
|
||||
so many models will not work out of the box.
|
||||
|
||||
## Assumptions
|
||||
|
||||
This is a list of current assumptions for use with
|
||||
this feature.
|
||||
|
||||
- Inference only execution
|
||||
- Single CPU device
|
||||
|
||||
After `torch.jit.freeze` and inlining/constant propagation is run on the model:
|
||||
|
||||
- No control flow
|
||||
- No submodule invocations
|
||||
- No references to `self`
|
||||
- Inlined weights (i.e. no calls to `GetAttr`)
|
||||
|
||||
## Planned features
|
||||
|
||||
- Memory planning
|
||||
- Operator dispatch inlining
|
||||
- Operator subsitution
|
||||
- Weight layout transformations (pre-packing)
|
||||
- Lowering to `torch.jit.tensorexpr`
|
45
torch/csrc/jit/runtime/static/impl.cpp
Normal file
45
torch/csrc/jit/runtime/static/impl.cpp
Normal file
@ -0,0 +1,45 @@
|
||||
#include <torch/csrc/jit/runtime/static/impl.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
StaticRuntime::StaticRuntime(
|
||||
const torch::jit::Module& m,
|
||||
std::shared_ptr<torch::jit::Graph> g)
|
||||
: graph_(std::move(g)), module_(m.deepcopy()) {
|
||||
Inline(*graph_);
|
||||
ConstantPropagation(graph_);
|
||||
for (auto n : graph_->nodes()) {
|
||||
if (n->kind() == c10::Symbol::fromQualString("prim::GetAttr")) {
|
||||
throw std::runtime_error("Cannot accelerate unfrozen graphs");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> StaticRuntime::run(
|
||||
const std::vector<at::Tensor>& inps) const {
|
||||
std::vector<torch::jit::IValue> stack;
|
||||
if (graph_->inputs().at(0)->type()->is_module()) {
|
||||
stack.emplace_back(module_._ivalue());
|
||||
}
|
||||
for (const auto& inp : inps) {
|
||||
stack.emplace_back(inp);
|
||||
}
|
||||
torch::jit::Code code(graph_, "");
|
||||
torch::jit::InterpreterState interp(code);
|
||||
interp.run(stack);
|
||||
std::vector<at::Tensor> out;
|
||||
for (const auto& v : stack) {
|
||||
if (v.isTuple()) {
|
||||
auto t = v.toTuple();
|
||||
for (const auto& el : t->elements()) {
|
||||
out.emplace_back(el.toTensor());
|
||||
}
|
||||
} else {
|
||||
out.emplace_back(v.toTensor());
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
28
torch/csrc/jit/runtime/static/impl.h
Normal file
28
torch/csrc/jit/runtime/static/impl.h
Normal file
@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/interned_strings.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/passes/constant_propagation.h>
|
||||
#include <torch/csrc/jit/passes/inliner.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
class TORCH_API StaticRuntime {
|
||||
public:
|
||||
StaticRuntime(std::shared_ptr<torch::jit::Graph> g) : graph_(std::move(g)) {}
|
||||
StaticRuntime(
|
||||
const torch::jit::Module& m,
|
||||
std::shared_ptr<torch::jit::Graph> g);
|
||||
|
||||
std::vector<at::Tensor> run(const std::vector<at::Tensor>& inps) const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<torch::jit::Graph> graph_;
|
||||
torch::jit::Module module_;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
24
torch/csrc/jit/runtime/static/init.cpp
Normal file
24
torch/csrc/jit/runtime/static/init.cpp
Normal file
@ -0,0 +1,24 @@
|
||||
#include <torch/csrc/jit/runtime/static/impl.h>
|
||||
#include <torch/csrc/jit/runtime/static/init.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void initStaticRuntimeBindings(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
py::class_<StaticRuntime>(m, "StaticRuntime").def("run", &StaticRuntime::run);
|
||||
m.def(
|
||||
"_jit_to_static_runtime",
|
||||
[](const std::shared_ptr<torch::jit::Graph>& g) {
|
||||
return StaticRuntime(g);
|
||||
})
|
||||
.def(
|
||||
"_jit_to_static_runtime",
|
||||
[](const torch::jit::Module& m,
|
||||
const std::shared_ptr<torch::jit::Graph>& g) {
|
||||
return StaticRuntime(m, g);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
9
torch/csrc/jit/runtime/static/init.h
Normal file
9
torch/csrc/jit/runtime/static/init.h
Normal file
@ -0,0 +1,9 @@
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void initStaticRuntimeBindings(PyObject* module);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
Reference in New Issue
Block a user