Add a demo backend with compiler (#52603)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52603

This PR introduced a backend with minimum compilation capability to the to_<backend> flow. The targets are:

- Demonstrate the end-to-end flow with adding a backend -> compilation -> runtime
- How the backend compilation errors be surfaced to the user, with the original model's source code information. (C++ only in this PR. Python APIs will be demonstrated in a following PR.)

Changes:

- Compilation

1. A backend with minimum compilation features, "backend_with_compiler_demo" is added.
2. The compilation happens AOT in the ```pre_process``` function registered to this backend.
3. Compiled results are stored in a string blob for each method. They are serialized to the lowered module with ```__get_state__``` function.
4. Error message with model source code is thrown, for features not handled by the backend compiler.

- Runtime

1. The compiled blob is loaded in ```__set_state__``` method.
2. The ```compile``` function of the backend pass through the AOT compiled blob. (TODO: parsing the blob to the format that the backend can understand can happen here.)
3. The ```execute``` function of the backend executes the specified method (handle).

Test Plan:
- ```BackendTest.TestCompiler```: the C++ end-to-end demonstration on a supported model. After compilation and running, the lowered model produces the same result as the original torchscript model.
- ```BackendTest.TestCompilerNotSupport```: Demonstrate the error message from the AOT compilation for a feature not supported from the input module. The error message looks like:

```
"The node of aten::mul is not supported in this compiler. Source code:   File "<string>", line 3

    def forward(self, x, h):
        return x * h
               ~~~~~ <--- HERE
```

Reviewed By: raziel

Differential Revision: D26593968

Pulled By: iseeyuan

fbshipit-source-id: 8f264f60a0470e9f07e36fdeccbf17da6c1d7cd7
This commit is contained in:
Martin Yuan
2021-02-26 11:51:29 -08:00
committed by Facebook GitHub Bot
parent 502a85990d
commit b2520ab3dc
7 changed files with 267 additions and 18 deletions

View File

@ -11,9 +11,13 @@ target_link_libraries(torchbind_test torch)
add_library(jitbackend_test SHARED ${JIT_TEST_ROOT}/test_backend_lib.cpp)
target_link_libraries(jitbackend_test torch)
add_library(backend_with_compiler SHARED ${JIT_TEST_ROOT}/test_backend_compiler_lib.cpp)
target_link_libraries(backend_with_compiler torch)
if(INSTALL_TEST)
install(TARGETS torchbind_test DESTINATION lib)
install(TARGETS jitbackend_test DESTINATION lib)
install(TARGETS backend_with_compiler DESTINATION lib)
endif()
# Build the cpp gtest binary containing the cpp-only tests.
@ -74,7 +78,7 @@ if(USE_SYSTEM_ONNX)
target_link_libraries(test_jit PRIVATE onnx_proto onnx)
endif()
set(JIT_TEST_DEPENDENCIES jitbackend_test gtest)
set(JIT_TEST_DEPENDENCIES torch gtest jitbackend_test backend_with_compiler)
if(MSVC)
list(APPEND JIT_TEST_DEPENDENCIES onnx_library)
@ -85,7 +89,7 @@ target_include_directories(test_jit PRIVATE ${ATen_CPU_INCLUDE})
if(LINUX)
#Update to target_link_options when CMake version can be upgraded
target_link_libraries(test_jit PRIVATE "-Wl,--no-as-needed,$<TARGET_FILE:jitbackend_test>,--as-needed")
target_link_libraries(test_jit PRIVATE "-Wl,--no-as-needed,$<TARGET_FILE:jitbackend_test>,$<TARGET_FILE:backend_with_compiler>,--as-needed")
endif()
if(USE_CUDA)

View File

@ -1,4 +1,5 @@
#include <gtest/gtest.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/backends/backend_detail.h>
#include <torch/torch.h>
@ -32,9 +33,92 @@ TEST(BackendTest, ToBackend) {
// lowered module
auto lm = torch::jit::detail::codegen_backend_module(
"test_backend", m, compile_spec, any_dict_ty);
// lowered module code:
/*
class test_backendLoweredModule(Module):
__parameters__ = []
__buffers__ = []
__processed_module : Any
__method_compile_spec : Dict[str, Any]
__backend : __torch__.torch.classes.__backends__.test_backend
__handles : Dict[str, Any]
def __create_backend(self: torch.jit.test_backendLoweredModule) -> None:
_0 =
__torch__.torch.classes.__backends__.test_backend.__new__(__torch__.torch.classes.__backends__.test_backend)
_1 = (_0).__init__()
self.__backend = _0
return None
def __getstate__(self: torch.jit.test_backendLoweredModule) ->
Tuple[Dict[str, Any], Any]: _2 = (self.__method_compile_spec,
self.__processed_module) return _2 def __setstate__(self:
torch.jit.test_backendLoweredModule, state: Tuple[Dict[str, Any], Any]) ->
None: self.__method_compile_spec = (state)[0] self.__processed_module =
(state)[1] _3 = (self).__create_backend() _4 =
(self.__backend).compile(self.__processed_module,
self.__method_compile_spec, ) self.__handles = _4 return None def
forward(self: torch.jit.test_backendLoweredModule, x: Tensor, h: Tensor) ->
Tuple[Tensor, Tensor]: _5 = uninitialized(Tensor) typed_inputs =
annotate(List[Any], [x, h]) _6 =
(self.__backend).execute((self.__handles)["forward"], typed_inputs, ) _7,
_8, = _6 _9 = isinstance(_7, Tensor) if _9: _10 = unchecked_cast(Tensor, _7)
else:
ops.prim.RaiseException("AssertionError: ")
_10 = _5
_11 = isinstance(_8, Tensor)
if _11:
_12 = unchecked_cast(Tensor, _8)
else:
ops.prim.RaiseException("AssertionError: ")
_12 = _5
return (_10, _12)
*/
auto res = lm.forward(inputs).toTuple()->elements();
AT_ASSERT(res[0].toTensor().equal(ref[0].toTensor()));
AT_ASSERT(res[1].toTensor().equal(ref[1].toTensor()));
}
TEST(BackendTest, TestCompiler) {
Module m("m");
m.define(R"(
def forward(self, x, h):
return x + h
)");
std::vector<IValue> inputs;
inputs.emplace_back(2.0 * torch::ones({}));
inputs.emplace_back(1.0 * torch::ones({}));
auto ref = m.forward(inputs);
c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
fake_dict.insert("", "");
compile_spec.insert("forward", fake_dict);
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
// lowered module
auto lm = torch::jit::detail::codegen_backend_module(
"backend_with_compiler_demo", m, compile_spec, any_dict_ty);
auto res = lm.forward(inputs);
AT_ASSERT(res.toTensor().equal(ref.toTensor()));
}
TEST(BackendTest, TestCompilerNotSupport) {
Module m("m");
m.define(R"(
def forward(self, x, h):
return x * h
)");
c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
fake_dict.insert("", "");
compile_spec.insert("forward", fake_dict);
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
// lowered module
ASSERT_THROWS_WITH_MESSAGE(
torch::jit::detail::codegen_backend_module(
"backend_with_compiler_demo", m, compile_spec, any_dict_ty),
"The node of aten::mul is not supported in this compiler. Source code:");
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,158 @@
#include <torch/csrc/jit/backends/backend.h>
namespace torch {
namespace jit {
// Implementation of a PyTorch Backend that can process, compile and execute
// TorchScript Modules composed of 'add' and 'sub' operators. It just supports
// for modules that implement a sum or subtraction of 2 inputs (i.e. in1 + in2
// or in1 - in2). Hence the methods of the models expect exactly 2 inputs of
// type Tensor. This backend is used to demonstrate the flow of compilation and
// execution with minimum amount of work. It's not intended to a practical
// backend that can be used for actual inference.
// Implementation details:
//
// Compilation
// 1. A backend with minimum compilation features, "backend_with_compiler_demo"
// is added.
// 2. The compilation happens AOT in the preprocess function registered to this
// backend.
// 3. Compiled results are stored in a string blob for each method. They are
// serialized to the lowered module with __getstate__ function.
// 4. Error message with model source code is thrown, for features not handled
// by the backend compiler.
//
// Runtime
// 1. The compiled blob is loaded in __setstate__ method.
// 2. The compile function of the backend: parse the preprocessed blob to the
// format (a list of tokens) that the backend can understand.
// 3. The execute function of the backend executes the specified method
// (handle).
namespace {
std::vector<std::string> parseMethodHandle(const std::string& blob) {
std::vector<std::string> result;
std::stringstream s_stream(blob);
while (s_stream.good()) {
std::string substr;
getline(s_stream, substr, ',');
result.push_back(substr);
}
return result;
}
} // namespace
class BackendWithCompiler : public PyTorchBackendInterface {
public:
// Constructor.
explicit BackendWithCompiler() {}
virtual ~BackendWithCompiler() = default;
// Since the actual compilation is done AOT,
c10::impl::GenericDict compile(
c10::IValue processed,
c10::impl::GenericDict method_compile_spec) override {
auto dict = processed.toGenericDict();
auto handles = c10::Dict<std::string, std::vector<std::string>>();
for (const auto& kv : dict) {
auto tokens = parseMethodHandle(kv.value().toStringRef());
handles.insert(kv.key().toStringRef(), tokens);
}
return c10::impl::toGenericDict(handles);
}
c10::impl::GenericList execute(
c10::IValue handle,
c10::impl::GenericList inputs) override {
TORCH_INTERNAL_ASSERT(inputs.size() == 2);
c10::IValue val0 = inputs[0];
at::Tensor x = val0.toTensor();
c10::IValue val1 = inputs[1];
at::Tensor h = val1.toTensor();
c10::List<at::Tensor> output_list;
double scalar_val = 1.0;
for (const auto& token : handle.toList()) {
IValue val = token;
auto instruction = std::string(IValue(token).toStringRef());
double const_val = 1.0;
if (instruction.rfind("prim::Constant", 0) == 0) {
TORCH_CHECK(
instruction.size() > 15,
"Constant value is expected in ",
instruction);
auto sub = instruction.substr(15);
const_val = stod(sub);
} else if (token == "aten::add") {
output_list.emplace_back(x.add_(h, const_val));
} else if (token == "aten::sub") {
output_list.emplace_back(x.sub_(h, const_val));
} else {
TORCH_CHECK(
false,
"Instruction, ",
instruction,
" is not supported. ",
"Contact the backend POC for details. ");
}
}
return c10::impl::toList(output_list);
}
};
namespace {
// For this backend, the actual compilation happens in preprocess function AOT.
// Put here for demonstration of backend
// as a whole piece. It's used when compilation is required. A dummy function
// can be passed when there's no usage of compilation in runtime backend lib.
c10::IValue preprocess(
const Module& mod,
const c10::Dict<IValue, IValue>& method_compile_spec) {
// The output of this process would produce a dictionary
// Key: method name.
// Val: compiled blob (represented by a string).
c10::Dict<IValue, IValue> compiled(StringType::get(), StringType::get());
for (const auto& method : mod.get_methods()) {
const auto graph = method.function().graph()->copy();
auto key = method.name();
std::stringstream ss;
for (const auto& node : graph->nodes()) {
switch (node->kind()) {
case prim::Constant:
ss << node->kind().toDisplayString() << "#"
<< toIValue(node->output()).value();
break;
case aten::add:
ss << node->kind().toQualString();
break;
case aten::sub:
ss << node->kind().toQualString();
break;
default:
TORCH_CHECK(
false,
"The node of ",
node->kind().toQualString(),
" is not supported in this compiler. Source code: ",
node->sourceRange().str());
break;
}
ss << ",";
}
std::string blob = ss.str();
if (!blob.empty()) {
blob.pop_back();
}
compiled.insert(method.name(), blob);
}
return compiled;
}
static auto cls = torch::jit::backend<BackendWithCompiler>(
"backend_with_compiler_demo",
preprocess);
} // namespace
} // namespace jit
} // namespace torch

View File

@ -61,15 +61,15 @@ class TestBackend : public PyTorchBackendInterface {
}
};
namespace {
c10::IValue preprocess(
const Module& mod,
const c10::Dict<IValue, IValue>& method_compile_spec) {
return mod._ivalue();
}
namespace {
static auto cls = torch::jit::backend<TestBackend>("test_backend", preprocess);
}
} // namespace
} // namespace jit
} // namespace torch

View File

@ -1,3 +1,5 @@
#include <test/cpp/jit/test_utils.h>
#include <gtest/gtest.h>
#include <c10/core/TensorOptions.h>
@ -13,14 +15,6 @@
#include <unordered_set>
#define ASSERT_THROWS_WITH(statement, substring) \
try { \
(void)statement; \
ASSERT_TRUE(false); \
} catch (const std::exception& e) { \
ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \
}
// Tests go in torch::jit
namespace torch {
namespace jit {
@ -281,7 +275,7 @@ TEST(LiteInterpreterTest, LoadOrigJit) {
)");
std::stringstream ss;
m.save(ss);
ASSERT_THROWS_WITH(_load_for_mobile(ss), "file not found");
ASSERT_THROWS_WITH_MESSAGE(_load_for_mobile(ss), "file not found");
}
TEST(LiteInterpreterTest, WrongMethodName) {
@ -298,7 +292,8 @@ TEST(LiteInterpreterTest, WrongMethodName) {
std::vector<IValue> inputs;
auto minput = 5 * torch::ones({});
inputs.emplace_back(minput);
ASSERT_THROWS_WITH(bc.get_method("forward")(inputs), "is not defined");
ASSERT_THROWS_WITH_MESSAGE(
bc.get_method("forward")(inputs), "is not defined");
}
TEST(LiteInterpreterTest, SetState) {

View File

@ -1,9 +1,17 @@
#pragma once
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/runtime/autodiff.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/testing/file_check.h>
#include "torch/csrc/jit/ir/irparser.h"
#include "torch/csrc/jit/runtime/autodiff.h"
#include "torch/csrc/jit/runtime/interpreter.h"
#define ASSERT_THROWS_WITH_MESSAGE(statement, substring) \
try { \
(void)statement; \
FAIL(); \
} catch (const std::exception& e) { \
ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \
}
namespace torch {
namespace jit {

View File

@ -131,7 +131,7 @@ class class_ {
/// taking an `int` and a `std::string` as argument.
template <typename... Types>
class_& def(
detail::types<void, Types...>,
torch::detail::types<void, Types...>,
std::string doc_string = "",
std::initializer_list<arg> default_args = {}) { // Used in combination with
// torch::init<...>()