mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[static runtime] Add _out variants and reuse memory (#44128)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44128 Test Plan: Imported from OSS Reviewed By: hlu1 Differential Revision: D23604304 Pulled By: bwasti fbshipit-source-id: 06a23cb75700a0fc733069071843b7b498e7b9e9
This commit is contained in:
committed by
Facebook GitHub Bot
parent
d1d9017a66
commit
d1a11618f5
@ -1,3 +1,7 @@
|
||||
list(APPEND STATIC_RUNTIME_BENCHMARK_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt_bench.cc)
|
||||
list(APPEND STATIC_RUNTIME_BENCHMARK_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt.cc)
|
||||
list(APPEND STATIC_RUNTIME_BENCHMARK_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt_bench.cc)
|
||||
set(STATIC_RUNTIME_BENCHMARK_SRCS ${STATIC_RUNTIME_BENCHMARK_SRCS} PARENT_SCOPE)
|
||||
|
||||
list(APPEND STATIC_RUNTIME_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/deep_wide_pt.cc)
|
||||
list(APPEND STATIC_RUNTIME_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/test_static_runtime.cc)
|
||||
set(STATIC_RUNTIME_TEST_SRCS ${STATIC_RUNTIME_TEST_SRCS} PARENT_SCOPE)
|
||||
|
@ -1247,7 +1247,9 @@ endif()
|
||||
if(BUILD_STATIC_RUNTIME_BENCHMARK)
|
||||
add_subdirectory(${TORCH_ROOT}/benchmarks/static_runtime ${PROJECT_BINARY_DIR}/bin)
|
||||
add_executable(static_runtime_bench "${STATIC_RUNTIME_BENCHMARK_SRCS}")
|
||||
add_executable(static_runtime_test "${STATIC_RUNTIME_TEST_SRCS}")
|
||||
target_link_libraries(static_runtime_bench torch_library benchmark)
|
||||
target_link_libraries(static_runtime_test torch_library gtest_main)
|
||||
endif()
|
||||
|
||||
if(BUILD_MOBILE_BENCHMARK)
|
||||
|
@ -106,7 +106,8 @@ class TestStaticRuntime(TestCase):
|
||||
DROPOUT = 0.1
|
||||
device = torch.device("cpu")
|
||||
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
|
||||
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
|
||||
with torch.no_grad():
|
||||
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
|
||||
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
|
||||
|
||||
attention.eval()
|
||||
@ -129,8 +130,9 @@ class TestStaticRuntime(TestCase):
|
||||
bot_l_acc = StaticRuntime(bot_l)
|
||||
top_l = create_mlp(ln_top, sigmoid_top)
|
||||
top_l_acc = StaticRuntime(top_l)
|
||||
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
|
||||
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
|
||||
with torch.no_grad():
|
||||
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
|
||||
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
|
||||
ref_bot = bot_l(bot_inp)
|
||||
acc_bot = bot_l_acc(bot_inp)[0]
|
||||
torch.testing.assert_allclose(acc_bot, ref_bot)
|
||||
@ -138,8 +140,9 @@ class TestStaticRuntime(TestCase):
|
||||
acc_top = top_l_acc(top_inp)[0]
|
||||
torch.testing.assert_allclose(acc_top, ref_top)
|
||||
for _ in range(5):
|
||||
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
|
||||
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
|
||||
with torch.no_grad():
|
||||
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
|
||||
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
|
||||
ref_bot = bot_l(bot_inp)
|
||||
acc_bot = bot_l_acc(bot_inp)[0]
|
||||
torch.testing.assert_allclose(acc_bot, ref_bot)
|
||||
|
@ -219,6 +219,7 @@ core_sources_full = [
|
||||
"torch/csrc/jit/runtime/profiling_record.cpp",
|
||||
"torch/csrc/jit/runtime/symbolic_script.cpp",
|
||||
"torch/csrc/jit/runtime/static/impl.cpp",
|
||||
"torch/csrc/jit/runtime/static/ops.cpp",
|
||||
"torch/csrc/jit/serialization/import.cpp",
|
||||
"torch/csrc/jit/serialization/import_export_helpers.cpp",
|
||||
"torch/csrc/jit/serialization/import_source.cpp",
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <torch/csrc/jit/passes/freeze_module.h>
|
||||
#include <torch/csrc/jit/passes/remove_mutation.h>
|
||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||
#include <torch/csrc/jit/runtime/static/ops.h>
|
||||
#include <torch/csrc/jit/runtime/vararg_functions.h>
|
||||
|
||||
namespace torch {
|
||||
@ -12,48 +13,6 @@ namespace jit {
|
||||
using c10::DispatchKey;
|
||||
using c10::RegisterOperators;
|
||||
|
||||
static auto reg =
|
||||
RegisterOperators()
|
||||
.op("static::add(Tensor a, Tensor b) -> Tensor",
|
||||
RegisterOperators::options().kernel(
|
||||
DispatchKey::CPU,
|
||||
[](at::Tensor a, at::Tensor b) -> at::Tensor { return a + b; }))
|
||||
.op("static::mul.a(Tensor a, Tensor b) -> Tensor",
|
||||
RegisterOperators::options().kernel(
|
||||
DispatchKey::CPU,
|
||||
[](at::Tensor a, at::Tensor b) -> at::Tensor { return a * b; }))
|
||||
.op("static::mul.b(Tensor a, int b) -> Tensor",
|
||||
RegisterOperators::options().kernel(
|
||||
DispatchKey::CPU,
|
||||
[](at::Tensor a, int64_t b) -> at::Tensor { return a * b; }));
|
||||
|
||||
#define SUPPORTED_OPS(F) \
|
||||
F(aten::__getitem__) \
|
||||
F(aten::add) \
|
||||
F(aten::addmm) \
|
||||
F(aten::bmm) \
|
||||
F(aten::cat) \
|
||||
F(aten::clamp) \
|
||||
F(aten::contiguous) \
|
||||
F(aten::div) \
|
||||
F(aten::flatten) \
|
||||
F(aten::index_put_) \
|
||||
F(aten::isnan) \
|
||||
F(aten::matmul) \
|
||||
F(aten::mul) \
|
||||
F(aten::permute) \
|
||||
F(aten::relu) \
|
||||
F(aten::sigmoid) \
|
||||
F(aten::size) \
|
||||
F(aten::softmax) \
|
||||
F(aten::t) \
|
||||
F(aten::to) \
|
||||
F(aten::transpose) \
|
||||
F(aten::view) \
|
||||
F(prim::Constant) \
|
||||
F(prim::ListConstruct) \
|
||||
F(prim::TupleConstruct)
|
||||
|
||||
StaticRuntime::StaticRuntime(const torch::jit::Module& m)
|
||||
: module_(m.copy()), graph_(nullptr) {
|
||||
module_.eval();
|
||||
@ -84,19 +43,6 @@ StaticRuntime::StaticRuntime(const torch::jit::Module& m)
|
||||
}
|
||||
}
|
||||
|
||||
SubgraphRewriter sr;
|
||||
sr.RegisterRewritePattern(
|
||||
R"IR(
|
||||
graph(%x, %w, %s):
|
||||
%r = aten::add(%x, %w, %s)
|
||||
return (%r))IR",
|
||||
R"IR(
|
||||
graph(%x, %w, %s):
|
||||
%y = static::add(%x, %w)
|
||||
%r = static::mul(%y, %s)
|
||||
return (%r))IR");
|
||||
sr.runOnGraph(graph_);
|
||||
|
||||
// remove unused input 0 from graph
|
||||
if (graph_->inputs().at(0)->type()->is_module()) {
|
||||
if (!graph_->inputs().at(0)->hasUses()) {
|
||||
@ -157,10 +103,13 @@ ProcessedNode::ProcessedNode(Node* node) : node_(node) {
|
||||
CHECK(op.hasOperation());
|
||||
op_ = op.getOperation(node);
|
||||
}
|
||||
if (canRunOutOfPlace(node)) {
|
||||
fn_ = getOutOfPlaceOperation(node);
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessedNode::run(StaticRuntime::ConstantMap& workspace) const {
|
||||
if (use_stack_) {
|
||||
if (!fn_) {
|
||||
std::vector<IValue> stack;
|
||||
const size_t size = node_->inputs().size();
|
||||
stack.reserve(size);
|
||||
@ -201,7 +150,7 @@ void ProcessedNode::run(StaticRuntime::ConstantMap& workspace) const {
|
||||
workspace[node_->outputs()[i]] = stack[i];
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(0, "Non-stack execution not yet implemented");
|
||||
(*fn_)(workspace);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -53,8 +53,7 @@ class ProcessedNode {
|
||||
private:
|
||||
Node* node_;
|
||||
c10::optional<Operation> op_;
|
||||
// if false, we have an optimized version
|
||||
bool use_stack_ = true;
|
||||
c10::optional<std::function<void(StaticRuntime::ConstantMap&)>> fn_;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
|
128
torch/csrc/jit/runtime/static/ops.cpp
Normal file
128
torch/csrc/jit/runtime/static/ops.cpp
Normal file
@ -0,0 +1,128 @@
|
||||
#include <torch/csrc/jit/runtime/static/ops.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
bool canRunOutOfPlace(Node* n) {
|
||||
auto str = std::string(n->kind().toQualString());
|
||||
if ((str == "aten::add") || (str == "aten::mul") || (str == "aten::addmm") ||
|
||||
(str == "aten::bmm") || (str == "aten::sigmoid") ||
|
||||
(str == "aten::cat")) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::function<void(StaticRuntime::ConstantMap&)> getOutOfPlaceOperation(
|
||||
Node* n) {
|
||||
auto create_empty_from = [](const at::Tensor& t) {
|
||||
return at::empty({0}, t.options());
|
||||
};
|
||||
|
||||
if (n->kind() == c10::Symbol::fromQualString("aten::add")) {
|
||||
auto out = n->outputs().at(0);
|
||||
auto in0 = n->inputs().at(0);
|
||||
auto in1 = n->inputs().at(1);
|
||||
auto in2 = n->inputs().at(2);
|
||||
return [=](StaticRuntime::ConstantMap& ws) {
|
||||
auto in0_t = ws.at(in0).toTensor();
|
||||
auto in1_t = ws.at(in1).toTensor();
|
||||
auto in2_s = ws.at(in2).toScalar();
|
||||
if (!ws.count(out)) {
|
||||
ws.emplace(out, create_empty_from(in0_t));
|
||||
}
|
||||
auto out_t = ws.at(out).toTensor();
|
||||
at::native::add_out(out_t, in0_t, in1_t, in2_s);
|
||||
};
|
||||
} else if (n->kind() == c10::Symbol::fromQualString("aten::mul")) {
|
||||
auto out = n->outputs().at(0);
|
||||
auto in0 = n->inputs().at(0);
|
||||
auto in1 = n->inputs().at(1);
|
||||
return [=](StaticRuntime::ConstantMap& ws) {
|
||||
auto in0_t = ws.at(in0).toTensor();
|
||||
auto in1_t = ws.at(in1).toTensor();
|
||||
if (!ws.count(out)) {
|
||||
ws.emplace(out, create_empty_from(in0_t));
|
||||
}
|
||||
auto out_t = ws.at(out).toTensor();
|
||||
at::native::mul_out(out_t, in0_t, in1_t);
|
||||
};
|
||||
} else if (n->kind() == c10::Symbol::fromQualString("aten::addmm")) {
|
||||
auto out = n->outputs().at(0);
|
||||
auto in0 = n->inputs().at(0);
|
||||
auto in1 = n->inputs().at(1);
|
||||
auto in2 = n->inputs().at(2);
|
||||
auto in3 = n->inputs().at(3);
|
||||
auto in4 = n->inputs().at(4);
|
||||
return [=](StaticRuntime::ConstantMap& ws) {
|
||||
auto in0_t = ws.at(in0).toTensor();
|
||||
auto in1_t = ws.at(in1).toTensor();
|
||||
auto in2_t = ws.at(in2).toTensor();
|
||||
auto in3_s = ws.at(in3).toScalar();
|
||||
auto in4_s = ws.at(in3).toScalar();
|
||||
if (!ws.count(out)) {
|
||||
ws.emplace(out, create_empty_from(in0_t));
|
||||
}
|
||||
auto out_t = ws.at(out).toTensor();
|
||||
at::native::addmm_cpu_out(out_t, in0_t, in1_t, in2_t, in3_s, in4_s);
|
||||
};
|
||||
} else if (n->kind() == c10::Symbol::fromQualString("aten::clamp")) {
|
||||
auto out = n->outputs().at(0);
|
||||
auto in0 = n->inputs().at(0);
|
||||
auto in1 = n->inputs().at(1);
|
||||
auto in2 = n->inputs().at(2);
|
||||
return [=](StaticRuntime::ConstantMap& ws) {
|
||||
auto in0_t = ws.at(in0).toTensor();
|
||||
auto in1_s = ws.at(in1).toScalar();
|
||||
auto in2_s = ws.at(in2).toScalar();
|
||||
if (!ws.count(out)) {
|
||||
ws.emplace(out, create_empty_from(in0_t));
|
||||
}
|
||||
auto out_t = ws.at(out).toTensor();
|
||||
at::native::clamp_out(out_t, in0_t, in1_s, in2_s);
|
||||
};
|
||||
} else if (n->kind() == c10::Symbol::fromQualString("aten::bmm")) {
|
||||
auto out = n->outputs().at(0);
|
||||
auto in0 = n->inputs().at(0);
|
||||
auto in1 = n->inputs().at(1);
|
||||
return [=](StaticRuntime::ConstantMap& ws) {
|
||||
auto in0_t = ws.at(in0).toTensor();
|
||||
auto in1_t = ws.at(in1).toTensor();
|
||||
if (!ws.count(out)) {
|
||||
ws.emplace(out, create_empty_from(in0_t));
|
||||
}
|
||||
auto out_t = ws.at(out).toTensor();
|
||||
at::native::bmm_out_cpu(out_t, in0_t, in1_t);
|
||||
};
|
||||
} else if (n->kind() == c10::Symbol::fromQualString("aten::cat")) {
|
||||
auto out = n->outputs().at(0);
|
||||
auto in0 = n->inputs().at(0);
|
||||
auto in1 = n->inputs().at(1);
|
||||
return [=](StaticRuntime::ConstantMap& ws) {
|
||||
auto in0_tl = ws.at(in0).toTensorVector();
|
||||
auto in1_i = ws.at(in1).toInt();
|
||||
if (!ws.count(out)) {
|
||||
ws.emplace(out, create_empty_from(in0_tl[0]));
|
||||
}
|
||||
auto out_t = ws.at(out).toTensor();
|
||||
at::native::cat_out(out_t, in0_tl, in1_i);
|
||||
};
|
||||
} else if (n->kind() == c10::Symbol::fromQualString("aten::sigmoid")) {
|
||||
auto out = n->outputs().at(0);
|
||||
auto in0 = n->inputs().at(0);
|
||||
return [=](StaticRuntime::ConstantMap& ws) {
|
||||
auto in0_t = ws.at(in0).toTensor();
|
||||
if (!ws.count(out)) {
|
||||
ws.emplace(out, create_empty_from(in0_t));
|
||||
}
|
||||
auto out_t = ws.at(out).toTensor();
|
||||
at::native::sigmoid_out(out_t, in0_t);
|
||||
};
|
||||
}
|
||||
|
||||
return [](StaticRuntime::ConstantMap&) { TORCH_CHECK(0); };
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
41
torch/csrc/jit/runtime/static/ops.h
Normal file
41
torch/csrc/jit/runtime/static/ops.h
Normal file
@ -0,0 +1,41 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/runtime/static/impl.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
bool canRunOutOfPlace(Node* n);
|
||||
std::function<void(StaticRuntime::ConstantMap&)> getOutOfPlaceOperation(
|
||||
Node* n);
|
||||
|
||||
#define SUPPORTED_OPS(F) \
|
||||
F(aten::__getitem__) \
|
||||
F(aten::add) \
|
||||
F(aten::addmm) \
|
||||
F(aten::bmm) \
|
||||
F(aten::cat) \
|
||||
F(aten::clamp) \
|
||||
F(aten::contiguous) \
|
||||
F(aten::div) \
|
||||
F(aten::flatten) \
|
||||
F(aten::index_put_) \
|
||||
F(aten::isnan) \
|
||||
F(aten::matmul) \
|
||||
F(aten::mul) \
|
||||
F(aten::permute) \
|
||||
F(aten::relu) \
|
||||
F(aten::sigmoid) \
|
||||
F(aten::size) \
|
||||
F(aten::softmax) \
|
||||
F(aten::t) \
|
||||
F(aten::to) \
|
||||
F(aten::transpose) \
|
||||
F(aten::view) \
|
||||
F(prim::Constant) \
|
||||
F(prim::ListConstruct) \
|
||||
F(prim::TupleConstruct)
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
Reference in New Issue
Block a user