Files
pytorch/test/custom_backend/custom_backend.h
Kimish Patel 2ce21b2e61 [Pytorch backend delegation] Preprocess to accept (#58873)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58873

BackenDebugInforRecorder

Prior to this PR:
In order to generate debug handles corresponding to the graph being
lowered, backend's preprocess will call generate_debug_handles and will
get map of Node*-to-debug_handles.
In order to facilitate this, to_backend will own
BackendDebugInfoRecorder and initialize thread local pointer to it.
generate_debug_handle function will query thread local pointer to see if
there is a valid BackendDebugInforRecorder for the context. If there is
it will generate debug handles.

After this PR:
Signature of preprocess is changed such that backends have to register
preprocess that accepts instance of BackendDebugInfoRecorder by
reference. generate_debug_handles is no more a free function but becomes
part of the API of BackendDebugInfoRecorder. Now backend's preprocess
function will call generate_debug_handles on BackendDebugInfoRecorder
instead of free function.

Reason for this change:
With RAII that initializes thread local pointer, results in a lose
contract with backends, which may result in backends not storing
debug information. Making it part of API results in
backends having to be aware of BackendDebugInfoRecorder and explicitly
chosing not to generate/store debug information if they chose to do so.

Test Plan:
backend tests

Imported from OSS

Reviewed By: jbschlosser, raziel

Differential Revision: D28648613

fbshipit-source-id: c9b7e7bf0f78e87023ea7bc08612cf893b08cb98
2021-06-11 10:16:00 -07:00

92 lines
2.9 KiB
C++

#include <torch/csrc/jit/backends/backend.h>
#include <torch/csrc/jit/backends/backend_detail.h>
#include <torch/csrc/jit/api/module.h>
namespace torch {
namespace custom_backend {
// This custom JIT backend is intended to do the minimal amount of work
// necessary to test that the JIT backend registration endpoints and
// code generation are working correctly. It is not intended to
// produce numerically correct results.
class CustomBackend : public torch::jit::PyTorchBackendInterface {
public:
// Constructor.
explicit CustomBackend() {}
virtual ~CustomBackend() = default;
bool is_available() override {
return true;
}
c10::impl::GenericDict compile(
c10::IValue processed,
c10::impl::GenericDict method_compile_spec) override {
auto spec =
c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
// Return the same string as a value for every key in method_compile_spec.
auto handles = c10::Dict<std::string, std::string>();
for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
handles.insert(it->key(), it->key());
}
return c10::impl::toGenericDict(handles);
}
c10::impl::GenericList execute(
c10::IValue handle,
c10::impl::GenericList inputs) override {
TORCH_INTERNAL_ASSERT(handle.isString());
TORCH_INTERNAL_ASSERT(inputs.size() > 0);
c10::List<at::Tensor> output_list;
// Implement simple accumulator and negative accumulator (?) ops. Return one
// or both of them depending on the handle to make sure multiple outputs are
// handled.
c10::IValue value = inputs[0];
at::Tensor accum = value.toTensor();
accum = accum.clone();
at::Tensor sub_accum = value.toTensor();
sub_accum = sub_accum.clone();
for (size_t i = 1, e = inputs.size(); i < e; ++i) {
value = inputs[i];
accum.add_(value.toTensor(), 1.0);
sub_accum.sub_(value.toTensor(), 1.0);
}
if (handle.toStringRef() == "accum") {
output_list.emplace_back(accum);
} else if (handle.toStringRef() == "sub_accum") {
output_list.emplace_back(sub_accum);
} else if (handle.toStringRef() == "forward") {
output_list.emplace_back(accum);
output_list.emplace_back(sub_accum);
}
return c10::impl::toList(output_list);
}
};
c10::IValue preprocess(
const torch::jit::Module& mod,
const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec,
const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) {
return mod._ivalue();
}
// clang-format off
# if defined(_WIN32)
# if defined(custom_ops_EXPORTS)
# define CUSTOM_BACKEND_API __declspec(dllexport)
# else
# define CUSTOM_BACKEND_API __declspec(dllimport)
# endif
# else
# define CUSTOM_BACKEND_API
# endif
// clang-format on
CUSTOM_BACKEND_API std::string getBackendName();
} // namespace custom_backend
} // namespace torch