mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add fallback() to torch.library (#131707)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131707 Approved by: https://github.com/zou3519
This commit is contained in:
@ -36,7 +36,8 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
|||||||
c10::DispatchKey,
|
c10::DispatchKey,
|
||||||
c10::DispatchKeySet keyset,
|
c10::DispatchKeySet keyset,
|
||||||
torch::jit::Stack* stack,
|
torch::jit::Stack* stack,
|
||||||
bool with_keyset) const override {
|
bool with_keyset,
|
||||||
|
bool with_op) const override {
|
||||||
PANIC(python_op_registration_trampoline);
|
PANIC(python_op_registration_trampoline);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -150,7 +150,8 @@ struct C10_API PyInterpreterVTable {
|
|||||||
c10::DispatchKey,
|
c10::DispatchKey,
|
||||||
c10::DispatchKeySet keyset,
|
c10::DispatchKeySet keyset,
|
||||||
torch::jit::Stack* stack,
|
torch::jit::Stack* stack,
|
||||||
bool with_keyset) const = 0;
|
bool with_keyset,
|
||||||
|
bool with_op) const = 0;
|
||||||
|
|
||||||
virtual void throw_abstract_impl_not_imported_error(
|
virtual void throw_abstract_impl_not_imported_error(
|
||||||
std::string opname,
|
std::string opname,
|
||||||
|
|||||||
@ -69,6 +69,137 @@ class TestPythonRegistration(TestCase):
|
|||||||
if hasattr(torch.ops, self.test_ns):
|
if hasattr(torch.ops, self.test_ns):
|
||||||
del torch.ops._test_python_registration
|
del torch.ops._test_python_registration
|
||||||
|
|
||||||
|
def test_fallback(self) -> None:
|
||||||
|
test_key = "TESTING_ONLY_GenericMode"
|
||||||
|
test_keyset = torch._C.DispatchKeySet(test_key)
|
||||||
|
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
|
||||||
|
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
|
||||||
|
|
||||||
|
with _scoped_library("_", "IMPL") as my_lib:
|
||||||
|
expected_op = None
|
||||||
|
expected_args = None
|
||||||
|
expected_kwargs = None
|
||||||
|
# Use this out shape to make sure the result from our fallback
|
||||||
|
# is what is returned to the user
|
||||||
|
out_shape = None
|
||||||
|
|
||||||
|
def my_fallback(op, *args, **kwargs):
|
||||||
|
# Disable our handler during checks and generating the output
|
||||||
|
with torch._C._ForceDispatchKeyGuard(
|
||||||
|
include_to_set, exclude_to_set | test_keyset
|
||||||
|
):
|
||||||
|
self.assertIs(op, expected_op)
|
||||||
|
self.assertEqual(args, expected_args)
|
||||||
|
self.assertEqual(kwargs, expected_kwargs)
|
||||||
|
# Return something specific
|
||||||
|
return torch.empty(out_shape)
|
||||||
|
|
||||||
|
my_lib.fallback(my_fallback, test_key)
|
||||||
|
|
||||||
|
a, b = torch.rand(2), torch.rand(2)
|
||||||
|
|
||||||
|
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
|
||||||
|
# Check a factory function
|
||||||
|
expected_op = torch.ops.aten.empty.memory_format
|
||||||
|
expected_args = ((2, 2),)
|
||||||
|
# Extra kwargs to bypass issues with default args in factory functions
|
||||||
|
expected_kwargs = {
|
||||||
|
"dtype": torch.float64,
|
||||||
|
"pin_memory": False,
|
||||||
|
"device": torch.device("cpu"),
|
||||||
|
}
|
||||||
|
out_shape = (3,)
|
||||||
|
out = torch.empty(*expected_args, **expected_kwargs)
|
||||||
|
self.assertEqual(out.size(), out_shape)
|
||||||
|
|
||||||
|
# Check a regular function
|
||||||
|
expected_op = torch.ops.aten.add.Tensor
|
||||||
|
expected_args = (a, b)
|
||||||
|
expected_kwargs = {}
|
||||||
|
out_shape = (4,)
|
||||||
|
out = a + b
|
||||||
|
self.assertEqual(out.size(), out_shape)
|
||||||
|
|
||||||
|
def test_fallback_keyset(self) -> None:
|
||||||
|
test_key_first = "TESTING_ONLY_GenericMode"
|
||||||
|
test_key_second = "TESTING_ONLY_GenericWrapper"
|
||||||
|
test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet(
|
||||||
|
test_key_second
|
||||||
|
)
|
||||||
|
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
|
||||||
|
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
|
||||||
|
|
||||||
|
with _scoped_library("_", "IMPL") as my_lib:
|
||||||
|
first_called = False
|
||||||
|
second_called = False
|
||||||
|
|
||||||
|
def first_fallback(keyset, op, *args, **kwargs):
|
||||||
|
nonlocal first_called
|
||||||
|
if second_called:
|
||||||
|
# Recursive call
|
||||||
|
first_called = True
|
||||||
|
with torch._C._ForceDispatchKeyGuard(
|
||||||
|
include_to_set, exclude_to_set | test_keyset
|
||||||
|
):
|
||||||
|
return op(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
# Redispatch down
|
||||||
|
keyset = keyset.remove(test_key_first)
|
||||||
|
return op.redispatch(keyset, *args, **kwargs)
|
||||||
|
|
||||||
|
def second_fallback(op, *args, **kwargs):
|
||||||
|
nonlocal second_called
|
||||||
|
# Set to avoid infinite recursion
|
||||||
|
second_called = True
|
||||||
|
# New dispatcher call should hit the first callback again
|
||||||
|
self.assertFalse(first_called)
|
||||||
|
a, b = args
|
||||||
|
# Make a substraction here instead of add !
|
||||||
|
c = a - b
|
||||||
|
self.assertTrue(first_called)
|
||||||
|
return c
|
||||||
|
|
||||||
|
my_lib.fallback(first_fallback, test_key_first, with_keyset=True)
|
||||||
|
my_lib.fallback(second_fallback, test_key_second)
|
||||||
|
|
||||||
|
a, b = torch.rand(2), torch.rand(2)
|
||||||
|
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
|
||||||
|
c = a + b
|
||||||
|
|
||||||
|
self.assertEqual(c, a - b)
|
||||||
|
self.assertTrue(first_called)
|
||||||
|
self.assertTrue(second_called)
|
||||||
|
|
||||||
|
def test_fallback_fallthrough(self) -> None:
|
||||||
|
test_key_first = "TESTING_ONLY_GenericMode"
|
||||||
|
test_key_second = "TESTING_ONLY_GenericWrapper"
|
||||||
|
test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet(
|
||||||
|
test_key_second
|
||||||
|
)
|
||||||
|
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
|
||||||
|
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
|
||||||
|
|
||||||
|
with _scoped_library("_", "IMPL") as my_lib:
|
||||||
|
is_called = False
|
||||||
|
|
||||||
|
def my_fallback(op, *args, **kwargs):
|
||||||
|
nonlocal is_called
|
||||||
|
is_called = True
|
||||||
|
with torch._C._ForceDispatchKeyGuard(
|
||||||
|
include_to_set, exclude_to_set | test_keyset
|
||||||
|
):
|
||||||
|
return op(*args, **kwargs)
|
||||||
|
|
||||||
|
my_lib.fallback(torch.library.fallthrough_kernel, test_key_first)
|
||||||
|
my_lib.fallback(my_fallback, test_key_second)
|
||||||
|
|
||||||
|
a, b = torch.rand(2), torch.rand(2)
|
||||||
|
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
|
||||||
|
c = a + b
|
||||||
|
|
||||||
|
self.assertEqual(c, a + b)
|
||||||
|
self.assertTrue(is_called)
|
||||||
|
|
||||||
def test_override_aten_ops_with_multiple_libraries(self) -> None:
|
def test_override_aten_ops_with_multiple_libraries(self) -> None:
|
||||||
x = torch.tensor([1, 2])
|
x = torch.tensor([1, 2])
|
||||||
with _scoped_library("aten", "IMPL") as my_lib2:
|
with _scoped_library("aten", "IMPL") as my_lib2:
|
||||||
|
|||||||
@ -12,6 +12,8 @@ using namespace torch;
|
|||||||
using namespace at;
|
using namespace at;
|
||||||
using namespace c10;
|
using namespace c10;
|
||||||
|
|
||||||
|
namespace torch::detail {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// NB: This is a macro and not a template function (like it was before)
|
// NB: This is a macro and not a template function (like it was before)
|
||||||
@ -62,9 +64,10 @@ struct ConcretePyInterpreterVTable final
|
|||||||
c10::DispatchKey key,
|
c10::DispatchKey key,
|
||||||
c10::DispatchKeySet keyset,
|
c10::DispatchKeySet keyset,
|
||||||
torch::jit::Stack* stack,
|
torch::jit::Stack* stack,
|
||||||
bool with_keyset) const override {
|
bool with_keyset,
|
||||||
|
bool with_op) const override {
|
||||||
torch::impl::dispatch::python_op_registration_trampoline_impl(
|
torch::impl::dispatch::python_op_registration_trampoline_impl(
|
||||||
op, key, keyset, stack, with_keyset);
|
op, key, keyset, stack, with_keyset, with_op);
|
||||||
}
|
}
|
||||||
void throw_abstract_impl_not_imported_error(
|
void throw_abstract_impl_not_imported_error(
|
||||||
std::string opname,
|
std::string opname,
|
||||||
@ -272,30 +275,6 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
|
|||||||
Py_DECREF(pyobj);
|
Py_DECREF(pyobj);
|
||||||
};
|
};
|
||||||
|
|
||||||
py::handle getTorchApiFunction(const c10::OperatorHandle& op) {
|
|
||||||
return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* {
|
|
||||||
// Parse the name into namespace and name (no overload_name)
|
|
||||||
// TODO: put this into the library
|
|
||||||
const auto& schema = op.schema();
|
|
||||||
const auto& qualified_name = op.operator_name().name;
|
|
||||||
const auto& overload_name = schema.overload_name();
|
|
||||||
auto pos = qualified_name.find("::");
|
|
||||||
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
|
|
||||||
// Make me some null terminated strings
|
|
||||||
std::string ns_str = qualified_name.substr(0, pos);
|
|
||||||
const char* ns = ns_str.c_str();
|
|
||||||
const char* func_name = qualified_name.c_str() + pos + strlen("::");
|
|
||||||
|
|
||||||
py::handle torch_api_function =
|
|
||||||
py::module::import("torch").attr("ops").attr(ns).attr(func_name);
|
|
||||||
if (overload_name.empty()) {
|
|
||||||
return torch_api_function.attr("default").ptr();
|
|
||||||
} else {
|
|
||||||
return torch_api_function.attr(overload_name.c_str()).ptr();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isPythonTensor(const at::Tensor& tensor) {
|
bool isPythonTensor(const at::Tensor& tensor) {
|
||||||
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
|
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
|
||||||
}
|
}
|
||||||
@ -956,20 +935,46 @@ void ConcretePyInterpreterVTable::reset_backward_hooks(
|
|||||||
END_HANDLE_TH_ERRORS_PYBIND
|
END_HANDLE_TH_ERRORS_PYBIND
|
||||||
}
|
}
|
||||||
|
|
||||||
PyInterpreterHolder self_interpreter;
|
|
||||||
|
|
||||||
} // anonymous namespace
|
|
||||||
|
|
||||||
c10::impl::PyInterpreter* getPyInterpreter() {
|
|
||||||
return self_interpreter.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isMainPyInterpreter() {
|
|
||||||
return self_interpreter.is_main_interpreter();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string ConcretePyInterpreterVTable::name() const {
|
std::string ConcretePyInterpreterVTable::name() const {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << getPyInterpreter();
|
ss << getPyInterpreter();
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyInterpreterHolder self_interpreter;
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
py::handle getTorchApiFunction(const c10::OperatorHandle& op) {
|
||||||
|
return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* {
|
||||||
|
// Parse the name into namespace and name (no overload_name)
|
||||||
|
// TODO: put this into the library
|
||||||
|
const auto& schema = op.schema();
|
||||||
|
const auto& qualified_name = op.operator_name().name;
|
||||||
|
const auto& overload_name = schema.overload_name();
|
||||||
|
auto pos = qualified_name.find("::");
|
||||||
|
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
|
||||||
|
// Make me some null terminated strings
|
||||||
|
std::string ns_str = qualified_name.substr(0, pos);
|
||||||
|
const char* ns = ns_str.c_str();
|
||||||
|
const char* func_name = qualified_name.c_str() + pos + strlen("::");
|
||||||
|
|
||||||
|
py::handle torch_api_function =
|
||||||
|
py::module::import("torch").attr("ops").attr(ns).attr(func_name);
|
||||||
|
if (overload_name.empty()) {
|
||||||
|
return torch_api_function.attr("default").ptr();
|
||||||
|
} else {
|
||||||
|
return torch_api_function.attr(overload_name.c_str()).ptr();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace torch::detail
|
||||||
|
|
||||||
|
c10::impl::PyInterpreter* getPyInterpreter() {
|
||||||
|
return torch::detail::self_interpreter.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isMainPyInterpreter() {
|
||||||
|
return torch::detail::self_interpreter.is_main_interpreter();
|
||||||
|
}
|
||||||
|
|||||||
@ -2,6 +2,12 @@
|
|||||||
|
|
||||||
#include <c10/core/impl/PyInterpreter.h>
|
#include <c10/core/impl/PyInterpreter.h>
|
||||||
#include <torch/csrc/Export.h>
|
#include <torch/csrc/Export.h>
|
||||||
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
|
namespace torch::detail {
|
||||||
|
TORCH_PYTHON_API py::handle getTorchApiFunction(const c10::OperatorHandle& op);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Move these to a proper namespace
|
||||||
TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter();
|
TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter();
|
||||||
TORCH_PYTHON_API bool isMainPyInterpreter();
|
TORCH_PYTHON_API bool isMainPyInterpreter();
|
||||||
|
|||||||
@ -108,15 +108,19 @@ class PythonKernelHolder : public c10::OperatorKernel {
|
|||||||
c10::DispatchKey dispatch_key_;
|
c10::DispatchKey dispatch_key_;
|
||||||
// If "with_keyset", then we expect a keyset as the first arg.
|
// If "with_keyset", then we expect a keyset as the first arg.
|
||||||
bool with_keyset_;
|
bool with_keyset_;
|
||||||
|
// If "with_op", then we expect the op as first arg (or second if keyset)
|
||||||
|
bool with_op_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
PythonKernelHolder(
|
PythonKernelHolder(
|
||||||
py::object func,
|
py::object func,
|
||||||
c10::DispatchKey dispatch_key,
|
c10::DispatchKey dispatch_key,
|
||||||
bool with_keyset = false)
|
bool with_keyset = false,
|
||||||
|
bool with_op = false)
|
||||||
: func_(func.release().ptr(), getPyInterpreter()),
|
: func_(func.release().ptr(), getPyInterpreter()),
|
||||||
dispatch_key_(dispatch_key),
|
dispatch_key_(dispatch_key),
|
||||||
with_keyset_(with_keyset) {}
|
with_keyset_(with_keyset),
|
||||||
|
with_op_(with_op) {}
|
||||||
|
|
||||||
void operator()(
|
void operator()(
|
||||||
const c10::OperatorHandle& op,
|
const c10::OperatorHandle& op,
|
||||||
@ -132,7 +136,7 @@ class PythonKernelHolder : public c10::OperatorKernel {
|
|||||||
c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
|
c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
|
||||||
cur_torch_dispatch_mode_state->pyinterpreter()
|
cur_torch_dispatch_mode_state->pyinterpreter()
|
||||||
->python_op_registration_trampoline(
|
->python_op_registration_trampoline(
|
||||||
op, dispatch_key_, keyset, stack, with_keyset_);
|
op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -150,7 +154,7 @@ class PythonKernelHolder : public c10::OperatorKernel {
|
|||||||
at::DispatchKey::Python)) {
|
at::DispatchKey::Python)) {
|
||||||
(*interpreter)
|
(*interpreter)
|
||||||
->python_op_registration_trampoline(
|
->python_op_registration_trampoline(
|
||||||
op, dispatch_key_, keyset, stack, with_keyset_);
|
op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
|
} else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
|
||||||
@ -166,7 +170,7 @@ class PythonKernelHolder : public c10::OperatorKernel {
|
|||||||
nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) {
|
nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) {
|
||||||
(*interpreter)
|
(*interpreter)
|
||||||
->python_op_registration_trampoline(
|
->python_op_registration_trampoline(
|
||||||
op, dispatch_key_, keyset, stack, with_keyset_);
|
op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -189,9 +193,18 @@ class PythonKernelHolder : public c10::OperatorKernel {
|
|||||||
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
|
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
|
||||||
auto func =
|
auto func =
|
||||||
py::reinterpret_borrow<py::object>(func_.ptr(getPyInterpreter()));
|
py::reinterpret_borrow<py::object>(func_.ptr(getPyInterpreter()));
|
||||||
auto obj = with_keyset_
|
auto obj = with_op_ ? with_keyset_
|
||||||
? func(keyset, *args_kwargs.first, **args_kwargs.second)
|
? func(
|
||||||
: func(*args_kwargs.first, **args_kwargs.second);
|
keyset,
|
||||||
|
torch::detail::getTorchApiFunction(op),
|
||||||
|
*args_kwargs.first,
|
||||||
|
**args_kwargs.second)
|
||||||
|
: func(
|
||||||
|
torch::detail::getTorchApiFunction(op),
|
||||||
|
*args_kwargs.first,
|
||||||
|
**args_kwargs.second)
|
||||||
|
: with_keyset_ ? func(keyset, *args_kwargs.first, **args_kwargs.second)
|
||||||
|
: func(*args_kwargs.first, **args_kwargs.second);
|
||||||
if (!obj) {
|
if (!obj) {
|
||||||
throw python_error();
|
throw python_error();
|
||||||
}
|
}
|
||||||
@ -461,7 +474,33 @@ void initDispatchBindings(PyObject* module) {
|
|||||||
return self;
|
return self;
|
||||||
},
|
},
|
||||||
"",
|
"",
|
||||||
py::arg("dispatch") = "");
|
py::arg("dispatch") = "")
|
||||||
|
.def(
|
||||||
|
"fallback",
|
||||||
|
[](const py::object& self,
|
||||||
|
c10::DispatchKey dispatch,
|
||||||
|
const py::object& func,
|
||||||
|
bool with_keyset) {
|
||||||
|
HANDLE_TH_ERRORS
|
||||||
|
auto& lib = self.cast<torch::Library&>();
|
||||||
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
||||||
|
if (func.is(py::module::import("torch.library")
|
||||||
|
.attr("fallthrough_kernel"))) {
|
||||||
|
lib.fallback(
|
||||||
|
torch::dispatch(dispatch, CppFunction::makeFallthrough()));
|
||||||
|
} else {
|
||||||
|
lib.fallback(torch::dispatch(
|
||||||
|
dispatch,
|
||||||
|
CppFunction::makeFromBoxedFunctor(
|
||||||
|
std::make_unique<PythonKernelHolder>(
|
||||||
|
func, dispatch, with_keyset, /*with_op*/ true))));
|
||||||
|
}
|
||||||
|
END_HANDLE_TH_ERRORS_PYBIND
|
||||||
|
},
|
||||||
|
"",
|
||||||
|
py::arg("dispatch"),
|
||||||
|
py::arg("func"),
|
||||||
|
py::arg("with_keyset") = false);
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"_dispatch_library",
|
"_dispatch_library",
|
||||||
@ -954,7 +993,8 @@ void python_op_registration_trampoline_impl(
|
|||||||
c10::DispatchKey key,
|
c10::DispatchKey key,
|
||||||
c10::DispatchKeySet keyset,
|
c10::DispatchKeySet keyset,
|
||||||
torch::jit::Stack* stack,
|
torch::jit::Stack* stack,
|
||||||
bool with_keyset) {
|
bool with_keyset,
|
||||||
|
bool with_op) {
|
||||||
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
|
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
|
||||||
py::gil_scoped_acquire g;
|
py::gil_scoped_acquire g;
|
||||||
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
|
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
|
||||||
@ -963,9 +1003,17 @@ void python_op_registration_trampoline_impl(
|
|||||||
auto* pyobj = func->ptr(getPyInterpreter());
|
auto* pyobj = func->ptr(getPyInterpreter());
|
||||||
TORCH_INTERNAL_ASSERT(pyobj != nullptr);
|
TORCH_INTERNAL_ASSERT(pyobj != nullptr);
|
||||||
auto callable = py::reinterpret_borrow<py::object>(pyobj);
|
auto callable = py::reinterpret_borrow<py::object>(pyobj);
|
||||||
auto obj = with_keyset
|
auto obj = with_op ? with_keyset ? callable(
|
||||||
? callable(keyset, *args_kwargs.first, **args_kwargs.second)
|
keyset,
|
||||||
: callable(*args_kwargs.first, **args_kwargs.second);
|
torch::detail::getTorchApiFunction(op),
|
||||||
|
*args_kwargs.first,
|
||||||
|
**args_kwargs.second)
|
||||||
|
: callable(
|
||||||
|
torch::detail::getTorchApiFunction(op),
|
||||||
|
*args_kwargs.first,
|
||||||
|
**args_kwargs.second)
|
||||||
|
: with_keyset ? callable(keyset, *args_kwargs.first, **args_kwargs.second)
|
||||||
|
: callable(*args_kwargs.first, **args_kwargs.second);
|
||||||
if (!obj) {
|
if (!obj) {
|
||||||
throw python_error();
|
throw python_error();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -10,6 +10,7 @@ void python_op_registration_trampoline_impl(
|
|||||||
c10::DispatchKey key,
|
c10::DispatchKey key,
|
||||||
c10::DispatchKeySet keyset,
|
c10::DispatchKeySet keyset,
|
||||||
torch::jit::Stack* stack,
|
torch::jit::Stack* stack,
|
||||||
bool with_keyset);
|
bool with_keyset,
|
||||||
|
bool with_op);
|
||||||
|
|
||||||
} // namespace torch::impl::dispatch
|
} // namespace torch::impl::dispatch
|
||||||
|
|||||||
@ -278,6 +278,8 @@ class Library:
|
|||||||
to register a fallthrough.
|
to register a fallthrough.
|
||||||
dispatch_key: dispatch key that the input function should be registered for. By default, it uses
|
dispatch_key: dispatch key that the input function should be registered for. By default, it uses
|
||||||
the dispatch key that the library was created with.
|
the dispatch key that the library was created with.
|
||||||
|
with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument
|
||||||
|
to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
>>> my_lib = Library("aten", "IMPL")
|
>>> my_lib = Library("aten", "IMPL")
|
||||||
@ -345,6 +347,43 @@ class Library:
|
|||||||
_impls.add(key)
|
_impls.add(key)
|
||||||
self._op_impls.add(key)
|
self._op_impls.add(key)
|
||||||
|
|
||||||
|
def fallback(self, fn, dispatch_key="", *, with_keyset=False):
|
||||||
|
r"""Registers the function implementation as the fallback for the given key.
|
||||||
|
|
||||||
|
This function only works for a library with global namespace ("_").
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: function used as fallback for the given dispatch key or :func:`~fallthrough_kernel`
|
||||||
|
to register a fallthrough.
|
||||||
|
dispatch_key: dispatch key that the input function should be registered for. By default, it uses
|
||||||
|
the dispatch key that the library was created with.
|
||||||
|
with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument
|
||||||
|
to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
>>> my_lib = Library("_", "IMPL")
|
||||||
|
>>> def fallback_kernel(op, *args, **kwargs):
|
||||||
|
>>> # Handle all autocast ops generically
|
||||||
|
>>> # ...
|
||||||
|
>>> my_lib.fallback(fallback_kernel, "Autocast")
|
||||||
|
"""
|
||||||
|
if dispatch_key == "":
|
||||||
|
dispatch_key = self.dispatch_key
|
||||||
|
|
||||||
|
if self.ns != "_":
|
||||||
|
raise RuntimeError(
|
||||||
|
f"""Fallback can only be registered using libary fragment on the global namespace "_" but it is {self.ns}"""
|
||||||
|
)
|
||||||
|
|
||||||
|
assert dispatch_key != ""
|
||||||
|
assert self.m is not None
|
||||||
|
|
||||||
|
self.m.fallback(
|
||||||
|
dispatch_key,
|
||||||
|
fn,
|
||||||
|
with_keyset,
|
||||||
|
)
|
||||||
|
|
||||||
def _destroy(self):
|
def _destroy(self):
|
||||||
if self.m is not None:
|
if self.m is not None:
|
||||||
self.m.reset()
|
self.m.reset()
|
||||||
|
|||||||
Reference in New Issue
Block a user