mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 06:11:27 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52603 This PR introduced a backend with minimum compilation capability to the to_<backend> flow. The targets are: - Demonstrate the end-to-end flow with adding a backend -> compilation -> runtime - How the backend compilation errors be surfaced to the user, with the original model's source code information. (C++ only in this PR. Python APIs will be demonstrated in a following PR.) Changes: - Compilation 1. A backend with minimum compilation features, "backend_with_compiler_demo" is added. 2. The compilation happens AOT in the ```pre_process``` function registered to this backend. 3. Compiled results are stored in a string blob for each method. They are serialized to the lowered module with ```__get_state__``` function. 4. Error message with model source code is thrown, for features not handled by the backend compiler. - Runtime 1. The compiled blob is loaded in ```__set_state__``` method. 2. The ```compile``` function of the backend pass through the AOT compiled blob. (TODO: parsing the blob to the format that the backend can understand can happen here.) 3. The ```execute``` function of the backend executes the specified method (handle). Test Plan: - ```BackendTest.TestCompiler```: the C++ end-to-end demonstration on a supported model. After compilation and running, the lowered model produces the same result as the original torchscript model. - ```BackendTest.TestCompilerNotSupport```: Demonstrate the error message from the AOT compilation for a feature not supported from the input module. The error message looks like: ``` "The node of aten::mul is not supported in this compiler. Source code: File "<string>", line 3 def forward(self, x, h): return x * h ~~~~~ <--- HERE ``` Reviewed By: raziel Differential Revision: D26593968 Pulled By: iseeyuan fbshipit-source-id: 8f264f60a0470e9f07e36fdeccbf17da6c1d7cd7
76 lines
2.3 KiB
C++
76 lines
2.3 KiB
C++
#include <torch/csrc/jit/backends/backend.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
// This test JIT backend is intended to do the minimal amount of work
|
|
// necessary to test that the JIT backend registration endpoints and
|
|
// code generation are working correctly. It is not intended to
|
|
// produce numerically correct results.
|
|
class TestBackend : public PyTorchBackendInterface {
|
|
public:
|
|
// Constructor.
|
|
explicit TestBackend() {}
|
|
virtual ~TestBackend() = default;
|
|
|
|
c10::impl::GenericDict compile(
|
|
c10::IValue processed,
|
|
c10::impl::GenericDict method_compile_spec) override {
|
|
auto spec =
|
|
c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
|
|
|
|
// Return the same string as a value for every key in method_compile_spec.
|
|
auto handles = c10::Dict<std::string, std::string>();
|
|
for (const auto& it : spec) {
|
|
handles.insert(it.key(), it.key());
|
|
}
|
|
return c10::impl::toGenericDict(handles);
|
|
}
|
|
c10::impl::GenericList execute(
|
|
c10::IValue handle,
|
|
c10::impl::GenericList inputs) override {
|
|
TORCH_INTERNAL_ASSERT(handle.isString());
|
|
TORCH_INTERNAL_ASSERT(inputs.size() > 0);
|
|
|
|
c10::List<at::Tensor> output_list;
|
|
|
|
// Implement simple accumulator and negative accumulator (?) ops. Return one
|
|
// or both of them depending on the handle to make sure multiple outputs are
|
|
// handled.
|
|
c10::IValue value = inputs[0];
|
|
at::Tensor accum = value.toTensor();
|
|
accum = accum.clone();
|
|
at::Tensor sub_accum = value.toTensor();
|
|
sub_accum = sub_accum.clone();
|
|
|
|
for (size_t i = 1, e = inputs.size(); i < e; ++i) {
|
|
value = inputs[i];
|
|
accum.add_(value.toTensor(), 1.0);
|
|
sub_accum.sub_(value.toTensor(), 1.0);
|
|
}
|
|
|
|
if (handle.toStringRef() == "accum") {
|
|
output_list.emplace_back(accum);
|
|
} else if (handle.toStringRef() == "sub_accum") {
|
|
output_list.emplace_back(sub_accum);
|
|
} else if (handle.toStringRef() == "forward") {
|
|
output_list.emplace_back(accum);
|
|
output_list.emplace_back(sub_accum);
|
|
}
|
|
|
|
return c10::impl::toList(output_list);
|
|
}
|
|
};
|
|
|
|
namespace {
|
|
c10::IValue preprocess(
|
|
const Module& mod,
|
|
const c10::Dict<IValue, IValue>& method_compile_spec) {
|
|
return mod._ivalue();
|
|
}
|
|
|
|
static auto cls = torch::jit::backend<TestBackend>("test_backend", preprocess);
|
|
} // namespace
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|