Add fallback() to torch.library (#131707)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131707
Approved by: https://github.com/zou3519
This commit is contained in:
albanD
2024-07-26 15:40:17 -04:00
committed by PyTorch MergeBot
parent 8e5a367311
commit 466ea8ce54
10 changed files with 286 additions and 54 deletions

View File

@ -36,7 +36,8 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
c10::DispatchKey, c10::DispatchKey,
c10::DispatchKeySet keyset, c10::DispatchKeySet keyset,
torch::jit::Stack* stack, torch::jit::Stack* stack,
bool with_keyset) const override { bool with_keyset,
bool with_op) const override {
PANIC(python_op_registration_trampoline); PANIC(python_op_registration_trampoline);
} }

View File

@ -150,7 +150,8 @@ struct C10_API PyInterpreterVTable {
c10::DispatchKey, c10::DispatchKey,
c10::DispatchKeySet keyset, c10::DispatchKeySet keyset,
torch::jit::Stack* stack, torch::jit::Stack* stack,
bool with_keyset) const = 0; bool with_keyset,
bool with_op) const = 0;
virtual void throw_abstract_impl_not_imported_error( virtual void throw_abstract_impl_not_imported_error(
std::string opname, std::string opname,

View File

@ -69,6 +69,137 @@ class TestPythonRegistration(TestCase):
if hasattr(torch.ops, self.test_ns): if hasattr(torch.ops, self.test_ns):
del torch.ops._test_python_registration del torch.ops._test_python_registration
def test_fallback(self) -> None:
test_key = "TESTING_ONLY_GenericMode"
test_keyset = torch._C.DispatchKeySet(test_key)
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
with _scoped_library("_", "IMPL") as my_lib:
expected_op = None
expected_args = None
expected_kwargs = None
# Use this out shape to make sure the result from our fallback
# is what is returned to the user
out_shape = None
def my_fallback(op, *args, **kwargs):
# Disable our handler during checks and generating the output
with torch._C._ForceDispatchKeyGuard(
include_to_set, exclude_to_set | test_keyset
):
self.assertIs(op, expected_op)
self.assertEqual(args, expected_args)
self.assertEqual(kwargs, expected_kwargs)
# Return something specific
return torch.empty(out_shape)
my_lib.fallback(my_fallback, test_key)
a, b = torch.rand(2), torch.rand(2)
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
# Check a factory function
expected_op = torch.ops.aten.empty.memory_format
expected_args = ((2, 2),)
# Extra kwargs to bypass issues with default args in factory functions
expected_kwargs = {
"dtype": torch.float64,
"pin_memory": False,
"device": torch.device("cpu"),
}
out_shape = (3,)
out = torch.empty(*expected_args, **expected_kwargs)
self.assertEqual(out.size(), out_shape)
# Check a regular function
expected_op = torch.ops.aten.add.Tensor
expected_args = (a, b)
expected_kwargs = {}
out_shape = (4,)
out = a + b
self.assertEqual(out.size(), out_shape)
def test_fallback_keyset(self) -> None:
test_key_first = "TESTING_ONLY_GenericMode"
test_key_second = "TESTING_ONLY_GenericWrapper"
test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet(
test_key_second
)
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
with _scoped_library("_", "IMPL") as my_lib:
first_called = False
second_called = False
def first_fallback(keyset, op, *args, **kwargs):
nonlocal first_called
if second_called:
# Recursive call
first_called = True
with torch._C._ForceDispatchKeyGuard(
include_to_set, exclude_to_set | test_keyset
):
return op(*args, **kwargs)
else:
# Redispatch down
keyset = keyset.remove(test_key_first)
return op.redispatch(keyset, *args, **kwargs)
def second_fallback(op, *args, **kwargs):
nonlocal second_called
# Set to avoid infinite recursion
second_called = True
# New dispatcher call should hit the first callback again
self.assertFalse(first_called)
a, b = args
# Make a substraction here instead of add !
c = a - b
self.assertTrue(first_called)
return c
my_lib.fallback(first_fallback, test_key_first, with_keyset=True)
my_lib.fallback(second_fallback, test_key_second)
a, b = torch.rand(2), torch.rand(2)
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
c = a + b
self.assertEqual(c, a - b)
self.assertTrue(first_called)
self.assertTrue(second_called)
def test_fallback_fallthrough(self) -> None:
test_key_first = "TESTING_ONLY_GenericMode"
test_key_second = "TESTING_ONLY_GenericWrapper"
test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet(
test_key_second
)
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
with _scoped_library("_", "IMPL") as my_lib:
is_called = False
def my_fallback(op, *args, **kwargs):
nonlocal is_called
is_called = True
with torch._C._ForceDispatchKeyGuard(
include_to_set, exclude_to_set | test_keyset
):
return op(*args, **kwargs)
my_lib.fallback(torch.library.fallthrough_kernel, test_key_first)
my_lib.fallback(my_fallback, test_key_second)
a, b = torch.rand(2), torch.rand(2)
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
c = a + b
self.assertEqual(c, a + b)
self.assertTrue(is_called)
def test_override_aten_ops_with_multiple_libraries(self) -> None: def test_override_aten_ops_with_multiple_libraries(self) -> None:
x = torch.tensor([1, 2]) x = torch.tensor([1, 2])
with _scoped_library("aten", "IMPL") as my_lib2: with _scoped_library("aten", "IMPL") as my_lib2:

View File

@ -12,6 +12,8 @@ using namespace torch;
using namespace at; using namespace at;
using namespace c10; using namespace c10;
namespace torch::detail {
namespace { namespace {
// NB: This is a macro and not a template function (like it was before) // NB: This is a macro and not a template function (like it was before)
@ -62,9 +64,10 @@ struct ConcretePyInterpreterVTable final
c10::DispatchKey key, c10::DispatchKey key,
c10::DispatchKeySet keyset, c10::DispatchKeySet keyset,
torch::jit::Stack* stack, torch::jit::Stack* stack,
bool with_keyset) const override { bool with_keyset,
bool with_op) const override {
torch::impl::dispatch::python_op_registration_trampoline_impl( torch::impl::dispatch::python_op_registration_trampoline_impl(
op, key, keyset, stack, with_keyset); op, key, keyset, stack, with_keyset, with_op);
} }
void throw_abstract_impl_not_imported_error( void throw_abstract_impl_not_imported_error(
std::string opname, std::string opname,
@ -272,30 +275,6 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
Py_DECREF(pyobj); Py_DECREF(pyobj);
}; };
py::handle getTorchApiFunction(const c10::OperatorHandle& op) {
return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* {
// Parse the name into namespace and name (no overload_name)
// TODO: put this into the library
const auto& schema = op.schema();
const auto& qualified_name = op.operator_name().name;
const auto& overload_name = schema.overload_name();
auto pos = qualified_name.find("::");
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
// Make me some null terminated strings
std::string ns_str = qualified_name.substr(0, pos);
const char* ns = ns_str.c_str();
const char* func_name = qualified_name.c_str() + pos + strlen("::");
py::handle torch_api_function =
py::module::import("torch").attr("ops").attr(ns).attr(func_name);
if (overload_name.empty()) {
return torch_api_function.attr("default").ptr();
} else {
return torch_api_function.attr(overload_name.c_str()).ptr();
}
});
}
bool isPythonTensor(const at::Tensor& tensor) { bool isPythonTensor(const at::Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python); return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
} }
@ -956,20 +935,46 @@ void ConcretePyInterpreterVTable::reset_backward_hooks(
END_HANDLE_TH_ERRORS_PYBIND END_HANDLE_TH_ERRORS_PYBIND
} }
PyInterpreterHolder self_interpreter;
} // anonymous namespace
c10::impl::PyInterpreter* getPyInterpreter() {
return self_interpreter.get();
}
bool isMainPyInterpreter() {
return self_interpreter.is_main_interpreter();
}
std::string ConcretePyInterpreterVTable::name() const { std::string ConcretePyInterpreterVTable::name() const {
std::stringstream ss; std::stringstream ss;
ss << getPyInterpreter(); ss << getPyInterpreter();
return ss.str(); return ss.str();
} }
PyInterpreterHolder self_interpreter;
} // anonymous namespace
py::handle getTorchApiFunction(const c10::OperatorHandle& op) {
return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* {
// Parse the name into namespace and name (no overload_name)
// TODO: put this into the library
const auto& schema = op.schema();
const auto& qualified_name = op.operator_name().name;
const auto& overload_name = schema.overload_name();
auto pos = qualified_name.find("::");
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
// Make me some null terminated strings
std::string ns_str = qualified_name.substr(0, pos);
const char* ns = ns_str.c_str();
const char* func_name = qualified_name.c_str() + pos + strlen("::");
py::handle torch_api_function =
py::module::import("torch").attr("ops").attr(ns).attr(func_name);
if (overload_name.empty()) {
return torch_api_function.attr("default").ptr();
} else {
return torch_api_function.attr(overload_name.c_str()).ptr();
}
});
}
} // namespace torch::detail
c10::impl::PyInterpreter* getPyInterpreter() {
return torch::detail::self_interpreter.get();
}
bool isMainPyInterpreter() {
return torch::detail::self_interpreter.is_main_interpreter();
}

View File

@ -2,6 +2,12 @@
#include <c10/core/impl/PyInterpreter.h> #include <c10/core/impl/PyInterpreter.h>
#include <torch/csrc/Export.h> #include <torch/csrc/Export.h>
#include <torch/csrc/utils/pybind.h>
namespace torch::detail {
TORCH_PYTHON_API py::handle getTorchApiFunction(const c10::OperatorHandle& op);
}
// TODO: Move these to a proper namespace
TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter(); TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter();
TORCH_PYTHON_API bool isMainPyInterpreter(); TORCH_PYTHON_API bool isMainPyInterpreter();

View File

@ -108,15 +108,19 @@ class PythonKernelHolder : public c10::OperatorKernel {
c10::DispatchKey dispatch_key_; c10::DispatchKey dispatch_key_;
// If "with_keyset", then we expect a keyset as the first arg. // If "with_keyset", then we expect a keyset as the first arg.
bool with_keyset_; bool with_keyset_;
// If "with_op", then we expect the op as first arg (or second if keyset)
bool with_op_;
public: public:
PythonKernelHolder( PythonKernelHolder(
py::object func, py::object func,
c10::DispatchKey dispatch_key, c10::DispatchKey dispatch_key,
bool with_keyset = false) bool with_keyset = false,
bool with_op = false)
: func_(func.release().ptr(), getPyInterpreter()), : func_(func.release().ptr(), getPyInterpreter()),
dispatch_key_(dispatch_key), dispatch_key_(dispatch_key),
with_keyset_(with_keyset) {} with_keyset_(with_keyset),
with_op_(with_op) {}
void operator()( void operator()(
const c10::OperatorHandle& op, const c10::OperatorHandle& op,
@ -132,7 +136,7 @@ class PythonKernelHolder : public c10::OperatorKernel {
c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1); c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
cur_torch_dispatch_mode_state->pyinterpreter() cur_torch_dispatch_mode_state->pyinterpreter()
->python_op_registration_trampoline( ->python_op_registration_trampoline(
op, dispatch_key_, keyset, stack, with_keyset_); op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
return; return;
} }
@ -150,7 +154,7 @@ class PythonKernelHolder : public c10::OperatorKernel {
at::DispatchKey::Python)) { at::DispatchKey::Python)) {
(*interpreter) (*interpreter)
->python_op_registration_trampoline( ->python_op_registration_trampoline(
op, dispatch_key_, keyset, stack, with_keyset_); op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
return; return;
} }
} else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) { } else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
@ -166,7 +170,7 @@ class PythonKernelHolder : public c10::OperatorKernel {
nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) { nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) {
(*interpreter) (*interpreter)
->python_op_registration_trampoline( ->python_op_registration_trampoline(
op, dispatch_key_, keyset, stack, with_keyset_); op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
return; return;
} }
} }
@ -189,9 +193,18 @@ class PythonKernelHolder : public c10::OperatorKernel {
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
auto func = auto func =
py::reinterpret_borrow<py::object>(func_.ptr(getPyInterpreter())); py::reinterpret_borrow<py::object>(func_.ptr(getPyInterpreter()));
auto obj = with_keyset_ auto obj = with_op_ ? with_keyset_
? func(keyset, *args_kwargs.first, **args_kwargs.second) ? func(
: func(*args_kwargs.first, **args_kwargs.second); keyset,
torch::detail::getTorchApiFunction(op),
*args_kwargs.first,
**args_kwargs.second)
: func(
torch::detail::getTorchApiFunction(op),
*args_kwargs.first,
**args_kwargs.second)
: with_keyset_ ? func(keyset, *args_kwargs.first, **args_kwargs.second)
: func(*args_kwargs.first, **args_kwargs.second);
if (!obj) { if (!obj) {
throw python_error(); throw python_error();
} }
@ -461,7 +474,33 @@ void initDispatchBindings(PyObject* module) {
return self; return self;
}, },
"", "",
py::arg("dispatch") = ""); py::arg("dispatch") = "")
.def(
"fallback",
[](const py::object& self,
c10::DispatchKey dispatch,
const py::object& func,
bool with_keyset) {
HANDLE_TH_ERRORS
auto& lib = self.cast<torch::Library&>();
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
if (func.is(py::module::import("torch.library")
.attr("fallthrough_kernel"))) {
lib.fallback(
torch::dispatch(dispatch, CppFunction::makeFallthrough()));
} else {
lib.fallback(torch::dispatch(
dispatch,
CppFunction::makeFromBoxedFunctor(
std::make_unique<PythonKernelHolder>(
func, dispatch, with_keyset, /*with_op*/ true))));
}
END_HANDLE_TH_ERRORS_PYBIND
},
"",
py::arg("dispatch"),
py::arg("func"),
py::arg("with_keyset") = false);
m.def( m.def(
"_dispatch_library", "_dispatch_library",
@ -954,7 +993,8 @@ void python_op_registration_trampoline_impl(
c10::DispatchKey key, c10::DispatchKey key,
c10::DispatchKeySet keyset, c10::DispatchKeySet keyset,
torch::jit::Stack* stack, torch::jit::Stack* stack,
bool with_keyset) { bool with_keyset,
bool with_op) {
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size()); auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
py::gil_scoped_acquire g; py::gil_scoped_acquire g;
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
@ -963,9 +1003,17 @@ void python_op_registration_trampoline_impl(
auto* pyobj = func->ptr(getPyInterpreter()); auto* pyobj = func->ptr(getPyInterpreter());
TORCH_INTERNAL_ASSERT(pyobj != nullptr); TORCH_INTERNAL_ASSERT(pyobj != nullptr);
auto callable = py::reinterpret_borrow<py::object>(pyobj); auto callable = py::reinterpret_borrow<py::object>(pyobj);
auto obj = with_keyset auto obj = with_op ? with_keyset ? callable(
? callable(keyset, *args_kwargs.first, **args_kwargs.second) keyset,
: callable(*args_kwargs.first, **args_kwargs.second); torch::detail::getTorchApiFunction(op),
*args_kwargs.first,
**args_kwargs.second)
: callable(
torch::detail::getTorchApiFunction(op),
*args_kwargs.first,
**args_kwargs.second)
: with_keyset ? callable(keyset, *args_kwargs.first, **args_kwargs.second)
: callable(*args_kwargs.first, **args_kwargs.second);
if (!obj) { if (!obj) {
throw python_error(); throw python_error();
} }

View File

@ -10,6 +10,7 @@ void python_op_registration_trampoline_impl(
c10::DispatchKey key, c10::DispatchKey key,
c10::DispatchKeySet keyset, c10::DispatchKeySet keyset,
torch::jit::Stack* stack, torch::jit::Stack* stack,
bool with_keyset); bool with_keyset,
bool with_op);
} // namespace torch::impl::dispatch } // namespace torch::impl::dispatch

View File

@ -278,6 +278,8 @@ class Library:
to register a fallthrough. to register a fallthrough.
dispatch_key: dispatch key that the input function should be registered for. By default, it uses dispatch_key: dispatch key that the input function should be registered for. By default, it uses
the dispatch key that the library was created with. the dispatch key that the library was created with.
with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument
to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls.
Example:: Example::
>>> my_lib = Library("aten", "IMPL") >>> my_lib = Library("aten", "IMPL")
@ -345,6 +347,43 @@ class Library:
_impls.add(key) _impls.add(key)
self._op_impls.add(key) self._op_impls.add(key)
def fallback(self, fn, dispatch_key="", *, with_keyset=False):
r"""Registers the function implementation as the fallback for the given key.
This function only works for a library with global namespace ("_").
Args:
fn: function used as fallback for the given dispatch key or :func:`~fallthrough_kernel`
to register a fallthrough.
dispatch_key: dispatch key that the input function should be registered for. By default, it uses
the dispatch key that the library was created with.
with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument
to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls.
Example::
>>> my_lib = Library("_", "IMPL")
>>> def fallback_kernel(op, *args, **kwargs):
>>> # Handle all autocast ops generically
>>> # ...
>>> my_lib.fallback(fallback_kernel, "Autocast")
"""
if dispatch_key == "":
dispatch_key = self.dispatch_key
if self.ns != "_":
raise RuntimeError(
f"""Fallback can only be registered using libary fragment on the global namespace "_" but it is {self.ns}"""
)
assert dispatch_key != ""
assert self.m is not None
self.m.fallback(
dispatch_key,
fn,
with_keyset,
)
def _destroy(self): def _destroy(self):
if self.m is not None: if self.m is not None:
self.m.reset() self.m.reset()