Add __torch_function__ for methods (#37091)

Summary:
According to pytorch/rfcs#3

From the goals in the RFC:

1. Support subclassing `torch.Tensor` in Python (done here)
2. Preserve `torch.Tensor` subclasses when calling `torch` functions on them (done here)
3. Use the PyTorch API with `torch.Tensor`-like objects that are _not_ `torch.Tensor`
   subclasses (done in https://github.com/pytorch/pytorch/issues/30730)
4. Preserve `torch.Tensor` subclasses when calling `torch.Tensor` methods. (done here)
5. Propagating subclass instances correctly also with operators, using
   views/slices/indexing/etc. (done here)
6. Preserve subclass attributes when using methods or views/slices/indexing. (done here)
7. A way to insert code that operates on both functions and methods uniformly
   (so we can write a single function that overrides all operators). (done here)
8. The ability to give external libraries a way to also define
   functions/methods that follow the `__torch_function__` protocol. (will be addressed in a separate PR)

This PR makes the following changes:

1. Adds the `self` argument to the arg parser.
2. Dispatches on `self` as well if `self` is not `nullptr`.
3. Adds a `torch._C.DisableTorchFunction` context manager to disable `__torch_function__`.
4. Adds a `torch::torch_function_enabled()` and `torch._C._torch_function_enabled()` to check the state of `__torch_function__`.
5. Dispatches all `torch._C.TensorBase` and `torch.Tensor` methods via `__torch_function__`.

TODO:

- [x] Sequence Methods
- [x] Docs
- [x] Tests

Closes https://github.com/pytorch/pytorch/issues/28361

Benchmarks in https://github.com/pytorch/pytorch/pull/37091#issuecomment-633657778

Pull Request resolved: https://github.com/pytorch/pytorch/pull/37091

Reviewed By: ngimel

Differential Revision: D22765678

Pulled By: ezyang

fbshipit-source-id: 53f8aa17ddb8b1108c0997f6a7aa13cb5be73de0
This commit is contained in:
Hameer Abbasi
2020-08-05 20:39:27 -07:00
committed by Facebook GitHub Bot
parent 92b7347fd7
commit 3d46e02ea1
25 changed files with 1387 additions and 105 deletions

View File

@ -135,11 +135,70 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
}
}
auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name) -> PyObject* {
auto handle_torch_function_getter(THPVariable* self, const std::string& property_name) -> PyObject* {
py::object torch_api = PyObject_FastGetAttrString(THPVariableClass, (char*)property_name.c_str());
std::string module_name = "torch.Tensor." + property_name;
return handle_torch_function((PyObject *)self, "__get__", nullptr, torch_api.ptr(), module_name);
}
auto handle_torch_function_setter(THPVariable* self, const std::string& property_name, PyObject* value) -> int {
py::object torch_api = PyObject_FastGetAttrString(THPVariableClass, (char*)property_name.c_str());
std::string module_name = "torch.Tensor." + property_name;
if (value != nullptr)
{
py::tuple args_ = py::make_tuple(py::handle(value));
handle_torch_function((PyObject *)self, "__set__", args_.ptr(), torch_api.ptr(), module_name);
}
else {
handle_torch_function((PyObject *)self, "__delete__", nullptr, torch_api.ptr(), module_name);
}
return 0;
}
// Combines self and args into one tuple.
auto combine_self_args(PyObject *self, PyObject *args) -> py::tuple {
if (args == nullptr) {
return py::make_tuple(py::handle(self));
}
else if (self == nullptr) {
return py::reinterpret_borrow<py::tuple>(args);
}
auto py_args = py::reinterpret_borrow<py::tuple>(args);
size_t n = py_args.size();
auto args_ = py::tuple(n + 1);
args_[0] = py::handle(self);
for (size_t i = 0; i < n; i++) {
args_[i+1] = py_args[i];
}
return args_;
}
auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args, PyObject* torch_api, const std::string& module_name) -> PyObject* {
py::object torch_api_function = PyObject_FastGetAttrString(torch_api, (char*)func_name.c_str());
TORCH_INTERNAL_ASSERT(torch_api_function.ptr() != nullptr, "torch API function must exist");
py::tuple args_ = combine_self_args(self, args);
py::tuple py_types = py::make_tuple(py::handle(PyObject_Type(self)));
py::object torch_function = PyObject_FastGetAttrString(self, "__torch_function__");
py::object ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function.ptr(), py_types.ptr(), args_.ptr(), NULL));
if (ret.ptr() == nullptr) {
// if an exception occurred in a user's implementation of
// __torch_function__, throw it
throw python_error();
}
if (ret.ptr() == Py_NotImplemented) {
std::string error_msg = "no implementation found for " + module_name + "." + func_name + "' on types that implement __torch_function__: [" + self->ob_type->tp_name + "]";
PyErr_SetString(PyExc_TypeError, error_msg.c_str());
throw python_error();
}
return ret.release().ptr();
}
auto handle_torch_function(PythonArgs &r, PyObject* self, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name) -> PyObject* {
py::object torch_api_function = PyObject_FastGetAttrString(torch_api, (char*)r.get_func_name().c_str());
TORCH_INTERNAL_ASSERT(torch_api_function.ptr() != nullptr, "torch API function must exist");
py::object ret;
py::tuple args_ = combine_self_args(self, args);
// overloaded_args already all have unique types
std::vector<py::object> overloaded_types;
overloaded_types.reserve(r.signature.overloaded_args.size());
@ -150,7 +209,7 @@ auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyOb
for (auto &arg : r.signature.overloaded_args) {
py::object torch_function = PyObject_FastGetAttrString(arg.ptr(), "__torch_function__");
ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function.ptr(), py_types.ptr(), args, kwargs, NULL));
ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function.ptr(), py_types.ptr(), args_.ptr(), kwargs, NULL));
if (ret.ptr() != Py_NotImplemented) {
// Return the reference to the result. This also covers the case where ret
// is NULL and __torch_function__ raised an exception, which we throw below
@ -159,7 +218,7 @@ auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyOb
}
if (ret.ptr() == nullptr) {
// if an exception occurred in a user's implementation of
// __array_function__, throw it
// __torch_function__, throw it
throw python_error();
}
else if (ret.ptr() == Py_NotImplemented) {
@ -184,6 +243,11 @@ auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyOb
return ret.release().ptr();
}
auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name) -> PyObject*
{
return handle_torch_function(r, nullptr, args, kwargs, torch_api, module_name);
}
/*
* obj has a __torch_function__ implementation and may either be a
* subclass of Tensor or a Tensor-like duck type. We may need to
@ -653,9 +717,9 @@ static void extra_kwargs(FunctionSignature& signature, PyObject* kwargs, ssize_t
throw TypeError("invalid keyword arguments");
}
bool FunctionSignature::parse(PyObject* args, PyObject* kwargs, PyObject* dst[],
bool FunctionSignature::parse(PyObject* self, PyObject* args, PyObject* kwargs, PyObject* dst[], // NOLINT
bool raise_exception) {
auto nargs = PyTuple_GET_SIZE(args);
auto nargs = args ? PyTuple_GET_SIZE(args) : 0;
ssize_t remaining_kwargs = kwargs ? PyDict_Size(kwargs) : 0;
ssize_t arg_pos = 0;
bool allow_varargs_intlist = false;
@ -679,6 +743,9 @@ bool FunctionSignature::parse(PyObject* args, PyObject* kwargs, PyObject* dst[],
}
int i = 0;
if (self != nullptr && !THPVariable_CheckExact(self) && check_has_torch_function(self)) {
append_overloaded_arg(this->overloaded_args, self);
}
for (auto& param : params) {
PyObject* obj = nullptr;
bool is_kwd = false;
@ -798,25 +865,25 @@ void PythonArgParser::check_deprecated(const FunctionSignature & signature) {
}
}
PythonArgs PythonArgParser::raw_parse(PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) {
PythonArgs PythonArgParser::raw_parse(PyObject* self, PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) { // NOLINT
if (signatures_.size() == 1) {
auto& signature = signatures_[0];
signature.parse(args, kwargs, parsed_args, true);
signature.parse(self, args, kwargs, parsed_args, true);
check_deprecated(signature);
return PythonArgs(traceable, signature, parsed_args);
}
for (auto& signature : signatures_) {
if (signature.parse(args, kwargs, parsed_args, false)) {
if (signature.parse(self, args, kwargs, parsed_args, false)) {
check_deprecated(signature);
return PythonArgs(traceable, signature, parsed_args);
}
}
print_error(args, kwargs, parsed_args);
print_error(self, args, kwargs, parsed_args);
}
void PythonArgParser::print_error(PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) {
void PythonArgParser::print_error(PyObject* self, PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) { // NOLINT
auto num_args = PyTuple_GET_SIZE(args) + (kwargs ? PyDict_Size(kwargs) : 0);
std::vector<int> plausible_idxs;
ssize_t i = 0;
@ -829,7 +896,7 @@ void PythonArgParser::print_error(PyObject* args, PyObject* kwargs, PyObject* pa
if (plausible_idxs.size() == 1) {
auto& signature = signatures_[plausible_idxs[0]];
signature.parse(args, kwargs, parsed_args, true);
signature.parse(self, args, kwargs, parsed_args, true);
}
auto options = get_signatures();