mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[static runtime] Move all heavy constructor logic into InferenceModule (renamed to StaticModule) (#51564)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51564 Constructor logic was spread throughout InferenceModule and StaticRuntime. This diff unifies the two. After a lot of discussion on this diff D25961626 it became apparent that `clone` is uglier than a cheap StaticRuntime. This means StaticRuntime is effectively StaticModule and the only code in the new StaticRuntime is the `run` functions. ``` graph, schema = PrepareForStaticModule(torchscript_module) sm = StaticModule(graph, schema, options) sm(inputs) // or create many cheap runtimes with the module sr = StaticRuntime(sm) sr(inputs) ``` Changelist: - Rename InferenceModule StaticModule - Move all logic for construction into StaticModule - Create a new StaticRuntime that only has a unique memory planner (everything else is in StaticModule) - Update comments with explanation - Propagate all changes to predictor integration - Propagate all changes to python integration - Change semantics to be a bit more PyTorch-standard (no "run" calls, no "get_" getters). Test Plan: buck test //caffe2/test:static_runtime buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest Reviewed By: hlu1 Differential Revision: D25592967 fbshipit-source-id: 8233bed03137ce129137af2d44bce0095033ef0f
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5ebfabb310
commit
56f8379802
@ -75,8 +75,7 @@ static void BM_deep_wide_jit_profiling_executor(benchmark::State& state) {
|
||||
|
||||
static void BM_deep_wide_static(benchmark::State& state) {
|
||||
auto mod = getDeepAndWideSciptModel();
|
||||
auto g = torch::jit::PrepareForStaticRuntime(mod);
|
||||
torch::jit::StaticRuntime runtime(g);
|
||||
torch::jit::StaticModule smod(mod);
|
||||
|
||||
const int batch_size = state.range(0);
|
||||
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
|
||||
@ -85,21 +84,19 @@ static void BM_deep_wide_static(benchmark::State& state) {
|
||||
|
||||
std::vector<at::Tensor> inputs({ad_emb_packed, user_emb, wide});
|
||||
|
||||
runtime.run(inputs);
|
||||
smod(inputs);
|
||||
for (auto _ : state) {
|
||||
runtime.run(inputs);
|
||||
smod(inputs);
|
||||
}
|
||||
}
|
||||
|
||||
const std::shared_ptr<torch::jit::InferenceModule>& getStaticGraph() {
|
||||
static const std::shared_ptr<torch::jit::InferenceModule> g =
|
||||
torch::jit::PrepareForStaticRuntime(getDeepAndWideSciptModel());
|
||||
return g;
|
||||
torch::jit::StaticRuntime getStaticRuntime() {
|
||||
static auto smod = std::make_shared<torch::jit::StaticModule>(getDeepAndWideSciptModel());
|
||||
return torch::jit::StaticRuntime(*smod);
|
||||
}
|
||||
|
||||
static void BM_deep_wide_static_threaded(benchmark::State& state) {
|
||||
auto g = getStaticGraph();
|
||||
torch::jit::StaticRuntime runtime(g);
|
||||
auto sr = getStaticRuntime();
|
||||
|
||||
const int batch_size = 1; // state.range(0);
|
||||
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
|
||||
@ -108,39 +105,38 @@ static void BM_deep_wide_static_threaded(benchmark::State& state) {
|
||||
|
||||
std::vector<at::Tensor> inputs({ad_emb_packed, user_emb, wide});
|
||||
|
||||
sr(inputs);
|
||||
for (auto _ : state) {
|
||||
runtime.run(inputs);
|
||||
sr(inputs);
|
||||
}
|
||||
}
|
||||
|
||||
static void BM_leaky_relu_const(benchmark::State& state) {
|
||||
auto mod = getLeakyReLUConstScriptModel();
|
||||
auto g = torch::jit::PrepareForStaticRuntime(mod);
|
||||
torch::jit::StaticRuntime runtime(g);
|
||||
torch::jit::StaticModule smod(mod);
|
||||
|
||||
const int batch_size = state.range(0);
|
||||
auto data = torch::randn({batch_size, num_features});
|
||||
std::vector<at::Tensor> inputs({data});
|
||||
|
||||
runtime.run(inputs);
|
||||
smod(inputs);
|
||||
for (auto _ : state) {
|
||||
runtime.run(inputs);
|
||||
smod(inputs);
|
||||
}
|
||||
}
|
||||
|
||||
static void BM_leaky_relu(benchmark::State& state) {
|
||||
auto mod = getLeakyReLUScriptModel();
|
||||
auto g = torch::jit::PrepareForStaticRuntime(mod);
|
||||
torch::jit::StaticRuntime runtime(g);
|
||||
torch::jit::StaticModule smod(mod);
|
||||
|
||||
const int batch_size = state.range(0);
|
||||
auto neg_slope = torch::randn(1);
|
||||
auto data = torch::randn({batch_size, num_features});
|
||||
std::vector<at::Tensor> inputs({data, neg_slope[0]});
|
||||
|
||||
runtime.run(inputs);
|
||||
smod(inputs);
|
||||
for (auto _ : state) {
|
||||
runtime.run(inputs);
|
||||
smod(inputs);
|
||||
}
|
||||
}
|
||||
|
||||
@ -149,10 +145,9 @@ BENCHMARK(BM_leaky_relu_const)->RangeMultiplier(8)->Ranges({{1, 20}});
|
||||
|
||||
static void BM_long_static_memory_optimization(benchmark::State& state) {
|
||||
auto mod = getLongScriptModel();
|
||||
auto g = torch::jit::PrepareForStaticRuntime(mod);
|
||||
torch::jit::StaticRuntimeOptions opts;
|
||||
torch::jit::StaticModuleOptions opts;
|
||||
opts.optimize_memory = state.range(1);
|
||||
torch::jit::StaticRuntime runtime(g, opts);
|
||||
torch::jit::StaticModule smod(mod, opts);
|
||||
|
||||
const auto N = state.range(0);
|
||||
auto a = torch::randn({N, N});
|
||||
@ -160,9 +155,9 @@ static void BM_long_static_memory_optimization(benchmark::State& state) {
|
||||
auto c = torch::randn({N, N});
|
||||
std::vector<at::Tensor> inputs({a, b, c});
|
||||
|
||||
runtime.run(inputs);
|
||||
smod(inputs);
|
||||
for (auto _ : state) {
|
||||
runtime.run(inputs);
|
||||
smod(inputs);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,9 +70,9 @@ void testStaticRuntime(
|
||||
|
||||
auto expect = module.forward(args);
|
||||
|
||||
StaticRuntime runtime(module);
|
||||
auto actual = runtime.run(args, {});
|
||||
runtime.check_for_memory_leak();
|
||||
torch::jit::StaticModule smodule(module);
|
||||
auto actual = smodule(args, {});
|
||||
smodule.runtime().check_for_memory_leak();
|
||||
|
||||
if (expect.isTuple()) {
|
||||
compareTensorLists(
|
||||
@ -187,10 +187,9 @@ TEST(StaticRuntime, LongModel) {
|
||||
|
||||
// run static runtime
|
||||
std::vector<at::Tensor> input_tensors({a, b, c});
|
||||
auto g = torch::jit::PrepareForStaticRuntime(mod);
|
||||
torch::jit::StaticRuntime runtime(g);
|
||||
at::Tensor output_2 = runtime.run(input_tensors)[0];
|
||||
runtime.check_for_memory_leak();
|
||||
torch::jit::StaticModule smod(mod);
|
||||
at::Tensor output_2 = smod(input_tensors)[0];
|
||||
smod.runtime().check_for_memory_leak();
|
||||
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
||||
}
|
||||
|
||||
@ -206,10 +205,9 @@ TEST(StaticRuntime, TrivialModel) {
|
||||
|
||||
// run static runtime
|
||||
std::vector<at::Tensor> input_tensors({a, b, c});
|
||||
auto g = torch::jit::PrepareForStaticRuntime(mod);
|
||||
torch::jit::StaticRuntime runtime(g);
|
||||
at::Tensor output_2 = runtime.run(input_tensors)[0];
|
||||
runtime.check_for_memory_leak();
|
||||
torch::jit::StaticModule smod(mod);
|
||||
at::Tensor output_2 = smod(input_tensors)[0];
|
||||
smod.runtime().check_for_memory_leak();
|
||||
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
||||
}
|
||||
|
||||
@ -223,10 +221,9 @@ TEST(StaticRuntime, LeakyReLU) {
|
||||
|
||||
// run static runtime
|
||||
std::vector<at::Tensor> input_tensors({inputs});
|
||||
auto g = torch::jit::PrepareForStaticRuntime(mod);
|
||||
torch::jit::StaticRuntime runtime(g);
|
||||
at::Tensor output_2 = runtime.run(input_tensors)[0];
|
||||
runtime.check_for_memory_leak();
|
||||
torch::jit::StaticModule smod(mod);
|
||||
at::Tensor output_2 = smod(input_tensors)[0];
|
||||
smod.runtime().check_for_memory_leak();
|
||||
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
||||
}
|
||||
|
||||
@ -234,8 +231,7 @@ TEST(StaticRuntime, DeepWide) {
|
||||
const int embedding_size = 32;
|
||||
const int num_features = 50;
|
||||
torch::jit::Module mod = getDeepAndWideSciptModel();
|
||||
auto g = torch::jit::PrepareForStaticRuntime(mod);
|
||||
torch::jit::StaticRuntime runtime(g);
|
||||
torch::jit::StaticModule smod(mod);
|
||||
|
||||
for (int batch_size : {1, 8, 32}) {
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
@ -249,8 +245,8 @@ TEST(StaticRuntime, DeepWide) {
|
||||
|
||||
// run static runtime
|
||||
std::vector<at::Tensor> input_tensors({ad_emb_packed, user_emb, wide});
|
||||
at::Tensor output_2 = runtime.run(input_tensors)[0];
|
||||
runtime.check_for_memory_leak();
|
||||
at::Tensor output_2 = smod(input_tensors)[0];
|
||||
smod.runtime().check_for_memory_leak();
|
||||
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
||||
}
|
||||
}
|
||||
@ -260,7 +256,7 @@ TEST(StaticRuntime, KWargsAPI_1) {
|
||||
const int embedding_size = 32;
|
||||
const int num_features = 50;
|
||||
auto module = getDeepAndWideSciptModel();
|
||||
torch::jit::StaticRuntime runtime(module);
|
||||
torch::jit::StaticModule smod(module);
|
||||
|
||||
for (int batch_size : {1, 8, 32}) {
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
@ -274,8 +270,8 @@ TEST(StaticRuntime, KWargsAPI_1) {
|
||||
at::Tensor output_1 = getTensor(module.forward(inputs));
|
||||
|
||||
// run static runtime
|
||||
c10::IValue output_ivalue = runtime.run(inputs, {});
|
||||
runtime.check_for_memory_leak();
|
||||
c10::IValue output_ivalue = smod(inputs, {});
|
||||
smod.runtime().check_for_memory_leak();
|
||||
|
||||
at::Tensor output_2 = getTensor(output_ivalue);
|
||||
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
||||
@ -300,8 +296,7 @@ TEST(StaticRuntime, KWargsAPI_2) {
|
||||
const int embedding_size = 32;
|
||||
const int num_features = 50;
|
||||
auto module = getDeepAndWideSciptModel();
|
||||
auto g = torch::jit::PrepareForStaticRuntime(module);
|
||||
torch::jit::StaticRuntime runtime(module);
|
||||
torch::jit::StaticModule smod(module);
|
||||
|
||||
for (int batch_size : {1, 8, 32}) {
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
@ -319,8 +314,8 @@ TEST(StaticRuntime, KWargsAPI_2) {
|
||||
{"wide", wide}});
|
||||
|
||||
// run static runtime
|
||||
c10::IValue output_ivalue = runtime.run({}, kwargs);
|
||||
runtime.check_for_memory_leak();
|
||||
c10::IValue output_ivalue = smod({}, kwargs);
|
||||
smod.runtime().check_for_memory_leak();
|
||||
|
||||
at::Tensor output_2 = getTensor(output_ivalue);
|
||||
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
||||
@ -343,14 +338,14 @@ TEST(StaticRuntime, CleanUpMemory) {
|
||||
const int embedding_size = 32;
|
||||
const int num_features = 50;
|
||||
torch::jit::Module mod = getDeepAndWideSciptModel();
|
||||
auto g = torch::jit::PrepareForStaticRuntime(mod);
|
||||
torch::jit::StaticModule smod(mod);
|
||||
|
||||
for (auto cleanup_memory : {true, false}) {
|
||||
for (auto enable_out_variant : {true, false}) {
|
||||
VLOG(1) << "cleanup_memory: " << cleanup_memory
|
||||
<< ", enable_out_variant: " << enable_out_variant;
|
||||
torch::jit::StaticRuntimeOptions opts{cleanup_memory, enable_out_variant};
|
||||
torch::jit::StaticRuntime runtime(g, opts);
|
||||
torch::jit::StaticModuleOptions opts{cleanup_memory, enable_out_variant};
|
||||
torch::jit::StaticModule smod(mod, opts);
|
||||
|
||||
for (int batch_size : {1, 8, 32}) {
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
@ -365,8 +360,8 @@ TEST(StaticRuntime, CleanUpMemory) {
|
||||
// run static runtime
|
||||
std::vector<at::Tensor> input_tensors(
|
||||
{ad_emb_packed, user_emb, wide});
|
||||
at::Tensor output_2 = runtime.run(input_tensors)[0];
|
||||
runtime.check_for_memory_leak();
|
||||
at::Tensor output_2 = smod(input_tensors)[0];
|
||||
smod.runtime().check_for_memory_leak();
|
||||
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
||||
}
|
||||
}
|
||||
|
@ -5,25 +5,25 @@ from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
class StaticRuntime:
|
||||
class StaticModule:
|
||||
def __init__(self, scripted):
|
||||
# this is an nn.Module
|
||||
if hasattr(scripted, "_c"):
|
||||
self.static_runtime = torch._C._jit_to_static_runtime(scripted._c)
|
||||
self.static_module = torch._C._jit_to_static_module(scripted._c)
|
||||
else:
|
||||
self.static_runtime = torch._C._jit_to_static_runtime(scripted.graph)
|
||||
self.static_module = torch._C._jit_to_static_module(scripted.graph)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if not kwargs:
|
||||
return self.static_runtime.run(args)
|
||||
return self.static_module(args)
|
||||
else:
|
||||
return self.static_runtime.run(args, kwargs)
|
||||
return self.static_module(args, kwargs)
|
||||
|
||||
def benchmark(self, args, kwargs, warmup_runs, main_runs):
|
||||
self.static_runtime.benchmark(args, kwargs, warmup_runs, main_runs)
|
||||
self.static_module.benchmark(args, kwargs, warmup_runs, main_runs)
|
||||
|
||||
def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs):
|
||||
return self.static_runtime.benchmark_individual_ops(
|
||||
return self.static_module.benchmark_individual_ops(
|
||||
args, kwargs, warmup_runs, main_runs
|
||||
)
|
||||
|
||||
@ -121,7 +121,7 @@ def output_graph(a, b, c, iters : int):
|
||||
d[i] = k + i
|
||||
return d
|
||||
|
||||
class TestStaticRuntime(TestCase):
|
||||
class TestStaticModule(TestCase):
|
||||
def test_multihead_attention_layer(self):
|
||||
HID_DIM = 256
|
||||
QUERY_LEN = 8
|
||||
@ -140,7 +140,7 @@ class TestStaticRuntime(TestCase):
|
||||
attention.eval()
|
||||
o_ref = attention(src, src, src, src_mask)
|
||||
|
||||
attention_a = StaticRuntime(attention)
|
||||
attention_a = StaticModule(attention)
|
||||
o_test = attention_a(src, src, src, src_mask)
|
||||
o_test_kw = attention_a(src, src, value=src, mask=src_mask)
|
||||
|
||||
@ -165,7 +165,7 @@ class TestStaticRuntime(TestCase):
|
||||
|
||||
attention.eval()
|
||||
attention = torch.jit.script(attention)
|
||||
attention_a = StaticRuntime(attention)
|
||||
attention_a = StaticModule(attention)
|
||||
|
||||
attention_a.benchmark([src, src, src, src_mask], {}, 2, 2)
|
||||
metrics = attention_a.benchmark_individual_ops(
|
||||
@ -179,9 +179,9 @@ class TestStaticRuntime(TestCase):
|
||||
ln_top = [100, 1024, 1024, 1024, 1]
|
||||
sigmoid_top = 3
|
||||
bot_l = create_mlp(ln_bot, sigmoid_bot)
|
||||
bot_l_acc = StaticRuntime(bot_l)
|
||||
bot_l_acc = StaticModule(bot_l)
|
||||
top_l = create_mlp(ln_top, sigmoid_top)
|
||||
top_l_acc = StaticRuntime(top_l)
|
||||
top_l_acc = StaticModule(top_l)
|
||||
with torch.no_grad():
|
||||
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
|
||||
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
|
||||
@ -206,7 +206,7 @@ class TestStaticRuntime(TestCase):
|
||||
s = torch.full((2, 2), 2)
|
||||
tg = torch.jit.script(trivial_graph)
|
||||
o_ref = tg(s, s, s)
|
||||
tg_a = StaticRuntime(tg)
|
||||
tg_a = StaticModule(tg)
|
||||
o_test = tg_a(s, s, s)[0]
|
||||
torch.testing.assert_allclose(o_ref, o_test)
|
||||
|
||||
@ -214,7 +214,7 @@ class TestStaticRuntime(TestCase):
|
||||
s = torch.randn(5, 5)
|
||||
tg = torch.jit.script(nn.LeakyReLU(0.1))
|
||||
o_ref = tg(s)
|
||||
tg_a = StaticRuntime(tg)
|
||||
tg_a = StaticModule(tg)
|
||||
o_test = tg_a(s)[0]
|
||||
torch.testing.assert_allclose(o_ref, o_test)
|
||||
|
||||
@ -222,7 +222,7 @@ class TestStaticRuntime(TestCase):
|
||||
s = torch.full((2, 2), 2)
|
||||
tg = torch.jit.script(trivial_graph)
|
||||
o_ref = tg(s, s, s)
|
||||
torch._C._fuse_to_static_runtime(tg.graph)
|
||||
torch._C._fuse_to_static_module(tg.graph)
|
||||
assert "StaticSubgraph" in str(tg.graph)
|
||||
o_test = tg(s, s, s)
|
||||
torch.testing.assert_allclose(o_ref, o_test)
|
||||
@ -245,7 +245,7 @@ class TestStaticRuntime(TestCase):
|
||||
attention.eval()
|
||||
o_ref = attention(src, src, src, src_mask)
|
||||
|
||||
torch._C._fuse_to_static_runtime(attention._c)
|
||||
torch._C._fuse_to_static_module(attention._c)
|
||||
o_test = attention(src, src, src, src_mask)
|
||||
|
||||
for a, b in zip(o_ref, o_test):
|
||||
@ -257,7 +257,7 @@ class TestStaticRuntime(TestCase):
|
||||
c = 4
|
||||
lg = torch.jit.script(loop_graph)
|
||||
o_ref = lg(a, b, c)
|
||||
torch._C._fuse_to_static_runtime(lg.graph)
|
||||
torch._C._fuse_to_static_module(lg.graph)
|
||||
assert "StaticSubgraph" in str(lg.graph)
|
||||
o_test = lg(a, b, c)
|
||||
torch.testing.assert_allclose(o_ref, o_test)
|
||||
@ -268,7 +268,7 @@ class TestStaticRuntime(TestCase):
|
||||
c = 4
|
||||
og = torch.jit.script(output_graph)
|
||||
o_ref = og(a, b, b, c)
|
||||
torch._C._fuse_to_static_runtime(og.graph)
|
||||
torch._C._fuse_to_static_module(og.graph)
|
||||
assert "StaticSubgraph" in str(og.graph)
|
||||
o_test = og(a, b, b, c)
|
||||
for i in o_ref.keys():
|
||||
|
@ -1335,7 +1335,7 @@ void initJITBindings(PyObject* module) {
|
||||
initTreeViewBindings(module);
|
||||
initJitScriptBindings(module);
|
||||
initJitBackendBindings(module);
|
||||
initStaticRuntimeBindings(module);
|
||||
initStaticModuleBindings(module);
|
||||
initTensorExprBindings(module);
|
||||
|
||||
setPrintHandler([](const std::string& str) {
|
||||
|
@ -2,7 +2,10 @@
|
||||
|
||||
#include <ATen/core/interned_strings.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/canonicalize.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/freeze_module.h>
|
||||
#include <torch/csrc/jit/passes/remove_mutation.h>
|
||||
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
||||
#include <torch/csrc/jit/runtime/custom_operator.h>
|
||||
#include <torch/csrc/jit/runtime/static/impl.h>
|
||||
@ -14,24 +17,30 @@ namespace jit {
|
||||
void createFusionGroups(Block* block, AliasDb* aliasDb);
|
||||
|
||||
void fuseStaticSubgraphs(std::shared_ptr<Graph> graph) {
|
||||
PrepareGraphForStaticRuntime(graph);
|
||||
Inline(*graph);
|
||||
ConstantPropagation(graph);
|
||||
Canonicalize(graph);
|
||||
ConstantPropagation(graph);
|
||||
RemoveTensorMutation(graph);
|
||||
ConstantPropagation(graph);
|
||||
EliminateDeadCode(graph);
|
||||
auto aliasDb = torch::make_unique<AliasDb>(graph);
|
||||
createFusionGroups(graph->block(), aliasDb.get());
|
||||
torch::jit::EliminateDeadCode(graph);
|
||||
}
|
||||
|
||||
Operation createStaticSubgraphRuntime(const Node* node) {
|
||||
auto g = torch::jit::PrepareForStaticRuntime(node->g(attr::Subgraph));
|
||||
auto runtime = std::make_shared<torch::jit::StaticRuntime>(g);
|
||||
auto num_inputs = runtime->num_inputs();
|
||||
return [runtime, num_inputs](Stack* stack) {
|
||||
auto g = node->g(attr::Subgraph);
|
||||
auto module = std::make_shared<torch::jit::StaticModule>(g);
|
||||
auto num_inputs = module->num_inputs();
|
||||
return [module, num_inputs](Stack* stack) {
|
||||
RECORD_FUNCTION("Static Runtime", std::vector<c10::IValue>());
|
||||
auto inps = torch::jit::last(stack, num_inputs);
|
||||
// TODO maybe avoid call to vec
|
||||
auto outputs = runtime->run(inps.vec(), {});
|
||||
auto outputs = (*module)(inps.vec(), {});
|
||||
torch::jit::drop(stack, num_inputs);
|
||||
|
||||
if (runtime->num_outputs() > 1) {
|
||||
if (module->num_outputs() > 1) {
|
||||
for (auto& o : outputs.toTuple()->elements()) {
|
||||
push_one(*stack, std::move(o));
|
||||
}
|
||||
|
@ -18,7 +18,9 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void PrepareGraphForStaticRuntime(std::shared_ptr<torch::jit::Graph> graph) {
|
||||
namespace {
|
||||
|
||||
void OptimizeGraph(std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
Inline(*graph);
|
||||
ConstantPropagation(graph);
|
||||
Canonicalize(graph);
|
||||
@ -26,11 +28,6 @@ void PrepareGraphForStaticRuntime(std::shared_ptr<torch::jit::Graph> graph) {
|
||||
RemoveTensorMutation(graph);
|
||||
ConstantPropagation(graph);
|
||||
EliminateDeadCode(graph);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void OptimizeGraph(std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
PrepareGraphForStaticRuntime(graph);
|
||||
FuseInferenceOpsForSparseNN(graph);
|
||||
|
||||
// TODO: we can avoid this guard by moving operations
|
||||
@ -82,11 +79,10 @@ void RemoveSelfFromGraphInput(std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
}
|
||||
|
||||
// remove "self" from function schema
|
||||
std::unique_ptr<c10::FunctionSchema> RemoveSelfFromSchema(
|
||||
const c10::FunctionSchema& s) {
|
||||
c10::FunctionSchema RemoveSelfFromSchema(const c10::FunctionSchema& s) {
|
||||
TORCH_CHECK(s.arguments().size() >= 1 && s.arguments()[0].name() == "self");
|
||||
std::vector<Argument> args({s.arguments().begin() + 1, s.arguments().end()});
|
||||
return std::make_unique<c10::FunctionSchema>(s.cloneWithArguments(args));
|
||||
return s.cloneWithArguments(args);
|
||||
}
|
||||
|
||||
bool mayContainAlias(AliasDb& db, const Value* a, const Value* b) {
|
||||
@ -427,74 +423,78 @@ std::unordered_map<const Value*, std::vector<const Value*>> FindShared(
|
||||
|
||||
return shared;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void InferenceModule::init() {
|
||||
void PrepareGraphForStaticModule(std::shared_ptr<torch::jit::Graph> graph) {
|
||||
OptimizeGraph(graph);
|
||||
CheckGraphEligibility(graph);
|
||||
RemoveSelfFromGraphInput(graph);
|
||||
}
|
||||
|
||||
InferenceModule::InferenceModule(const torch::jit::Module& m)
|
||||
: module(m.copy()), graph(nullptr), schema(nullptr) {
|
||||
std::pair<std::shared_ptr<Graph>, c10::optional<c10::FunctionSchema>>
|
||||
PrepareForStaticModule(const torch::jit::Module& m) {
|
||||
auto module = m.copy();
|
||||
module.eval();
|
||||
|
||||
Method method = module.get_method("forward");
|
||||
graph = method.graph();
|
||||
auto graph = method.graph();
|
||||
// Move this pass to before running the freeze_module pass so that the
|
||||
// sigrid_hash_compute_multipler_shift op can be precomputed and the results
|
||||
// being folded into the module as constants. This is required to enable the
|
||||
// ClipRangesGatherRangesX2SigridHashPrecompute pass. See D26833478
|
||||
SplitOutPrecomputeOpsForSparseNN(graph);
|
||||
|
||||
module = freeze_module(module);
|
||||
method = module.get_method("forward");
|
||||
graph = method.graph();
|
||||
graph = module.get_method("forward").graph();
|
||||
PrepareGraphForStaticModule(graph);
|
||||
|
||||
const c10::FunctionSchema& s = method.function().getSchema();
|
||||
schema = RemoveSelfFromSchema(s);
|
||||
|
||||
init();
|
||||
c10::FunctionSchema s = RemoveSelfFromSchema(method.function().getSchema());
|
||||
return std::make_pair(graph, s);
|
||||
}
|
||||
|
||||
InferenceModule::InferenceModule(std::shared_ptr<torch::jit::Graph> g)
|
||||
: module(), graph(std::move(g)), schema(nullptr) {
|
||||
init();
|
||||
std::pair<std::shared_ptr<Graph>, c10::optional<c10::FunctionSchema>>
|
||||
PrepareForStaticModule(std::shared_ptr<torch::jit::Graph> graph) {
|
||||
PrepareGraphForStaticModule(graph);
|
||||
return std::make_pair(graph, c10::nullopt);
|
||||
}
|
||||
|
||||
StaticRuntime::StaticRuntime(
|
||||
StaticModule::StaticModule(
|
||||
std::shared_ptr<torch::jit::Graph> g,
|
||||
const StaticModuleOptions& opts)
|
||||
: StaticModule(PrepareForStaticModule(g), opts) {}
|
||||
|
||||
StaticModule::StaticModule(
|
||||
const torch::jit::Module& m,
|
||||
const StaticRuntimeOptions& opts)
|
||||
: StaticRuntime(PrepareForStaticRuntime(m), opts) {}
|
||||
const StaticModuleOptions& opts)
|
||||
: StaticModule(PrepareForStaticModule(m), opts) {}
|
||||
|
||||
StaticRuntime::StaticRuntime(
|
||||
std::shared_ptr<InferenceModule> m,
|
||||
const StaticRuntimeOptions& opts)
|
||||
: module_(m), opts_(opts) {
|
||||
TORCH_CHECK(
|
||||
module_ != nullptr,
|
||||
"std::shared_ptr<InferenceModule> module_ cannot be nullptr")
|
||||
|
||||
Graph* graph = module_->graph.get();
|
||||
StaticModule::StaticModule(
|
||||
std::pair<
|
||||
std::shared_ptr<torch::jit::Graph>,
|
||||
c10::optional<c10::FunctionSchema>> graph_and_schema,
|
||||
const StaticModuleOptions& opts)
|
||||
: opts_(opts),
|
||||
graph_(std::move(graph_and_schema.first)),
|
||||
schema_(std::move(graph_and_schema.second)) {
|
||||
std::unordered_map<Value*, IValue*> val_to_ival;
|
||||
// value -> index into nodes, index into outputs of node
|
||||
std::unordered_map<Value*, std::pair<int, int>> val_to_idx;
|
||||
|
||||
// NB: create an unchanging std::vector<IValue> we can reference
|
||||
for (auto input : graph->inputs()) {
|
||||
inputs_.emplace_back();
|
||||
}
|
||||
for (auto i = 0; i < graph->inputs().size(); ++i) {
|
||||
Value* input = graph->inputs()[i];
|
||||
val_to_ival[input] = &(inputs_[i]);
|
||||
// N inputs map to the first N entries in storage
|
||||
for (auto i = 0; i < graph_->inputs().size(); ++i) {
|
||||
Value* input = graph_->inputs()[i];
|
||||
val_to_ival[input] = nullptr;
|
||||
// input denoted by -1
|
||||
val_to_idx[input] = std::make_pair(-1, i);
|
||||
}
|
||||
|
||||
// fill workspace_ with constants and create ProcessedNodes
|
||||
// NB: before optimizing the order of execution, ensure that the
|
||||
// memory optimization pass (GetLivenessInformation + AssignRegisters) is
|
||||
// memory optimization pass (LivenessMap) is
|
||||
// aware of the new order!
|
||||
|
||||
// Fill constants first, so we have a std::vector<IValue> we can reference
|
||||
// later
|
||||
for (Node* node : graph->nodes()) {
|
||||
for (Node* node : graph_->nodes()) {
|
||||
if (node->kind() != prim::Constant) {
|
||||
continue;
|
||||
}
|
||||
@ -504,37 +504,46 @@ StaticRuntime::StaticRuntime(
|
||||
}
|
||||
{
|
||||
int i = 0;
|
||||
for (Node* node : graph->nodes()) {
|
||||
for (Node* node : graph_->nodes()) {
|
||||
if (node->kind() != prim::Constant) {
|
||||
continue;
|
||||
}
|
||||
auto* v = node->output();
|
||||
// constants denoted -2, i
|
||||
val_to_idx[v] = std::make_pair(-2, i);
|
||||
val_to_ival[v] = &(constants_[i++]);
|
||||
}
|
||||
}
|
||||
for (Node* node : graph->nodes()) {
|
||||
|
||||
int node_idx = 0;
|
||||
for (Node* node : graph_->nodes()) {
|
||||
if (node->kind() == prim::Constant) {
|
||||
continue;
|
||||
}
|
||||
std::vector<const IValue*> inputs;
|
||||
std::vector<std::pair<int, int>> indices;
|
||||
for (Value* input : node->inputs()) {
|
||||
inputs.emplace_back(val_to_ival.at(input));
|
||||
indices.emplace_back(val_to_idx.at(input));
|
||||
}
|
||||
index_map_[node_idx] = indices;
|
||||
nodes_.emplace_back(
|
||||
ProcessedNode(node, std::move(inputs), opts.enable_out_variant));
|
||||
for (auto i = 0; i < node->outputs().size(); ++i) {
|
||||
val_to_ival[node->outputs()[i]] = &nodes_.back().Output(i);
|
||||
val_to_ival[node->outputs()[i]] = nullptr;
|
||||
val_to_idx[node->outputs()[i]] = std::make_pair(node_idx, i);
|
||||
}
|
||||
node_idx++;
|
||||
}
|
||||
for (auto output : graph->outputs()) {
|
||||
outputs_.emplace_back(val_to_ival.at(output));
|
||||
for (auto output : graph_->outputs()) {
|
||||
output_indices_.emplace_back(val_to_idx[output]);
|
||||
}
|
||||
|
||||
AliasDb alias_db(module_->graph);
|
||||
auto lm = GetLivenessInformation(module_->graph, alias_db);
|
||||
AliasDb alias_db(graph_);
|
||||
auto lm = GetLivenessInformation(graph_, alias_db);
|
||||
external_values_ = lm.second;
|
||||
if (opts_.optimize_memory) {
|
||||
auto values = GetOptimizableValues(module_->graph);
|
||||
auto values = GetOptimizableValues(graph_);
|
||||
if (!opts_.enable_out_variant) {
|
||||
values.first = {};
|
||||
}
|
||||
@ -542,7 +551,82 @@ StaticRuntime::StaticRuntime(
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> StaticRuntime::run(
|
||||
const StaticModuleOptions& StaticModule::opts() const {
|
||||
return opts_;
|
||||
}
|
||||
|
||||
size_t StaticModule::num_outputs() const {
|
||||
return graph_->outputs().size();
|
||||
}
|
||||
|
||||
size_t StaticModule::num_inputs() const {
|
||||
return graph_->inputs().size();
|
||||
}
|
||||
|
||||
StaticRuntime& StaticModule::runtime() {
|
||||
if (!cached_runtime_) {
|
||||
cached_runtime_ = std::make_unique<StaticRuntime>(*this);
|
||||
}
|
||||
return *cached_runtime_;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> StaticModule::operator()(
|
||||
const std::vector<at::Tensor>& inps) {
|
||||
return runtime()(inps);
|
||||
}
|
||||
c10::IValue StaticModule::operator()(
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs) {
|
||||
return runtime()(args, kwargs);
|
||||
}
|
||||
|
||||
StaticRuntime::StaticRuntime(const StaticModule& sm) : static_module_(sm) {
|
||||
// NB: create unchanging std::vector<IValue>s we can reference
|
||||
inputs_.resize(sm.num_inputs());
|
||||
nodes_.resize(sm.nodes().size());
|
||||
for (auto idx = 0; idx < sm.nodes().size(); ++idx) {
|
||||
const auto& n_ref = sm.nodes()[idx];
|
||||
nodes_[idx] = n_ref; // copy the node
|
||||
auto& n = nodes_[idx];
|
||||
// hook up the inputs
|
||||
for (auto i = 0; i < n.inputs().size(); ++i) {
|
||||
if (n.inputs()[i] == nullptr) {
|
||||
int node_idx;
|
||||
int out_idx;
|
||||
std::tie(node_idx, out_idx) = sm.index_map().at(idx)[i];
|
||||
DCHECK(out_idx >= 0);
|
||||
// input
|
||||
if (node_idx == -1) {
|
||||
n.set_input(i, &inputs_[out_idx]);
|
||||
} else if (node_idx == -2) {
|
||||
n.set_input(i, &sm.constants()[out_idx]);
|
||||
} else {
|
||||
n.set_input(i, &(nodes_[node_idx].Output(out_idx)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& index_pair : sm.output_indices()) {
|
||||
int node_idx;
|
||||
int out_idx;
|
||||
std::tie(node_idx, out_idx) = index_pair;
|
||||
if (node_idx == -1) {
|
||||
outputs_.emplace_back(&inputs_[out_idx]);
|
||||
} else if (node_idx == -2) {
|
||||
// This is a very rare case where const correctness
|
||||
// breaks -- the user is returning a constant from
|
||||
// the graph.
|
||||
outputs_.emplace_back(const_cast<IValue*>(&sm.constants()[out_idx]));
|
||||
} else {
|
||||
auto& n = nodes_.at(node_idx);
|
||||
auto* out = &n.Output(out_idx);
|
||||
outputs_.emplace_back(out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> StaticRuntime::operator()(
|
||||
const std::vector<at::Tensor>& inps) {
|
||||
std::vector<c10::IValue> stack;
|
||||
stack.resize(inps.size());
|
||||
@ -550,7 +634,8 @@ std::vector<at::Tensor> StaticRuntime::run(
|
||||
stack[i] = inps[i];
|
||||
}
|
||||
|
||||
c10::IValue v = run(stack, std::unordered_map<std::string, c10::IValue>());
|
||||
c10::IValue v =
|
||||
(*this)(stack, std::unordered_map<std::string, c10::IValue>());
|
||||
|
||||
std::vector<at::Tensor> out;
|
||||
|
||||
@ -565,7 +650,7 @@ std::vector<at::Tensor> StaticRuntime::run(
|
||||
return out;
|
||||
}
|
||||
|
||||
c10::IValue StaticRuntime::run(
|
||||
c10::IValue StaticRuntime::operator()(
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs) {
|
||||
// We assume inference workloads, so we do not need
|
||||
@ -581,11 +666,11 @@ c10::IValue StaticRuntime::run(
|
||||
if (!kwargs.empty()) {
|
||||
// This is not ideal
|
||||
TORCH_CHECK(
|
||||
module_->schema != nullptr,
|
||||
static_module_.schema(),
|
||||
"Schema is not available. Consider creating the Static Runtime "
|
||||
"with StaticRuntime(const torch::jit::Module& m) instead.");
|
||||
"with StaticModule(const torch::jit::Module& m) instead.");
|
||||
std::vector<c10::IValue> s = args;
|
||||
module_->schema->checkAndNormalizeInputs(s, kwargs);
|
||||
static_module_.schema()->checkAndNormalizeInputs(s, kwargs);
|
||||
for (size_t i = 0; i < s.size(); i++) {
|
||||
Input(i) = std::move(s[i]);
|
||||
}
|
||||
@ -596,16 +681,19 @@ c10::IValue StaticRuntime::run(
|
||||
}
|
||||
|
||||
// NB: before optimizing the order of execution, ensure that the
|
||||
// memory optimization pass (GetLivenessInformation + AssignRegisters) is
|
||||
// memory optimization pass (LivenessMap) is
|
||||
// aware of the new order!
|
||||
for (auto& n : nodes_) {
|
||||
n.run();
|
||||
}
|
||||
|
||||
if (opts_.cleanup_activations) {
|
||||
if (static_module_.opts().cleanup_activations) {
|
||||
if (!planner_) {
|
||||
planner_ = std::make_unique<MemoryPlanner>(
|
||||
this, shared_values_, external_values_, opts_.enable_out_variant);
|
||||
this,
|
||||
static_module_.shared_values(),
|
||||
static_module_.external_values(),
|
||||
static_module_.opts().enable_out_variant);
|
||||
}
|
||||
planner_->deallocate();
|
||||
// clean up owning refs of input tensors
|
||||
@ -615,10 +703,10 @@ c10::IValue StaticRuntime::run(
|
||||
}
|
||||
|
||||
// no need to keep references of outputs in static runtime anymore
|
||||
if (num_outputs() > 1) {
|
||||
if (static_module_.num_outputs() > 1) {
|
||||
std::vector<c10::IValue> outputs;
|
||||
outputs.reserve(num_outputs());
|
||||
for (auto i = 0; i < num_outputs(); ++i) {
|
||||
outputs.reserve(static_module_.num_outputs());
|
||||
for (auto i = 0; i < static_module_.num_outputs(); ++i) {
|
||||
// use move here. Otherwise, clean up outputs_[i] explicitly
|
||||
outputs.emplace_back(std::move(*outputs_[i]));
|
||||
}
|
||||
@ -646,7 +734,7 @@ void StaticRuntime::benchmark(
|
||||
benchmark_individual_ops(args, kwargs, warmup_runs, main_runs);
|
||||
|
||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||
const Node* node = nodes_[i].get_node();
|
||||
const Node* node = nodes_[i].node();
|
||||
std::cout << "Node #" << i << ": " << results.time_per_node[i]
|
||||
<< " ms/iter, ";
|
||||
node->print(std::cout, 0, nullptr, false);
|
||||
@ -682,7 +770,7 @@ void StaticRuntime::benchmark(
|
||||
if (planner_) {
|
||||
std::cout << "Total memory managed: " << planner_->total_managed()
|
||||
<< " bytes" << std::endl;
|
||||
if (opts_.optimize_memory) {
|
||||
if (static_module_.opts().optimize_memory) {
|
||||
std::cout << "Total number of reused tensors: "
|
||||
<< planner_->total_reused_tensors() << std::endl;
|
||||
}
|
||||
@ -697,11 +785,11 @@ float StaticRuntime::benchmark_model(
|
||||
TORCH_CHECK(warmup_runs >= 0 && main_runs >= 1);
|
||||
|
||||
for (int i = 0; i < warmup_runs; i++) {
|
||||
run(args, kwargs);
|
||||
operator()(args, kwargs);
|
||||
}
|
||||
caffe2::Timer timer;
|
||||
for (int i = 0; i < main_runs; i++) {
|
||||
run(args, kwargs);
|
||||
operator()(args, kwargs);
|
||||
}
|
||||
float millis = timer.MilliSeconds();
|
||||
return millis / static_cast<float>(main_runs);
|
||||
@ -727,10 +815,10 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
|
||||
if (!kwargs.empty()) {
|
||||
// This is not ideal
|
||||
TORCH_CHECK(
|
||||
module_->schema != nullptr,
|
||||
static_module_.schema(),
|
||||
"Schema is not available. Consider creating the Static Runtime "
|
||||
"with StaticRuntime(const torch::jit::Module& m) instead.");
|
||||
module_->schema->checkAndNormalizeInputs(stack, kwargs);
|
||||
"with StaticModule(const torch::jit::Module& m) instead.");
|
||||
static_module_.schema()->checkAndNormalizeInputs(stack, kwargs);
|
||||
}
|
||||
for (size_t i = 0; i < stack.size(); i++) {
|
||||
Input(i) = stack[i];
|
||||
@ -739,7 +827,7 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
|
||||
|
||||
// warmup runs
|
||||
for (int i = 0; i < warmup_runs; i++) {
|
||||
run(args, kwargs);
|
||||
operator()(args, kwargs);
|
||||
}
|
||||
|
||||
// main runs
|
||||
@ -761,10 +849,13 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
|
||||
results.time_per_node[i] += millis;
|
||||
}
|
||||
timer.Start();
|
||||
if (opts_.cleanup_activations) {
|
||||
if (static_module_.opts().cleanup_activations) {
|
||||
if (!planner_) {
|
||||
planner_ = std::make_unique<MemoryPlanner>(
|
||||
this, shared_values_, external_values_, opts_.enable_out_variant);
|
||||
this,
|
||||
static_module_.shared_values(),
|
||||
static_module_.external_values(),
|
||||
static_module_.opts().enable_out_variant);
|
||||
}
|
||||
planner_->deallocate();
|
||||
// clean up owning refs of input tensors
|
||||
@ -778,10 +869,10 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
|
||||
timer.Start();
|
||||
// no need to keep references of outputs in static runtime anymore
|
||||
c10::IValue output;
|
||||
if (num_outputs() > 1) {
|
||||
if (static_module_.num_outputs() > 1) {
|
||||
std::vector<c10::IValue> outputs;
|
||||
outputs.reserve(num_outputs());
|
||||
for (auto i = 0; i < num_outputs(); ++i) {
|
||||
outputs.reserve(static_module_.num_outputs());
|
||||
for (auto i = 0; i < static_module_.num_outputs(); ++i) {
|
||||
// use move here. Otherwise, clean up outputs_[i] explicitly
|
||||
outputs.emplace_back(std::move(*outputs_[i]));
|
||||
}
|
||||
@ -802,7 +893,7 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
|
||||
|
||||
// post processing
|
||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||
const Node* node = nodes_[i].get_node();
|
||||
const Node* node = nodes_[i].node();
|
||||
std::string kind = std::string(node->kind().toQualString());
|
||||
results.time_per_node[i] /= static_cast<float>(main_runs);
|
||||
results.time_per_node_type[kind] += results.time_per_node[i];
|
||||
@ -820,7 +911,7 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
|
||||
}
|
||||
|
||||
void StaticRuntime::check_for_memory_leak(bool output_returned) {
|
||||
if (!opts_.cleanup_activations) {
|
||||
if (!static_module_.opts().cleanup_activations) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -832,16 +923,16 @@ void StaticRuntime::check_for_memory_leak(bool output_returned) {
|
||||
std::unordered_set<const IValue*> output_ivalues(
|
||||
outputs_.begin(), outputs_.end());
|
||||
for (size_t n = 0; n < nodes_.size(); n++) {
|
||||
auto& node = nodes_[n];
|
||||
for (size_t i = 0; i < node.outputs().size(); i++) {
|
||||
const IValue* ival = &node.Output(i);
|
||||
auto& pnode = nodes_[n];
|
||||
for (size_t i = 0; i < pnode.outputs().size(); i++) {
|
||||
const IValue* ival = &pnode.Output(i);
|
||||
const std::string error_msg = "Output " + c10::to_string(i) +
|
||||
" of node " + c10::to_string(n) + " was not cleaned up";
|
||||
if (output_ivalues.count(ival) == 0) {
|
||||
// check for intermediates
|
||||
if (!ival->isNone()) {
|
||||
TORCH_CHECK(
|
||||
ival->isTensor() || canOptimizeConstruct(node.get_node()),
|
||||
ival->isTensor() || canOptimizeConstruct(pnode.node()),
|
||||
error_msg);
|
||||
if (ival->isTensor()) {
|
||||
const auto& t = ival->toTensor();
|
||||
@ -870,17 +961,17 @@ MemoryPlanner::MemoryPlanner(
|
||||
// collect register indices of outputs of ops with out variant
|
||||
std::unordered_set<const Value*> managed_values;
|
||||
std::unordered_set<IValue*> unmanaged_ivalue_set;
|
||||
for (ProcessedNode& pnode : runtime->get_nodes()) {
|
||||
if (canReuseInputsOutputs(pnode.get_node())) {
|
||||
for (ProcessedNode& pnode : runtime->nodes()) {
|
||||
if (canReuseInputsOutputs(pnode.node())) {
|
||||
for (auto i = 0; i < pnode.outputs().size(); ++i) {
|
||||
// Types are stored in the underlying TorchScript IR
|
||||
const Value* out_v = pnode.get_node()->outputs()[i];
|
||||
const Value* out_v = pnode.node()->outputs()[i];
|
||||
IValue& out = pnode.Output(i);
|
||||
const auto& type = out_v->type();
|
||||
if (out_variants && !external_values.count(out_v)) {
|
||||
if (type->cast<TensorType>()) {
|
||||
managed_values.insert(out_v);
|
||||
} else if (canOptimizeConstruct(pnode.get_node())) {
|
||||
} else if (canOptimizeConstruct(pnode.node())) {
|
||||
// We "leak" containers of this type
|
||||
} else {
|
||||
unmanaged_ivalue_set.insert(&out);
|
||||
@ -896,10 +987,8 @@ MemoryPlanner::MemoryPlanner(
|
||||
}
|
||||
}
|
||||
|
||||
const InferenceModule* module = runtime->get_inference_module();
|
||||
|
||||
// remove model outputs from managed_values and unmanaged_ivalue_set
|
||||
for (Value* output : module->graph->outputs()) {
|
||||
for (const Value* output : runtime->graph().outputs()) {
|
||||
managed_values.erase(output);
|
||||
}
|
||||
for (IValue* output : runtime->outputs()) {
|
||||
@ -918,10 +1007,10 @@ MemoryPlanner::MemoryPlanner(
|
||||
std::unordered_set<c10::StorageImpl*> managed_storage_impls;
|
||||
|
||||
// Snapshot of the current memory state
|
||||
for (const auto& pnode : runtime->get_nodes()) {
|
||||
for (const auto& pnode : runtime->nodes()) {
|
||||
for (auto i = 0; i < pnode.outputs().size(); ++i) {
|
||||
const auto& ival = pnode.outputs()[i];
|
||||
const auto* val = pnode.get_node()->outputs()[i];
|
||||
const auto* val = pnode.node()->outputs()[i];
|
||||
if (managed_values.count(val)) {
|
||||
TORCH_CHECK(ival.isTensor());
|
||||
auto* impl = ival.toTensor().storage().unsafeGetStorageImpl();
|
||||
|
@ -6,42 +6,46 @@
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/passes/constant_propagation.h>
|
||||
#include <torch/csrc/jit/passes/freeze_module.h>
|
||||
#include <torch/csrc/jit/passes/inliner.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
struct TORCH_API StaticRuntimeOptions {
|
||||
struct TORCH_API StaticModuleOptions {
|
||||
bool cleanup_activations{true};
|
||||
bool enable_out_variant{true};
|
||||
bool optimize_memory{true};
|
||||
};
|
||||
|
||||
/// Static runime supports two execution modes.
|
||||
/// The static runime supports two execution modes.
|
||||
///
|
||||
/// Mode 1: single-threaded with no parallelism except for intra-op parallelism
|
||||
/// For this mode, you can do either:
|
||||
/// @code
|
||||
/// // m is the TorchScript module
|
||||
/// auto runtime = StaticRuntime(m, opts);
|
||||
/// auto output = runtime.run(args, kwargs);
|
||||
/// // m is a TorchScript module
|
||||
/// auto module = StaticModule(m, opts);
|
||||
/// auto output = module(args, kwargs);
|
||||
/// @endcode
|
||||
///
|
||||
/// or
|
||||
///
|
||||
/// @code
|
||||
/// auto mod = PrepareForStaticRuntime(m);
|
||||
/// auto runtime = StaticRuntime(mod, opts);
|
||||
/// auto output = runtime.run(args, kwargs);
|
||||
/// // g is the TorchScript graph
|
||||
/// auto module = StaticModule(g, opts);
|
||||
/// auto output = module(args, kwargs);
|
||||
/// @endcode
|
||||
///
|
||||
/// Mode 2: similar to data parallelism, run the same model for different inputs
|
||||
/// on different threads at the same time. In this case, run
|
||||
/// PrepareForStaticRuntime to prepare the graph for Static Runtime. You
|
||||
/// should have one InferenceModule instance per model, and one Static Runtime
|
||||
/// instance per running thread. To avoiding creating StaticRuntime on the fly,
|
||||
/// use a synchronized stack (i.e. boost::lockfree::stack) to cache all the
|
||||
/// Static Runtime instances in your code.
|
||||
/// on different threads at the same time.
|
||||
/// You should have one StaticModule per model, and one StaticRuntime instance
|
||||
/// per running thread. To avoiding creating StaticRuntimes on the fly, use a
|
||||
/// synchronized stack (i.e. boost::lockfree::stack) to cache all the
|
||||
/// StaticRuntime instances in your code.
|
||||
/// @code
|
||||
/// // initialization
|
||||
/// auto mod = PrepareForStaticRuntime(m);
|
||||
/// auto module = std::make_shared<StaticModule>(m, opts);
|
||||
///
|
||||
/// // 128 is good for most cases. Pick a number that works for you
|
||||
/// boost::lockfree::stack<std::shared_ptr<StaticRuntime>,
|
||||
/// boost::lockfree::fixed_sized<true>> pool(128);
|
||||
@ -50,58 +54,113 @@ struct TORCH_API StaticRuntimeOptions {
|
||||
/// std::shared_ptr<StaticRuntime> runtime = nullptr;
|
||||
/// pool.pop(runtime);
|
||||
/// if (!runtime) {
|
||||
/// runtime = std::make_shared<StaticRuntime>(mod, opts);
|
||||
/// // holds a reference to the underlying module
|
||||
/// // but does its own memory management
|
||||
/// runtime = std::make_shared<StaticRuntime>(*module);
|
||||
/// }
|
||||
/// auto output = runtime->run(args, kwargs);
|
||||
/// auto output = runtime(args, kwargs);
|
||||
/// pool.push(runtime);
|
||||
/// @endcode
|
||||
///
|
||||
|
||||
// Group readonly data structures into InferenceModule
|
||||
struct TORCH_API InferenceModule {
|
||||
public:
|
||||
explicit InferenceModule(const torch::jit::Module& m);
|
||||
explicit InferenceModule(std::shared_ptr<torch::jit::Graph> g);
|
||||
torch::jit::Module module;
|
||||
std::shared_ptr<torch::jit::Graph> graph;
|
||||
std::unique_ptr<c10::FunctionSchema> schema;
|
||||
|
||||
private:
|
||||
void init();
|
||||
};
|
||||
|
||||
TORCH_API void PrepareGraphForStaticRuntime(
|
||||
std::shared_ptr<torch::jit::Graph> g);
|
||||
|
||||
inline TORCH_API std::shared_ptr<InferenceModule> PrepareForStaticRuntime(
|
||||
const torch::jit::Module& m) {
|
||||
return std::make_shared<InferenceModule>(m);
|
||||
}
|
||||
|
||||
inline TORCH_API std::shared_ptr<InferenceModule> PrepareForStaticRuntime(
|
||||
const std::shared_ptr<torch::jit::Graph>& g) {
|
||||
return std::make_shared<InferenceModule>(g);
|
||||
}
|
||||
|
||||
class MemoryPlanner;
|
||||
class ProcessedNode;
|
||||
class StaticRuntime;
|
||||
class TORCH_API StaticModule {
|
||||
public:
|
||||
explicit StaticModule(
|
||||
std::shared_ptr<torch::jit::Graph> g,
|
||||
const StaticModuleOptions& opts = StaticModuleOptions());
|
||||
|
||||
explicit StaticModule(
|
||||
const torch::jit::Module& m,
|
||||
const StaticModuleOptions& opts = StaticModuleOptions());
|
||||
|
||||
private:
|
||||
explicit StaticModule(
|
||||
std::pair<
|
||||
std::shared_ptr<torch::jit::Graph>,
|
||||
c10::optional<c10::FunctionSchema>> graph_and_schema,
|
||||
const StaticModuleOptions& opts);
|
||||
|
||||
public:
|
||||
std::vector<at::Tensor> operator()(const std::vector<at::Tensor>& inps);
|
||||
|
||||
// This interface only works if StaticModule was initialized
|
||||
// with a TorchScript module, otherwise use the above interface
|
||||
c10::IValue operator()(
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs);
|
||||
|
||||
const Graph& graph() const {
|
||||
return *graph_;
|
||||
}
|
||||
|
||||
const StaticModuleOptions& opts() const;
|
||||
size_t num_inputs() const;
|
||||
size_t num_outputs() const;
|
||||
|
||||
inline const std::unordered_map<int, std::vector<std::pair<int, int>>>&
|
||||
index_map() const {
|
||||
return index_map_;
|
||||
}
|
||||
|
||||
inline const std::vector<std::pair<int, int>>& output_indices() const {
|
||||
return output_indices_;
|
||||
}
|
||||
|
||||
inline const std::vector<IValue>& constants() const {
|
||||
return constants_;
|
||||
}
|
||||
|
||||
inline const std::vector<ProcessedNode>& nodes() const {
|
||||
return nodes_;
|
||||
}
|
||||
|
||||
inline const c10::optional<c10::FunctionSchema>& schema() const {
|
||||
return schema_;
|
||||
}
|
||||
|
||||
inline const std::unordered_map<const Value*, std::vector<const Value*>>&
|
||||
shared_values() const {
|
||||
return shared_values_;
|
||||
}
|
||||
|
||||
inline const std::unordered_set<const Value*>& external_values() const {
|
||||
return external_values_;
|
||||
}
|
||||
|
||||
StaticRuntime& runtime();
|
||||
|
||||
private:
|
||||
// Static runtime states
|
||||
StaticModuleOptions opts_;
|
||||
std::unique_ptr<StaticRuntime> cached_runtime_;
|
||||
// IValue table (including inputs, outputs, intermediates, and weights)
|
||||
std::vector<IValue> constants_;
|
||||
std::vector<std::pair<int, int>> output_indices_;
|
||||
std::unordered_map<int, std::vector<std::pair<int, int>>> index_map_;
|
||||
// The nodes we need to run
|
||||
std::vector<ProcessedNode> nodes_;
|
||||
// Output of liveness analyis. A mapping from a value to the set of values
|
||||
// with which it could potentially share memory.
|
||||
std::unordered_map<const Value*, std::vector<const Value*>> shared_values_;
|
||||
std::unordered_set<const Value*> external_values_;
|
||||
|
||||
// Original input
|
||||
std::shared_ptr<torch::jit::Graph> graph_;
|
||||
c10::optional<c10::FunctionSchema> schema_;
|
||||
};
|
||||
|
||||
class TORCH_API StaticRuntime {
|
||||
public:
|
||||
// InferenceModule m is created by PrepareForStaticRuntime
|
||||
explicit StaticRuntime(
|
||||
std::shared_ptr<InferenceModule> m,
|
||||
const StaticRuntimeOptions& opts = StaticRuntimeOptions());
|
||||
explicit StaticRuntime(const StaticModule& sm);
|
||||
|
||||
// m is unoptimized
|
||||
explicit StaticRuntime(
|
||||
const torch::jit::Module& m,
|
||||
const StaticRuntimeOptions& opts = StaticRuntimeOptions());
|
||||
std::vector<at::Tensor> operator()(const std::vector<at::Tensor>& inps);
|
||||
|
||||
std::vector<at::Tensor> run(const std::vector<at::Tensor>& inps);
|
||||
|
||||
// This interface only works module_ that has a non-empty TorchScript module
|
||||
// member; otherwise use the above interface
|
||||
c10::IValue run(
|
||||
// This interface only works if StaticModule was initialized
|
||||
// with a TorchScript module, otherwise use the above interface
|
||||
c10::IValue operator()(
|
||||
const std::vector<c10::IValue>& args,
|
||||
const std::unordered_map<std::string, c10::IValue>& kwargs);
|
||||
|
||||
@ -135,52 +194,6 @@ class TORCH_API StaticRuntime {
|
||||
const int warmup_runs,
|
||||
const int main_runs);
|
||||
|
||||
const InferenceModule* get_inference_module() {
|
||||
return module_.get();
|
||||
}
|
||||
|
||||
const std::vector<ProcessedNode>& get_nodes() const {
|
||||
return nodes_;
|
||||
}
|
||||
|
||||
std::vector<ProcessedNode>& get_nodes() {
|
||||
return nodes_;
|
||||
}
|
||||
|
||||
size_t num_inputs() const {
|
||||
return inputs_.size();
|
||||
}
|
||||
|
||||
size_t num_outputs() const {
|
||||
return outputs_.size();
|
||||
}
|
||||
|
||||
inline const std::vector<IValue*>& outputs() const {
|
||||
return outputs_;
|
||||
}
|
||||
|
||||
void check_for_memory_leak(bool output_returned = true);
|
||||
|
||||
private:
|
||||
// Static runtime states
|
||||
std::shared_ptr<InferenceModule> module_;
|
||||
StaticRuntimeOptions opts_;
|
||||
// IValue table (including inputs, outputs, intermediates, and weights)
|
||||
std::vector<IValue> constants_;
|
||||
std::vector<IValue> inputs_;
|
||||
std::vector<IValue*> outputs_;
|
||||
// The nodes we need to run
|
||||
std::vector<ProcessedNode> nodes_;
|
||||
// Output of liveness analyis. A mapping from a value to the set of values
|
||||
// with which it could potentially share memory.
|
||||
std::unordered_map<const Value*, std::vector<const Value*>> shared_values_;
|
||||
std::unordered_set<const Value*> external_values_;
|
||||
|
||||
// Memory planning is only enabled if opts_.cleanup_activations is true.
|
||||
// Otherwise, the memory used by activations is cached inside the static
|
||||
// runtime.
|
||||
std::unique_ptr<MemoryPlanner> planner_;
|
||||
|
||||
// Input is readwrite
|
||||
IValue& Input(size_t i) {
|
||||
DCHECK(i < inputs_.size());
|
||||
@ -192,6 +205,34 @@ class TORCH_API StaticRuntime {
|
||||
DCHECK(i < outputs_.size());
|
||||
return *outputs_[i];
|
||||
}
|
||||
|
||||
const std::vector<IValue*> outputs() const {
|
||||
return outputs_;
|
||||
}
|
||||
|
||||
inline const std::vector<ProcessedNode>& nodes() const {
|
||||
return nodes_;
|
||||
}
|
||||
|
||||
inline std::vector<ProcessedNode>& nodes() {
|
||||
return nodes_;
|
||||
}
|
||||
|
||||
const Graph& graph() const {
|
||||
return static_module_.graph();
|
||||
}
|
||||
|
||||
void check_for_memory_leak(bool output_returned = true);
|
||||
|
||||
private:
|
||||
// Memory planning is only enabled if sm->opts().cleanup_activations is true.
|
||||
// Otherwise, the memory used by activations is cached inside the static
|
||||
// runtime.
|
||||
std::unique_ptr<MemoryPlanner> planner_;
|
||||
std::vector<IValue> inputs_;
|
||||
std::vector<IValue*> outputs_;
|
||||
const StaticModule& static_module_;
|
||||
std::vector<ProcessedNode> nodes_;
|
||||
};
|
||||
|
||||
/// There are three types of ops in a processed graph in Static Runtime:
|
||||
@ -223,8 +264,7 @@ class MemoryPlanner {
|
||||
public:
|
||||
explicit MemoryPlanner(
|
||||
StaticRuntime* runtime,
|
||||
const std::unordered_map<const Value*, std::vector<const Value*>>&
|
||||
should_share,
|
||||
const std::unordered_map<const Value*, std::vector<const Value*>>&,
|
||||
const std::unordered_set<const Value*>& external_values,
|
||||
bool out_variants);
|
||||
|
||||
@ -254,6 +294,7 @@ class MemoryPlanner {
|
||||
|
||||
class ProcessedNode {
|
||||
public:
|
||||
ProcessedNode() = default;
|
||||
ProcessedNode(
|
||||
Node* n,
|
||||
std::vector<const IValue*>&& inputs,
|
||||
@ -261,10 +302,14 @@ class ProcessedNode {
|
||||
|
||||
void run();
|
||||
|
||||
Node* get_node() const {
|
||||
Node* node() const {
|
||||
return node_;
|
||||
}
|
||||
|
||||
inline void set_input(size_t index, const IValue* ival) {
|
||||
inputs_[index] = ival;
|
||||
}
|
||||
|
||||
// Input is readonly
|
||||
const IValue& Input(size_t i) const {
|
||||
DCHECK(i < inputs_.size());
|
||||
@ -295,7 +340,7 @@ class ProcessedNode {
|
||||
std::function<void(ProcessedNode*)> fn_;
|
||||
std::function<void(ProcessedNode*)> native_fn_;
|
||||
std::vector<const IValue*> inputs_; // unowned
|
||||
std::vector<IValue> outputs_; // TODO make list for safety
|
||||
std::vector<IValue> outputs_;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
|
@ -7,11 +7,11 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void initStaticRuntimeBindings(PyObject* module) {
|
||||
void initStaticModuleBindings(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
py::class_<StaticRuntime> static_runtime(m, "StaticRuntime");
|
||||
py::class_<StaticModule> static_module(m, "StaticModule");
|
||||
py::class_<StaticRuntime::IndividualMetrics>(
|
||||
static_runtime, "IndividualMetrics")
|
||||
static_module, "IndividualMetrics")
|
||||
.def_readonly("setup_time", &StaticRuntime::IndividualMetrics::setup_time)
|
||||
.def_readonly(
|
||||
"memory_alloc_time",
|
||||
@ -34,25 +34,25 @@ void initStaticRuntimeBindings(PyObject* module) {
|
||||
.def_readonly(
|
||||
"instances_per_node_type",
|
||||
&StaticRuntime::IndividualMetrics::instances_per_node_type);
|
||||
static_runtime
|
||||
static_module
|
||||
.def(
|
||||
"run",
|
||||
"__call__",
|
||||
py::overload_cast<const std::vector<at::Tensor>&>(
|
||||
&StaticRuntime::run))
|
||||
&StaticModule::operator()))
|
||||
.def(
|
||||
"run",
|
||||
[](StaticRuntime& self,
|
||||
"__call__",
|
||||
[](StaticModule& self,
|
||||
const std::vector<at::Tensor>& args,
|
||||
const std::unordered_map<std::string, at::Tensor>& kwargs) {
|
||||
std::vector<c10::IValue> arg_ivalues{args.begin(), args.end()};
|
||||
std::unordered_map<std::string, c10::IValue> kwarg_ivalues{
|
||||
kwargs.begin(), kwargs.end()};
|
||||
c10::IValue ret = self.run(arg_ivalues, kwarg_ivalues);
|
||||
c10::IValue ret = self(arg_ivalues, kwarg_ivalues);
|
||||
return toPyObject(ret);
|
||||
})
|
||||
.def(
|
||||
"benchmark",
|
||||
[](StaticRuntime& self,
|
||||
[](StaticModule& self,
|
||||
const std::vector<at::Tensor>& args,
|
||||
const std::unordered_map<std::string, at::Tensor>& kwargs,
|
||||
const int warmup_runs,
|
||||
@ -60,11 +60,12 @@ void initStaticRuntimeBindings(PyObject* module) {
|
||||
std::vector<c10::IValue> arg_ivalues{args.begin(), args.end()};
|
||||
std::unordered_map<std::string, c10::IValue> kwarg_ivalues{
|
||||
kwargs.begin(), kwargs.end()};
|
||||
self.benchmark(arg_ivalues, kwarg_ivalues, warmup_runs, main_runs);
|
||||
self.runtime().benchmark(
|
||||
arg_ivalues, kwarg_ivalues, warmup_runs, main_runs);
|
||||
})
|
||||
.def(
|
||||
"benchmark_individual_ops",
|
||||
[](StaticRuntime& self,
|
||||
[](StaticModule& self,
|
||||
const std::vector<at::Tensor>& args,
|
||||
const std::unordered_map<std::string, at::Tensor>& kwargs,
|
||||
const int warmup_runs,
|
||||
@ -72,21 +73,17 @@ void initStaticRuntimeBindings(PyObject* module) {
|
||||
std::vector<c10::IValue> arg_ivalues{args.begin(), args.end()};
|
||||
std::unordered_map<std::string, c10::IValue> kwarg_ivalues{
|
||||
kwargs.begin(), kwargs.end()};
|
||||
return self.benchmark_individual_ops(
|
||||
return self.runtime().benchmark_individual_ops(
|
||||
arg_ivalues, kwarg_ivalues, warmup_runs, main_runs);
|
||||
});
|
||||
m.def(
|
||||
"_jit_to_static_runtime",
|
||||
[](std::shared_ptr<torch::jit::Graph> g) {
|
||||
return StaticRuntime(PrepareForStaticRuntime(g));
|
||||
})
|
||||
"_jit_to_static_module",
|
||||
[](std::shared_ptr<torch::jit::Graph> g) { return StaticModule(g); })
|
||||
.def(
|
||||
"_jit_to_static_runtime",
|
||||
[](const torch::jit::Module& m) {
|
||||
return StaticRuntime(PrepareForStaticRuntime(m));
|
||||
})
|
||||
"_jit_to_static_module",
|
||||
[](const torch::jit::Module& module) { return StaticModule(module); })
|
||||
.def(
|
||||
"_fuse_to_static_runtime",
|
||||
"_fuse_to_static_module",
|
||||
[](torch::jit::Module& module) {
|
||||
module.eval();
|
||||
module = freeze_module(module);
|
||||
@ -95,7 +92,7 @@ void initStaticRuntimeBindings(PyObject* module) {
|
||||
auto graph = method.graph();
|
||||
fuseStaticSubgraphs(graph);
|
||||
})
|
||||
.def("_fuse_to_static_runtime", [](std::shared_ptr<torch::jit::Graph> g) {
|
||||
.def("_fuse_to_static_module", [](std::shared_ptr<torch::jit::Graph> g) {
|
||||
fuseStaticSubgraphs(g);
|
||||
});
|
||||
}
|
||||
|
@ -3,7 +3,7 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void initStaticRuntimeBindings(PyObject* module);
|
||||
void initStaticModuleBindings(PyObject* module);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -643,11 +643,9 @@ REGISTER_OPERATOR_FUNCTOR(aten::clone, aten_clone, [](Node* n) -> SROperator {
|
||||
at::native::copy_(out_t, in0_t, false);
|
||||
};
|
||||
});
|
||||
REGISTER_OPERATOR_FUNCTOR_OPT(
|
||||
REGISTER_OPERATOR_FUNCTOR(
|
||||
quantized::embedding_bag_byte_rowwise_offsets,
|
||||
quantized_embedding_bag_byte_rowwise_offsets,
|
||||
false, // don't reuse byte inputs
|
||||
true,
|
||||
[](Node* n) -> SROperator {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& weight = p_node->Input(0).toTensor();
|
||||
@ -677,11 +675,9 @@ REGISTER_OPERATOR_FUNCTOR_OPT(
|
||||
include_last_offset);
|
||||
};
|
||||
});
|
||||
REGISTER_OPERATOR_FUNCTOR_OPT(
|
||||
REGISTER_OPERATOR_FUNCTOR(
|
||||
quantized::embedding_bag_4bit_rowwise_offsets,
|
||||
embedding_bag_4bit_rowwise_offsets,
|
||||
false, // don't reuse byte inputs
|
||||
true,
|
||||
[](Node* n) -> SROperator {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& weight = p_node->Input(0).toTensor();
|
||||
@ -918,7 +914,7 @@ std::function<void(ProcessedNode*)> getNativeOperation(Node* n) {
|
||||
stack.emplace_back(p_node->Input(i));
|
||||
}
|
||||
// run op
|
||||
auto* node = p_node->get_node();
|
||||
auto* node = p_node->node();
|
||||
const auto& type = node->output()->type()->expect<TupleType>();
|
||||
if (type->name().has_value()) {
|
||||
namedTupleConstruct(stack, type, node->inputs().size());
|
||||
@ -940,7 +936,7 @@ std::function<void(ProcessedNode*)> getNativeOperation(Node* n) {
|
||||
// run op
|
||||
listConstruct(
|
||||
stack,
|
||||
p_node->get_node()->output()->type()->expectRef<ListType>(),
|
||||
p_node->node()->output()->type()->expectRef<ListType>(),
|
||||
p_node->inputs().size());
|
||||
// put output back
|
||||
p_node->Output(0) = std::move(stack[0]);
|
||||
|
@ -27,24 +27,15 @@ C10_DECLARE_REGISTRY(SROperatorRegistry, SROperatorFunctor);
|
||||
|
||||
// TODO: reuse_inp reuse_out can be deprecated with further analysis
|
||||
// try to avoid this API.
|
||||
#define REGISTER_OPERATOR_FUNCTOR_OPT(name, id, reuse_inp, reuse_out, ...) \
|
||||
struct SROperatorFunctor_##id : public SROperatorFunctor { \
|
||||
const SROpFunctor fn = __VA_ARGS__; \
|
||||
bool CanReuseInput() override { \
|
||||
return reuse_inp; \
|
||||
} \
|
||||
bool CanReuseOutput() override { \
|
||||
return reuse_out; \
|
||||
} \
|
||||
SROperator Generate(Node* n) override { \
|
||||
return fn(n); \
|
||||
} \
|
||||
}; \
|
||||
#define REGISTER_OPERATOR_FUNCTOR(name, id, ...) \
|
||||
struct SROperatorFunctor_##id : public SROperatorFunctor { \
|
||||
const SROpFunctor fn = __VA_ARGS__; \
|
||||
SROperator Generate(Node* n) override { \
|
||||
return fn(n); \
|
||||
} \
|
||||
}; \
|
||||
C10_REGISTER_CLASS(SROperatorRegistry, name, SROperatorFunctor_##id);
|
||||
|
||||
#define REGISTER_OPERATOR_FUNCTOR(name, id, ...) \
|
||||
REGISTER_OPERATOR_FUNCTOR_OPT(name, id, true, true, __VA_ARGS__)
|
||||
|
||||
#define REGISTER_VIEW_OPERATOR_FUNCTOR(name, id, ...) \
|
||||
struct SROperatorFunctor_##id : public SROperatorFunctor { \
|
||||
const SROpFunctor fn = __VA_ARGS__; \
|
||||
|
Reference in New Issue
Block a user