mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow users to override kernels for existing C++ ops through Python
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75905 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
32ae584008
commit
55f55a4cf6
@ -7,7 +7,7 @@ namespace torch {
|
||||
namespace {
|
||||
// TODO: Consider representing debug info as a struct instead so you
|
||||
// don't have to allocate strings all the time
|
||||
std::string debugString(const char* file, uint32_t line) {
|
||||
std::string debugString(std::string file, uint32_t line) {
|
||||
#ifdef STRIP_ERROR_MESSAGES
|
||||
return std::string();
|
||||
#else
|
||||
@ -15,7 +15,7 @@ namespace {
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string debugString(std::string debug, const char* file, uint32_t line) {
|
||||
std::string debugString(std::string debug, std::string file, uint32_t line) {
|
||||
#ifdef STRIP_ERROR_MESSAGES
|
||||
return std::string();
|
||||
#else
|
||||
|
@ -3,6 +3,7 @@
|
||||
import tempfile
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
from torch.library import Library
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, LoggingTensorMode, \
|
||||
log_input, capture_logs, no_dispatch
|
||||
@ -12,6 +13,86 @@ from torch.utils._python_dispatch import enable_torch_dispatch_mode, push_torch_
|
||||
import logging
|
||||
from functools import partial
|
||||
|
||||
class TestPythonRegistration(TestCase):
|
||||
def test_override_aten_ops_with_multiple_libraries(self) -> None:
|
||||
x = torch.tensor([1, 2])
|
||||
my_lib1 = Library("aten", "IMPL")
|
||||
my_lib2 = Library("aten", "IMPL")
|
||||
|
||||
# Example 1
|
||||
def my_neg(*args, **kwargs):
|
||||
return args[0]._neg_view()
|
||||
|
||||
# Now we are secretly making the operator a view op so autograd needs to know how
|
||||
# to handle it
|
||||
my_lib1.impl('neg', my_neg, "AutogradCPU")
|
||||
|
||||
self.assertTrue(torch.neg(x).is_neg())
|
||||
|
||||
# RuntimeError: impl("aten::neg", ...):
|
||||
# Explicitly provided namespace (aten) in operator name does not match ...
|
||||
with self.assertRaisesRegex(RuntimeError, "operator name does not match namespace"):
|
||||
my_lib3 = Library("foo", "IMPL")
|
||||
my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU")
|
||||
del my_lib3
|
||||
|
||||
# Example 2
|
||||
def my_mul(*args, **kwargs):
|
||||
return torch.zeros_like(args[0])
|
||||
|
||||
# torch.ops.aten.mul.Tensor
|
||||
my_lib2.impl("aten::mul.Tensor", my_mul, "ZeroTensor")
|
||||
|
||||
y = torch._efficientzerotensor(2)
|
||||
self.assertFalse(torch.mul(x, y)._is_zerotensor())
|
||||
|
||||
# Assert that a user can't override the behavior of a (ns, op, dispatch_key)
|
||||
# combination if someone overrided the behavior for the same before them
|
||||
with self.assertRaisesRegex(RuntimeError, 'already a kernel registered from python'):
|
||||
my_lib2.impl(torch.ops.aten.mul.Tensor, my_mul, "ZeroTensor")
|
||||
|
||||
del my_lib1
|
||||
|
||||
# Validate that lib2 is not affected by removing lib1
|
||||
self.assertFalse(torch.mul(x, y)._is_zerotensor())
|
||||
|
||||
del my_lib2
|
||||
|
||||
# Validate that the old behavior is restored for neg and mul
|
||||
self.assertFalse(torch.neg(x).is_neg())
|
||||
self.assertTrue(torch.mul(x, y)._is_zerotensor())
|
||||
|
||||
def test_override_cpu_sum(self) -> None:
|
||||
# Example 1
|
||||
run = [False]
|
||||
|
||||
def my_sum(*args, **kwargs):
|
||||
run[0] = True
|
||||
return args[0]
|
||||
|
||||
my_lib1 = Library("aten", "IMPL")
|
||||
my_lib1.impl('aten::sum', my_sum, "CPU")
|
||||
x = torch.tensor([1, 2])
|
||||
self.assertEqual(torch.sum(x), x)
|
||||
self.assertTrue(run[0])
|
||||
del my_lib1
|
||||
# Validate that the old behavior is restored for sum
|
||||
self.assertEqual(torch.sum(x), torch.tensor(3))
|
||||
|
||||
def test_extend_library_with_dispatch_key_arg(self):
|
||||
def my_sum(*args, **kwargs):
|
||||
return args[0]
|
||||
my_lib1 = Library("aten", "IMPL", dispatch_key="CPU")
|
||||
|
||||
# RuntimeError: Explicitly provided dispatch key (Conjugate) is
|
||||
# inconsistent with the dispatch key of the enclosing TORCH_LIBRARY_IMPL block
|
||||
with self.assertRaisesRegex(RuntimeError, "inconsistent with the dispatch key"):
|
||||
my_lib1.impl('sum', my_sum, "Conjugate")
|
||||
my_lib1.impl('aten::sum', my_sum)
|
||||
x = torch.tensor([1, 2])
|
||||
self.assertEqual(torch.sum(x), x)
|
||||
del my_lib1
|
||||
|
||||
class TestPythonDispatch(TestCase):
|
||||
def test_basic(self) -> None:
|
||||
with capture_logs() as logs:
|
||||
|
@ -829,6 +829,10 @@ class Generator(object):
|
||||
def seed(self) -> _int: ...
|
||||
def initial_seed(self) -> _int: ...
|
||||
|
||||
|
||||
# Defined in torch/csrc/utils/python_dispatch.cpp
|
||||
def _dispatch_library(kind: str, name: str, dispatch: str, file: str, linenum: Any) -> Any: ...
|
||||
|
||||
# Defined in torch/csrc/utils/init.cpp
|
||||
class BenchmarkConfig(object):
|
||||
num_calling_threads: _int
|
||||
|
@ -918,3 +918,4 @@ def _register_device_module(device_type, module):
|
||||
|
||||
# expose return_types
|
||||
from . import return_types
|
||||
from . import library
|
||||
|
@ -4,7 +4,6 @@
|
||||
#include <c10/core/SafePyObject.h>
|
||||
#include <c10/util/DeadlockDetection.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <torch/csrc/Device.h>
|
||||
#include <torch/csrc/DynamicTypes.h>
|
||||
#include <torch/csrc/Exceptions.h>
|
||||
@ -25,18 +24,19 @@
|
||||
#include <torch/csrc/jit/frontend/tracer.h>
|
||||
#include <torch/csrc/tensor/python_tensor.h>
|
||||
#include <torch/csrc/utils/cuda_lazy_init.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/pycfunction_helpers.h>
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
#include <torch/csrc/utils/tensor_new.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
|
||||
#include <torch/library.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include <structmember.h>
|
||||
#include <cstdint>
|
||||
@ -51,6 +51,91 @@ using namespace at;
|
||||
using namespace torch;
|
||||
using namespace torch::autograd;
|
||||
|
||||
std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs(const c10::OperatorHandle& op, const std::vector<c10::IValue>& arguments) {
|
||||
TORCH_CHECK(PyGILState_Check(), "GIL must be held before you call parseIValuesToPyArgsKwargs");
|
||||
const auto& schema = op.schema();
|
||||
py::dict kwargs;
|
||||
// About all the pointers:
|
||||
//
|
||||
// f(int x, int y = 0, *, int z = 0)
|
||||
// ^- arguments.size()
|
||||
// ^- kwarg_only_start
|
||||
// ^- positional_default_start
|
||||
// ^- 0
|
||||
|
||||
// Find the split point between kwarg-only and regular. Since most functions
|
||||
// don't have kwarg-only arguments, it is more efficient to scan from the
|
||||
// right (but ideally, this would just be precomputed in FunctionSchema
|
||||
// itself). (NB: minus one in the loop is because we're testing if the
|
||||
// *next* argument is kwarg-only before we advance the starting index)
|
||||
int64_t kwarg_only_start = arguments.size();
|
||||
for (; kwarg_only_start > 0; kwarg_only_start--) {
|
||||
const auto& arg = schema.arguments()[kwarg_only_start - 1];
|
||||
if (!arg.kwarg_only()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Find the first positional argument that isn't defaulted
|
||||
auto is_default = [&](int64_t idx) -> bool {
|
||||
const auto& arg = schema.arguments()[idx];
|
||||
if (!arg.default_value().has_value()) {
|
||||
return false;
|
||||
}
|
||||
const auto& default_ivalue = *arg.default_value();
|
||||
const auto& ivalue = arguments[idx];
|
||||
if (default_ivalue != ivalue) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
int64_t positional_default_start = kwarg_only_start;
|
||||
for (; positional_default_start > 0; positional_default_start--) {
|
||||
if (!is_default(positional_default_start - 1)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto args = py::reinterpret_steal<py::object>(PyTuple_New(positional_default_start));
|
||||
|
||||
// Populate positional arguments
|
||||
for (const auto idx : c10::irange(positional_default_start)) {
|
||||
PyTuple_SET_ITEM(args.ptr(), idx, torch::jit::toPyObject(arguments[idx]).release().ptr());
|
||||
}
|
||||
|
||||
// Populate keyword arguments
|
||||
for (const auto idx : c10::irange(kwarg_only_start, arguments.size())) {
|
||||
// But don't populate default keyword arguments
|
||||
if (is_default(idx)) continue;
|
||||
const auto& arg = schema.arguments()[idx];
|
||||
kwargs[py::cast(arg.name())] = torch::jit::toPyObject(arguments[idx]);
|
||||
}
|
||||
return std::make_pair(std::move(args), std::move(kwargs));
|
||||
}
|
||||
|
||||
void pushPyOutToStack(
|
||||
const c10::OperatorHandle& op,
|
||||
torch::jit::Stack* stack,
|
||||
py::object out,
|
||||
const char* msg) {
|
||||
TORCH_CHECK(PyGILState_Check(), "GIL must be held before you call pushPyOutToStack");
|
||||
auto schema_returns = op.schema().returns();
|
||||
const auto num_returns = schema_returns.size();
|
||||
if (num_returns == 0) {
|
||||
// Check that we got a None return from Python. Anything else is an error.
|
||||
TORCH_CHECK(out.is(py::none()), "Expected ", msg, " for ", op.operator_name(),
|
||||
" to return None but it returned something else instead.");
|
||||
} else if (num_returns == 1) {
|
||||
torch::jit::push(stack, torch::jit::toIValue(out.ptr(), schema_returns[0].type()));
|
||||
} else {
|
||||
auto outs = py::cast<py::sequence>(out);
|
||||
for (const auto idx : c10::irange(outs.size())) {
|
||||
torch::jit::push(stack, torch::jit::toIValue(outs[idx].ptr(), schema_returns[idx].type()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
std::string concrete_name_fn(const c10::impl::PyInterpreter* self) {
|
||||
@ -131,8 +216,6 @@ c10::impl::PyInterpreter* getPyInterpreter() {
|
||||
return self_interpreter.get();
|
||||
}
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PyObject *THPVariableClass = nullptr;
|
||||
|
||||
PyObject *ParameterClass = nullptr;
|
||||
@ -1751,8 +1834,6 @@ void concrete_dispatch_fn(
|
||||
torch::jit::Stack* stack,
|
||||
const std::shared_ptr<SafePyObject>& type) {
|
||||
const auto& schema = op.schema();
|
||||
const auto num_returns = schema.returns().size();
|
||||
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
auto arguments = torch::jit::pop(*stack, num_arguments);
|
||||
|
||||
@ -1776,9 +1857,6 @@ void concrete_dispatch_fn(
|
||||
py::gil_scoped_acquire g;
|
||||
|
||||
std::vector<py::handle> overloaded_args;
|
||||
// For now, overloads get coalesced. Might be easier for users if they get
|
||||
// overload resolution but is more complicated (need to expose separate
|
||||
// functions per overload)
|
||||
py::handle torch_api_function = py::module::import("torch").attr("ops").attr(ns).attr(func_name);
|
||||
py::handle torch_api_function_overload;
|
||||
if (overload_name == "") {
|
||||
@ -1788,51 +1866,6 @@ void concrete_dispatch_fn(
|
||||
}
|
||||
std::string module_name_str = "torch.ops." + ns_str;
|
||||
|
||||
// About all the pointers:
|
||||
//
|
||||
// f(int x, int y = 0, *, int z = 0)
|
||||
// ^- arguments.size()
|
||||
// ^- kwarg_only_start
|
||||
// ^- positional_default_start
|
||||
// ^- 0
|
||||
|
||||
// Find the split point between kwarg-only and regular. Since most functions
|
||||
// don't have kwarg-only arguments, it is more efficient to scan from the
|
||||
// right (but ideally, this would just be precomputed in FunctionSchema
|
||||
// itself). (NB: minus one in the loop is because we're testing if the
|
||||
// *next* argument is kwarg-only before we advance the starting index)
|
||||
int64_t kwarg_only_start = arguments.size();
|
||||
for (; kwarg_only_start > 0; kwarg_only_start--) {
|
||||
const auto& arg = schema.arguments()[kwarg_only_start - 1];
|
||||
if (!arg.kwarg_only()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Find the first positional argument that isn't defaulted
|
||||
auto is_default = [&](int64_t idx) -> bool {
|
||||
const auto& arg = schema.arguments()[idx];
|
||||
if (!arg.default_value().has_value()) {
|
||||
return false;
|
||||
}
|
||||
const auto& default_ivalue = *arg.default_value();
|
||||
const auto& ivalue = arguments[idx];
|
||||
if (default_ivalue != ivalue) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
int64_t positional_default_start = kwarg_only_start;
|
||||
for (; positional_default_start > 0; positional_default_start--) {
|
||||
if (!is_default(positional_default_start - 1)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto args = py::reinterpret_steal<py::object>(PyTuple_New(positional_default_start));
|
||||
py::dict kwargs;
|
||||
|
||||
if (type) {
|
||||
append_overloaded_type(&overloaded_args, type->ptr(getPyInterpreter()));
|
||||
}
|
||||
@ -1859,41 +1892,19 @@ void concrete_dispatch_fn(
|
||||
}
|
||||
}
|
||||
|
||||
// Populate positional arguments
|
||||
for (const auto idx : c10::irange(positional_default_start)) {
|
||||
PyTuple_SET_ITEM(args.ptr(), idx, torch::jit::toPyObject(std::move(arguments[idx])).release().ptr());
|
||||
}
|
||||
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
|
||||
auto args = std::move(args_kwargs.first);
|
||||
auto kwargs = std::move(args_kwargs.second);
|
||||
|
||||
// Populate keyword arguments
|
||||
for (const auto idx : c10::irange(kwarg_only_start, arguments.size())) {
|
||||
// But don't populate default keyword arguments
|
||||
if (is_default(idx)) continue;
|
||||
const auto& arg = schema.arguments()[idx];
|
||||
kwargs[py::cast(arg.name())] = torch::jit::toPyObject(std::move(arguments[idx]));
|
||||
}
|
||||
|
||||
auto out = py::reinterpret_steal<py::object>(
|
||||
handle_torch_function_no_python_arg_parser(
|
||||
overloaded_args,
|
||||
args.ptr(),
|
||||
kwargs.ptr(),
|
||||
func_name,
|
||||
torch_api_function_overload.ptr(),
|
||||
module_name_str.c_str(),
|
||||
TorchFunctionName::TorchDispatch));
|
||||
|
||||
if (num_returns == 0) {
|
||||
// Check that we got a None return from Python. Anything else is an error.
|
||||
TORCH_CHECK(out.is(py::none()), "Expected __torch_dispatch__ for ", op.operator_name(),
|
||||
" to return None but it returned something else instead.");
|
||||
} else if (num_returns == 1) {
|
||||
torch::jit::push(stack, torch::jit::toIValue(out.ptr(), op.schema().returns()[0].type()));
|
||||
} else {
|
||||
auto outs = py::cast<py::sequence>(out);
|
||||
for (const auto idx : c10::irange(outs.size())) {
|
||||
torch::jit::push(stack, torch::jit::toIValue(outs[idx].ptr(), op.schema().returns()[idx].type()));
|
||||
}
|
||||
}
|
||||
PyObject* obj = handle_torch_function_no_python_arg_parser(
|
||||
overloaded_args,
|
||||
args.ptr(),
|
||||
kwargs.ptr(),
|
||||
func_name,
|
||||
torch_api_function_overload.ptr(),
|
||||
module_name_str.c_str(),
|
||||
TorchFunctionName::TorchDispatch);
|
||||
pushPyOutToStack(op, stack, py::reinterpret_steal<py::object>(obj), "__torch_dispatch__");
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self) {
|
||||
|
@ -7,6 +7,10 @@
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/csrc/Exceptions.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <ATen/core/function_schema.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
// Python object that backs torch.autograd.Variable
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
@ -62,3 +66,7 @@ inline const at::Tensor& THPVariable_Unpack(PyObject* obj) {
|
||||
}
|
||||
|
||||
TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter();
|
||||
|
||||
std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs(const c10::OperatorHandle& op, const std::vector<c10::IValue>& arguments);
|
||||
|
||||
void pushPyOutToStack(const c10::OperatorHandle& op, torch::jit::Stack* stack, py::object out, const char* msg);
|
||||
|
@ -5,6 +5,10 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
|
||||
#include <c10/core/SafePyObject.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
#include <torch/csrc/autograd/python_variable.h>
|
||||
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
@ -50,6 +54,20 @@ inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) {
|
||||
}
|
||||
}
|
||||
|
||||
class PythonKernelHolder : public c10::OperatorKernel {
|
||||
SafePyObject func_;
|
||||
public:
|
||||
PythonKernelHolder(py::object func) : func_(func.release().ptr(), getPyInterpreter()) {}
|
||||
|
||||
void operator()(const c10::OperatorHandle& op, c10::DispatchKeySet keyset, torch::jit::Stack* stack) {
|
||||
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
|
||||
py::gil_scoped_acquire g;
|
||||
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
|
||||
auto obj = py::reinterpret_steal<py::object>(PyObject_Call(func_.ptr(getPyInterpreter()), args_kwargs.first.ptr(), args_kwargs.second.ptr()));
|
||||
pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
|
||||
}
|
||||
};
|
||||
|
||||
void initDispatchBindings(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
|
||||
@ -122,6 +140,14 @@ void initDispatchBindings(PyObject* module) {
|
||||
);
|
||||
return self;
|
||||
}, "", py::arg("name"), py::arg("dispatch") = "", py::arg("debug") = "")
|
||||
.def("impl", [](py::object self, const char* name, const char* dispatch, py::object func) {
|
||||
HANDLE_TH_ERRORS
|
||||
self.cast<torch::Library&>().impl(
|
||||
name,
|
||||
dispatch_str(dispatch, CppFunction::makeFromBoxedFunctor(std::make_unique<PythonKernelHolder>(std::move(func))))
|
||||
);
|
||||
END_HANDLE_TH_ERRORS_PYBIND
|
||||
}, "", py::arg("name"), py::arg("dispatch"), py::arg("func"))
|
||||
.def("fallback_fallthrough", [](py::object self, const char* dispatch) {
|
||||
self.cast<torch::Library&>().fallback(
|
||||
dispatch_str(dispatch, CppFunction::makeFallthrough())
|
||||
@ -130,14 +156,17 @@ void initDispatchBindings(PyObject* module) {
|
||||
}, "", py::arg("dispatch") = "")
|
||||
;
|
||||
|
||||
m.def("_dispatch_library", [](const char* kind, std::string name, const char* dispatch) {
|
||||
m.def("_dispatch_library", [](const char* kind, std::string name, const char* dispatch, const char* file, uint32_t linenum) {
|
||||
HANDLE_TH_ERRORS
|
||||
return std::make_unique<torch::Library>(
|
||||
parseKind(kind),
|
||||
std::move(name),
|
||||
std::string(dispatch) == "" ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(dispatch)),
|
||||
"/dev/null",
|
||||
0);
|
||||
});
|
||||
file,
|
||||
linenum);
|
||||
END_HANDLE_TH_ERRORS_PYBIND
|
||||
}, "", py::arg("kind"), py::arg("name"), py::arg("dispatch"), py::arg("file")="/dev/null", py::arg("linenum")=0)
|
||||
;
|
||||
|
||||
m.def("_dispatch_dump", [](const char* name) -> std::string {
|
||||
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
||||
|
@ -555,7 +555,7 @@ class TORCH_API Library final {
|
||||
Library& operator=(const Library&) = delete;
|
||||
Library(Library&&) = default;
|
||||
Library& operator=(Library&&) = default;
|
||||
|
||||
// TODO: add gen_python_error boolean flag
|
||||
// Some notes about the API design here. We had the following constraints:
|
||||
//
|
||||
// - We need to support multiple "types" of arguments for schema and
|
||||
@ -800,7 +800,7 @@ class TORCH_API Library final {
|
||||
Kind kind_;
|
||||
c10::optional<std::string> ns_;
|
||||
c10::optional<c10::DispatchKey> dispatch_key_;
|
||||
const char* file_;
|
||||
std::string file_;
|
||||
uint32_t line_;
|
||||
|
||||
std::vector<c10::RegistrationHandleRAII> registrars_;
|
||||
|
71
torch/library.py
Normal file
71
torch/library.py
Normal file
@ -0,0 +1,71 @@
|
||||
from ._ops import OpOverload
|
||||
from typing import Set
|
||||
import traceback
|
||||
import torch._C as C
|
||||
|
||||
__all__ = ['Library']
|
||||
|
||||
# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
|
||||
# The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.
|
||||
# This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid
|
||||
# libraries calling into kernels not intended to be called.
|
||||
_impls: Set[str] = set()
|
||||
|
||||
class Library:
|
||||
"""
|
||||
Class to create linraries that can be used to register new operators or
|
||||
override operators in existing libraries from Python.
|
||||
A user can pass in a dispatch keyname if they only want the library to override kernels corresponding
|
||||
to only one specific dispatch key.
|
||||
|
||||
Args:
|
||||
ns: library name
|
||||
kind: "IMPL" by default
|
||||
dispatch_key: PyTorch dispatch key (default: "")
|
||||
"""
|
||||
def __init__(self, ns, kind, dispatch_key=""):
|
||||
frame = traceback.extract_stack(limit=3)[0]
|
||||
filename, lineno = frame.filename, frame.lineno
|
||||
self.m = C._dispatch_library(kind, ns, dispatch_key, filename, lineno)
|
||||
self.ns = ns
|
||||
self._op_impls = set()
|
||||
self.kind = kind
|
||||
self.dispatch_key = dispatch_key
|
||||
if kind != "IMPL":
|
||||
raise ValueError("Unsupported kind: ", kind)
|
||||
|
||||
def __repr__(self):
|
||||
return "Library(kind={}, ns={}, dispatch_key={})>".format(self.kind, self.ns, self.dispatch_key)
|
||||
|
||||
def impl(self, op_name, fn, dispatch_key=''):
|
||||
if dispatch_key == '':
|
||||
if self.dispatch_key == '':
|
||||
raise RuntimeError("Please specify the dispatch key that you want to register the kernel for.")
|
||||
dispatch_key = self.dispatch_key
|
||||
|
||||
if isinstance(op_name, str):
|
||||
name = op_name
|
||||
elif isinstance(op_name, OpOverload):
|
||||
name = op_name._schema.name
|
||||
overload_name = op_name._schema.overload_name
|
||||
if overload_name != '':
|
||||
name = name + '.' + overload_name
|
||||
else:
|
||||
raise RuntimeError("impl should be passed either a name or an OpOverload object as the first argument")
|
||||
|
||||
key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
|
||||
if key in _impls:
|
||||
# TODO: in future, add more info about where the existing function is registered (this info is
|
||||
# today already returned by the C++ warning when impl is called but we error out before that)
|
||||
raise RuntimeError("This is not allowed since there's already a kernel registered from python overriding {}"
|
||||
"'s behavior for {} dispatch key and {} namespace.".
|
||||
format(name.split("::")[-1], dispatch_key, self.ns))
|
||||
|
||||
self.m.impl(name, dispatch_key, fn)
|
||||
_impls.add(key)
|
||||
self._op_impls.add(key)
|
||||
|
||||
def __del__(self):
|
||||
for key in self._op_impls:
|
||||
_impls.remove(key)
|
||||
del self.m
|
Reference in New Issue
Block a user