mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add a demo backend with compiler (#52603)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52603 This PR introduced a backend with minimum compilation capability to the to_<backend> flow. The targets are: - Demonstrate the end-to-end flow with adding a backend -> compilation -> runtime - How the backend compilation errors be surfaced to the user, with the original model's source code information. (C++ only in this PR. Python APIs will be demonstrated in a following PR.) Changes: - Compilation 1. A backend with minimum compilation features, "backend_with_compiler_demo" is added. 2. The compilation happens AOT in the ```pre_process``` function registered to this backend. 3. Compiled results are stored in a string blob for each method. They are serialized to the lowered module with ```__get_state__``` function. 4. Error message with model source code is thrown, for features not handled by the backend compiler. - Runtime 1. The compiled blob is loaded in ```__set_state__``` method. 2. The ```compile``` function of the backend pass through the AOT compiled blob. (TODO: parsing the blob to the format that the backend can understand can happen here.) 3. The ```execute``` function of the backend executes the specified method (handle). Test Plan: - ```BackendTest.TestCompiler```: the C++ end-to-end demonstration on a supported model. After compilation and running, the lowered model produces the same result as the original torchscript model. - ```BackendTest.TestCompilerNotSupport```: Demonstrate the error message from the AOT compilation for a feature not supported from the input module. The error message looks like: ``` "The node of aten::mul is not supported in this compiler. Source code: File "<string>", line 3 def forward(self, x, h): return x * h ~~~~~ <--- HERE ``` Reviewed By: raziel Differential Revision: D26593968 Pulled By: iseeyuan fbshipit-source-id: 8f264f60a0470e9f07e36fdeccbf17da6c1d7cd7
This commit is contained in:
committed by
Facebook GitHub Bot
parent
502a85990d
commit
b2520ab3dc
@ -11,9 +11,13 @@ target_link_libraries(torchbind_test torch)
|
||||
add_library(jitbackend_test SHARED ${JIT_TEST_ROOT}/test_backend_lib.cpp)
|
||||
target_link_libraries(jitbackend_test torch)
|
||||
|
||||
add_library(backend_with_compiler SHARED ${JIT_TEST_ROOT}/test_backend_compiler_lib.cpp)
|
||||
target_link_libraries(backend_with_compiler torch)
|
||||
|
||||
if(INSTALL_TEST)
|
||||
install(TARGETS torchbind_test DESTINATION lib)
|
||||
install(TARGETS jitbackend_test DESTINATION lib)
|
||||
install(TARGETS backend_with_compiler DESTINATION lib)
|
||||
endif()
|
||||
|
||||
# Build the cpp gtest binary containing the cpp-only tests.
|
||||
@ -74,7 +78,7 @@ if(USE_SYSTEM_ONNX)
|
||||
target_link_libraries(test_jit PRIVATE onnx_proto onnx)
|
||||
endif()
|
||||
|
||||
set(JIT_TEST_DEPENDENCIES jitbackend_test gtest)
|
||||
set(JIT_TEST_DEPENDENCIES torch gtest jitbackend_test backend_with_compiler)
|
||||
|
||||
if(MSVC)
|
||||
list(APPEND JIT_TEST_DEPENDENCIES onnx_library)
|
||||
@ -85,7 +89,7 @@ target_include_directories(test_jit PRIVATE ${ATen_CPU_INCLUDE})
|
||||
|
||||
if(LINUX)
|
||||
#Update to target_link_options when CMake version can be upgraded
|
||||
target_link_libraries(test_jit PRIVATE "-Wl,--no-as-needed,$<TARGET_FILE:jitbackend_test>,--as-needed")
|
||||
target_link_libraries(test_jit PRIVATE "-Wl,--no-as-needed,$<TARGET_FILE:jitbackend_test>,$<TARGET_FILE:backend_with_compiler>,--as-needed")
|
||||
endif()
|
||||
|
||||
if(USE_CUDA)
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/backends/backend_detail.h>
|
||||
#include <torch/torch.h>
|
||||
@ -32,9 +33,92 @@ TEST(BackendTest, ToBackend) {
|
||||
// lowered module
|
||||
auto lm = torch::jit::detail::codegen_backend_module(
|
||||
"test_backend", m, compile_spec, any_dict_ty);
|
||||
// lowered module code:
|
||||
/*
|
||||
class test_backendLoweredModule(Module):
|
||||
__parameters__ = []
|
||||
__buffers__ = []
|
||||
__processed_module : Any
|
||||
__method_compile_spec : Dict[str, Any]
|
||||
__backend : __torch__.torch.classes.__backends__.test_backend
|
||||
__handles : Dict[str, Any]
|
||||
def __create_backend(self: torch.jit.test_backendLoweredModule) -> None:
|
||||
_0 =
|
||||
__torch__.torch.classes.__backends__.test_backend.__new__(__torch__.torch.classes.__backends__.test_backend)
|
||||
_1 = (_0).__init__()
|
||||
self.__backend = _0
|
||||
return None
|
||||
def __getstate__(self: torch.jit.test_backendLoweredModule) ->
|
||||
Tuple[Dict[str, Any], Any]: _2 = (self.__method_compile_spec,
|
||||
self.__processed_module) return _2 def __setstate__(self:
|
||||
torch.jit.test_backendLoweredModule, state: Tuple[Dict[str, Any], Any]) ->
|
||||
None: self.__method_compile_spec = (state)[0] self.__processed_module =
|
||||
(state)[1] _3 = (self).__create_backend() _4 =
|
||||
(self.__backend).compile(self.__processed_module,
|
||||
self.__method_compile_spec, ) self.__handles = _4 return None def
|
||||
forward(self: torch.jit.test_backendLoweredModule, x: Tensor, h: Tensor) ->
|
||||
Tuple[Tensor, Tensor]: _5 = uninitialized(Tensor) typed_inputs =
|
||||
annotate(List[Any], [x, h]) _6 =
|
||||
(self.__backend).execute((self.__handles)["forward"], typed_inputs, ) _7,
|
||||
_8, = _6 _9 = isinstance(_7, Tensor) if _9: _10 = unchecked_cast(Tensor, _7)
|
||||
else:
|
||||
ops.prim.RaiseException("AssertionError: ")
|
||||
_10 = _5
|
||||
_11 = isinstance(_8, Tensor)
|
||||
if _11:
|
||||
_12 = unchecked_cast(Tensor, _8)
|
||||
else:
|
||||
ops.prim.RaiseException("AssertionError: ")
|
||||
_12 = _5
|
||||
return (_10, _12)
|
||||
|
||||
*/
|
||||
auto res = lm.forward(inputs).toTuple()->elements();
|
||||
AT_ASSERT(res[0].toTensor().equal(ref[0].toTensor()));
|
||||
AT_ASSERT(res[1].toTensor().equal(ref[1].toTensor()));
|
||||
}
|
||||
|
||||
TEST(BackendTest, TestCompiler) {
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
def forward(self, x, h):
|
||||
return x + h
|
||||
)");
|
||||
|
||||
std::vector<IValue> inputs;
|
||||
inputs.emplace_back(2.0 * torch::ones({}));
|
||||
inputs.emplace_back(1.0 * torch::ones({}));
|
||||
auto ref = m.forward(inputs);
|
||||
|
||||
c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
|
||||
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
|
||||
fake_dict.insert("", "");
|
||||
compile_spec.insert("forward", fake_dict);
|
||||
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
|
||||
// lowered module
|
||||
auto lm = torch::jit::detail::codegen_backend_module(
|
||||
"backend_with_compiler_demo", m, compile_spec, any_dict_ty);
|
||||
auto res = lm.forward(inputs);
|
||||
AT_ASSERT(res.toTensor().equal(ref.toTensor()));
|
||||
}
|
||||
|
||||
TEST(BackendTest, TestCompilerNotSupport) {
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
def forward(self, x, h):
|
||||
return x * h
|
||||
)");
|
||||
|
||||
c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
|
||||
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
|
||||
fake_dict.insert("", "");
|
||||
compile_spec.insert("forward", fake_dict);
|
||||
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
|
||||
// lowered module
|
||||
ASSERT_THROWS_WITH_MESSAGE(
|
||||
torch::jit::detail::codegen_backend_module(
|
||||
"backend_with_compiler_demo", m, compile_spec, any_dict_ty),
|
||||
"The node of aten::mul is not supported in this compiler. Source code:");
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
158
test/cpp/jit/test_backend_compiler_lib.cpp
Normal file
158
test/cpp/jit/test_backend_compiler_lib.cpp
Normal file
@ -0,0 +1,158 @@
|
||||
#include <torch/csrc/jit/backends/backend.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// Implementation of a PyTorch Backend that can process, compile and execute
|
||||
// TorchScript Modules composed of 'add' and 'sub' operators. It just supports
|
||||
// for modules that implement a sum or subtraction of 2 inputs (i.e. in1 + in2
|
||||
// or in1 - in2). Hence the methods of the models expect exactly 2 inputs of
|
||||
// type Tensor. This backend is used to demonstrate the flow of compilation and
|
||||
// execution with minimum amount of work. It's not intended to a practical
|
||||
// backend that can be used for actual inference.
|
||||
|
||||
// Implementation details:
|
||||
//
|
||||
// Compilation
|
||||
// 1. A backend with minimum compilation features, "backend_with_compiler_demo"
|
||||
// is added.
|
||||
// 2. The compilation happens AOT in the preprocess function registered to this
|
||||
// backend.
|
||||
// 3. Compiled results are stored in a string blob for each method. They are
|
||||
// serialized to the lowered module with __getstate__ function.
|
||||
// 4. Error message with model source code is thrown, for features not handled
|
||||
// by the backend compiler.
|
||||
//
|
||||
// Runtime
|
||||
// 1. The compiled blob is loaded in __setstate__ method.
|
||||
// 2. The compile function of the backend: parse the preprocessed blob to the
|
||||
// format (a list of tokens) that the backend can understand.
|
||||
// 3. The execute function of the backend executes the specified method
|
||||
// (handle).
|
||||
|
||||
namespace {
|
||||
std::vector<std::string> parseMethodHandle(const std::string& blob) {
|
||||
std::vector<std::string> result;
|
||||
std::stringstream s_stream(blob);
|
||||
while (s_stream.good()) {
|
||||
std::string substr;
|
||||
getline(s_stream, substr, ',');
|
||||
result.push_back(substr);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class BackendWithCompiler : public PyTorchBackendInterface {
|
||||
public:
|
||||
// Constructor.
|
||||
explicit BackendWithCompiler() {}
|
||||
virtual ~BackendWithCompiler() = default;
|
||||
|
||||
// Since the actual compilation is done AOT,
|
||||
c10::impl::GenericDict compile(
|
||||
c10::IValue processed,
|
||||
c10::impl::GenericDict method_compile_spec) override {
|
||||
auto dict = processed.toGenericDict();
|
||||
auto handles = c10::Dict<std::string, std::vector<std::string>>();
|
||||
for (const auto& kv : dict) {
|
||||
auto tokens = parseMethodHandle(kv.value().toStringRef());
|
||||
handles.insert(kv.key().toStringRef(), tokens);
|
||||
}
|
||||
return c10::impl::toGenericDict(handles);
|
||||
}
|
||||
|
||||
c10::impl::GenericList execute(
|
||||
c10::IValue handle,
|
||||
c10::impl::GenericList inputs) override {
|
||||
TORCH_INTERNAL_ASSERT(inputs.size() == 2);
|
||||
c10::IValue val0 = inputs[0];
|
||||
at::Tensor x = val0.toTensor();
|
||||
c10::IValue val1 = inputs[1];
|
||||
at::Tensor h = val1.toTensor();
|
||||
|
||||
c10::List<at::Tensor> output_list;
|
||||
double scalar_val = 1.0;
|
||||
for (const auto& token : handle.toList()) {
|
||||
IValue val = token;
|
||||
auto instruction = std::string(IValue(token).toStringRef());
|
||||
double const_val = 1.0;
|
||||
if (instruction.rfind("prim::Constant", 0) == 0) {
|
||||
TORCH_CHECK(
|
||||
instruction.size() > 15,
|
||||
"Constant value is expected in ",
|
||||
instruction);
|
||||
auto sub = instruction.substr(15);
|
||||
const_val = stod(sub);
|
||||
} else if (token == "aten::add") {
|
||||
output_list.emplace_back(x.add_(h, const_val));
|
||||
} else if (token == "aten::sub") {
|
||||
output_list.emplace_back(x.sub_(h, const_val));
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Instruction, ",
|
||||
instruction,
|
||||
" is not supported. ",
|
||||
"Contact the backend POC for details. ");
|
||||
}
|
||||
}
|
||||
return c10::impl::toList(output_list);
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
// For this backend, the actual compilation happens in preprocess function AOT.
|
||||
// Put here for demonstration of backend
|
||||
// as a whole piece. It's used when compilation is required. A dummy function
|
||||
// can be passed when there's no usage of compilation in runtime backend lib.
|
||||
c10::IValue preprocess(
|
||||
const Module& mod,
|
||||
const c10::Dict<IValue, IValue>& method_compile_spec) {
|
||||
// The output of this process would produce a dictionary
|
||||
// Key: method name.
|
||||
// Val: compiled blob (represented by a string).
|
||||
c10::Dict<IValue, IValue> compiled(StringType::get(), StringType::get());
|
||||
for (const auto& method : mod.get_methods()) {
|
||||
const auto graph = method.function().graph()->copy();
|
||||
auto key = method.name();
|
||||
std::stringstream ss;
|
||||
for (const auto& node : graph->nodes()) {
|
||||
switch (node->kind()) {
|
||||
case prim::Constant:
|
||||
ss << node->kind().toDisplayString() << "#"
|
||||
<< toIValue(node->output()).value();
|
||||
break;
|
||||
case aten::add:
|
||||
ss << node->kind().toQualString();
|
||||
break;
|
||||
case aten::sub:
|
||||
ss << node->kind().toQualString();
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"The node of ",
|
||||
node->kind().toQualString(),
|
||||
" is not supported in this compiler. Source code: ",
|
||||
node->sourceRange().str());
|
||||
break;
|
||||
}
|
||||
ss << ",";
|
||||
}
|
||||
std::string blob = ss.str();
|
||||
if (!blob.empty()) {
|
||||
blob.pop_back();
|
||||
}
|
||||
compiled.insert(method.name(), blob);
|
||||
}
|
||||
return compiled;
|
||||
}
|
||||
|
||||
static auto cls = torch::jit::backend<BackendWithCompiler>(
|
||||
"backend_with_compiler_demo",
|
||||
preprocess);
|
||||
} // namespace
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -61,15 +61,15 @@ class TestBackend : public PyTorchBackendInterface {
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
c10::IValue preprocess(
|
||||
const Module& mod,
|
||||
const c10::Dict<IValue, IValue>& method_compile_spec) {
|
||||
return mod._ivalue();
|
||||
}
|
||||
|
||||
namespace {
|
||||
static auto cls = torch::jit::backend<TestBackend>("test_backend", preprocess);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -1,3 +1,5 @@
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/core/TensorOptions.h>
|
||||
@ -13,14 +15,6 @@
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#define ASSERT_THROWS_WITH(statement, substring) \
|
||||
try { \
|
||||
(void)statement; \
|
||||
ASSERT_TRUE(false); \
|
||||
} catch (const std::exception& e) { \
|
||||
ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \
|
||||
}
|
||||
|
||||
// Tests go in torch::jit
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -281,7 +275,7 @@ TEST(LiteInterpreterTest, LoadOrigJit) {
|
||||
)");
|
||||
std::stringstream ss;
|
||||
m.save(ss);
|
||||
ASSERT_THROWS_WITH(_load_for_mobile(ss), "file not found");
|
||||
ASSERT_THROWS_WITH_MESSAGE(_load_for_mobile(ss), "file not found");
|
||||
}
|
||||
|
||||
TEST(LiteInterpreterTest, WrongMethodName) {
|
||||
@ -298,7 +292,8 @@ TEST(LiteInterpreterTest, WrongMethodName) {
|
||||
std::vector<IValue> inputs;
|
||||
auto minput = 5 * torch::ones({});
|
||||
inputs.emplace_back(minput);
|
||||
ASSERT_THROWS_WITH(bc.get_method("forward")(inputs), "is not defined");
|
||||
ASSERT_THROWS_WITH_MESSAGE(
|
||||
bc.get_method("forward")(inputs), "is not defined");
|
||||
}
|
||||
|
||||
TEST(LiteInterpreterTest, SetState) {
|
||||
|
@ -1,9 +1,17 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
#include <torch/csrc/jit/runtime/autodiff.h>
|
||||
#include <torch/csrc/jit/runtime/interpreter.h>
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
#include "torch/csrc/jit/ir/irparser.h"
|
||||
#include "torch/csrc/jit/runtime/autodiff.h"
|
||||
#include "torch/csrc/jit/runtime/interpreter.h"
|
||||
|
||||
#define ASSERT_THROWS_WITH_MESSAGE(statement, substring) \
|
||||
try { \
|
||||
(void)statement; \
|
||||
FAIL(); \
|
||||
} catch (const std::exception& e) { \
|
||||
ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \
|
||||
}
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
@ -131,7 +131,7 @@ class class_ {
|
||||
/// taking an `int` and a `std::string` as argument.
|
||||
template <typename... Types>
|
||||
class_& def(
|
||||
detail::types<void, Types...>,
|
||||
torch::detail::types<void, Types...>,
|
||||
std::string doc_string = "",
|
||||
std::initializer_list<arg> default_args = {}) { // Used in combination with
|
||||
// torch::init<...>()
|
||||
|
Reference in New Issue
Block a user