mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 04:04:57 +08:00 
			
		
		
		
	Add __torch_function__ for methods (#37091)
Summary: According to pytorch/rfcs#3 From the goals in the RFC: 1. Support subclassing `torch.Tensor` in Python (done here) 2. Preserve `torch.Tensor` subclasses when calling `torch` functions on them (done here) 3. Use the PyTorch API with `torch.Tensor`-like objects that are _not_ `torch.Tensor` subclasses (done in https://github.com/pytorch/pytorch/issues/30730) 4. Preserve `torch.Tensor` subclasses when calling `torch.Tensor` methods. (done here) 5. Propagating subclass instances correctly also with operators, using views/slices/indexing/etc. (done here) 6. Preserve subclass attributes when using methods or views/slices/indexing. (done here) 7. A way to insert code that operates on both functions and methods uniformly (so we can write a single function that overrides all operators). (done here) 8. The ability to give external libraries a way to also define functions/methods that follow the `__torch_function__` protocol. (will be addressed in a separate PR) This PR makes the following changes: 1. Adds the `self` argument to the arg parser. 2. Dispatches on `self` as well if `self` is not `nullptr`. 3. Adds a `torch._C.DisableTorchFunction` context manager to disable `__torch_function__`. 4. Adds a `torch::torch_function_enabled()` and `torch._C._torch_function_enabled()` to check the state of `__torch_function__`. 5. Dispatches all `torch._C.TensorBase` and `torch.Tensor` methods via `__torch_function__`. TODO: - [x] Sequence Methods - [x] Docs - [x] Tests Closes https://github.com/pytorch/pytorch/issues/28361 Benchmarks in https://github.com/pytorch/pytorch/pull/37091#issuecomment-633657778 Pull Request resolved: https://github.com/pytorch/pytorch/pull/37091 Reviewed By: ngimel Differential Revision: D22765678 Pulled By: ezyang fbshipit-source-id: 53f8aa17ddb8b1108c0997f6a7aa13cb5be73de0
This commit is contained in:
		
				
					committed by
					
						 Facebook GitHub Bot
						Facebook GitHub Bot
					
				
			
			
				
	
			
			
			
						parent
						
							92b7347fd7
						
					
				
				
					commit
					3d46e02ea1
				
			| @ -135,11 +135,70 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only) | ||||
|   } | ||||
| } | ||||
|  | ||||
| auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name) -> PyObject* { | ||||
| auto handle_torch_function_getter(THPVariable* self, const std::string& property_name) -> PyObject* { | ||||
|   py::object torch_api = PyObject_FastGetAttrString(THPVariableClass, (char*)property_name.c_str()); | ||||
|   std::string module_name = "torch.Tensor." + property_name; | ||||
|   return handle_torch_function((PyObject *)self, "__get__", nullptr, torch_api.ptr(), module_name); | ||||
| } | ||||
|  | ||||
| auto handle_torch_function_setter(THPVariable* self, const std::string& property_name, PyObject* value) -> int { | ||||
|   py::object torch_api = PyObject_FastGetAttrString(THPVariableClass, (char*)property_name.c_str()); | ||||
|   std::string module_name = "torch.Tensor." + property_name; | ||||
|   if (value != nullptr) | ||||
|   { | ||||
|     py::tuple args_ = py::make_tuple(py::handle(value)); | ||||
|     handle_torch_function((PyObject *)self, "__set__", args_.ptr(), torch_api.ptr(), module_name); | ||||
|   } | ||||
|   else { | ||||
|     handle_torch_function((PyObject *)self, "__delete__", nullptr, torch_api.ptr(), module_name); | ||||
|   } | ||||
|   return 0; | ||||
| } | ||||
|  | ||||
| // Combines self and args into one tuple. | ||||
| auto combine_self_args(PyObject *self, PyObject *args) -> py::tuple { | ||||
|   if (args == nullptr) { | ||||
|     return py::make_tuple(py::handle(self)); | ||||
|   } | ||||
|   else if (self == nullptr) { | ||||
|     return py::reinterpret_borrow<py::tuple>(args); | ||||
|   } | ||||
|  | ||||
|   auto py_args = py::reinterpret_borrow<py::tuple>(args); | ||||
|   size_t n = py_args.size(); | ||||
|   auto args_ = py::tuple(n + 1); | ||||
|   args_[0] = py::handle(self); | ||||
|   for (size_t i = 0; i < n; i++) { | ||||
|     args_[i+1] = py_args[i]; | ||||
|   } | ||||
|   return args_; | ||||
| } | ||||
|  | ||||
| auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args, PyObject* torch_api, const std::string& module_name) -> PyObject* { | ||||
|   py::object torch_api_function = PyObject_FastGetAttrString(torch_api, (char*)func_name.c_str()); | ||||
|   TORCH_INTERNAL_ASSERT(torch_api_function.ptr() != nullptr, "torch API function must exist"); | ||||
|   py::tuple args_ = combine_self_args(self, args); | ||||
|   py::tuple py_types = py::make_tuple(py::handle(PyObject_Type(self))); | ||||
|   py::object torch_function = PyObject_FastGetAttrString(self, "__torch_function__"); | ||||
|   py::object ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function.ptr(), py_types.ptr(), args_.ptr(), NULL)); | ||||
|   if (ret.ptr() == nullptr) { | ||||
|     // if an exception occurred in a user's implementation of | ||||
|     // __torch_function__, throw it | ||||
|     throw python_error(); | ||||
|   } | ||||
|   if (ret.ptr() == Py_NotImplemented) { | ||||
|     std::string error_msg = "no implementation found for " + module_name + "." + func_name + "' on types that implement __torch_function__: [" + self->ob_type->tp_name + "]"; | ||||
|     PyErr_SetString(PyExc_TypeError, error_msg.c_str()); | ||||
|     throw python_error(); | ||||
|   } | ||||
|   return ret.release().ptr(); | ||||
| } | ||||
|  | ||||
| auto handle_torch_function(PythonArgs &r, PyObject* self, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name) -> PyObject* { | ||||
|   py::object torch_api_function = PyObject_FastGetAttrString(torch_api, (char*)r.get_func_name().c_str()); | ||||
|   TORCH_INTERNAL_ASSERT(torch_api_function.ptr() != nullptr, "torch API function must exist"); | ||||
|   py::object ret; | ||||
|  | ||||
|   py::tuple args_ = combine_self_args(self, args); | ||||
|   // overloaded_args already all have unique types | ||||
|   std::vector<py::object> overloaded_types; | ||||
|   overloaded_types.reserve(r.signature.overloaded_args.size()); | ||||
| @ -150,7 +209,7 @@ auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyOb | ||||
|  | ||||
|   for (auto &arg : r.signature.overloaded_args) { | ||||
|     py::object torch_function = PyObject_FastGetAttrString(arg.ptr(), "__torch_function__"); | ||||
|     ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function.ptr(), py_types.ptr(), args, kwargs, NULL)); | ||||
|     ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function.ptr(), py_types.ptr(), args_.ptr(), kwargs, NULL)); | ||||
|     if (ret.ptr() != Py_NotImplemented) { | ||||
|       // Return the reference to the result. This also covers the case where ret | ||||
|       // is NULL and __torch_function__ raised an exception, which we throw below | ||||
| @ -159,7 +218,7 @@ auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyOb | ||||
|   } | ||||
|   if (ret.ptr() == nullptr) { | ||||
|     // if an exception occurred in a user's implementation of | ||||
|     // __array_function__, throw it | ||||
|     // __torch_function__, throw it | ||||
|     throw python_error(); | ||||
|   } | ||||
|   else if (ret.ptr() == Py_NotImplemented) { | ||||
| @ -184,6 +243,11 @@ auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyOb | ||||
|   return ret.release().ptr(); | ||||
| } | ||||
|  | ||||
| auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name) -> PyObject* | ||||
| { | ||||
|   return handle_torch_function(r, nullptr, args, kwargs, torch_api, module_name); | ||||
| } | ||||
|  | ||||
| /* | ||||
|  *  obj has a __torch_function__ implementation and may either be a | ||||
|  *  subclass of Tensor or a Tensor-like duck type. We may need to | ||||
| @ -653,9 +717,9 @@ static void extra_kwargs(FunctionSignature& signature, PyObject* kwargs, ssize_t | ||||
|   throw TypeError("invalid keyword arguments"); | ||||
| } | ||||
|  | ||||
| bool FunctionSignature::parse(PyObject* args, PyObject* kwargs, PyObject* dst[], | ||||
| bool FunctionSignature::parse(PyObject* self, PyObject* args, PyObject* kwargs, PyObject* dst[],  // NOLINT | ||||
|                               bool raise_exception) { | ||||
|   auto nargs = PyTuple_GET_SIZE(args); | ||||
|   auto nargs = args ? PyTuple_GET_SIZE(args) : 0; | ||||
|   ssize_t remaining_kwargs = kwargs ? PyDict_Size(kwargs) : 0; | ||||
|   ssize_t arg_pos = 0; | ||||
|   bool allow_varargs_intlist = false; | ||||
| @ -679,6 +743,9 @@ bool FunctionSignature::parse(PyObject* args, PyObject* kwargs, PyObject* dst[], | ||||
|   } | ||||
|  | ||||
|   int i = 0; | ||||
|   if (self != nullptr && !THPVariable_CheckExact(self) && check_has_torch_function(self)) { | ||||
|     append_overloaded_arg(this->overloaded_args, self); | ||||
|   } | ||||
|   for (auto& param : params) { | ||||
|     PyObject* obj = nullptr; | ||||
|     bool is_kwd = false; | ||||
| @ -798,25 +865,25 @@ void PythonArgParser::check_deprecated(const FunctionSignature & signature) { | ||||
|   } | ||||
| } | ||||
|  | ||||
| PythonArgs PythonArgParser::raw_parse(PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) { | ||||
| PythonArgs PythonArgParser::raw_parse(PyObject* self, PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) {  // NOLINT | ||||
|   if (signatures_.size() == 1) { | ||||
|     auto& signature = signatures_[0]; | ||||
|     signature.parse(args, kwargs, parsed_args, true); | ||||
|     signature.parse(self, args, kwargs, parsed_args, true); | ||||
|     check_deprecated(signature); | ||||
|     return PythonArgs(traceable, signature, parsed_args); | ||||
|   } | ||||
|  | ||||
|   for (auto& signature : signatures_) { | ||||
|     if (signature.parse(args, kwargs, parsed_args, false)) { | ||||
|     if (signature.parse(self, args, kwargs, parsed_args, false)) { | ||||
|       check_deprecated(signature); | ||||
|       return PythonArgs(traceable, signature, parsed_args); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   print_error(args, kwargs, parsed_args); | ||||
|   print_error(self, args, kwargs, parsed_args); | ||||
| } | ||||
|  | ||||
| void PythonArgParser::print_error(PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) { | ||||
| void PythonArgParser::print_error(PyObject* self, PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) {  // NOLINT | ||||
|   auto num_args = PyTuple_GET_SIZE(args) + (kwargs ? PyDict_Size(kwargs) : 0); | ||||
|   std::vector<int> plausible_idxs; | ||||
|   ssize_t i = 0; | ||||
| @ -829,7 +896,7 @@ void PythonArgParser::print_error(PyObject* args, PyObject* kwargs, PyObject* pa | ||||
|  | ||||
|   if (plausible_idxs.size() == 1) { | ||||
|     auto& signature = signatures_[plausible_idxs[0]]; | ||||
|     signature.parse(args, kwargs, parsed_args, true); | ||||
|     signature.parse(self, args, kwargs, parsed_args, true); | ||||
|   } | ||||
|  | ||||
|   auto options = get_signatures(); | ||||
|  | ||||
		Reference in New Issue
	
	Block a user