Add fallback() to torch.library (#131707)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131707
Approved by: https://github.com/zou3519
This commit is contained in:
albanD
2024-07-26 15:40:17 -04:00
committed by PyTorch MergeBot
parent 8e5a367311
commit 466ea8ce54
10 changed files with 286 additions and 54 deletions

View File

@ -36,7 +36,8 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
c10::DispatchKey,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack,
bool with_keyset) const override {
bool with_keyset,
bool with_op) const override {
PANIC(python_op_registration_trampoline);
}

View File

@ -150,7 +150,8 @@ struct C10_API PyInterpreterVTable {
c10::DispatchKey,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack,
bool with_keyset) const = 0;
bool with_keyset,
bool with_op) const = 0;
virtual void throw_abstract_impl_not_imported_error(
std::string opname,

View File

@ -69,6 +69,137 @@ class TestPythonRegistration(TestCase):
if hasattr(torch.ops, self.test_ns):
del torch.ops._test_python_registration
def test_fallback(self) -> None:
test_key = "TESTING_ONLY_GenericMode"
test_keyset = torch._C.DispatchKeySet(test_key)
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
with _scoped_library("_", "IMPL") as my_lib:
expected_op = None
expected_args = None
expected_kwargs = None
# Use this out shape to make sure the result from our fallback
# is what is returned to the user
out_shape = None
def my_fallback(op, *args, **kwargs):
# Disable our handler during checks and generating the output
with torch._C._ForceDispatchKeyGuard(
include_to_set, exclude_to_set | test_keyset
):
self.assertIs(op, expected_op)
self.assertEqual(args, expected_args)
self.assertEqual(kwargs, expected_kwargs)
# Return something specific
return torch.empty(out_shape)
my_lib.fallback(my_fallback, test_key)
a, b = torch.rand(2), torch.rand(2)
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
# Check a factory function
expected_op = torch.ops.aten.empty.memory_format
expected_args = ((2, 2),)
# Extra kwargs to bypass issues with default args in factory functions
expected_kwargs = {
"dtype": torch.float64,
"pin_memory": False,
"device": torch.device("cpu"),
}
out_shape = (3,)
out = torch.empty(*expected_args, **expected_kwargs)
self.assertEqual(out.size(), out_shape)
# Check a regular function
expected_op = torch.ops.aten.add.Tensor
expected_args = (a, b)
expected_kwargs = {}
out_shape = (4,)
out = a + b
self.assertEqual(out.size(), out_shape)
def test_fallback_keyset(self) -> None:
test_key_first = "TESTING_ONLY_GenericMode"
test_key_second = "TESTING_ONLY_GenericWrapper"
test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet(
test_key_second
)
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
with _scoped_library("_", "IMPL") as my_lib:
first_called = False
second_called = False
def first_fallback(keyset, op, *args, **kwargs):
nonlocal first_called
if second_called:
# Recursive call
first_called = True
with torch._C._ForceDispatchKeyGuard(
include_to_set, exclude_to_set | test_keyset
):
return op(*args, **kwargs)
else:
# Redispatch down
keyset = keyset.remove(test_key_first)
return op.redispatch(keyset, *args, **kwargs)
def second_fallback(op, *args, **kwargs):
nonlocal second_called
# Set to avoid infinite recursion
second_called = True
# New dispatcher call should hit the first callback again
self.assertFalse(first_called)
a, b = args
# Make a substraction here instead of add !
c = a - b
self.assertTrue(first_called)
return c
my_lib.fallback(first_fallback, test_key_first, with_keyset=True)
my_lib.fallback(second_fallback, test_key_second)
a, b = torch.rand(2), torch.rand(2)
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
c = a + b
self.assertEqual(c, a - b)
self.assertTrue(first_called)
self.assertTrue(second_called)
def test_fallback_fallthrough(self) -> None:
test_key_first = "TESTING_ONLY_GenericMode"
test_key_second = "TESTING_ONLY_GenericWrapper"
test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet(
test_key_second
)
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
with _scoped_library("_", "IMPL") as my_lib:
is_called = False
def my_fallback(op, *args, **kwargs):
nonlocal is_called
is_called = True
with torch._C._ForceDispatchKeyGuard(
include_to_set, exclude_to_set | test_keyset
):
return op(*args, **kwargs)
my_lib.fallback(torch.library.fallthrough_kernel, test_key_first)
my_lib.fallback(my_fallback, test_key_second)
a, b = torch.rand(2), torch.rand(2)
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
c = a + b
self.assertEqual(c, a + b)
self.assertTrue(is_called)
def test_override_aten_ops_with_multiple_libraries(self) -> None:
x = torch.tensor([1, 2])
with _scoped_library("aten", "IMPL") as my_lib2:

View File

@ -12,6 +12,8 @@ using namespace torch;
using namespace at;
using namespace c10;
namespace torch::detail {
namespace {
// NB: This is a macro and not a template function (like it was before)
@ -62,9 +64,10 @@ struct ConcretePyInterpreterVTable final
c10::DispatchKey key,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack,
bool with_keyset) const override {
bool with_keyset,
bool with_op) const override {
torch::impl::dispatch::python_op_registration_trampoline_impl(
op, key, keyset, stack, with_keyset);
op, key, keyset, stack, with_keyset, with_op);
}
void throw_abstract_impl_not_imported_error(
std::string opname,
@ -272,30 +275,6 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
Py_DECREF(pyobj);
};
py::handle getTorchApiFunction(const c10::OperatorHandle& op) {
return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* {
// Parse the name into namespace and name (no overload_name)
// TODO: put this into the library
const auto& schema = op.schema();
const auto& qualified_name = op.operator_name().name;
const auto& overload_name = schema.overload_name();
auto pos = qualified_name.find("::");
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
// Make me some null terminated strings
std::string ns_str = qualified_name.substr(0, pos);
const char* ns = ns_str.c_str();
const char* func_name = qualified_name.c_str() + pos + strlen("::");
py::handle torch_api_function =
py::module::import("torch").attr("ops").attr(ns).attr(func_name);
if (overload_name.empty()) {
return torch_api_function.attr("default").ptr();
} else {
return torch_api_function.attr(overload_name.c_str()).ptr();
}
});
}
bool isPythonTensor(const at::Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
}
@ -956,20 +935,46 @@ void ConcretePyInterpreterVTable::reset_backward_hooks(
END_HANDLE_TH_ERRORS_PYBIND
}
PyInterpreterHolder self_interpreter;
} // anonymous namespace
c10::impl::PyInterpreter* getPyInterpreter() {
return self_interpreter.get();
}
bool isMainPyInterpreter() {
return self_interpreter.is_main_interpreter();
}
std::string ConcretePyInterpreterVTable::name() const {
std::stringstream ss;
ss << getPyInterpreter();
return ss.str();
}
PyInterpreterHolder self_interpreter;
} // anonymous namespace
py::handle getTorchApiFunction(const c10::OperatorHandle& op) {
return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* {
// Parse the name into namespace and name (no overload_name)
// TODO: put this into the library
const auto& schema = op.schema();
const auto& qualified_name = op.operator_name().name;
const auto& overload_name = schema.overload_name();
auto pos = qualified_name.find("::");
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
// Make me some null terminated strings
std::string ns_str = qualified_name.substr(0, pos);
const char* ns = ns_str.c_str();
const char* func_name = qualified_name.c_str() + pos + strlen("::");
py::handle torch_api_function =
py::module::import("torch").attr("ops").attr(ns).attr(func_name);
if (overload_name.empty()) {
return torch_api_function.attr("default").ptr();
} else {
return torch_api_function.attr(overload_name.c_str()).ptr();
}
});
}
} // namespace torch::detail
c10::impl::PyInterpreter* getPyInterpreter() {
return torch::detail::self_interpreter.get();
}
bool isMainPyInterpreter() {
return torch::detail::self_interpreter.is_main_interpreter();
}

View File

@ -2,6 +2,12 @@
#include <c10/core/impl/PyInterpreter.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/utils/pybind.h>
namespace torch::detail {
TORCH_PYTHON_API py::handle getTorchApiFunction(const c10::OperatorHandle& op);
}
// TODO: Move these to a proper namespace
TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter();
TORCH_PYTHON_API bool isMainPyInterpreter();

View File

@ -108,15 +108,19 @@ class PythonKernelHolder : public c10::OperatorKernel {
c10::DispatchKey dispatch_key_;
// If "with_keyset", then we expect a keyset as the first arg.
bool with_keyset_;
// If "with_op", then we expect the op as first arg (or second if keyset)
bool with_op_;
public:
PythonKernelHolder(
py::object func,
c10::DispatchKey dispatch_key,
bool with_keyset = false)
bool with_keyset = false,
bool with_op = false)
: func_(func.release().ptr(), getPyInterpreter()),
dispatch_key_(dispatch_key),
with_keyset_(with_keyset) {}
with_keyset_(with_keyset),
with_op_(with_op) {}
void operator()(
const c10::OperatorHandle& op,
@ -132,7 +136,7 @@ class PythonKernelHolder : public c10::OperatorKernel {
c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
cur_torch_dispatch_mode_state->pyinterpreter()
->python_op_registration_trampoline(
op, dispatch_key_, keyset, stack, with_keyset_);
op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
return;
}
@ -150,7 +154,7 @@ class PythonKernelHolder : public c10::OperatorKernel {
at::DispatchKey::Python)) {
(*interpreter)
->python_op_registration_trampoline(
op, dispatch_key_, keyset, stack, with_keyset_);
op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
return;
}
} else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
@ -166,7 +170,7 @@ class PythonKernelHolder : public c10::OperatorKernel {
nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) {
(*interpreter)
->python_op_registration_trampoline(
op, dispatch_key_, keyset, stack, with_keyset_);
op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
return;
}
}
@ -189,9 +193,18 @@ class PythonKernelHolder : public c10::OperatorKernel {
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
auto func =
py::reinterpret_borrow<py::object>(func_.ptr(getPyInterpreter()));
auto obj = with_keyset_
? func(keyset, *args_kwargs.first, **args_kwargs.second)
: func(*args_kwargs.first, **args_kwargs.second);
auto obj = with_op_ ? with_keyset_
? func(
keyset,
torch::detail::getTorchApiFunction(op),
*args_kwargs.first,
**args_kwargs.second)
: func(
torch::detail::getTorchApiFunction(op),
*args_kwargs.first,
**args_kwargs.second)
: with_keyset_ ? func(keyset, *args_kwargs.first, **args_kwargs.second)
: func(*args_kwargs.first, **args_kwargs.second);
if (!obj) {
throw python_error();
}
@ -461,7 +474,33 @@ void initDispatchBindings(PyObject* module) {
return self;
},
"",
py::arg("dispatch") = "");
py::arg("dispatch") = "")
.def(
"fallback",
[](const py::object& self,
c10::DispatchKey dispatch,
const py::object& func,
bool with_keyset) {
HANDLE_TH_ERRORS
auto& lib = self.cast<torch::Library&>();
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
if (func.is(py::module::import("torch.library")
.attr("fallthrough_kernel"))) {
lib.fallback(
torch::dispatch(dispatch, CppFunction::makeFallthrough()));
} else {
lib.fallback(torch::dispatch(
dispatch,
CppFunction::makeFromBoxedFunctor(
std::make_unique<PythonKernelHolder>(
func, dispatch, with_keyset, /*with_op*/ true))));
}
END_HANDLE_TH_ERRORS_PYBIND
},
"",
py::arg("dispatch"),
py::arg("func"),
py::arg("with_keyset") = false);
m.def(
"_dispatch_library",
@ -954,7 +993,8 @@ void python_op_registration_trampoline_impl(
c10::DispatchKey key,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack,
bool with_keyset) {
bool with_keyset,
bool with_op) {
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
py::gil_scoped_acquire g;
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
@ -963,9 +1003,17 @@ void python_op_registration_trampoline_impl(
auto* pyobj = func->ptr(getPyInterpreter());
TORCH_INTERNAL_ASSERT(pyobj != nullptr);
auto callable = py::reinterpret_borrow<py::object>(pyobj);
auto obj = with_keyset
? callable(keyset, *args_kwargs.first, **args_kwargs.second)
: callable(*args_kwargs.first, **args_kwargs.second);
auto obj = with_op ? with_keyset ? callable(
keyset,
torch::detail::getTorchApiFunction(op),
*args_kwargs.first,
**args_kwargs.second)
: callable(
torch::detail::getTorchApiFunction(op),
*args_kwargs.first,
**args_kwargs.second)
: with_keyset ? callable(keyset, *args_kwargs.first, **args_kwargs.second)
: callable(*args_kwargs.first, **args_kwargs.second);
if (!obj) {
throw python_error();
}

View File

@ -10,6 +10,7 @@ void python_op_registration_trampoline_impl(
c10::DispatchKey key,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack,
bool with_keyset);
bool with_keyset,
bool with_op);
} // namespace torch::impl::dispatch

View File

@ -278,6 +278,8 @@ class Library:
to register a fallthrough.
dispatch_key: dispatch key that the input function should be registered for. By default, it uses
the dispatch key that the library was created with.
with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument
to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls.
Example::
>>> my_lib = Library("aten", "IMPL")
@ -345,6 +347,43 @@ class Library:
_impls.add(key)
self._op_impls.add(key)
def fallback(self, fn, dispatch_key="", *, with_keyset=False):
r"""Registers the function implementation as the fallback for the given key.
This function only works for a library with global namespace ("_").
Args:
fn: function used as fallback for the given dispatch key or :func:`~fallthrough_kernel`
to register a fallthrough.
dispatch_key: dispatch key that the input function should be registered for. By default, it uses
the dispatch key that the library was created with.
with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument
to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls.
Example::
>>> my_lib = Library("_", "IMPL")
>>> def fallback_kernel(op, *args, **kwargs):
>>> # Handle all autocast ops generically
>>> # ...
>>> my_lib.fallback(fallback_kernel, "Autocast")
"""
if dispatch_key == "":
dispatch_key = self.dispatch_key
if self.ns != "_":
raise RuntimeError(
f"""Fallback can only be registered using libary fragment on the global namespace "_" but it is {self.ns}"""
)
assert dispatch_key != ""
assert self.m is not None
self.m.fallback(
dispatch_key,
fn,
with_keyset,
)
def _destroy(self):
if self.m is not None:
self.m.reset()