Update __torch_dispatch__ to return op overload instead of the opoverload packet function (#72673)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72673

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D34627164

Pulled By: anjali411

fbshipit-source-id: 3cb6406a392d530bf9da36b4d8e0a62b30e6497e
(cherry picked from commit 65b85a0a67df4d0f16ac8964e2b685d478a610fb)
This commit is contained in:
anjali411
2022-03-07 14:32:41 -08:00
committed by PyTorch MergeBot
parent a3d099ea18
commit 086645ad77
9 changed files with 142 additions and 120 deletions

View File

@ -7941,7 +7941,7 @@ class TestAutogradForwardMode(TestCase):
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func == torch.ops.aten.alias:
if func.overloadpacket == torch.ops.aten.alias:
counter[0] += 1
with no_dispatch():

View File

@ -10,8 +10,8 @@ class TestPerOverloadAPI(TestCase):
add_packet = torch.ops.aten.add
# class attributes
self.assertEqual(add_packet.op_name, 'add')
self.assertEqual(add_packet.qualified_op_name, 'aten.add')
self.assertEqual(add_packet.__name__, 'add')
self.assertEqual(str(add_packet), 'aten.add')
# callable
self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
@ -27,7 +27,7 @@ class TestPerOverloadAPI(TestCase):
self.assertEqual(id(add_packet), id(copy.deepcopy(add_packet)))
# pretty print
self.assertEqual(str(add_packet), "OpOverloadPacket(op='aten.add')")
self.assertEqual(repr(add_packet), "<OpOverloadPacket(op='aten.add')>")
self.assertRaises(AttributeError, lambda: add_packet.foo)
@ -36,9 +36,9 @@ class TestPerOverloadAPI(TestCase):
add_tensoroverload = add_packet.Tensor
# class attributes
self.assertEqual(add_tensoroverload.name, 'aten.add')
self.assertEqual(add_tensoroverload.overload_name, 'Tensor')
self.assertEqual(add_tensoroverload.overload_packet, add_packet)
self.assertEqual(str(add_tensoroverload), 'aten.add.Tensor')
self.assertEqual(add_tensoroverload.__name__, 'add.Tensor')
self.assertEqual(add_tensoroverload.overloadpacket, add_packet)
# deepcopy is a no-op
self.assertEqual(id(add_tensoroverload), id(copy.deepcopy(add_tensoroverload)))
@ -48,7 +48,7 @@ class TestPerOverloadAPI(TestCase):
self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload))
# pretty print
self.assertEqual(str(add_tensoroverload), "OpOverload(op='aten.add', overload='Tensor')")
self.assertEqual(repr(add_tensoroverload), "<OpOverload(op='aten.add', overload='Tensor')>")
# callable
self.assertEqual(add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5))

View File

@ -31,11 +31,11 @@ class TestPythonDispatch(TestCase):
# self.assertEqual(saved_x._version, x._version)
self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x')
$1 = torch._ops.aten.mul($0, $0)
$1 = torch._ops.aten.mul.Tensor($0, $0)
$2 = input('grad_y')
$3 = torch._ops.aten.mul($2, $0)
$4 = torch._ops.aten.mul($2, $0)
$5 = torch._ops.aten.add($4, $3)''')
$3 = torch._ops.aten.mul.Tensor($2, $0)
$4 = torch._ops.aten.mul.Tensor($2, $0)
$5 = torch._ops.aten.add.Tensor($4, $3)''')
def test_out(self) -> None:
with capture_logs() as logs:
@ -51,7 +51,7 @@ $5 = torch._ops.aten.add($4, $3)''')
self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x')
$1 = input('y')
$2 = torch._ops.aten.abs($0, out=$1)''')
$2 = torch._ops.aten.abs.out($0, out=$1)''')
def test_kwarg_only(self) -> None:
@ -74,11 +74,11 @@ $2 = torch._ops.aten.abs($0, out=$1)''')
$0 = input('x')
$1 = input('y')
$2 = input('z')
$3 = torch._ops.aten.addmv($0, $1, $2)
$4 = torch._ops.aten.addmv($0, $1, $2)
$5 = torch._ops.aten.addmv($0, $1, $2, beta=2)
$6 = torch._ops.aten.addmv($0, $1, $2, alpha=2)
$7 = torch._ops.aten.addmv($0, $1, $2, beta=2, alpha=2)''')
$3 = torch._ops.aten.addmv.default($0, $1, $2)
$4 = torch._ops.aten.addmv.default($0, $1, $2)
$5 = torch._ops.aten.addmv.default($0, $1, $2, beta=2)
$6 = torch._ops.aten.addmv.default($0, $1, $2, alpha=2)
$7 = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)''')
def test_kwarg_only_and_positional_default(self) -> None:
with capture_logs() as logs:
@ -96,10 +96,10 @@ $7 = torch._ops.aten.addmv($0, $1, $2, beta=2, alpha=2)''')
self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x')
$1 = input('y')
$2 = torch._ops.aten.kl_div($0, $1)
$3 = torch._ops.aten.kl_div($0, $1, 2)
$4 = torch._ops.aten.kl_div($0, $1, log_target=True)
$5 = torch._ops.aten.kl_div($0, $1, 2, log_target=True)''')
$2 = torch._ops.aten.kl_div.default($0, $1)
$3 = torch._ops.aten.kl_div.default($0, $1, 2)
$4 = torch._ops.aten.kl_div.default($0, $1, log_target=True)
$5 = torch._ops.aten.kl_div.default($0, $1, 2, log_target=True)''')
def test_list_ret(self) -> None:
# test all sequence types are permissible returns
@ -111,7 +111,7 @@ $5 = torch._ops.aten.kl_div($0, $1, 2, log_target=True)''')
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func == torch.ops.aten.split:
if func.overloadpacket == torch.ops.aten.split:
with no_dispatch():
return list_type(torch.split(*args))
else:
@ -152,8 +152,8 @@ $5 = torch._ops.aten.kl_div($0, $1, 2, log_target=True)''')
# would be bad if calling .detach() once emits 3+ detaches).
self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x')
$1 = torch._ops.aten.detach($0)
$2 = torch._ops.aten.detach($1)''')
$1 = torch._ops.aten.detach.default($0)
$2 = torch._ops.aten.detach.default($1)''')
def test_metadata_change_not_allowed(self) -> None:
x = LoggingTensor(torch.ones(1))
@ -264,11 +264,11 @@ $2 = torch._ops.aten.detach($1)''')
self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x')
$1 = input('x.grad')
$2 = torch._ops.aten.pow($0, 2)
$2 = torch._ops.aten.pow.Tensor_Scalar($0, 2)
$3 = input('grad_output')
$4 = torch._ops.aten.mul($3, tensor(2))
$5 = torch._ops.aten.mul($4, $0)
$6 = torch._ops.aten.add_($1, $5)''')
$4 = torch._ops.aten.mul.Tensor($3, tensor(2))
$5 = torch._ops.aten.mul.Tensor($4, $0)
$6 = torch._ops.aten.add_.Tensor($1, $5)''')
def test_subclass_creation(self):
# Make sure these statements runs without error
@ -376,7 +376,7 @@ $6 = torch._ops.aten.add_($1, $5)''')
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func.__name__ == "clone":
if func.overloadpacket.__name__ == "clone":
# Return a plain tensor from clone().
return args[0].elem.clone()
raise RuntimeError("NYI")
@ -444,7 +444,7 @@ $6 = torch._ops.aten.add_($1, $5)''')
idxs = (MyTensor(torch.tensor(0)),)
v = torch.randn(1)
res = x.index_put_(idxs, v)
self.assertEqual(called_funcs, [torch.ops.aten.index_put_])
self.assertEqual(called_funcs, [torch.ops.aten.index_put_.default])
def test_enable_python_mode_error(self) -> None:
with self.assertRaisesRegex(ValueError, "__torch_dispatch__"):
@ -594,7 +594,7 @@ $6 = torch._ops.aten.add_($1, $5)''')
# It prevents infinite recursion.
with no_dispatch():
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
if func.__name__ == "add":
if func.overloadpacket.__name__ == "add":
return None
else:
return rs
@ -659,7 +659,7 @@ $6 = torch._ops.aten.add_($1, $5)''')
x = torch.randn(2)
y = torch.randn(2)
self.assertEqual(SubTensor(x) + SubTensor(y), x + y)
self.assertEqual(called, [torch.ops.aten.add])
self.assertEqual(called, [torch.ops.aten.add.Tensor])
def test_dispatch_super_dont_autograd(self):
called = []
@ -685,7 +685,7 @@ $6 = torch._ops.aten.add_($1, $5)''')
x = SubTensor(torch.randn(2, requires_grad=True))
x.neg()
self.assertEqual(called, [torch.ops.aten.neg])
self.assertEqual(called, [torch.ops.aten.neg.default])
def test_multiple_ops_subclass(self):
# This is a Direct Subclass, don't do that!

View File

@ -32,13 +32,17 @@ class OpOverload:
self._op = op
self._schema = schema
self._overloadpacket = overloadpacket
self._overloadname = 'default' if schema.overload_name == '' else schema.overload_name
self.__name__ = "{}.{}".format(self._schema.name.split("::")[1], self._overloadname)
self.__module__ = overloadpacket.__module__
op.__module__ = overloadpacket.__module__
# it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
def __deepcopy__(self, memo=None):
return self
def __str__(self):
return "OpOverload(op='{}.{}', overload='{}')".format(*self._schema.name.split("::"), self.overload_name)
def __repr__(self):
return "<OpOverload(op='{}.{}', overload='{}')>".format(*self._schema.name.split("::"), self._overloadname)
def __call__(self, *args, **kwargs):
return self._op(*args, **kwargs or {})
@ -46,17 +50,15 @@ class OpOverload:
def __getattr__(self, key):
return getattr(self._op, key)
# `my_namespace::my_op`
@property
def name(self):
return "{}.{}".format(*self._schema.name.split("::"))
def __hash__(self):
return hash(self._op)
# `my_namespace.my_op_name.overload_name`
def __str__(self):
return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)
@property
def overload_name(self):
return self._schema.overload_name
@property
def overload_packet(self):
def overloadpacket(self):
return self._overloadpacket
@property
@ -72,24 +74,22 @@ class OpOverloadPacket:
# These attributes are accessible on the object through the properties
# defined below but are immutable
self._qualified_op_name = qualified_op_name
self._op_name = op_name
self.__name__ = op_name
self._op = op
# it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op.
def __deepcopy__(self, memo=None):
return self
def __repr__(self):
return "<OpOverloadPacket(op='{}.{}')>".format(*self._qualified_op_name.split("::"))
def __hash__(self):
return hash(self._op)
def __str__(self):
return "OpOverloadPacket(op='{}.{}')".format(*self._qualified_op_name.split("::"))
@property
def qualified_op_name(self):
return "{}.{}".format(*self._qualified_op_name.split("::"))
@property
def op_name(self):
return self._op_name
@property
def op(self):
return self._op

View File

@ -1684,6 +1684,7 @@ void concrete_dispatch_fn(
// Parse the name into namespace and name (no overload_name)
// TODO: put this into the library
const auto& qualified_name = op.operator_name().name;
const auto& overload_name = schema.overload_name();
auto pos = qualified_name.find("::");
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
// Make me some null terminated strings
@ -1704,6 +1705,12 @@ void concrete_dispatch_fn(
// overload resolution but is more complicated (need to expose separate
// functions per overload)
py::handle torch_api_function = py::module::import("torch").attr("ops").attr(ns).attr(func_name);
py::handle torch_api_function_overload;
if (overload_name == "") {
torch_api_function_overload = torch_api_function.attr("default");
} else {
torch_api_function_overload = torch_api_function.attr(overload_name.c_str());
}
std::string module_name_str = "torch.ops." + ns_str;
// About all the pointers:
@ -1795,7 +1802,7 @@ void concrete_dispatch_fn(
args.ptr(),
kwargs.ptr(),
func_name,
torch_api_function.ptr(),
torch_api_function_overload.ptr(),
module_name_str.c_str(),
"__torch_dispatch__"
));
@ -1835,7 +1842,7 @@ c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter
args.ptr(),
kwargs.ptr(),
"detach",
py::module::import("torch").attr("ops").attr("aten").attr("detach").ptr(),
py::module::import("torch").attr("ops").attr("aten").attr("detach").attr("default").ptr(),
"torch.ops.aten",
"__torch_dispatch__"
));

View File

@ -1232,9 +1232,10 @@ void initJITBindings(PyObject* module) {
auto operations = getAllOperatorsFor(symbol);
for (const auto& op : operations) {
if (op->schema().overload_name() == overload_name) {
auto func =
py::cpp_function([op](py::args args, py::kwargs kwargs) {
return invokeOperatorFromPython({op}, args, kwargs);
auto func = py::cpp_function(
[op, symbol](py::args args, py::kwargs kwargs) {
return _get_operation_for_overload_or_packet(
{op}, symbol, args, kwargs, true);
});
return func;
}
@ -1265,60 +1266,8 @@ void initJITBindings(PyObject* module) {
auto func = py::cpp_function(
[operations, symbol](py::args args, py::kwargs kwargs) {
std::vector<py::handle> overloaded_args;
size_t total_arg_num = args.size() + kwargs.size();
for (const auto i : c10::irange(args.size())) {
is_tensor_and_append_overloaded(
args[i].ptr(), &overloaded_args);
is_tensor_list_and_append_overloaded(
args[i].ptr(),
&overloaded_args,
static_cast<int>(total_arg_num),
false /* throw_error */);
}
// NB: for kwargs, we cannot guarantee the order of appending
// is the same as the argument order in operator's schema.
// This is suboptimal, but should be fine. Later when we have
// better schema matching and argument parsing, we could
// match the operator in `operations` first, then the order will
// be guaranteed.
for (auto item : kwargs) {
is_tensor_and_append_overloaded(
item.second.ptr(), &overloaded_args);
is_tensor_list_and_append_overloaded(
item.second.ptr(),
&overloaded_args,
total_arg_num,
false /* throw_error */);
}
if (overloaded_args.size() > 0) {
std::vector<py::object> overloaded_types;
overloaded_types.reserve(overloaded_args.size());
for (auto& oarg : overloaded_args) {
overloaded_types.push_back(
py::reinterpret_borrow<py::object>(
(PyObject*)Py_TYPE(oarg.ptr())));
}
py::tuple py_types = py::cast(overloaded_types);
py::object ret;
std::string ns = symbol.ns().toUnqualString();
std::string method_name = symbol.toUnqualString();
auto self_func = py::module::import("torch")
.attr("ops")
.attr(ns.c_str())
.attr(method_name.c_str());
std::string module_name("torch.ops");
module_name.append(ns);
return pybind11::reinterpret_steal<py::object>(
handle_torch_function_no_python_arg_parser(
overloaded_args,
args.ptr(),
kwargs.ptr(),
method_name.c_str(),
self_func.ptr(),
module_name.c_str()));
}
return invokeOperatorFromPython(operations, args, kwargs);
return _get_operation_for_overload_or_packet(
operations, symbol, args, kwargs, false);
},
py::name(symbol.toUnqualString()),
py::doc(docstring.str().c_str()));

View File

@ -1159,5 +1159,72 @@ inline py::object invokeOperatorFromPython(
return createPyObjectForStack(std::move(stack));
}
inline py::object _get_operation_for_overload_or_packet(
const std::vector<std::shared_ptr<Operator>>& operations,
Symbol symbol,
py::args args,
const py::kwargs& kwargs,
bool is_overload) {
std::vector<py::handle> overloaded_args;
size_t total_arg_num = args.size() + kwargs.size();
for (const auto i : c10::irange(args.size())) {
is_tensor_and_append_overloaded(args[i].ptr(), &overloaded_args);
is_tensor_list_and_append_overloaded(
args[i].ptr(),
&overloaded_args,
static_cast<int>(total_arg_num),
false /* throw_error */);
}
// NB: for kwargs, we cannot guarantee the order of appending
// is the same as the argument order in operator's schema.
// This is suboptimal, but should be fine. Later when we have
// better schema matching and argument parsing, we could
// match the operator in `operations` first, then the order will
// be guaranteed.
for (auto item : kwargs) {
is_tensor_and_append_overloaded(item.second.ptr(), &overloaded_args);
is_tensor_list_and_append_overloaded(
item.second.ptr(),
&overloaded_args,
total_arg_num,
false /* throw_error */);
}
if (overloaded_args.size() > 0) {
std::vector<py::object> overloaded_types;
overloaded_types.reserve(overloaded_args.size());
for (auto& oarg : overloaded_args) {
overloaded_types.push_back(
py::reinterpret_borrow<py::object>((PyObject*)Py_TYPE(oarg.ptr())));
}
py::tuple py_types = py::cast(overloaded_types);
py::object ret;
std::string ns = symbol.ns().toUnqualString();
std::string method_name = symbol.toUnqualString();
auto self_func = py::module::import("torch")
.attr("ops")
.attr(ns.c_str())
.attr(method_name.c_str());
if (is_overload) {
auto overload_name = operations[0]->schema().overload_name();
if (overload_name == "") {
self_func = self_func.attr("default");
} else {
self_func.attr(overload_name.c_str());
}
}
std::string module_name("torch.ops");
module_name.append(ns);
return pybind11::reinterpret_steal<py::object>(
handle_torch_function_no_python_arg_parser(
overloaded_args,
args.ptr(),
kwargs.ptr(),
method_name.c_str(),
self_func.ptr(),
module_name.c_str()));
}
return invokeOperatorFromPython(operations, args, kwargs);
}
} // namespace jit
} // namespace torch

View File

@ -7,7 +7,7 @@ import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING
from torch._jit_internal import boolean_dispatched
from ._compatibility import compatibility
from torch._ops import OpOverloadPacket
from torch._ops import OpOverloadPacket, OpOverload
if TYPE_CHECKING:
from .node import Argument
@ -134,13 +134,12 @@ def get_signature_for_torch_op(op : Callable, return_schemas : bool = False):
return_schemas=True, returns a tuple containing the optional Python signatures
and the optional TorchScript Function signature
"""
if isinstance(op, OpOverloadPacket) or isinstance(op, OpOverload):
op = op.op
override = _manual_overrides.get(op)
if override:
return (override, None) if return_schemas else None
if isinstance(op, OpOverloadPacket):
op = op._op
aten_fn = torch.jit._builtins._find_builtin(op)
if aten_fn is None:

View File

@ -41,7 +41,7 @@ def check_metadata_consistency(wrapper_tensor):
check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor)
def is_view_fn(func):
return func.__name__ in {
return func.overloadpacket.__name__ in {
'as_strided',
'detach',
'diagonal',
@ -81,7 +81,7 @@ def is_view_fn(func):
# manually populated from native_functions that have inplace_view: True.
# In the future we will probably be able to grab that list directly
def is_inplace_view_fn(func):
return func.__name__ in {
return func.overloadpacket.__name__ in {
'as_strided_',
'detach_',
'squeeze_',
@ -122,7 +122,7 @@ class CompositeCompliantTensor(torch.Tensor):
def wrap(e):
return CompositeCompliantTensor(e) if isinstance(e, torch.Tensor) else e
if func.__name__ in ('set_', 'resize_'):
if func.overloadpacket.__name__ in ('set_', 'resize_'):
raise RuntimeError(
f"{func.__name__} is not allowed to be called inside of "
f"CompositeImplicitAutograd operators.")