[static runtime] Move all heavy constructor logic into InferenceModule (renamed to StaticModule) (#51564)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51564

Constructor logic was spread throughout InferenceModule and StaticRuntime.  This diff unifies the two.  After a lot of discussion on this diff D25961626 it became apparent that `clone` is uglier than a cheap StaticRuntime.

This means StaticRuntime is effectively StaticModule and the only code in the new StaticRuntime is the `run` functions.

```
graph, schema = PrepareForStaticModule(torchscript_module)
sm = StaticModule(graph, schema, options)
sm(inputs)
// or create many cheap runtimes with the module
sr = StaticRuntime(sm)
sr(inputs)
```

Changelist:
- Rename InferenceModule StaticModule
- Move all logic for construction into StaticModule
- Create a new StaticRuntime that only has a unique memory planner (everything else is in StaticModule)
- Update comments with explanation
- Propagate all changes to predictor integration
- Propagate all changes to python integration
- Change semantics to be a bit more PyTorch-standard (no "run" calls, no "get_" getters).

Test Plan:
buck test //caffe2/test:static_runtime
buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest

Reviewed By: hlu1

Differential Revision: D25592967

fbshipit-source-id: 8233bed03137ce129137af2d44bce0095033ef0f
This commit is contained in:
Bram Wasti
2021-03-05 10:12:17 -08:00
committed by Facebook GitHub Bot
parent 5ebfabb310
commit 56f8379802
11 changed files with 445 additions and 328 deletions

View File

@ -75,8 +75,7 @@ static void BM_deep_wide_jit_profiling_executor(benchmark::State& state) {
static void BM_deep_wide_static(benchmark::State& state) {
auto mod = getDeepAndWideSciptModel();
auto g = torch::jit::PrepareForStaticRuntime(mod);
torch::jit::StaticRuntime runtime(g);
torch::jit::StaticModule smod(mod);
const int batch_size = state.range(0);
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
@ -85,21 +84,19 @@ static void BM_deep_wide_static(benchmark::State& state) {
std::vector<at::Tensor> inputs({ad_emb_packed, user_emb, wide});
runtime.run(inputs);
smod(inputs);
for (auto _ : state) {
runtime.run(inputs);
smod(inputs);
}
}
const std::shared_ptr<torch::jit::InferenceModule>& getStaticGraph() {
static const std::shared_ptr<torch::jit::InferenceModule> g =
torch::jit::PrepareForStaticRuntime(getDeepAndWideSciptModel());
return g;
torch::jit::StaticRuntime getStaticRuntime() {
static auto smod = std::make_shared<torch::jit::StaticModule>(getDeepAndWideSciptModel());
return torch::jit::StaticRuntime(*smod);
}
static void BM_deep_wide_static_threaded(benchmark::State& state) {
auto g = getStaticGraph();
torch::jit::StaticRuntime runtime(g);
auto sr = getStaticRuntime();
const int batch_size = 1; // state.range(0);
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
@ -108,39 +105,38 @@ static void BM_deep_wide_static_threaded(benchmark::State& state) {
std::vector<at::Tensor> inputs({ad_emb_packed, user_emb, wide});
sr(inputs);
for (auto _ : state) {
runtime.run(inputs);
sr(inputs);
}
}
static void BM_leaky_relu_const(benchmark::State& state) {
auto mod = getLeakyReLUConstScriptModel();
auto g = torch::jit::PrepareForStaticRuntime(mod);
torch::jit::StaticRuntime runtime(g);
torch::jit::StaticModule smod(mod);
const int batch_size = state.range(0);
auto data = torch::randn({batch_size, num_features});
std::vector<at::Tensor> inputs({data});
runtime.run(inputs);
smod(inputs);
for (auto _ : state) {
runtime.run(inputs);
smod(inputs);
}
}
static void BM_leaky_relu(benchmark::State& state) {
auto mod = getLeakyReLUScriptModel();
auto g = torch::jit::PrepareForStaticRuntime(mod);
torch::jit::StaticRuntime runtime(g);
torch::jit::StaticModule smod(mod);
const int batch_size = state.range(0);
auto neg_slope = torch::randn(1);
auto data = torch::randn({batch_size, num_features});
std::vector<at::Tensor> inputs({data, neg_slope[0]});
runtime.run(inputs);
smod(inputs);
for (auto _ : state) {
runtime.run(inputs);
smod(inputs);
}
}
@ -149,10 +145,9 @@ BENCHMARK(BM_leaky_relu_const)->RangeMultiplier(8)->Ranges({{1, 20}});
static void BM_long_static_memory_optimization(benchmark::State& state) {
auto mod = getLongScriptModel();
auto g = torch::jit::PrepareForStaticRuntime(mod);
torch::jit::StaticRuntimeOptions opts;
torch::jit::StaticModuleOptions opts;
opts.optimize_memory = state.range(1);
torch::jit::StaticRuntime runtime(g, opts);
torch::jit::StaticModule smod(mod, opts);
const auto N = state.range(0);
auto a = torch::randn({N, N});
@ -160,9 +155,9 @@ static void BM_long_static_memory_optimization(benchmark::State& state) {
auto c = torch::randn({N, N});
std::vector<at::Tensor> inputs({a, b, c});
runtime.run(inputs);
smod(inputs);
for (auto _ : state) {
runtime.run(inputs);
smod(inputs);
}
}

View File

@ -70,9 +70,9 @@ void testStaticRuntime(
auto expect = module.forward(args);
StaticRuntime runtime(module);
auto actual = runtime.run(args, {});
runtime.check_for_memory_leak();
torch::jit::StaticModule smodule(module);
auto actual = smodule(args, {});
smodule.runtime().check_for_memory_leak();
if (expect.isTuple()) {
compareTensorLists(
@ -187,10 +187,9 @@ TEST(StaticRuntime, LongModel) {
// run static runtime
std::vector<at::Tensor> input_tensors({a, b, c});
auto g = torch::jit::PrepareForStaticRuntime(mod);
torch::jit::StaticRuntime runtime(g);
at::Tensor output_2 = runtime.run(input_tensors)[0];
runtime.check_for_memory_leak();
torch::jit::StaticModule smod(mod);
at::Tensor output_2 = smod(input_tensors)[0];
smod.runtime().check_for_memory_leak();
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
}
@ -206,10 +205,9 @@ TEST(StaticRuntime, TrivialModel) {
// run static runtime
std::vector<at::Tensor> input_tensors({a, b, c});
auto g = torch::jit::PrepareForStaticRuntime(mod);
torch::jit::StaticRuntime runtime(g);
at::Tensor output_2 = runtime.run(input_tensors)[0];
runtime.check_for_memory_leak();
torch::jit::StaticModule smod(mod);
at::Tensor output_2 = smod(input_tensors)[0];
smod.runtime().check_for_memory_leak();
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
}
@ -223,10 +221,9 @@ TEST(StaticRuntime, LeakyReLU) {
// run static runtime
std::vector<at::Tensor> input_tensors({inputs});
auto g = torch::jit::PrepareForStaticRuntime(mod);
torch::jit::StaticRuntime runtime(g);
at::Tensor output_2 = runtime.run(input_tensors)[0];
runtime.check_for_memory_leak();
torch::jit::StaticModule smod(mod);
at::Tensor output_2 = smod(input_tensors)[0];
smod.runtime().check_for_memory_leak();
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
}
@ -234,8 +231,7 @@ TEST(StaticRuntime, DeepWide) {
const int embedding_size = 32;
const int num_features = 50;
torch::jit::Module mod = getDeepAndWideSciptModel();
auto g = torch::jit::PrepareForStaticRuntime(mod);
torch::jit::StaticRuntime runtime(g);
torch::jit::StaticModule smod(mod);
for (int batch_size : {1, 8, 32}) {
for (int i = 0; i < 2; ++i) {
@ -249,8 +245,8 @@ TEST(StaticRuntime, DeepWide) {
// run static runtime
std::vector<at::Tensor> input_tensors({ad_emb_packed, user_emb, wide});
at::Tensor output_2 = runtime.run(input_tensors)[0];
runtime.check_for_memory_leak();
at::Tensor output_2 = smod(input_tensors)[0];
smod.runtime().check_for_memory_leak();
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
}
}
@ -260,7 +256,7 @@ TEST(StaticRuntime, KWargsAPI_1) {
const int embedding_size = 32;
const int num_features = 50;
auto module = getDeepAndWideSciptModel();
torch::jit::StaticRuntime runtime(module);
torch::jit::StaticModule smod(module);
for (int batch_size : {1, 8, 32}) {
for (int i = 0; i < 2; ++i) {
@ -274,8 +270,8 @@ TEST(StaticRuntime, KWargsAPI_1) {
at::Tensor output_1 = getTensor(module.forward(inputs));
// run static runtime
c10::IValue output_ivalue = runtime.run(inputs, {});
runtime.check_for_memory_leak();
c10::IValue output_ivalue = smod(inputs, {});
smod.runtime().check_for_memory_leak();
at::Tensor output_2 = getTensor(output_ivalue);
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
@ -300,8 +296,7 @@ TEST(StaticRuntime, KWargsAPI_2) {
const int embedding_size = 32;
const int num_features = 50;
auto module = getDeepAndWideSciptModel();
auto g = torch::jit::PrepareForStaticRuntime(module);
torch::jit::StaticRuntime runtime(module);
torch::jit::StaticModule smod(module);
for (int batch_size : {1, 8, 32}) {
for (int i = 0; i < 2; ++i) {
@ -319,8 +314,8 @@ TEST(StaticRuntime, KWargsAPI_2) {
{"wide", wide}});
// run static runtime
c10::IValue output_ivalue = runtime.run({}, kwargs);
runtime.check_for_memory_leak();
c10::IValue output_ivalue = smod({}, kwargs);
smod.runtime().check_for_memory_leak();
at::Tensor output_2 = getTensor(output_ivalue);
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
@ -343,14 +338,14 @@ TEST(StaticRuntime, CleanUpMemory) {
const int embedding_size = 32;
const int num_features = 50;
torch::jit::Module mod = getDeepAndWideSciptModel();
auto g = torch::jit::PrepareForStaticRuntime(mod);
torch::jit::StaticModule smod(mod);
for (auto cleanup_memory : {true, false}) {
for (auto enable_out_variant : {true, false}) {
VLOG(1) << "cleanup_memory: " << cleanup_memory
<< ", enable_out_variant: " << enable_out_variant;
torch::jit::StaticRuntimeOptions opts{cleanup_memory, enable_out_variant};
torch::jit::StaticRuntime runtime(g, opts);
torch::jit::StaticModuleOptions opts{cleanup_memory, enable_out_variant};
torch::jit::StaticModule smod(mod, opts);
for (int batch_size : {1, 8, 32}) {
for (int i = 0; i < 2; ++i) {
@ -365,8 +360,8 @@ TEST(StaticRuntime, CleanUpMemory) {
// run static runtime
std::vector<at::Tensor> input_tensors(
{ad_emb_packed, user_emb, wide});
at::Tensor output_2 = runtime.run(input_tensors)[0];
runtime.check_for_memory_leak();
at::Tensor output_2 = smod(input_tensors)[0];
smod.runtime().check_for_memory_leak();
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
}
}

View File

@ -5,25 +5,25 @@ from torch.testing._internal.common_utils import TestCase, run_tests
from typing import Dict, Optional
class StaticRuntime:
class StaticModule:
def __init__(self, scripted):
# this is an nn.Module
if hasattr(scripted, "_c"):
self.static_runtime = torch._C._jit_to_static_runtime(scripted._c)
self.static_module = torch._C._jit_to_static_module(scripted._c)
else:
self.static_runtime = torch._C._jit_to_static_runtime(scripted.graph)
self.static_module = torch._C._jit_to_static_module(scripted.graph)
def __call__(self, *args, **kwargs):
if not kwargs:
return self.static_runtime.run(args)
return self.static_module(args)
else:
return self.static_runtime.run(args, kwargs)
return self.static_module(args, kwargs)
def benchmark(self, args, kwargs, warmup_runs, main_runs):
self.static_runtime.benchmark(args, kwargs, warmup_runs, main_runs)
self.static_module.benchmark(args, kwargs, warmup_runs, main_runs)
def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs):
return self.static_runtime.benchmark_individual_ops(
return self.static_module.benchmark_individual_ops(
args, kwargs, warmup_runs, main_runs
)
@ -121,7 +121,7 @@ def output_graph(a, b, c, iters : int):
d[i] = k + i
return d
class TestStaticRuntime(TestCase):
class TestStaticModule(TestCase):
def test_multihead_attention_layer(self):
HID_DIM = 256
QUERY_LEN = 8
@ -140,7 +140,7 @@ class TestStaticRuntime(TestCase):
attention.eval()
o_ref = attention(src, src, src, src_mask)
attention_a = StaticRuntime(attention)
attention_a = StaticModule(attention)
o_test = attention_a(src, src, src, src_mask)
o_test_kw = attention_a(src, src, value=src, mask=src_mask)
@ -165,7 +165,7 @@ class TestStaticRuntime(TestCase):
attention.eval()
attention = torch.jit.script(attention)
attention_a = StaticRuntime(attention)
attention_a = StaticModule(attention)
attention_a.benchmark([src, src, src, src_mask], {}, 2, 2)
metrics = attention_a.benchmark_individual_ops(
@ -179,9 +179,9 @@ class TestStaticRuntime(TestCase):
ln_top = [100, 1024, 1024, 1024, 1]
sigmoid_top = 3
bot_l = create_mlp(ln_bot, sigmoid_bot)
bot_l_acc = StaticRuntime(bot_l)
bot_l_acc = StaticModule(bot_l)
top_l = create_mlp(ln_top, sigmoid_top)
top_l_acc = StaticRuntime(top_l)
top_l_acc = StaticModule(top_l)
with torch.no_grad():
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
@ -206,7 +206,7 @@ class TestStaticRuntime(TestCase):
s = torch.full((2, 2), 2)
tg = torch.jit.script(trivial_graph)
o_ref = tg(s, s, s)
tg_a = StaticRuntime(tg)
tg_a = StaticModule(tg)
o_test = tg_a(s, s, s)[0]
torch.testing.assert_allclose(o_ref, o_test)
@ -214,7 +214,7 @@ class TestStaticRuntime(TestCase):
s = torch.randn(5, 5)
tg = torch.jit.script(nn.LeakyReLU(0.1))
o_ref = tg(s)
tg_a = StaticRuntime(tg)
tg_a = StaticModule(tg)
o_test = tg_a(s)[0]
torch.testing.assert_allclose(o_ref, o_test)
@ -222,7 +222,7 @@ class TestStaticRuntime(TestCase):
s = torch.full((2, 2), 2)
tg = torch.jit.script(trivial_graph)
o_ref = tg(s, s, s)
torch._C._fuse_to_static_runtime(tg.graph)
torch._C._fuse_to_static_module(tg.graph)
assert "StaticSubgraph" in str(tg.graph)
o_test = tg(s, s, s)
torch.testing.assert_allclose(o_ref, o_test)
@ -245,7 +245,7 @@ class TestStaticRuntime(TestCase):
attention.eval()
o_ref = attention(src, src, src, src_mask)
torch._C._fuse_to_static_runtime(attention._c)
torch._C._fuse_to_static_module(attention._c)
o_test = attention(src, src, src, src_mask)
for a, b in zip(o_ref, o_test):
@ -257,7 +257,7 @@ class TestStaticRuntime(TestCase):
c = 4
lg = torch.jit.script(loop_graph)
o_ref = lg(a, b, c)
torch._C._fuse_to_static_runtime(lg.graph)
torch._C._fuse_to_static_module(lg.graph)
assert "StaticSubgraph" in str(lg.graph)
o_test = lg(a, b, c)
torch.testing.assert_allclose(o_ref, o_test)
@ -268,7 +268,7 @@ class TestStaticRuntime(TestCase):
c = 4
og = torch.jit.script(output_graph)
o_ref = og(a, b, b, c)
torch._C._fuse_to_static_runtime(og.graph)
torch._C._fuse_to_static_module(og.graph)
assert "StaticSubgraph" in str(og.graph)
o_test = og(a, b, b, c)
for i in o_ref.keys():

View File

@ -1335,7 +1335,7 @@ void initJITBindings(PyObject* module) {
initTreeViewBindings(module);
initJitScriptBindings(module);
initJitBackendBindings(module);
initStaticRuntimeBindings(module);
initStaticModuleBindings(module);
initTensorExprBindings(module);
setPrintHandler([](const std::string& str) {

View File

@ -2,7 +2,10 @@
#include <ATen/core/interned_strings.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/static/impl.h>
@ -14,24 +17,30 @@ namespace jit {
void createFusionGroups(Block* block, AliasDb* aliasDb);
void fuseStaticSubgraphs(std::shared_ptr<Graph> graph) {
PrepareGraphForStaticRuntime(graph);
Inline(*graph);
ConstantPropagation(graph);
Canonicalize(graph);
ConstantPropagation(graph);
RemoveTensorMutation(graph);
ConstantPropagation(graph);
EliminateDeadCode(graph);
auto aliasDb = torch::make_unique<AliasDb>(graph);
createFusionGroups(graph->block(), aliasDb.get());
torch::jit::EliminateDeadCode(graph);
}
Operation createStaticSubgraphRuntime(const Node* node) {
auto g = torch::jit::PrepareForStaticRuntime(node->g(attr::Subgraph));
auto runtime = std::make_shared<torch::jit::StaticRuntime>(g);
auto num_inputs = runtime->num_inputs();
return [runtime, num_inputs](Stack* stack) {
auto g = node->g(attr::Subgraph);
auto module = std::make_shared<torch::jit::StaticModule>(g);
auto num_inputs = module->num_inputs();
return [module, num_inputs](Stack* stack) {
RECORD_FUNCTION("Static Runtime", std::vector<c10::IValue>());
auto inps = torch::jit::last(stack, num_inputs);
// TODO maybe avoid call to vec
auto outputs = runtime->run(inps.vec(), {});
auto outputs = (*module)(inps.vec(), {});
torch::jit::drop(stack, num_inputs);
if (runtime->num_outputs() > 1) {
if (module->num_outputs() > 1) {
for (auto& o : outputs.toTuple()->elements()) {
push_one(*stack, std::move(o));
}

View File

@ -18,7 +18,9 @@
namespace torch {
namespace jit {
void PrepareGraphForStaticRuntime(std::shared_ptr<torch::jit::Graph> graph) {
namespace {
void OptimizeGraph(std::shared_ptr<torch::jit::Graph>& graph) {
Inline(*graph);
ConstantPropagation(graph);
Canonicalize(graph);
@ -26,11 +28,6 @@ void PrepareGraphForStaticRuntime(std::shared_ptr<torch::jit::Graph> graph) {
RemoveTensorMutation(graph);
ConstantPropagation(graph);
EliminateDeadCode(graph);
}
namespace {
void OptimizeGraph(std::shared_ptr<torch::jit::Graph>& graph) {
PrepareGraphForStaticRuntime(graph);
FuseInferenceOpsForSparseNN(graph);
// TODO: we can avoid this guard by moving operations
@ -82,11 +79,10 @@ void RemoveSelfFromGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
}
// remove "self" from function schema
std::unique_ptr<c10::FunctionSchema> RemoveSelfFromSchema(
const c10::FunctionSchema& s) {
c10::FunctionSchema RemoveSelfFromSchema(const c10::FunctionSchema& s) {
TORCH_CHECK(s.arguments().size() >= 1 && s.arguments()[0].name() == "self");
std::vector<Argument> args({s.arguments().begin() + 1, s.arguments().end()});
return std::make_unique<c10::FunctionSchema>(s.cloneWithArguments(args));
return s.cloneWithArguments(args);
}
bool mayContainAlias(AliasDb& db, const Value* a, const Value* b) {
@ -427,74 +423,78 @@ std::unordered_map<const Value*, std::vector<const Value*>> FindShared(
return shared;
}
} // namespace
void InferenceModule::init() {
void PrepareGraphForStaticModule(std::shared_ptr<torch::jit::Graph> graph) {
OptimizeGraph(graph);
CheckGraphEligibility(graph);
RemoveSelfFromGraphInput(graph);
}
InferenceModule::InferenceModule(const torch::jit::Module& m)
: module(m.copy()), graph(nullptr), schema(nullptr) {
std::pair<std::shared_ptr<Graph>, c10::optional<c10::FunctionSchema>>
PrepareForStaticModule(const torch::jit::Module& m) {
auto module = m.copy();
module.eval();
Method method = module.get_method("forward");
graph = method.graph();
auto graph = method.graph();
// Move this pass to before running the freeze_module pass so that the
// sigrid_hash_compute_multipler_shift op can be precomputed and the results
// being folded into the module as constants. This is required to enable the
// ClipRangesGatherRangesX2SigridHashPrecompute pass. See D26833478
SplitOutPrecomputeOpsForSparseNN(graph);
module = freeze_module(module);
method = module.get_method("forward");
graph = method.graph();
graph = module.get_method("forward").graph();
PrepareGraphForStaticModule(graph);
const c10::FunctionSchema& s = method.function().getSchema();
schema = RemoveSelfFromSchema(s);
init();
c10::FunctionSchema s = RemoveSelfFromSchema(method.function().getSchema());
return std::make_pair(graph, s);
}
InferenceModule::InferenceModule(std::shared_ptr<torch::jit::Graph> g)
: module(), graph(std::move(g)), schema(nullptr) {
init();
std::pair<std::shared_ptr<Graph>, c10::optional<c10::FunctionSchema>>
PrepareForStaticModule(std::shared_ptr<torch::jit::Graph> graph) {
PrepareGraphForStaticModule(graph);
return std::make_pair(graph, c10::nullopt);
}
StaticRuntime::StaticRuntime(
StaticModule::StaticModule(
std::shared_ptr<torch::jit::Graph> g,
const StaticModuleOptions& opts)
: StaticModule(PrepareForStaticModule(g), opts) {}
StaticModule::StaticModule(
const torch::jit::Module& m,
const StaticRuntimeOptions& opts)
: StaticRuntime(PrepareForStaticRuntime(m), opts) {}
const StaticModuleOptions& opts)
: StaticModule(PrepareForStaticModule(m), opts) {}
StaticRuntime::StaticRuntime(
std::shared_ptr<InferenceModule> m,
const StaticRuntimeOptions& opts)
: module_(m), opts_(opts) {
TORCH_CHECK(
module_ != nullptr,
"std::shared_ptr<InferenceModule> module_ cannot be nullptr")
Graph* graph = module_->graph.get();
StaticModule::StaticModule(
std::pair<
std::shared_ptr<torch::jit::Graph>,
c10::optional<c10::FunctionSchema>> graph_and_schema,
const StaticModuleOptions& opts)
: opts_(opts),
graph_(std::move(graph_and_schema.first)),
schema_(std::move(graph_and_schema.second)) {
std::unordered_map<Value*, IValue*> val_to_ival;
// value -> index into nodes, index into outputs of node
std::unordered_map<Value*, std::pair<int, int>> val_to_idx;
// NB: create an unchanging std::vector<IValue> we can reference
for (auto input : graph->inputs()) {
inputs_.emplace_back();
}
for (auto i = 0; i < graph->inputs().size(); ++i) {
Value* input = graph->inputs()[i];
val_to_ival[input] = &(inputs_[i]);
// N inputs map to the first N entries in storage
for (auto i = 0; i < graph_->inputs().size(); ++i) {
Value* input = graph_->inputs()[i];
val_to_ival[input] = nullptr;
// input denoted by -1
val_to_idx[input] = std::make_pair(-1, i);
}
// fill workspace_ with constants and create ProcessedNodes
// NB: before optimizing the order of execution, ensure that the
// memory optimization pass (GetLivenessInformation + AssignRegisters) is
// memory optimization pass (LivenessMap) is
// aware of the new order!
// Fill constants first, so we have a std::vector<IValue> we can reference
// later
for (Node* node : graph->nodes()) {
for (Node* node : graph_->nodes()) {
if (node->kind() != prim::Constant) {
continue;
}
@ -504,37 +504,46 @@ StaticRuntime::StaticRuntime(
}
{
int i = 0;
for (Node* node : graph->nodes()) {
for (Node* node : graph_->nodes()) {
if (node->kind() != prim::Constant) {
continue;
}
auto* v = node->output();
// constants denoted -2, i
val_to_idx[v] = std::make_pair(-2, i);
val_to_ival[v] = &(constants_[i++]);
}
}
for (Node* node : graph->nodes()) {
int node_idx = 0;
for (Node* node : graph_->nodes()) {
if (node->kind() == prim::Constant) {
continue;
}
std::vector<const IValue*> inputs;
std::vector<std::pair<int, int>> indices;
for (Value* input : node->inputs()) {
inputs.emplace_back(val_to_ival.at(input));
indices.emplace_back(val_to_idx.at(input));
}
index_map_[node_idx] = indices;
nodes_.emplace_back(
ProcessedNode(node, std::move(inputs), opts.enable_out_variant));
for (auto i = 0; i < node->outputs().size(); ++i) {
val_to_ival[node->outputs()[i]] = &nodes_.back().Output(i);
val_to_ival[node->outputs()[i]] = nullptr;
val_to_idx[node->outputs()[i]] = std::make_pair(node_idx, i);
}
node_idx++;
}
for (auto output : graph->outputs()) {
outputs_.emplace_back(val_to_ival.at(output));
for (auto output : graph_->outputs()) {
output_indices_.emplace_back(val_to_idx[output]);
}
AliasDb alias_db(module_->graph);
auto lm = GetLivenessInformation(module_->graph, alias_db);
AliasDb alias_db(graph_);
auto lm = GetLivenessInformation(graph_, alias_db);
external_values_ = lm.second;
if (opts_.optimize_memory) {
auto values = GetOptimizableValues(module_->graph);
auto values = GetOptimizableValues(graph_);
if (!opts_.enable_out_variant) {
values.first = {};
}
@ -542,7 +551,82 @@ StaticRuntime::StaticRuntime(
}
}
std::vector<at::Tensor> StaticRuntime::run(
const StaticModuleOptions& StaticModule::opts() const {
return opts_;
}
size_t StaticModule::num_outputs() const {
return graph_->outputs().size();
}
size_t StaticModule::num_inputs() const {
return graph_->inputs().size();
}
StaticRuntime& StaticModule::runtime() {
if (!cached_runtime_) {
cached_runtime_ = std::make_unique<StaticRuntime>(*this);
}
return *cached_runtime_;
}
std::vector<at::Tensor> StaticModule::operator()(
const std::vector<at::Tensor>& inps) {
return runtime()(inps);
}
c10::IValue StaticModule::operator()(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs) {
return runtime()(args, kwargs);
}
StaticRuntime::StaticRuntime(const StaticModule& sm) : static_module_(sm) {
// NB: create unchanging std::vector<IValue>s we can reference
inputs_.resize(sm.num_inputs());
nodes_.resize(sm.nodes().size());
for (auto idx = 0; idx < sm.nodes().size(); ++idx) {
const auto& n_ref = sm.nodes()[idx];
nodes_[idx] = n_ref; // copy the node
auto& n = nodes_[idx];
// hook up the inputs
for (auto i = 0; i < n.inputs().size(); ++i) {
if (n.inputs()[i] == nullptr) {
int node_idx;
int out_idx;
std::tie(node_idx, out_idx) = sm.index_map().at(idx)[i];
DCHECK(out_idx >= 0);
// input
if (node_idx == -1) {
n.set_input(i, &inputs_[out_idx]);
} else if (node_idx == -2) {
n.set_input(i, &sm.constants()[out_idx]);
} else {
n.set_input(i, &(nodes_[node_idx].Output(out_idx)));
}
}
}
}
for (const auto& index_pair : sm.output_indices()) {
int node_idx;
int out_idx;
std::tie(node_idx, out_idx) = index_pair;
if (node_idx == -1) {
outputs_.emplace_back(&inputs_[out_idx]);
} else if (node_idx == -2) {
// This is a very rare case where const correctness
// breaks -- the user is returning a constant from
// the graph.
outputs_.emplace_back(const_cast<IValue*>(&sm.constants()[out_idx]));
} else {
auto& n = nodes_.at(node_idx);
auto* out = &n.Output(out_idx);
outputs_.emplace_back(out);
}
}
}
std::vector<at::Tensor> StaticRuntime::operator()(
const std::vector<at::Tensor>& inps) {
std::vector<c10::IValue> stack;
stack.resize(inps.size());
@ -550,7 +634,8 @@ std::vector<at::Tensor> StaticRuntime::run(
stack[i] = inps[i];
}
c10::IValue v = run(stack, std::unordered_map<std::string, c10::IValue>());
c10::IValue v =
(*this)(stack, std::unordered_map<std::string, c10::IValue>());
std::vector<at::Tensor> out;
@ -565,7 +650,7 @@ std::vector<at::Tensor> StaticRuntime::run(
return out;
}
c10::IValue StaticRuntime::run(
c10::IValue StaticRuntime::operator()(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs) {
// We assume inference workloads, so we do not need
@ -581,11 +666,11 @@ c10::IValue StaticRuntime::run(
if (!kwargs.empty()) {
// This is not ideal
TORCH_CHECK(
module_->schema != nullptr,
static_module_.schema(),
"Schema is not available. Consider creating the Static Runtime "
"with StaticRuntime(const torch::jit::Module& m) instead.");
"with StaticModule(const torch::jit::Module& m) instead.");
std::vector<c10::IValue> s = args;
module_->schema->checkAndNormalizeInputs(s, kwargs);
static_module_.schema()->checkAndNormalizeInputs(s, kwargs);
for (size_t i = 0; i < s.size(); i++) {
Input(i) = std::move(s[i]);
}
@ -596,16 +681,19 @@ c10::IValue StaticRuntime::run(
}
// NB: before optimizing the order of execution, ensure that the
// memory optimization pass (GetLivenessInformation + AssignRegisters) is
// memory optimization pass (LivenessMap) is
// aware of the new order!
for (auto& n : nodes_) {
n.run();
}
if (opts_.cleanup_activations) {
if (static_module_.opts().cleanup_activations) {
if (!planner_) {
planner_ = std::make_unique<MemoryPlanner>(
this, shared_values_, external_values_, opts_.enable_out_variant);
this,
static_module_.shared_values(),
static_module_.external_values(),
static_module_.opts().enable_out_variant);
}
planner_->deallocate();
// clean up owning refs of input tensors
@ -615,10 +703,10 @@ c10::IValue StaticRuntime::run(
}
// no need to keep references of outputs in static runtime anymore
if (num_outputs() > 1) {
if (static_module_.num_outputs() > 1) {
std::vector<c10::IValue> outputs;
outputs.reserve(num_outputs());
for (auto i = 0; i < num_outputs(); ++i) {
outputs.reserve(static_module_.num_outputs());
for (auto i = 0; i < static_module_.num_outputs(); ++i) {
// use move here. Otherwise, clean up outputs_[i] explicitly
outputs.emplace_back(std::move(*outputs_[i]));
}
@ -646,7 +734,7 @@ void StaticRuntime::benchmark(
benchmark_individual_ops(args, kwargs, warmup_runs, main_runs);
for (size_t i = 0; i < nodes_.size(); i++) {
const Node* node = nodes_[i].get_node();
const Node* node = nodes_[i].node();
std::cout << "Node #" << i << ": " << results.time_per_node[i]
<< " ms/iter, ";
node->print(std::cout, 0, nullptr, false);
@ -682,7 +770,7 @@ void StaticRuntime::benchmark(
if (planner_) {
std::cout << "Total memory managed: " << planner_->total_managed()
<< " bytes" << std::endl;
if (opts_.optimize_memory) {
if (static_module_.opts().optimize_memory) {
std::cout << "Total number of reused tensors: "
<< planner_->total_reused_tensors() << std::endl;
}
@ -697,11 +785,11 @@ float StaticRuntime::benchmark_model(
TORCH_CHECK(warmup_runs >= 0 && main_runs >= 1);
for (int i = 0; i < warmup_runs; i++) {
run(args, kwargs);
operator()(args, kwargs);
}
caffe2::Timer timer;
for (int i = 0; i < main_runs; i++) {
run(args, kwargs);
operator()(args, kwargs);
}
float millis = timer.MilliSeconds();
return millis / static_cast<float>(main_runs);
@ -727,10 +815,10 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
if (!kwargs.empty()) {
// This is not ideal
TORCH_CHECK(
module_->schema != nullptr,
static_module_.schema(),
"Schema is not available. Consider creating the Static Runtime "
"with StaticRuntime(const torch::jit::Module& m) instead.");
module_->schema->checkAndNormalizeInputs(stack, kwargs);
"with StaticModule(const torch::jit::Module& m) instead.");
static_module_.schema()->checkAndNormalizeInputs(stack, kwargs);
}
for (size_t i = 0; i < stack.size(); i++) {
Input(i) = stack[i];
@ -739,7 +827,7 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
// warmup runs
for (int i = 0; i < warmup_runs; i++) {
run(args, kwargs);
operator()(args, kwargs);
}
// main runs
@ -761,10 +849,13 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
results.time_per_node[i] += millis;
}
timer.Start();
if (opts_.cleanup_activations) {
if (static_module_.opts().cleanup_activations) {
if (!planner_) {
planner_ = std::make_unique<MemoryPlanner>(
this, shared_values_, external_values_, opts_.enable_out_variant);
this,
static_module_.shared_values(),
static_module_.external_values(),
static_module_.opts().enable_out_variant);
}
planner_->deallocate();
// clean up owning refs of input tensors
@ -778,10 +869,10 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
timer.Start();
// no need to keep references of outputs in static runtime anymore
c10::IValue output;
if (num_outputs() > 1) {
if (static_module_.num_outputs() > 1) {
std::vector<c10::IValue> outputs;
outputs.reserve(num_outputs());
for (auto i = 0; i < num_outputs(); ++i) {
outputs.reserve(static_module_.num_outputs());
for (auto i = 0; i < static_module_.num_outputs(); ++i) {
// use move here. Otherwise, clean up outputs_[i] explicitly
outputs.emplace_back(std::move(*outputs_[i]));
}
@ -802,7 +893,7 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
// post processing
for (size_t i = 0; i < nodes_.size(); i++) {
const Node* node = nodes_[i].get_node();
const Node* node = nodes_[i].node();
std::string kind = std::string(node->kind().toQualString());
results.time_per_node[i] /= static_cast<float>(main_runs);
results.time_per_node_type[kind] += results.time_per_node[i];
@ -820,7 +911,7 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
}
void StaticRuntime::check_for_memory_leak(bool output_returned) {
if (!opts_.cleanup_activations) {
if (!static_module_.opts().cleanup_activations) {
return;
}
@ -832,16 +923,16 @@ void StaticRuntime::check_for_memory_leak(bool output_returned) {
std::unordered_set<const IValue*> output_ivalues(
outputs_.begin(), outputs_.end());
for (size_t n = 0; n < nodes_.size(); n++) {
auto& node = nodes_[n];
for (size_t i = 0; i < node.outputs().size(); i++) {
const IValue* ival = &node.Output(i);
auto& pnode = nodes_[n];
for (size_t i = 0; i < pnode.outputs().size(); i++) {
const IValue* ival = &pnode.Output(i);
const std::string error_msg = "Output " + c10::to_string(i) +
" of node " + c10::to_string(n) + " was not cleaned up";
if (output_ivalues.count(ival) == 0) {
// check for intermediates
if (!ival->isNone()) {
TORCH_CHECK(
ival->isTensor() || canOptimizeConstruct(node.get_node()),
ival->isTensor() || canOptimizeConstruct(pnode.node()),
error_msg);
if (ival->isTensor()) {
const auto& t = ival->toTensor();
@ -870,17 +961,17 @@ MemoryPlanner::MemoryPlanner(
// collect register indices of outputs of ops with out variant
std::unordered_set<const Value*> managed_values;
std::unordered_set<IValue*> unmanaged_ivalue_set;
for (ProcessedNode& pnode : runtime->get_nodes()) {
if (canReuseInputsOutputs(pnode.get_node())) {
for (ProcessedNode& pnode : runtime->nodes()) {
if (canReuseInputsOutputs(pnode.node())) {
for (auto i = 0; i < pnode.outputs().size(); ++i) {
// Types are stored in the underlying TorchScript IR
const Value* out_v = pnode.get_node()->outputs()[i];
const Value* out_v = pnode.node()->outputs()[i];
IValue& out = pnode.Output(i);
const auto& type = out_v->type();
if (out_variants && !external_values.count(out_v)) {
if (type->cast<TensorType>()) {
managed_values.insert(out_v);
} else if (canOptimizeConstruct(pnode.get_node())) {
} else if (canOptimizeConstruct(pnode.node())) {
// We "leak" containers of this type
} else {
unmanaged_ivalue_set.insert(&out);
@ -896,10 +987,8 @@ MemoryPlanner::MemoryPlanner(
}
}
const InferenceModule* module = runtime->get_inference_module();
// remove model outputs from managed_values and unmanaged_ivalue_set
for (Value* output : module->graph->outputs()) {
for (const Value* output : runtime->graph().outputs()) {
managed_values.erase(output);
}
for (IValue* output : runtime->outputs()) {
@ -918,10 +1007,10 @@ MemoryPlanner::MemoryPlanner(
std::unordered_set<c10::StorageImpl*> managed_storage_impls;
// Snapshot of the current memory state
for (const auto& pnode : runtime->get_nodes()) {
for (const auto& pnode : runtime->nodes()) {
for (auto i = 0; i < pnode.outputs().size(); ++i) {
const auto& ival = pnode.outputs()[i];
const auto* val = pnode.get_node()->outputs()[i];
const auto* val = pnode.node()->outputs()[i];
if (managed_values.count(val)) {
TORCH_CHECK(ival.isTensor());
auto* impl = ival.toTensor().storage().unsafeGetStorageImpl();

View File

@ -6,42 +6,46 @@
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/inliner.h>
namespace torch {
namespace jit {
struct TORCH_API StaticRuntimeOptions {
struct TORCH_API StaticModuleOptions {
bool cleanup_activations{true};
bool enable_out_variant{true};
bool optimize_memory{true};
};
/// Static runime supports two execution modes.
/// The static runime supports two execution modes.
///
/// Mode 1: single-threaded with no parallelism except for intra-op parallelism
/// For this mode, you can do either:
/// @code
/// // m is the TorchScript module
/// auto runtime = StaticRuntime(m, opts);
/// auto output = runtime.run(args, kwargs);
/// // m is a TorchScript module
/// auto module = StaticModule(m, opts);
/// auto output = module(args, kwargs);
/// @endcode
///
/// or
///
/// @code
/// auto mod = PrepareForStaticRuntime(m);
/// auto runtime = StaticRuntime(mod, opts);
/// auto output = runtime.run(args, kwargs);
/// // g is the TorchScript graph
/// auto module = StaticModule(g, opts);
/// auto output = module(args, kwargs);
/// @endcode
///
/// Mode 2: similar to data parallelism, run the same model for different inputs
/// on different threads at the same time. In this case, run
/// PrepareForStaticRuntime to prepare the graph for Static Runtime. You
/// should have one InferenceModule instance per model, and one Static Runtime
/// instance per running thread. To avoiding creating StaticRuntime on the fly,
/// use a synchronized stack (i.e. boost::lockfree::stack) to cache all the
/// Static Runtime instances in your code.
/// on different threads at the same time.
/// You should have one StaticModule per model, and one StaticRuntime instance
/// per running thread. To avoiding creating StaticRuntimes on the fly, use a
/// synchronized stack (i.e. boost::lockfree::stack) to cache all the
/// StaticRuntime instances in your code.
/// @code
/// // initialization
/// auto mod = PrepareForStaticRuntime(m);
/// auto module = std::make_shared<StaticModule>(m, opts);
///
/// // 128 is good for most cases. Pick a number that works for you
/// boost::lockfree::stack<std::shared_ptr<StaticRuntime>,
/// boost::lockfree::fixed_sized<true>> pool(128);
@ -50,58 +54,113 @@ struct TORCH_API StaticRuntimeOptions {
/// std::shared_ptr<StaticRuntime> runtime = nullptr;
/// pool.pop(runtime);
/// if (!runtime) {
/// runtime = std::make_shared<StaticRuntime>(mod, opts);
/// // holds a reference to the underlying module
/// // but does its own memory management
/// runtime = std::make_shared<StaticRuntime>(*module);
/// }
/// auto output = runtime->run(args, kwargs);
/// auto output = runtime(args, kwargs);
/// pool.push(runtime);
/// @endcode
///
// Group readonly data structures into InferenceModule
struct TORCH_API InferenceModule {
public:
explicit InferenceModule(const torch::jit::Module& m);
explicit InferenceModule(std::shared_ptr<torch::jit::Graph> g);
torch::jit::Module module;
std::shared_ptr<torch::jit::Graph> graph;
std::unique_ptr<c10::FunctionSchema> schema;
private:
void init();
};
TORCH_API void PrepareGraphForStaticRuntime(
std::shared_ptr<torch::jit::Graph> g);
inline TORCH_API std::shared_ptr<InferenceModule> PrepareForStaticRuntime(
const torch::jit::Module& m) {
return std::make_shared<InferenceModule>(m);
}
inline TORCH_API std::shared_ptr<InferenceModule> PrepareForStaticRuntime(
const std::shared_ptr<torch::jit::Graph>& g) {
return std::make_shared<InferenceModule>(g);
}
class MemoryPlanner;
class ProcessedNode;
class StaticRuntime;
class TORCH_API StaticModule {
public:
explicit StaticModule(
std::shared_ptr<torch::jit::Graph> g,
const StaticModuleOptions& opts = StaticModuleOptions());
explicit StaticModule(
const torch::jit::Module& m,
const StaticModuleOptions& opts = StaticModuleOptions());
private:
explicit StaticModule(
std::pair<
std::shared_ptr<torch::jit::Graph>,
c10::optional<c10::FunctionSchema>> graph_and_schema,
const StaticModuleOptions& opts);
public:
std::vector<at::Tensor> operator()(const std::vector<at::Tensor>& inps);
// This interface only works if StaticModule was initialized
// with a TorchScript module, otherwise use the above interface
c10::IValue operator()(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs);
const Graph& graph() const {
return *graph_;
}
const StaticModuleOptions& opts() const;
size_t num_inputs() const;
size_t num_outputs() const;
inline const std::unordered_map<int, std::vector<std::pair<int, int>>>&
index_map() const {
return index_map_;
}
inline const std::vector<std::pair<int, int>>& output_indices() const {
return output_indices_;
}
inline const std::vector<IValue>& constants() const {
return constants_;
}
inline const std::vector<ProcessedNode>& nodes() const {
return nodes_;
}
inline const c10::optional<c10::FunctionSchema>& schema() const {
return schema_;
}
inline const std::unordered_map<const Value*, std::vector<const Value*>>&
shared_values() const {
return shared_values_;
}
inline const std::unordered_set<const Value*>& external_values() const {
return external_values_;
}
StaticRuntime& runtime();
private:
// Static runtime states
StaticModuleOptions opts_;
std::unique_ptr<StaticRuntime> cached_runtime_;
// IValue table (including inputs, outputs, intermediates, and weights)
std::vector<IValue> constants_;
std::vector<std::pair<int, int>> output_indices_;
std::unordered_map<int, std::vector<std::pair<int, int>>> index_map_;
// The nodes we need to run
std::vector<ProcessedNode> nodes_;
// Output of liveness analyis. A mapping from a value to the set of values
// with which it could potentially share memory.
std::unordered_map<const Value*, std::vector<const Value*>> shared_values_;
std::unordered_set<const Value*> external_values_;
// Original input
std::shared_ptr<torch::jit::Graph> graph_;
c10::optional<c10::FunctionSchema> schema_;
};
class TORCH_API StaticRuntime {
public:
// InferenceModule m is created by PrepareForStaticRuntime
explicit StaticRuntime(
std::shared_ptr<InferenceModule> m,
const StaticRuntimeOptions& opts = StaticRuntimeOptions());
explicit StaticRuntime(const StaticModule& sm);
// m is unoptimized
explicit StaticRuntime(
const torch::jit::Module& m,
const StaticRuntimeOptions& opts = StaticRuntimeOptions());
std::vector<at::Tensor> operator()(const std::vector<at::Tensor>& inps);
std::vector<at::Tensor> run(const std::vector<at::Tensor>& inps);
// This interface only works module_ that has a non-empty TorchScript module
// member; otherwise use the above interface
c10::IValue run(
// This interface only works if StaticModule was initialized
// with a TorchScript module, otherwise use the above interface
c10::IValue operator()(
const std::vector<c10::IValue>& args,
const std::unordered_map<std::string, c10::IValue>& kwargs);
@ -135,52 +194,6 @@ class TORCH_API StaticRuntime {
const int warmup_runs,
const int main_runs);
const InferenceModule* get_inference_module() {
return module_.get();
}
const std::vector<ProcessedNode>& get_nodes() const {
return nodes_;
}
std::vector<ProcessedNode>& get_nodes() {
return nodes_;
}
size_t num_inputs() const {
return inputs_.size();
}
size_t num_outputs() const {
return outputs_.size();
}
inline const std::vector<IValue*>& outputs() const {
return outputs_;
}
void check_for_memory_leak(bool output_returned = true);
private:
// Static runtime states
std::shared_ptr<InferenceModule> module_;
StaticRuntimeOptions opts_;
// IValue table (including inputs, outputs, intermediates, and weights)
std::vector<IValue> constants_;
std::vector<IValue> inputs_;
std::vector<IValue*> outputs_;
// The nodes we need to run
std::vector<ProcessedNode> nodes_;
// Output of liveness analyis. A mapping from a value to the set of values
// with which it could potentially share memory.
std::unordered_map<const Value*, std::vector<const Value*>> shared_values_;
std::unordered_set<const Value*> external_values_;
// Memory planning is only enabled if opts_.cleanup_activations is true.
// Otherwise, the memory used by activations is cached inside the static
// runtime.
std::unique_ptr<MemoryPlanner> planner_;
// Input is readwrite
IValue& Input(size_t i) {
DCHECK(i < inputs_.size());
@ -192,6 +205,34 @@ class TORCH_API StaticRuntime {
DCHECK(i < outputs_.size());
return *outputs_[i];
}
const std::vector<IValue*> outputs() const {
return outputs_;
}
inline const std::vector<ProcessedNode>& nodes() const {
return nodes_;
}
inline std::vector<ProcessedNode>& nodes() {
return nodes_;
}
const Graph& graph() const {
return static_module_.graph();
}
void check_for_memory_leak(bool output_returned = true);
private:
// Memory planning is only enabled if sm->opts().cleanup_activations is true.
// Otherwise, the memory used by activations is cached inside the static
// runtime.
std::unique_ptr<MemoryPlanner> planner_;
std::vector<IValue> inputs_;
std::vector<IValue*> outputs_;
const StaticModule& static_module_;
std::vector<ProcessedNode> nodes_;
};
/// There are three types of ops in a processed graph in Static Runtime:
@ -223,8 +264,7 @@ class MemoryPlanner {
public:
explicit MemoryPlanner(
StaticRuntime* runtime,
const std::unordered_map<const Value*, std::vector<const Value*>>&
should_share,
const std::unordered_map<const Value*, std::vector<const Value*>>&,
const std::unordered_set<const Value*>& external_values,
bool out_variants);
@ -254,6 +294,7 @@ class MemoryPlanner {
class ProcessedNode {
public:
ProcessedNode() = default;
ProcessedNode(
Node* n,
std::vector<const IValue*>&& inputs,
@ -261,10 +302,14 @@ class ProcessedNode {
void run();
Node* get_node() const {
Node* node() const {
return node_;
}
inline void set_input(size_t index, const IValue* ival) {
inputs_[index] = ival;
}
// Input is readonly
const IValue& Input(size_t i) const {
DCHECK(i < inputs_.size());
@ -295,7 +340,7 @@ class ProcessedNode {
std::function<void(ProcessedNode*)> fn_;
std::function<void(ProcessedNode*)> native_fn_;
std::vector<const IValue*> inputs_; // unowned
std::vector<IValue> outputs_; // TODO make list for safety
std::vector<IValue> outputs_;
};
} // namespace jit

View File

@ -7,11 +7,11 @@
namespace torch {
namespace jit {
void initStaticRuntimeBindings(PyObject* module) {
void initStaticModuleBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
py::class_<StaticRuntime> static_runtime(m, "StaticRuntime");
py::class_<StaticModule> static_module(m, "StaticModule");
py::class_<StaticRuntime::IndividualMetrics>(
static_runtime, "IndividualMetrics")
static_module, "IndividualMetrics")
.def_readonly("setup_time", &StaticRuntime::IndividualMetrics::setup_time)
.def_readonly(
"memory_alloc_time",
@ -34,25 +34,25 @@ void initStaticRuntimeBindings(PyObject* module) {
.def_readonly(
"instances_per_node_type",
&StaticRuntime::IndividualMetrics::instances_per_node_type);
static_runtime
static_module
.def(
"run",
"__call__",
py::overload_cast<const std::vector<at::Tensor>&>(
&StaticRuntime::run))
&StaticModule::operator()))
.def(
"run",
[](StaticRuntime& self,
"__call__",
[](StaticModule& self,
const std::vector<at::Tensor>& args,
const std::unordered_map<std::string, at::Tensor>& kwargs) {
std::vector<c10::IValue> arg_ivalues{args.begin(), args.end()};
std::unordered_map<std::string, c10::IValue> kwarg_ivalues{
kwargs.begin(), kwargs.end()};
c10::IValue ret = self.run(arg_ivalues, kwarg_ivalues);
c10::IValue ret = self(arg_ivalues, kwarg_ivalues);
return toPyObject(ret);
})
.def(
"benchmark",
[](StaticRuntime& self,
[](StaticModule& self,
const std::vector<at::Tensor>& args,
const std::unordered_map<std::string, at::Tensor>& kwargs,
const int warmup_runs,
@ -60,11 +60,12 @@ void initStaticRuntimeBindings(PyObject* module) {
std::vector<c10::IValue> arg_ivalues{args.begin(), args.end()};
std::unordered_map<std::string, c10::IValue> kwarg_ivalues{
kwargs.begin(), kwargs.end()};
self.benchmark(arg_ivalues, kwarg_ivalues, warmup_runs, main_runs);
self.runtime().benchmark(
arg_ivalues, kwarg_ivalues, warmup_runs, main_runs);
})
.def(
"benchmark_individual_ops",
[](StaticRuntime& self,
[](StaticModule& self,
const std::vector<at::Tensor>& args,
const std::unordered_map<std::string, at::Tensor>& kwargs,
const int warmup_runs,
@ -72,21 +73,17 @@ void initStaticRuntimeBindings(PyObject* module) {
std::vector<c10::IValue> arg_ivalues{args.begin(), args.end()};
std::unordered_map<std::string, c10::IValue> kwarg_ivalues{
kwargs.begin(), kwargs.end()};
return self.benchmark_individual_ops(
return self.runtime().benchmark_individual_ops(
arg_ivalues, kwarg_ivalues, warmup_runs, main_runs);
});
m.def(
"_jit_to_static_runtime",
[](std::shared_ptr<torch::jit::Graph> g) {
return StaticRuntime(PrepareForStaticRuntime(g));
})
"_jit_to_static_module",
[](std::shared_ptr<torch::jit::Graph> g) { return StaticModule(g); })
.def(
"_jit_to_static_runtime",
[](const torch::jit::Module& m) {
return StaticRuntime(PrepareForStaticRuntime(m));
})
"_jit_to_static_module",
[](const torch::jit::Module& module) { return StaticModule(module); })
.def(
"_fuse_to_static_runtime",
"_fuse_to_static_module",
[](torch::jit::Module& module) {
module.eval();
module = freeze_module(module);
@ -95,7 +92,7 @@ void initStaticRuntimeBindings(PyObject* module) {
auto graph = method.graph();
fuseStaticSubgraphs(graph);
})
.def("_fuse_to_static_runtime", [](std::shared_ptr<torch::jit::Graph> g) {
.def("_fuse_to_static_module", [](std::shared_ptr<torch::jit::Graph> g) {
fuseStaticSubgraphs(g);
});
}

View File

@ -3,7 +3,7 @@
namespace torch {
namespace jit {
void initStaticRuntimeBindings(PyObject* module);
void initStaticModuleBindings(PyObject* module);
} // namespace jit
} // namespace torch

View File

@ -643,11 +643,9 @@ REGISTER_OPERATOR_FUNCTOR(aten::clone, aten_clone, [](Node* n) -> SROperator {
at::native::copy_(out_t, in0_t, false);
};
});
REGISTER_OPERATOR_FUNCTOR_OPT(
REGISTER_OPERATOR_FUNCTOR(
quantized::embedding_bag_byte_rowwise_offsets,
quantized_embedding_bag_byte_rowwise_offsets,
false, // don't reuse byte inputs
true,
[](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const auto& weight = p_node->Input(0).toTensor();
@ -677,11 +675,9 @@ REGISTER_OPERATOR_FUNCTOR_OPT(
include_last_offset);
};
});
REGISTER_OPERATOR_FUNCTOR_OPT(
REGISTER_OPERATOR_FUNCTOR(
quantized::embedding_bag_4bit_rowwise_offsets,
embedding_bag_4bit_rowwise_offsets,
false, // don't reuse byte inputs
true,
[](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
const auto& weight = p_node->Input(0).toTensor();
@ -918,7 +914,7 @@ std::function<void(ProcessedNode*)> getNativeOperation(Node* n) {
stack.emplace_back(p_node->Input(i));
}
// run op
auto* node = p_node->get_node();
auto* node = p_node->node();
const auto& type = node->output()->type()->expect<TupleType>();
if (type->name().has_value()) {
namedTupleConstruct(stack, type, node->inputs().size());
@ -940,7 +936,7 @@ std::function<void(ProcessedNode*)> getNativeOperation(Node* n) {
// run op
listConstruct(
stack,
p_node->get_node()->output()->type()->expectRef<ListType>(),
p_node->node()->output()->type()->expectRef<ListType>(),
p_node->inputs().size());
// put output back
p_node->Output(0) = std::move(stack[0]);

View File

@ -27,24 +27,15 @@ C10_DECLARE_REGISTRY(SROperatorRegistry, SROperatorFunctor);
// TODO: reuse_inp reuse_out can be deprecated with further analysis
// try to avoid this API.
#define REGISTER_OPERATOR_FUNCTOR_OPT(name, id, reuse_inp, reuse_out, ...) \
struct SROperatorFunctor_##id : public SROperatorFunctor { \
const SROpFunctor fn = __VA_ARGS__; \
bool CanReuseInput() override { \
return reuse_inp; \
} \
bool CanReuseOutput() override { \
return reuse_out; \
} \
SROperator Generate(Node* n) override { \
return fn(n); \
} \
}; \
#define REGISTER_OPERATOR_FUNCTOR(name, id, ...) \
struct SROperatorFunctor_##id : public SROperatorFunctor { \
const SROpFunctor fn = __VA_ARGS__; \
SROperator Generate(Node* n) override { \
return fn(n); \
} \
}; \
C10_REGISTER_CLASS(SROperatorRegistry, name, SROperatorFunctor_##id);
#define REGISTER_OPERATOR_FUNCTOR(name, id, ...) \
REGISTER_OPERATOR_FUNCTOR_OPT(name, id, true, true, __VA_ARGS__)
#define REGISTER_VIEW_OPERATOR_FUNCTOR(name, id, ...) \
struct SROperatorFunctor_##id : public SROperatorFunctor { \
const SROpFunctor fn = __VA_ARGS__; \