mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update __torch_dispatch__ to return op overload instead of the opoverload packet function (#72673)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72673 Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D34627164 Pulled By: anjali411 fbshipit-source-id: 3cb6406a392d530bf9da36b4d8e0a62b30e6497e (cherry picked from commit 65b85a0a67df4d0f16ac8964e2b685d478a610fb)
This commit is contained in:
committed by
PyTorch MergeBot
parent
a3d099ea18
commit
086645ad77
@ -7941,7 +7941,7 @@ class TestAutogradForwardMode(TestCase):
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
if func == torch.ops.aten.alias:
|
||||
if func.overloadpacket == torch.ops.aten.alias:
|
||||
counter[0] += 1
|
||||
|
||||
with no_dispatch():
|
||||
|
@ -10,8 +10,8 @@ class TestPerOverloadAPI(TestCase):
|
||||
add_packet = torch.ops.aten.add
|
||||
|
||||
# class attributes
|
||||
self.assertEqual(add_packet.op_name, 'add')
|
||||
self.assertEqual(add_packet.qualified_op_name, 'aten.add')
|
||||
self.assertEqual(add_packet.__name__, 'add')
|
||||
self.assertEqual(str(add_packet), 'aten.add')
|
||||
|
||||
# callable
|
||||
self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
|
||||
@ -27,7 +27,7 @@ class TestPerOverloadAPI(TestCase):
|
||||
self.assertEqual(id(add_packet), id(copy.deepcopy(add_packet)))
|
||||
|
||||
# pretty print
|
||||
self.assertEqual(str(add_packet), "OpOverloadPacket(op='aten.add')")
|
||||
self.assertEqual(repr(add_packet), "<OpOverloadPacket(op='aten.add')>")
|
||||
|
||||
self.assertRaises(AttributeError, lambda: add_packet.foo)
|
||||
|
||||
@ -36,9 +36,9 @@ class TestPerOverloadAPI(TestCase):
|
||||
add_tensoroverload = add_packet.Tensor
|
||||
|
||||
# class attributes
|
||||
self.assertEqual(add_tensoroverload.name, 'aten.add')
|
||||
self.assertEqual(add_tensoroverload.overload_name, 'Tensor')
|
||||
self.assertEqual(add_tensoroverload.overload_packet, add_packet)
|
||||
self.assertEqual(str(add_tensoroverload), 'aten.add.Tensor')
|
||||
self.assertEqual(add_tensoroverload.__name__, 'add.Tensor')
|
||||
self.assertEqual(add_tensoroverload.overloadpacket, add_packet)
|
||||
|
||||
# deepcopy is a no-op
|
||||
self.assertEqual(id(add_tensoroverload), id(copy.deepcopy(add_tensoroverload)))
|
||||
@ -48,7 +48,7 @@ class TestPerOverloadAPI(TestCase):
|
||||
self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload))
|
||||
|
||||
# pretty print
|
||||
self.assertEqual(str(add_tensoroverload), "OpOverload(op='aten.add', overload='Tensor')")
|
||||
self.assertEqual(repr(add_tensoroverload), "<OpOverload(op='aten.add', overload='Tensor')>")
|
||||
|
||||
# callable
|
||||
self.assertEqual(add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
|
||||
|
@ -31,11 +31,11 @@ class TestPythonDispatch(TestCase):
|
||||
# self.assertEqual(saved_x._version, x._version)
|
||||
self.assertExpectedInline('\n'.join(logs), '''\
|
||||
$0 = input('x')
|
||||
$1 = torch._ops.aten.mul($0, $0)
|
||||
$1 = torch._ops.aten.mul.Tensor($0, $0)
|
||||
$2 = input('grad_y')
|
||||
$3 = torch._ops.aten.mul($2, $0)
|
||||
$4 = torch._ops.aten.mul($2, $0)
|
||||
$5 = torch._ops.aten.add($4, $3)''')
|
||||
$3 = torch._ops.aten.mul.Tensor($2, $0)
|
||||
$4 = torch._ops.aten.mul.Tensor($2, $0)
|
||||
$5 = torch._ops.aten.add.Tensor($4, $3)''')
|
||||
|
||||
def test_out(self) -> None:
|
||||
with capture_logs() as logs:
|
||||
@ -51,7 +51,7 @@ $5 = torch._ops.aten.add($4, $3)''')
|
||||
self.assertExpectedInline('\n'.join(logs), '''\
|
||||
$0 = input('x')
|
||||
$1 = input('y')
|
||||
$2 = torch._ops.aten.abs($0, out=$1)''')
|
||||
$2 = torch._ops.aten.abs.out($0, out=$1)''')
|
||||
|
||||
|
||||
def test_kwarg_only(self) -> None:
|
||||
@ -74,11 +74,11 @@ $2 = torch._ops.aten.abs($0, out=$1)''')
|
||||
$0 = input('x')
|
||||
$1 = input('y')
|
||||
$2 = input('z')
|
||||
$3 = torch._ops.aten.addmv($0, $1, $2)
|
||||
$4 = torch._ops.aten.addmv($0, $1, $2)
|
||||
$5 = torch._ops.aten.addmv($0, $1, $2, beta=2)
|
||||
$6 = torch._ops.aten.addmv($0, $1, $2, alpha=2)
|
||||
$7 = torch._ops.aten.addmv($0, $1, $2, beta=2, alpha=2)''')
|
||||
$3 = torch._ops.aten.addmv.default($0, $1, $2)
|
||||
$4 = torch._ops.aten.addmv.default($0, $1, $2)
|
||||
$5 = torch._ops.aten.addmv.default($0, $1, $2, beta=2)
|
||||
$6 = torch._ops.aten.addmv.default($0, $1, $2, alpha=2)
|
||||
$7 = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)''')
|
||||
|
||||
def test_kwarg_only_and_positional_default(self) -> None:
|
||||
with capture_logs() as logs:
|
||||
@ -96,10 +96,10 @@ $7 = torch._ops.aten.addmv($0, $1, $2, beta=2, alpha=2)''')
|
||||
self.assertExpectedInline('\n'.join(logs), '''\
|
||||
$0 = input('x')
|
||||
$1 = input('y')
|
||||
$2 = torch._ops.aten.kl_div($0, $1)
|
||||
$3 = torch._ops.aten.kl_div($0, $1, 2)
|
||||
$4 = torch._ops.aten.kl_div($0, $1, log_target=True)
|
||||
$5 = torch._ops.aten.kl_div($0, $1, 2, log_target=True)''')
|
||||
$2 = torch._ops.aten.kl_div.default($0, $1)
|
||||
$3 = torch._ops.aten.kl_div.default($0, $1, 2)
|
||||
$4 = torch._ops.aten.kl_div.default($0, $1, log_target=True)
|
||||
$5 = torch._ops.aten.kl_div.default($0, $1, 2, log_target=True)''')
|
||||
|
||||
def test_list_ret(self) -> None:
|
||||
# test all sequence types are permissible returns
|
||||
@ -111,7 +111,7 @@ $5 = torch._ops.aten.kl_div($0, $1, 2, log_target=True)''')
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
if func == torch.ops.aten.split:
|
||||
if func.overloadpacket == torch.ops.aten.split:
|
||||
with no_dispatch():
|
||||
return list_type(torch.split(*args))
|
||||
else:
|
||||
@ -152,8 +152,8 @@ $5 = torch._ops.aten.kl_div($0, $1, 2, log_target=True)''')
|
||||
# would be bad if calling .detach() once emits 3+ detaches).
|
||||
self.assertExpectedInline('\n'.join(logs), '''\
|
||||
$0 = input('x')
|
||||
$1 = torch._ops.aten.detach($0)
|
||||
$2 = torch._ops.aten.detach($1)''')
|
||||
$1 = torch._ops.aten.detach.default($0)
|
||||
$2 = torch._ops.aten.detach.default($1)''')
|
||||
|
||||
def test_metadata_change_not_allowed(self) -> None:
|
||||
x = LoggingTensor(torch.ones(1))
|
||||
@ -264,11 +264,11 @@ $2 = torch._ops.aten.detach($1)''')
|
||||
self.assertExpectedInline('\n'.join(logs), '''\
|
||||
$0 = input('x')
|
||||
$1 = input('x.grad')
|
||||
$2 = torch._ops.aten.pow($0, 2)
|
||||
$2 = torch._ops.aten.pow.Tensor_Scalar($0, 2)
|
||||
$3 = input('grad_output')
|
||||
$4 = torch._ops.aten.mul($3, tensor(2))
|
||||
$5 = torch._ops.aten.mul($4, $0)
|
||||
$6 = torch._ops.aten.add_($1, $5)''')
|
||||
$4 = torch._ops.aten.mul.Tensor($3, tensor(2))
|
||||
$5 = torch._ops.aten.mul.Tensor($4, $0)
|
||||
$6 = torch._ops.aten.add_.Tensor($1, $5)''')
|
||||
|
||||
def test_subclass_creation(self):
|
||||
# Make sure these statements runs without error
|
||||
@ -376,7 +376,7 @@ $6 = torch._ops.aten.add_($1, $5)''')
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
if func.__name__ == "clone":
|
||||
if func.overloadpacket.__name__ == "clone":
|
||||
# Return a plain tensor from clone().
|
||||
return args[0].elem.clone()
|
||||
raise RuntimeError("NYI")
|
||||
@ -444,7 +444,7 @@ $6 = torch._ops.aten.add_($1, $5)''')
|
||||
idxs = (MyTensor(torch.tensor(0)),)
|
||||
v = torch.randn(1)
|
||||
res = x.index_put_(idxs, v)
|
||||
self.assertEqual(called_funcs, [torch.ops.aten.index_put_])
|
||||
self.assertEqual(called_funcs, [torch.ops.aten.index_put_.default])
|
||||
|
||||
def test_enable_python_mode_error(self) -> None:
|
||||
with self.assertRaisesRegex(ValueError, "__torch_dispatch__"):
|
||||
@ -594,7 +594,7 @@ $6 = torch._ops.aten.add_($1, $5)''')
|
||||
# It prevents infinite recursion.
|
||||
with no_dispatch():
|
||||
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
|
||||
if func.__name__ == "add":
|
||||
if func.overloadpacket.__name__ == "add":
|
||||
return None
|
||||
else:
|
||||
return rs
|
||||
@ -659,7 +659,7 @@ $6 = torch._ops.aten.add_($1, $5)''')
|
||||
x = torch.randn(2)
|
||||
y = torch.randn(2)
|
||||
self.assertEqual(SubTensor(x) + SubTensor(y), x + y)
|
||||
self.assertEqual(called, [torch.ops.aten.add])
|
||||
self.assertEqual(called, [torch.ops.aten.add.Tensor])
|
||||
|
||||
def test_dispatch_super_dont_autograd(self):
|
||||
called = []
|
||||
@ -685,7 +685,7 @@ $6 = torch._ops.aten.add_($1, $5)''')
|
||||
|
||||
x = SubTensor(torch.randn(2, requires_grad=True))
|
||||
x.neg()
|
||||
self.assertEqual(called, [torch.ops.aten.neg])
|
||||
self.assertEqual(called, [torch.ops.aten.neg.default])
|
||||
|
||||
def test_multiple_ops_subclass(self):
|
||||
# This is a Direct Subclass, don't do that!
|
||||
|
@ -32,13 +32,17 @@ class OpOverload:
|
||||
self._op = op
|
||||
self._schema = schema
|
||||
self._overloadpacket = overloadpacket
|
||||
self._overloadname = 'default' if schema.overload_name == '' else schema.overload_name
|
||||
self.__name__ = "{}.{}".format(self._schema.name.split("::")[1], self._overloadname)
|
||||
self.__module__ = overloadpacket.__module__
|
||||
op.__module__ = overloadpacket.__module__
|
||||
|
||||
# it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
|
||||
def __deepcopy__(self, memo=None):
|
||||
return self
|
||||
|
||||
def __str__(self):
|
||||
return "OpOverload(op='{}.{}', overload='{}')".format(*self._schema.name.split("::"), self.overload_name)
|
||||
def __repr__(self):
|
||||
return "<OpOverload(op='{}.{}', overload='{}')>".format(*self._schema.name.split("::"), self._overloadname)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self._op(*args, **kwargs or {})
|
||||
@ -46,17 +50,15 @@ class OpOverload:
|
||||
def __getattr__(self, key):
|
||||
return getattr(self._op, key)
|
||||
|
||||
# `my_namespace::my_op`
|
||||
@property
|
||||
def name(self):
|
||||
return "{}.{}".format(*self._schema.name.split("::"))
|
||||
def __hash__(self):
|
||||
return hash(self._op)
|
||||
|
||||
# `my_namespace.my_op_name.overload_name`
|
||||
def __str__(self):
|
||||
return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)
|
||||
|
||||
@property
|
||||
def overload_name(self):
|
||||
return self._schema.overload_name
|
||||
|
||||
@property
|
||||
def overload_packet(self):
|
||||
def overloadpacket(self):
|
||||
return self._overloadpacket
|
||||
|
||||
@property
|
||||
@ -72,24 +74,22 @@ class OpOverloadPacket:
|
||||
# These attributes are accessible on the object through the properties
|
||||
# defined below but are immutable
|
||||
self._qualified_op_name = qualified_op_name
|
||||
self._op_name = op_name
|
||||
self.__name__ = op_name
|
||||
self._op = op
|
||||
|
||||
# it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op.
|
||||
def __deepcopy__(self, memo=None):
|
||||
return self
|
||||
|
||||
def __repr__(self):
|
||||
return "<OpOverloadPacket(op='{}.{}')>".format(*self._qualified_op_name.split("::"))
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self._op)
|
||||
|
||||
def __str__(self):
|
||||
return "OpOverloadPacket(op='{}.{}')".format(*self._qualified_op_name.split("::"))
|
||||
|
||||
@property
|
||||
def qualified_op_name(self):
|
||||
return "{}.{}".format(*self._qualified_op_name.split("::"))
|
||||
|
||||
@property
|
||||
def op_name(self):
|
||||
return self._op_name
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
return self._op
|
||||
|
@ -1684,6 +1684,7 @@ void concrete_dispatch_fn(
|
||||
// Parse the name into namespace and name (no overload_name)
|
||||
// TODO: put this into the library
|
||||
const auto& qualified_name = op.operator_name().name;
|
||||
const auto& overload_name = schema.overload_name();
|
||||
auto pos = qualified_name.find("::");
|
||||
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
|
||||
// Make me some null terminated strings
|
||||
@ -1704,6 +1705,12 @@ void concrete_dispatch_fn(
|
||||
// overload resolution but is more complicated (need to expose separate
|
||||
// functions per overload)
|
||||
py::handle torch_api_function = py::module::import("torch").attr("ops").attr(ns).attr(func_name);
|
||||
py::handle torch_api_function_overload;
|
||||
if (overload_name == "") {
|
||||
torch_api_function_overload = torch_api_function.attr("default");
|
||||
} else {
|
||||
torch_api_function_overload = torch_api_function.attr(overload_name.c_str());
|
||||
}
|
||||
std::string module_name_str = "torch.ops." + ns_str;
|
||||
|
||||
// About all the pointers:
|
||||
@ -1795,7 +1802,7 @@ void concrete_dispatch_fn(
|
||||
args.ptr(),
|
||||
kwargs.ptr(),
|
||||
func_name,
|
||||
torch_api_function.ptr(),
|
||||
torch_api_function_overload.ptr(),
|
||||
module_name_str.c_str(),
|
||||
"__torch_dispatch__"
|
||||
));
|
||||
@ -1835,7 +1842,7 @@ c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter
|
||||
args.ptr(),
|
||||
kwargs.ptr(),
|
||||
"detach",
|
||||
py::module::import("torch").attr("ops").attr("aten").attr("detach").ptr(),
|
||||
py::module::import("torch").attr("ops").attr("aten").attr("detach").attr("default").ptr(),
|
||||
"torch.ops.aten",
|
||||
"__torch_dispatch__"
|
||||
));
|
||||
|
@ -1232,9 +1232,10 @@ void initJITBindings(PyObject* module) {
|
||||
auto operations = getAllOperatorsFor(symbol);
|
||||
for (const auto& op : operations) {
|
||||
if (op->schema().overload_name() == overload_name) {
|
||||
auto func =
|
||||
py::cpp_function([op](py::args args, py::kwargs kwargs) {
|
||||
return invokeOperatorFromPython({op}, args, kwargs);
|
||||
auto func = py::cpp_function(
|
||||
[op, symbol](py::args args, py::kwargs kwargs) {
|
||||
return _get_operation_for_overload_or_packet(
|
||||
{op}, symbol, args, kwargs, true);
|
||||
});
|
||||
return func;
|
||||
}
|
||||
@ -1265,60 +1266,8 @@ void initJITBindings(PyObject* module) {
|
||||
|
||||
auto func = py::cpp_function(
|
||||
[operations, symbol](py::args args, py::kwargs kwargs) {
|
||||
std::vector<py::handle> overloaded_args;
|
||||
size_t total_arg_num = args.size() + kwargs.size();
|
||||
for (const auto i : c10::irange(args.size())) {
|
||||
is_tensor_and_append_overloaded(
|
||||
args[i].ptr(), &overloaded_args);
|
||||
is_tensor_list_and_append_overloaded(
|
||||
args[i].ptr(),
|
||||
&overloaded_args,
|
||||
static_cast<int>(total_arg_num),
|
||||
false /* throw_error */);
|
||||
}
|
||||
// NB: for kwargs, we cannot guarantee the order of appending
|
||||
// is the same as the argument order in operator's schema.
|
||||
// This is suboptimal, but should be fine. Later when we have
|
||||
// better schema matching and argument parsing, we could
|
||||
// match the operator in `operations` first, then the order will
|
||||
// be guaranteed.
|
||||
for (auto item : kwargs) {
|
||||
is_tensor_and_append_overloaded(
|
||||
item.second.ptr(), &overloaded_args);
|
||||
is_tensor_list_and_append_overloaded(
|
||||
item.second.ptr(),
|
||||
&overloaded_args,
|
||||
total_arg_num,
|
||||
false /* throw_error */);
|
||||
}
|
||||
if (overloaded_args.size() > 0) {
|
||||
std::vector<py::object> overloaded_types;
|
||||
overloaded_types.reserve(overloaded_args.size());
|
||||
for (auto& oarg : overloaded_args) {
|
||||
overloaded_types.push_back(
|
||||
py::reinterpret_borrow<py::object>(
|
||||
(PyObject*)Py_TYPE(oarg.ptr())));
|
||||
}
|
||||
py::tuple py_types = py::cast(overloaded_types);
|
||||
py::object ret;
|
||||
std::string ns = symbol.ns().toUnqualString();
|
||||
std::string method_name = symbol.toUnqualString();
|
||||
auto self_func = py::module::import("torch")
|
||||
.attr("ops")
|
||||
.attr(ns.c_str())
|
||||
.attr(method_name.c_str());
|
||||
std::string module_name("torch.ops");
|
||||
module_name.append(ns);
|
||||
return pybind11::reinterpret_steal<py::object>(
|
||||
handle_torch_function_no_python_arg_parser(
|
||||
overloaded_args,
|
||||
args.ptr(),
|
||||
kwargs.ptr(),
|
||||
method_name.c_str(),
|
||||
self_func.ptr(),
|
||||
module_name.c_str()));
|
||||
}
|
||||
return invokeOperatorFromPython(operations, args, kwargs);
|
||||
return _get_operation_for_overload_or_packet(
|
||||
operations, symbol, args, kwargs, false);
|
||||
},
|
||||
py::name(symbol.toUnqualString()),
|
||||
py::doc(docstring.str().c_str()));
|
||||
|
@ -1159,5 +1159,72 @@ inline py::object invokeOperatorFromPython(
|
||||
return createPyObjectForStack(std::move(stack));
|
||||
}
|
||||
|
||||
inline py::object _get_operation_for_overload_or_packet(
|
||||
const std::vector<std::shared_ptr<Operator>>& operations,
|
||||
Symbol symbol,
|
||||
py::args args,
|
||||
const py::kwargs& kwargs,
|
||||
bool is_overload) {
|
||||
std::vector<py::handle> overloaded_args;
|
||||
size_t total_arg_num = args.size() + kwargs.size();
|
||||
for (const auto i : c10::irange(args.size())) {
|
||||
is_tensor_and_append_overloaded(args[i].ptr(), &overloaded_args);
|
||||
is_tensor_list_and_append_overloaded(
|
||||
args[i].ptr(),
|
||||
&overloaded_args,
|
||||
static_cast<int>(total_arg_num),
|
||||
false /* throw_error */);
|
||||
}
|
||||
// NB: for kwargs, we cannot guarantee the order of appending
|
||||
// is the same as the argument order in operator's schema.
|
||||
// This is suboptimal, but should be fine. Later when we have
|
||||
// better schema matching and argument parsing, we could
|
||||
// match the operator in `operations` first, then the order will
|
||||
// be guaranteed.
|
||||
for (auto item : kwargs) {
|
||||
is_tensor_and_append_overloaded(item.second.ptr(), &overloaded_args);
|
||||
is_tensor_list_and_append_overloaded(
|
||||
item.second.ptr(),
|
||||
&overloaded_args,
|
||||
total_arg_num,
|
||||
false /* throw_error */);
|
||||
}
|
||||
if (overloaded_args.size() > 0) {
|
||||
std::vector<py::object> overloaded_types;
|
||||
overloaded_types.reserve(overloaded_args.size());
|
||||
for (auto& oarg : overloaded_args) {
|
||||
overloaded_types.push_back(
|
||||
py::reinterpret_borrow<py::object>((PyObject*)Py_TYPE(oarg.ptr())));
|
||||
}
|
||||
py::tuple py_types = py::cast(overloaded_types);
|
||||
py::object ret;
|
||||
std::string ns = symbol.ns().toUnqualString();
|
||||
std::string method_name = symbol.toUnqualString();
|
||||
auto self_func = py::module::import("torch")
|
||||
.attr("ops")
|
||||
.attr(ns.c_str())
|
||||
.attr(method_name.c_str());
|
||||
if (is_overload) {
|
||||
auto overload_name = operations[0]->schema().overload_name();
|
||||
if (overload_name == "") {
|
||||
self_func = self_func.attr("default");
|
||||
} else {
|
||||
self_func.attr(overload_name.c_str());
|
||||
}
|
||||
}
|
||||
std::string module_name("torch.ops");
|
||||
module_name.append(ns);
|
||||
return pybind11::reinterpret_steal<py::object>(
|
||||
handle_torch_function_no_python_arg_parser(
|
||||
overloaded_args,
|
||||
args.ptr(),
|
||||
kwargs.ptr(),
|
||||
method_name.c_str(),
|
||||
self_func.ptr(),
|
||||
module_name.c_str()));
|
||||
}
|
||||
return invokeOperatorFromPython(operations, args, kwargs);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -7,7 +7,7 @@ import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING
|
||||
from torch._jit_internal import boolean_dispatched
|
||||
from ._compatibility import compatibility
|
||||
from torch._ops import OpOverloadPacket
|
||||
from torch._ops import OpOverloadPacket, OpOverload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .node import Argument
|
||||
@ -134,13 +134,12 @@ def get_signature_for_torch_op(op : Callable, return_schemas : bool = False):
|
||||
return_schemas=True, returns a tuple containing the optional Python signatures
|
||||
and the optional TorchScript Function signature
|
||||
"""
|
||||
if isinstance(op, OpOverloadPacket) or isinstance(op, OpOverload):
|
||||
op = op.op
|
||||
override = _manual_overrides.get(op)
|
||||
if override:
|
||||
return (override, None) if return_schemas else None
|
||||
|
||||
if isinstance(op, OpOverloadPacket):
|
||||
op = op._op
|
||||
|
||||
aten_fn = torch.jit._builtins._find_builtin(op)
|
||||
|
||||
if aten_fn is None:
|
||||
|
@ -41,7 +41,7 @@ def check_metadata_consistency(wrapper_tensor):
|
||||
check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor)
|
||||
|
||||
def is_view_fn(func):
|
||||
return func.__name__ in {
|
||||
return func.overloadpacket.__name__ in {
|
||||
'as_strided',
|
||||
'detach',
|
||||
'diagonal',
|
||||
@ -81,7 +81,7 @@ def is_view_fn(func):
|
||||
# manually populated from native_functions that have inplace_view: True.
|
||||
# In the future we will probably be able to grab that list directly
|
||||
def is_inplace_view_fn(func):
|
||||
return func.__name__ in {
|
||||
return func.overloadpacket.__name__ in {
|
||||
'as_strided_',
|
||||
'detach_',
|
||||
'squeeze_',
|
||||
@ -122,7 +122,7 @@ class CompositeCompliantTensor(torch.Tensor):
|
||||
def wrap(e):
|
||||
return CompositeCompliantTensor(e) if isinstance(e, torch.Tensor) else e
|
||||
|
||||
if func.__name__ in ('set_', 'resize_'):
|
||||
if func.overloadpacket.__name__ in ('set_', 'resize_'):
|
||||
raise RuntimeError(
|
||||
f"{func.__name__} is not allowed to be called inside of "
|
||||
f"CompositeImplicitAutograd operators.")
|
||||
|
Reference in New Issue
Block a user