mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add fallback() to torch.library (#131707)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131707 Approved by: https://github.com/zou3519
This commit is contained in:
@ -36,7 +36,8 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
||||
c10::DispatchKey,
|
||||
c10::DispatchKeySet keyset,
|
||||
torch::jit::Stack* stack,
|
||||
bool with_keyset) const override {
|
||||
bool with_keyset,
|
||||
bool with_op) const override {
|
||||
PANIC(python_op_registration_trampoline);
|
||||
}
|
||||
|
||||
|
@ -150,7 +150,8 @@ struct C10_API PyInterpreterVTable {
|
||||
c10::DispatchKey,
|
||||
c10::DispatchKeySet keyset,
|
||||
torch::jit::Stack* stack,
|
||||
bool with_keyset) const = 0;
|
||||
bool with_keyset,
|
||||
bool with_op) const = 0;
|
||||
|
||||
virtual void throw_abstract_impl_not_imported_error(
|
||||
std::string opname,
|
||||
|
@ -69,6 +69,137 @@ class TestPythonRegistration(TestCase):
|
||||
if hasattr(torch.ops, self.test_ns):
|
||||
del torch.ops._test_python_registration
|
||||
|
||||
def test_fallback(self) -> None:
|
||||
test_key = "TESTING_ONLY_GenericMode"
|
||||
test_keyset = torch._C.DispatchKeySet(test_key)
|
||||
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
|
||||
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
|
||||
|
||||
with _scoped_library("_", "IMPL") as my_lib:
|
||||
expected_op = None
|
||||
expected_args = None
|
||||
expected_kwargs = None
|
||||
# Use this out shape to make sure the result from our fallback
|
||||
# is what is returned to the user
|
||||
out_shape = None
|
||||
|
||||
def my_fallback(op, *args, **kwargs):
|
||||
# Disable our handler during checks and generating the output
|
||||
with torch._C._ForceDispatchKeyGuard(
|
||||
include_to_set, exclude_to_set | test_keyset
|
||||
):
|
||||
self.assertIs(op, expected_op)
|
||||
self.assertEqual(args, expected_args)
|
||||
self.assertEqual(kwargs, expected_kwargs)
|
||||
# Return something specific
|
||||
return torch.empty(out_shape)
|
||||
|
||||
my_lib.fallback(my_fallback, test_key)
|
||||
|
||||
a, b = torch.rand(2), torch.rand(2)
|
||||
|
||||
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
|
||||
# Check a factory function
|
||||
expected_op = torch.ops.aten.empty.memory_format
|
||||
expected_args = ((2, 2),)
|
||||
# Extra kwargs to bypass issues with default args in factory functions
|
||||
expected_kwargs = {
|
||||
"dtype": torch.float64,
|
||||
"pin_memory": False,
|
||||
"device": torch.device("cpu"),
|
||||
}
|
||||
out_shape = (3,)
|
||||
out = torch.empty(*expected_args, **expected_kwargs)
|
||||
self.assertEqual(out.size(), out_shape)
|
||||
|
||||
# Check a regular function
|
||||
expected_op = torch.ops.aten.add.Tensor
|
||||
expected_args = (a, b)
|
||||
expected_kwargs = {}
|
||||
out_shape = (4,)
|
||||
out = a + b
|
||||
self.assertEqual(out.size(), out_shape)
|
||||
|
||||
def test_fallback_keyset(self) -> None:
|
||||
test_key_first = "TESTING_ONLY_GenericMode"
|
||||
test_key_second = "TESTING_ONLY_GenericWrapper"
|
||||
test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet(
|
||||
test_key_second
|
||||
)
|
||||
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
|
||||
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
|
||||
|
||||
with _scoped_library("_", "IMPL") as my_lib:
|
||||
first_called = False
|
||||
second_called = False
|
||||
|
||||
def first_fallback(keyset, op, *args, **kwargs):
|
||||
nonlocal first_called
|
||||
if second_called:
|
||||
# Recursive call
|
||||
first_called = True
|
||||
with torch._C._ForceDispatchKeyGuard(
|
||||
include_to_set, exclude_to_set | test_keyset
|
||||
):
|
||||
return op(*args, **kwargs)
|
||||
else:
|
||||
# Redispatch down
|
||||
keyset = keyset.remove(test_key_first)
|
||||
return op.redispatch(keyset, *args, **kwargs)
|
||||
|
||||
def second_fallback(op, *args, **kwargs):
|
||||
nonlocal second_called
|
||||
# Set to avoid infinite recursion
|
||||
second_called = True
|
||||
# New dispatcher call should hit the first callback again
|
||||
self.assertFalse(first_called)
|
||||
a, b = args
|
||||
# Make a substraction here instead of add !
|
||||
c = a - b
|
||||
self.assertTrue(first_called)
|
||||
return c
|
||||
|
||||
my_lib.fallback(first_fallback, test_key_first, with_keyset=True)
|
||||
my_lib.fallback(second_fallback, test_key_second)
|
||||
|
||||
a, b = torch.rand(2), torch.rand(2)
|
||||
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
|
||||
c = a + b
|
||||
|
||||
self.assertEqual(c, a - b)
|
||||
self.assertTrue(first_called)
|
||||
self.assertTrue(second_called)
|
||||
|
||||
def test_fallback_fallthrough(self) -> None:
|
||||
test_key_first = "TESTING_ONLY_GenericMode"
|
||||
test_key_second = "TESTING_ONLY_GenericWrapper"
|
||||
test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet(
|
||||
test_key_second
|
||||
)
|
||||
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
|
||||
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
|
||||
|
||||
with _scoped_library("_", "IMPL") as my_lib:
|
||||
is_called = False
|
||||
|
||||
def my_fallback(op, *args, **kwargs):
|
||||
nonlocal is_called
|
||||
is_called = True
|
||||
with torch._C._ForceDispatchKeyGuard(
|
||||
include_to_set, exclude_to_set | test_keyset
|
||||
):
|
||||
return op(*args, **kwargs)
|
||||
|
||||
my_lib.fallback(torch.library.fallthrough_kernel, test_key_first)
|
||||
my_lib.fallback(my_fallback, test_key_second)
|
||||
|
||||
a, b = torch.rand(2), torch.rand(2)
|
||||
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
|
||||
c = a + b
|
||||
|
||||
self.assertEqual(c, a + b)
|
||||
self.assertTrue(is_called)
|
||||
|
||||
def test_override_aten_ops_with_multiple_libraries(self) -> None:
|
||||
x = torch.tensor([1, 2])
|
||||
with _scoped_library("aten", "IMPL") as my_lib2:
|
||||
|
@ -12,6 +12,8 @@ using namespace torch;
|
||||
using namespace at;
|
||||
using namespace c10;
|
||||
|
||||
namespace torch::detail {
|
||||
|
||||
namespace {
|
||||
|
||||
// NB: This is a macro and not a template function (like it was before)
|
||||
@ -62,9 +64,10 @@ struct ConcretePyInterpreterVTable final
|
||||
c10::DispatchKey key,
|
||||
c10::DispatchKeySet keyset,
|
||||
torch::jit::Stack* stack,
|
||||
bool with_keyset) const override {
|
||||
bool with_keyset,
|
||||
bool with_op) const override {
|
||||
torch::impl::dispatch::python_op_registration_trampoline_impl(
|
||||
op, key, keyset, stack, with_keyset);
|
||||
op, key, keyset, stack, with_keyset, with_op);
|
||||
}
|
||||
void throw_abstract_impl_not_imported_error(
|
||||
std::string opname,
|
||||
@ -272,30 +275,6 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
|
||||
Py_DECREF(pyobj);
|
||||
};
|
||||
|
||||
py::handle getTorchApiFunction(const c10::OperatorHandle& op) {
|
||||
return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* {
|
||||
// Parse the name into namespace and name (no overload_name)
|
||||
// TODO: put this into the library
|
||||
const auto& schema = op.schema();
|
||||
const auto& qualified_name = op.operator_name().name;
|
||||
const auto& overload_name = schema.overload_name();
|
||||
auto pos = qualified_name.find("::");
|
||||
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
|
||||
// Make me some null terminated strings
|
||||
std::string ns_str = qualified_name.substr(0, pos);
|
||||
const char* ns = ns_str.c_str();
|
||||
const char* func_name = qualified_name.c_str() + pos + strlen("::");
|
||||
|
||||
py::handle torch_api_function =
|
||||
py::module::import("torch").attr("ops").attr(ns).attr(func_name);
|
||||
if (overload_name.empty()) {
|
||||
return torch_api_function.attr("default").ptr();
|
||||
} else {
|
||||
return torch_api_function.attr(overload_name.c_str()).ptr();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
bool isPythonTensor(const at::Tensor& tensor) {
|
||||
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
|
||||
}
|
||||
@ -956,20 +935,46 @@ void ConcretePyInterpreterVTable::reset_backward_hooks(
|
||||
END_HANDLE_TH_ERRORS_PYBIND
|
||||
}
|
||||
|
||||
PyInterpreterHolder self_interpreter;
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
c10::impl::PyInterpreter* getPyInterpreter() {
|
||||
return self_interpreter.get();
|
||||
}
|
||||
|
||||
bool isMainPyInterpreter() {
|
||||
return self_interpreter.is_main_interpreter();
|
||||
}
|
||||
|
||||
std::string ConcretePyInterpreterVTable::name() const {
|
||||
std::stringstream ss;
|
||||
ss << getPyInterpreter();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
PyInterpreterHolder self_interpreter;
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
py::handle getTorchApiFunction(const c10::OperatorHandle& op) {
|
||||
return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* {
|
||||
// Parse the name into namespace and name (no overload_name)
|
||||
// TODO: put this into the library
|
||||
const auto& schema = op.schema();
|
||||
const auto& qualified_name = op.operator_name().name;
|
||||
const auto& overload_name = schema.overload_name();
|
||||
auto pos = qualified_name.find("::");
|
||||
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
|
||||
// Make me some null terminated strings
|
||||
std::string ns_str = qualified_name.substr(0, pos);
|
||||
const char* ns = ns_str.c_str();
|
||||
const char* func_name = qualified_name.c_str() + pos + strlen("::");
|
||||
|
||||
py::handle torch_api_function =
|
||||
py::module::import("torch").attr("ops").attr(ns).attr(func_name);
|
||||
if (overload_name.empty()) {
|
||||
return torch_api_function.attr("default").ptr();
|
||||
} else {
|
||||
return torch_api_function.attr(overload_name.c_str()).ptr();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace torch::detail
|
||||
|
||||
c10::impl::PyInterpreter* getPyInterpreter() {
|
||||
return torch::detail::self_interpreter.get();
|
||||
}
|
||||
|
||||
bool isMainPyInterpreter() {
|
||||
return torch::detail::self_interpreter.is_main_interpreter();
|
||||
}
|
||||
|
@ -2,6 +2,12 @@
|
||||
|
||||
#include <c10/core/impl/PyInterpreter.h>
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch::detail {
|
||||
TORCH_PYTHON_API py::handle getTorchApiFunction(const c10::OperatorHandle& op);
|
||||
}
|
||||
|
||||
// TODO: Move these to a proper namespace
|
||||
TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter();
|
||||
TORCH_PYTHON_API bool isMainPyInterpreter();
|
||||
|
@ -108,15 +108,19 @@ class PythonKernelHolder : public c10::OperatorKernel {
|
||||
c10::DispatchKey dispatch_key_;
|
||||
// If "with_keyset", then we expect a keyset as the first arg.
|
||||
bool with_keyset_;
|
||||
// If "with_op", then we expect the op as first arg (or second if keyset)
|
||||
bool with_op_;
|
||||
|
||||
public:
|
||||
PythonKernelHolder(
|
||||
py::object func,
|
||||
c10::DispatchKey dispatch_key,
|
||||
bool with_keyset = false)
|
||||
bool with_keyset = false,
|
||||
bool with_op = false)
|
||||
: func_(func.release().ptr(), getPyInterpreter()),
|
||||
dispatch_key_(dispatch_key),
|
||||
with_keyset_(with_keyset) {}
|
||||
with_keyset_(with_keyset),
|
||||
with_op_(with_op) {}
|
||||
|
||||
void operator()(
|
||||
const c10::OperatorHandle& op,
|
||||
@ -132,7 +136,7 @@ class PythonKernelHolder : public c10::OperatorKernel {
|
||||
c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
|
||||
cur_torch_dispatch_mode_state->pyinterpreter()
|
||||
->python_op_registration_trampoline(
|
||||
op, dispatch_key_, keyset, stack, with_keyset_);
|
||||
op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -150,7 +154,7 @@ class PythonKernelHolder : public c10::OperatorKernel {
|
||||
at::DispatchKey::Python)) {
|
||||
(*interpreter)
|
||||
->python_op_registration_trampoline(
|
||||
op, dispatch_key_, keyset, stack, with_keyset_);
|
||||
op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
|
||||
return;
|
||||
}
|
||||
} else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
|
||||
@ -166,7 +170,7 @@ class PythonKernelHolder : public c10::OperatorKernel {
|
||||
nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) {
|
||||
(*interpreter)
|
||||
->python_op_registration_trampoline(
|
||||
op, dispatch_key_, keyset, stack, with_keyset_);
|
||||
op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
|
||||
return;
|
||||
}
|
||||
}
|
||||
@ -189,9 +193,18 @@ class PythonKernelHolder : public c10::OperatorKernel {
|
||||
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
|
||||
auto func =
|
||||
py::reinterpret_borrow<py::object>(func_.ptr(getPyInterpreter()));
|
||||
auto obj = with_keyset_
|
||||
? func(keyset, *args_kwargs.first, **args_kwargs.second)
|
||||
: func(*args_kwargs.first, **args_kwargs.second);
|
||||
auto obj = with_op_ ? with_keyset_
|
||||
? func(
|
||||
keyset,
|
||||
torch::detail::getTorchApiFunction(op),
|
||||
*args_kwargs.first,
|
||||
**args_kwargs.second)
|
||||
: func(
|
||||
torch::detail::getTorchApiFunction(op),
|
||||
*args_kwargs.first,
|
||||
**args_kwargs.second)
|
||||
: with_keyset_ ? func(keyset, *args_kwargs.first, **args_kwargs.second)
|
||||
: func(*args_kwargs.first, **args_kwargs.second);
|
||||
if (!obj) {
|
||||
throw python_error();
|
||||
}
|
||||
@ -461,7 +474,33 @@ void initDispatchBindings(PyObject* module) {
|
||||
return self;
|
||||
},
|
||||
"",
|
||||
py::arg("dispatch") = "");
|
||||
py::arg("dispatch") = "")
|
||||
.def(
|
||||
"fallback",
|
||||
[](const py::object& self,
|
||||
c10::DispatchKey dispatch,
|
||||
const py::object& func,
|
||||
bool with_keyset) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto& lib = self.cast<torch::Library&>();
|
||||
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
||||
if (func.is(py::module::import("torch.library")
|
||||
.attr("fallthrough_kernel"))) {
|
||||
lib.fallback(
|
||||
torch::dispatch(dispatch, CppFunction::makeFallthrough()));
|
||||
} else {
|
||||
lib.fallback(torch::dispatch(
|
||||
dispatch,
|
||||
CppFunction::makeFromBoxedFunctor(
|
||||
std::make_unique<PythonKernelHolder>(
|
||||
func, dispatch, with_keyset, /*with_op*/ true))));
|
||||
}
|
||||
END_HANDLE_TH_ERRORS_PYBIND
|
||||
},
|
||||
"",
|
||||
py::arg("dispatch"),
|
||||
py::arg("func"),
|
||||
py::arg("with_keyset") = false);
|
||||
|
||||
m.def(
|
||||
"_dispatch_library",
|
||||
@ -954,7 +993,8 @@ void python_op_registration_trampoline_impl(
|
||||
c10::DispatchKey key,
|
||||
c10::DispatchKeySet keyset,
|
||||
torch::jit::Stack* stack,
|
||||
bool with_keyset) {
|
||||
bool with_keyset,
|
||||
bool with_op) {
|
||||
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
|
||||
py::gil_scoped_acquire g;
|
||||
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
|
||||
@ -963,9 +1003,17 @@ void python_op_registration_trampoline_impl(
|
||||
auto* pyobj = func->ptr(getPyInterpreter());
|
||||
TORCH_INTERNAL_ASSERT(pyobj != nullptr);
|
||||
auto callable = py::reinterpret_borrow<py::object>(pyobj);
|
||||
auto obj = with_keyset
|
||||
? callable(keyset, *args_kwargs.first, **args_kwargs.second)
|
||||
: callable(*args_kwargs.first, **args_kwargs.second);
|
||||
auto obj = with_op ? with_keyset ? callable(
|
||||
keyset,
|
||||
torch::detail::getTorchApiFunction(op),
|
||||
*args_kwargs.first,
|
||||
**args_kwargs.second)
|
||||
: callable(
|
||||
torch::detail::getTorchApiFunction(op),
|
||||
*args_kwargs.first,
|
||||
**args_kwargs.second)
|
||||
: with_keyset ? callable(keyset, *args_kwargs.first, **args_kwargs.second)
|
||||
: callable(*args_kwargs.first, **args_kwargs.second);
|
||||
if (!obj) {
|
||||
throw python_error();
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ void python_op_registration_trampoline_impl(
|
||||
c10::DispatchKey key,
|
||||
c10::DispatchKeySet keyset,
|
||||
torch::jit::Stack* stack,
|
||||
bool with_keyset);
|
||||
bool with_keyset,
|
||||
bool with_op);
|
||||
|
||||
} // namespace torch::impl::dispatch
|
||||
|
@ -278,6 +278,8 @@ class Library:
|
||||
to register a fallthrough.
|
||||
dispatch_key: dispatch key that the input function should be registered for. By default, it uses
|
||||
the dispatch key that the library was created with.
|
||||
with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument
|
||||
to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls.
|
||||
|
||||
Example::
|
||||
>>> my_lib = Library("aten", "IMPL")
|
||||
@ -345,6 +347,43 @@ class Library:
|
||||
_impls.add(key)
|
||||
self._op_impls.add(key)
|
||||
|
||||
def fallback(self, fn, dispatch_key="", *, with_keyset=False):
|
||||
r"""Registers the function implementation as the fallback for the given key.
|
||||
|
||||
This function only works for a library with global namespace ("_").
|
||||
|
||||
Args:
|
||||
fn: function used as fallback for the given dispatch key or :func:`~fallthrough_kernel`
|
||||
to register a fallthrough.
|
||||
dispatch_key: dispatch key that the input function should be registered for. By default, it uses
|
||||
the dispatch key that the library was created with.
|
||||
with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument
|
||||
to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls.
|
||||
|
||||
Example::
|
||||
>>> my_lib = Library("_", "IMPL")
|
||||
>>> def fallback_kernel(op, *args, **kwargs):
|
||||
>>> # Handle all autocast ops generically
|
||||
>>> # ...
|
||||
>>> my_lib.fallback(fallback_kernel, "Autocast")
|
||||
"""
|
||||
if dispatch_key == "":
|
||||
dispatch_key = self.dispatch_key
|
||||
|
||||
if self.ns != "_":
|
||||
raise RuntimeError(
|
||||
f"""Fallback can only be registered using libary fragment on the global namespace "_" but it is {self.ns}"""
|
||||
)
|
||||
|
||||
assert dispatch_key != ""
|
||||
assert self.m is not None
|
||||
|
||||
self.m.fallback(
|
||||
dispatch_key,
|
||||
fn,
|
||||
with_keyset,
|
||||
)
|
||||
|
||||
def _destroy(self):
|
||||
if self.m is not None:
|
||||
self.m.reset()
|
||||
|
Reference in New Issue
Block a user