Allow users to override kernels for existing C++ ops through Python

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75905

Approved by: https://github.com/ezyang
This commit is contained in:
anjali411
2022-05-04 21:51:09 +00:00
committed by PyTorch MergeBot
parent 32ae584008
commit 55f55a4cf6
9 changed files with 302 additions and 97 deletions

View File

@ -7,7 +7,7 @@ namespace torch {
namespace {
// TODO: Consider representing debug info as a struct instead so you
// don't have to allocate strings all the time
std::string debugString(const char* file, uint32_t line) {
std::string debugString(std::string file, uint32_t line) {
#ifdef STRIP_ERROR_MESSAGES
return std::string();
#else
@ -15,7 +15,7 @@ namespace {
#endif
}
std::string debugString(std::string debug, const char* file, uint32_t line) {
std::string debugString(std::string debug, std::string file, uint32_t line) {
#ifdef STRIP_ERROR_MESSAGES
return std::string();
#else

View File

@ -3,6 +3,7 @@
import tempfile
import torch
from copy import deepcopy
from torch.library import Library
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, LoggingTensorMode, \
log_input, capture_logs, no_dispatch
@ -12,6 +13,86 @@ from torch.utils._python_dispatch import enable_torch_dispatch_mode, push_torch_
import logging
from functools import partial
class TestPythonRegistration(TestCase):
def test_override_aten_ops_with_multiple_libraries(self) -> None:
x = torch.tensor([1, 2])
my_lib1 = Library("aten", "IMPL")
my_lib2 = Library("aten", "IMPL")
# Example 1
def my_neg(*args, **kwargs):
return args[0]._neg_view()
# Now we are secretly making the operator a view op so autograd needs to know how
# to handle it
my_lib1.impl('neg', my_neg, "AutogradCPU")
self.assertTrue(torch.neg(x).is_neg())
# RuntimeError: impl("aten::neg", ...):
# Explicitly provided namespace (aten) in operator name does not match ...
with self.assertRaisesRegex(RuntimeError, "operator name does not match namespace"):
my_lib3 = Library("foo", "IMPL")
my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU")
del my_lib3
# Example 2
def my_mul(*args, **kwargs):
return torch.zeros_like(args[0])
# torch.ops.aten.mul.Tensor
my_lib2.impl("aten::mul.Tensor", my_mul, "ZeroTensor")
y = torch._efficientzerotensor(2)
self.assertFalse(torch.mul(x, y)._is_zerotensor())
# Assert that a user can't override the behavior of a (ns, op, dispatch_key)
# combination if someone overrided the behavior for the same before them
with self.assertRaisesRegex(RuntimeError, 'already a kernel registered from python'):
my_lib2.impl(torch.ops.aten.mul.Tensor, my_mul, "ZeroTensor")
del my_lib1
# Validate that lib2 is not affected by removing lib1
self.assertFalse(torch.mul(x, y)._is_zerotensor())
del my_lib2
# Validate that the old behavior is restored for neg and mul
self.assertFalse(torch.neg(x).is_neg())
self.assertTrue(torch.mul(x, y)._is_zerotensor())
def test_override_cpu_sum(self) -> None:
# Example 1
run = [False]
def my_sum(*args, **kwargs):
run[0] = True
return args[0]
my_lib1 = Library("aten", "IMPL")
my_lib1.impl('aten::sum', my_sum, "CPU")
x = torch.tensor([1, 2])
self.assertEqual(torch.sum(x), x)
self.assertTrue(run[0])
del my_lib1
# Validate that the old behavior is restored for sum
self.assertEqual(torch.sum(x), torch.tensor(3))
def test_extend_library_with_dispatch_key_arg(self):
def my_sum(*args, **kwargs):
return args[0]
my_lib1 = Library("aten", "IMPL", dispatch_key="CPU")
# RuntimeError: Explicitly provided dispatch key (Conjugate) is
# inconsistent with the dispatch key of the enclosing TORCH_LIBRARY_IMPL block
with self.assertRaisesRegex(RuntimeError, "inconsistent with the dispatch key"):
my_lib1.impl('sum', my_sum, "Conjugate")
my_lib1.impl('aten::sum', my_sum)
x = torch.tensor([1, 2])
self.assertEqual(torch.sum(x), x)
del my_lib1
class TestPythonDispatch(TestCase):
def test_basic(self) -> None:
with capture_logs() as logs:

View File

@ -829,6 +829,10 @@ class Generator(object):
def seed(self) -> _int: ...
def initial_seed(self) -> _int: ...
# Defined in torch/csrc/utils/python_dispatch.cpp
def _dispatch_library(kind: str, name: str, dispatch: str, file: str, linenum: Any) -> Any: ...
# Defined in torch/csrc/utils/init.cpp
class BenchmarkConfig(object):
num_calling_threads: _int

View File

@ -918,3 +918,4 @@ def _register_device_module(device_type, module):
# expose return_types
from . import return_types
from . import library

View File

@ -4,7 +4,6 @@
#include <c10/core/SafePyObject.h>
#include <c10/util/DeadlockDetection.h>
#include <c10/util/irange.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/Device.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/Exceptions.h>
@ -25,18 +24,19 @@
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/tensor/python_tensor.h>
#include <torch/csrc/utils/cuda_lazy_init.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/pycfunction_helpers.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/tensor_new.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/library.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <ATen/ATen.h>
#include <pybind11/pybind11.h>
#include <structmember.h>
#include <cstdint>
@ -51,6 +51,91 @@ using namespace at;
using namespace torch;
using namespace torch::autograd;
std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs(const c10::OperatorHandle& op, const std::vector<c10::IValue>& arguments) {
TORCH_CHECK(PyGILState_Check(), "GIL must be held before you call parseIValuesToPyArgsKwargs");
const auto& schema = op.schema();
py::dict kwargs;
// About all the pointers:
//
// f(int x, int y = 0, *, int z = 0)
// ^- arguments.size()
// ^- kwarg_only_start
// ^- positional_default_start
// ^- 0
// Find the split point between kwarg-only and regular. Since most functions
// don't have kwarg-only arguments, it is more efficient to scan from the
// right (but ideally, this would just be precomputed in FunctionSchema
// itself). (NB: minus one in the loop is because we're testing if the
// *next* argument is kwarg-only before we advance the starting index)
int64_t kwarg_only_start = arguments.size();
for (; kwarg_only_start > 0; kwarg_only_start--) {
const auto& arg = schema.arguments()[kwarg_only_start - 1];
if (!arg.kwarg_only()) {
break;
}
}
// Find the first positional argument that isn't defaulted
auto is_default = [&](int64_t idx) -> bool {
const auto& arg = schema.arguments()[idx];
if (!arg.default_value().has_value()) {
return false;
}
const auto& default_ivalue = *arg.default_value();
const auto& ivalue = arguments[idx];
if (default_ivalue != ivalue) {
return false;
}
return true;
};
int64_t positional_default_start = kwarg_only_start;
for (; positional_default_start > 0; positional_default_start--) {
if (!is_default(positional_default_start - 1)) {
break;
}
}
auto args = py::reinterpret_steal<py::object>(PyTuple_New(positional_default_start));
// Populate positional arguments
for (const auto idx : c10::irange(positional_default_start)) {
PyTuple_SET_ITEM(args.ptr(), idx, torch::jit::toPyObject(arguments[idx]).release().ptr());
}
// Populate keyword arguments
for (const auto idx : c10::irange(kwarg_only_start, arguments.size())) {
// But don't populate default keyword arguments
if (is_default(idx)) continue;
const auto& arg = schema.arguments()[idx];
kwargs[py::cast(arg.name())] = torch::jit::toPyObject(arguments[idx]);
}
return std::make_pair(std::move(args), std::move(kwargs));
}
void pushPyOutToStack(
const c10::OperatorHandle& op,
torch::jit::Stack* stack,
py::object out,
const char* msg) {
TORCH_CHECK(PyGILState_Check(), "GIL must be held before you call pushPyOutToStack");
auto schema_returns = op.schema().returns();
const auto num_returns = schema_returns.size();
if (num_returns == 0) {
// Check that we got a None return from Python. Anything else is an error.
TORCH_CHECK(out.is(py::none()), "Expected ", msg, " for ", op.operator_name(),
" to return None but it returned something else instead.");
} else if (num_returns == 1) {
torch::jit::push(stack, torch::jit::toIValue(out.ptr(), schema_returns[0].type()));
} else {
auto outs = py::cast<py::sequence>(out);
for (const auto idx : c10::irange(outs.size())) {
torch::jit::push(stack, torch::jit::toIValue(outs[idx].ptr(), schema_returns[idx].type()));
}
}
}
namespace {
std::string concrete_name_fn(const c10::impl::PyInterpreter* self) {
@ -131,8 +216,6 @@ c10::impl::PyInterpreter* getPyInterpreter() {
return self_interpreter.get();
}
namespace py = pybind11;
PyObject *THPVariableClass = nullptr;
PyObject *ParameterClass = nullptr;
@ -1751,8 +1834,6 @@ void concrete_dispatch_fn(
torch::jit::Stack* stack,
const std::shared_ptr<SafePyObject>& type) {
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
const auto num_arguments = schema.arguments().size();
auto arguments = torch::jit::pop(*stack, num_arguments);
@ -1776,9 +1857,6 @@ void concrete_dispatch_fn(
py::gil_scoped_acquire g;
std::vector<py::handle> overloaded_args;
// For now, overloads get coalesced. Might be easier for users if they get
// overload resolution but is more complicated (need to expose separate
// functions per overload)
py::handle torch_api_function = py::module::import("torch").attr("ops").attr(ns).attr(func_name);
py::handle torch_api_function_overload;
if (overload_name == "") {
@ -1788,51 +1866,6 @@ void concrete_dispatch_fn(
}
std::string module_name_str = "torch.ops." + ns_str;
// About all the pointers:
//
// f(int x, int y = 0, *, int z = 0)
// ^- arguments.size()
// ^- kwarg_only_start
// ^- positional_default_start
// ^- 0
// Find the split point between kwarg-only and regular. Since most functions
// don't have kwarg-only arguments, it is more efficient to scan from the
// right (but ideally, this would just be precomputed in FunctionSchema
// itself). (NB: minus one in the loop is because we're testing if the
// *next* argument is kwarg-only before we advance the starting index)
int64_t kwarg_only_start = arguments.size();
for (; kwarg_only_start > 0; kwarg_only_start--) {
const auto& arg = schema.arguments()[kwarg_only_start - 1];
if (!arg.kwarg_only()) {
break;
}
}
// Find the first positional argument that isn't defaulted
auto is_default = [&](int64_t idx) -> bool {
const auto& arg = schema.arguments()[idx];
if (!arg.default_value().has_value()) {
return false;
}
const auto& default_ivalue = *arg.default_value();
const auto& ivalue = arguments[idx];
if (default_ivalue != ivalue) {
return false;
}
return true;
};
int64_t positional_default_start = kwarg_only_start;
for (; positional_default_start > 0; positional_default_start--) {
if (!is_default(positional_default_start - 1)) {
break;
}
}
auto args = py::reinterpret_steal<py::object>(PyTuple_New(positional_default_start));
py::dict kwargs;
if (type) {
append_overloaded_type(&overloaded_args, type->ptr(getPyInterpreter()));
}
@ -1859,41 +1892,19 @@ void concrete_dispatch_fn(
}
}
// Populate positional arguments
for (const auto idx : c10::irange(positional_default_start)) {
PyTuple_SET_ITEM(args.ptr(), idx, torch::jit::toPyObject(std::move(arguments[idx])).release().ptr());
}
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
auto args = std::move(args_kwargs.first);
auto kwargs = std::move(args_kwargs.second);
// Populate keyword arguments
for (const auto idx : c10::irange(kwarg_only_start, arguments.size())) {
// But don't populate default keyword arguments
if (is_default(idx)) continue;
const auto& arg = schema.arguments()[idx];
kwargs[py::cast(arg.name())] = torch::jit::toPyObject(std::move(arguments[idx]));
}
auto out = py::reinterpret_steal<py::object>(
handle_torch_function_no_python_arg_parser(
overloaded_args,
args.ptr(),
kwargs.ptr(),
func_name,
torch_api_function_overload.ptr(),
module_name_str.c_str(),
TorchFunctionName::TorchDispatch));
if (num_returns == 0) {
// Check that we got a None return from Python. Anything else is an error.
TORCH_CHECK(out.is(py::none()), "Expected __torch_dispatch__ for ", op.operator_name(),
" to return None but it returned something else instead.");
} else if (num_returns == 1) {
torch::jit::push(stack, torch::jit::toIValue(out.ptr(), op.schema().returns()[0].type()));
} else {
auto outs = py::cast<py::sequence>(out);
for (const auto idx : c10::irange(outs.size())) {
torch::jit::push(stack, torch::jit::toIValue(outs[idx].ptr(), op.schema().returns()[idx].type()));
}
}
PyObject* obj = handle_torch_function_no_python_arg_parser(
overloaded_args,
args.ptr(),
kwargs.ptr(),
func_name,
torch_api_function_overload.ptr(),
module_name_str.c_str(),
TorchFunctionName::TorchDispatch);
pushPyOutToStack(op, stack, py::reinterpret_steal<py::object>(obj), "__torch_dispatch__");
}
c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self) {

View File

@ -7,6 +7,10 @@
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/Exceptions.h>
#include <pybind11/pybind11.h>
#include <ATen/core/function_schema.h>
namespace py = pybind11;
// Python object that backs torch.autograd.Variable
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
@ -62,3 +66,7 @@ inline const at::Tensor& THPVariable_Unpack(PyObject* obj) {
}
TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter();
std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs(const c10::OperatorHandle& op, const std::vector<c10::IValue>& arguments);
void pushPyOutToStack(const c10::OperatorHandle& op, torch::jit::Stack* stack, py::object out, const char* msg);

View File

@ -5,6 +5,10 @@
#include <ATen/ATen.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/core/SafePyObject.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/autograd/python_variable.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
@ -50,6 +54,20 @@ inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) {
}
}
class PythonKernelHolder : public c10::OperatorKernel {
SafePyObject func_;
public:
PythonKernelHolder(py::object func) : func_(func.release().ptr(), getPyInterpreter()) {}
void operator()(const c10::OperatorHandle& op, c10::DispatchKeySet keyset, torch::jit::Stack* stack) {
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
py::gil_scoped_acquire g;
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
auto obj = py::reinterpret_steal<py::object>(PyObject_Call(func_.ptr(getPyInterpreter()), args_kwargs.first.ptr(), args_kwargs.second.ptr()));
pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
}
};
void initDispatchBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
@ -122,6 +140,14 @@ void initDispatchBindings(PyObject* module) {
);
return self;
}, "", py::arg("name"), py::arg("dispatch") = "", py::arg("debug") = "")
.def("impl", [](py::object self, const char* name, const char* dispatch, py::object func) {
HANDLE_TH_ERRORS
self.cast<torch::Library&>().impl(
name,
dispatch_str(dispatch, CppFunction::makeFromBoxedFunctor(std::make_unique<PythonKernelHolder>(std::move(func))))
);
END_HANDLE_TH_ERRORS_PYBIND
}, "", py::arg("name"), py::arg("dispatch"), py::arg("func"))
.def("fallback_fallthrough", [](py::object self, const char* dispatch) {
self.cast<torch::Library&>().fallback(
dispatch_str(dispatch, CppFunction::makeFallthrough())
@ -130,14 +156,17 @@ void initDispatchBindings(PyObject* module) {
}, "", py::arg("dispatch") = "")
;
m.def("_dispatch_library", [](const char* kind, std::string name, const char* dispatch) {
m.def("_dispatch_library", [](const char* kind, std::string name, const char* dispatch, const char* file, uint32_t linenum) {
HANDLE_TH_ERRORS
return std::make_unique<torch::Library>(
parseKind(kind),
std::move(name),
std::string(dispatch) == "" ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(dispatch)),
"/dev/null",
0);
});
file,
linenum);
END_HANDLE_TH_ERRORS_PYBIND
}, "", py::arg("kind"), py::arg("name"), py::arg("dispatch"), py::arg("file")="/dev/null", py::arg("linenum")=0)
;
m.def("_dispatch_dump", [](const char* name) -> std::string {
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));

View File

@ -555,7 +555,7 @@ class TORCH_API Library final {
Library& operator=(const Library&) = delete;
Library(Library&&) = default;
Library& operator=(Library&&) = default;
// TODO: add gen_python_error boolean flag
// Some notes about the API design here. We had the following constraints:
//
// - We need to support multiple "types" of arguments for schema and
@ -800,7 +800,7 @@ class TORCH_API Library final {
Kind kind_;
c10::optional<std::string> ns_;
c10::optional<c10::DispatchKey> dispatch_key_;
const char* file_;
std::string file_;
uint32_t line_;
std::vector<c10::RegistrationHandleRAII> registrars_;

71
torch/library.py Normal file
View File

@ -0,0 +1,71 @@
from ._ops import OpOverload
from typing import Set
import traceback
import torch._C as C
__all__ = ['Library']
# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
# The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.
# This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid
# libraries calling into kernels not intended to be called.
_impls: Set[str] = set()
class Library:
"""
Class to create linraries that can be used to register new operators or
override operators in existing libraries from Python.
A user can pass in a dispatch keyname if they only want the library to override kernels corresponding
to only one specific dispatch key.
Args:
ns: library name
kind: "IMPL" by default
dispatch_key: PyTorch dispatch key (default: "")
"""
def __init__(self, ns, kind, dispatch_key=""):
frame = traceback.extract_stack(limit=3)[0]
filename, lineno = frame.filename, frame.lineno
self.m = C._dispatch_library(kind, ns, dispatch_key, filename, lineno)
self.ns = ns
self._op_impls = set()
self.kind = kind
self.dispatch_key = dispatch_key
if kind != "IMPL":
raise ValueError("Unsupported kind: ", kind)
def __repr__(self):
return "Library(kind={}, ns={}, dispatch_key={})>".format(self.kind, self.ns, self.dispatch_key)
def impl(self, op_name, fn, dispatch_key=''):
if dispatch_key == '':
if self.dispatch_key == '':
raise RuntimeError("Please specify the dispatch key that you want to register the kernel for.")
dispatch_key = self.dispatch_key
if isinstance(op_name, str):
name = op_name
elif isinstance(op_name, OpOverload):
name = op_name._schema.name
overload_name = op_name._schema.overload_name
if overload_name != '':
name = name + '.' + overload_name
else:
raise RuntimeError("impl should be passed either a name or an OpOverload object as the first argument")
key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
if key in _impls:
# TODO: in future, add more info about where the existing function is registered (this info is
# today already returned by the C++ warning when impl is called but we error out before that)
raise RuntimeError("This is not allowed since there's already a kernel registered from python overriding {}"
"'s behavior for {} dispatch key and {} namespace.".
format(name.split("::")[-1], dispatch_key, self.ns))
self.m.impl(name, dispatch_key, fn)
_impls.add(key)
self._op_impls.add(key)
def __del__(self):
for key in self._op_impls:
_impls.remove(key)
del self.m