Update __torch_dispatch__ to return op overload instead of the opoverload packet function (#72673)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72673

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D34627164

Pulled By: anjali411

fbshipit-source-id: 3cb6406a392d530bf9da36b4d8e0a62b30e6497e
(cherry picked from commit 65b85a0a67df4d0f16ac8964e2b685d478a610fb)
This commit is contained in:
anjali411
2022-03-07 14:32:41 -08:00
committed by PyTorch MergeBot
parent a3d099ea18
commit 086645ad77
9 changed files with 142 additions and 120 deletions

View File

@ -7941,7 +7941,7 @@ class TestAutogradForwardMode(TestCase):
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func == torch.ops.aten.alias: if func.overloadpacket == torch.ops.aten.alias:
counter[0] += 1 counter[0] += 1
with no_dispatch(): with no_dispatch():

View File

@ -10,8 +10,8 @@ class TestPerOverloadAPI(TestCase):
add_packet = torch.ops.aten.add add_packet = torch.ops.aten.add
# class attributes # class attributes
self.assertEqual(add_packet.op_name, 'add') self.assertEqual(add_packet.__name__, 'add')
self.assertEqual(add_packet.qualified_op_name, 'aten.add') self.assertEqual(str(add_packet), 'aten.add')
# callable # callable
self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5)) self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
@ -27,7 +27,7 @@ class TestPerOverloadAPI(TestCase):
self.assertEqual(id(add_packet), id(copy.deepcopy(add_packet))) self.assertEqual(id(add_packet), id(copy.deepcopy(add_packet)))
# pretty print # pretty print
self.assertEqual(str(add_packet), "OpOverloadPacket(op='aten.add')") self.assertEqual(repr(add_packet), "<OpOverloadPacket(op='aten.add')>")
self.assertRaises(AttributeError, lambda: add_packet.foo) self.assertRaises(AttributeError, lambda: add_packet.foo)
@ -36,9 +36,9 @@ class TestPerOverloadAPI(TestCase):
add_tensoroverload = add_packet.Tensor add_tensoroverload = add_packet.Tensor
# class attributes # class attributes
self.assertEqual(add_tensoroverload.name, 'aten.add') self.assertEqual(str(add_tensoroverload), 'aten.add.Tensor')
self.assertEqual(add_tensoroverload.overload_name, 'Tensor') self.assertEqual(add_tensoroverload.__name__, 'add.Tensor')
self.assertEqual(add_tensoroverload.overload_packet, add_packet) self.assertEqual(add_tensoroverload.overloadpacket, add_packet)
# deepcopy is a no-op # deepcopy is a no-op
self.assertEqual(id(add_tensoroverload), id(copy.deepcopy(add_tensoroverload))) self.assertEqual(id(add_tensoroverload), id(copy.deepcopy(add_tensoroverload)))
@ -48,7 +48,7 @@ class TestPerOverloadAPI(TestCase):
self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload)) self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload))
# pretty print # pretty print
self.assertEqual(str(add_tensoroverload), "OpOverload(op='aten.add', overload='Tensor')") self.assertEqual(repr(add_tensoroverload), "<OpOverload(op='aten.add', overload='Tensor')>")
# callable # callable
self.assertEqual(add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5)) self.assertEqual(add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5))

View File

@ -31,11 +31,11 @@ class TestPythonDispatch(TestCase):
# self.assertEqual(saved_x._version, x._version) # self.assertEqual(saved_x._version, x._version)
self.assertExpectedInline('\n'.join(logs), '''\ self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x') $0 = input('x')
$1 = torch._ops.aten.mul($0, $0) $1 = torch._ops.aten.mul.Tensor($0, $0)
$2 = input('grad_y') $2 = input('grad_y')
$3 = torch._ops.aten.mul($2, $0) $3 = torch._ops.aten.mul.Tensor($2, $0)
$4 = torch._ops.aten.mul($2, $0) $4 = torch._ops.aten.mul.Tensor($2, $0)
$5 = torch._ops.aten.add($4, $3)''') $5 = torch._ops.aten.add.Tensor($4, $3)''')
def test_out(self) -> None: def test_out(self) -> None:
with capture_logs() as logs: with capture_logs() as logs:
@ -51,7 +51,7 @@ $5 = torch._ops.aten.add($4, $3)''')
self.assertExpectedInline('\n'.join(logs), '''\ self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x') $0 = input('x')
$1 = input('y') $1 = input('y')
$2 = torch._ops.aten.abs($0, out=$1)''') $2 = torch._ops.aten.abs.out($0, out=$1)''')
def test_kwarg_only(self) -> None: def test_kwarg_only(self) -> None:
@ -74,11 +74,11 @@ $2 = torch._ops.aten.abs($0, out=$1)''')
$0 = input('x') $0 = input('x')
$1 = input('y') $1 = input('y')
$2 = input('z') $2 = input('z')
$3 = torch._ops.aten.addmv($0, $1, $2) $3 = torch._ops.aten.addmv.default($0, $1, $2)
$4 = torch._ops.aten.addmv($0, $1, $2) $4 = torch._ops.aten.addmv.default($0, $1, $2)
$5 = torch._ops.aten.addmv($0, $1, $2, beta=2) $5 = torch._ops.aten.addmv.default($0, $1, $2, beta=2)
$6 = torch._ops.aten.addmv($0, $1, $2, alpha=2) $6 = torch._ops.aten.addmv.default($0, $1, $2, alpha=2)
$7 = torch._ops.aten.addmv($0, $1, $2, beta=2, alpha=2)''') $7 = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)''')
def test_kwarg_only_and_positional_default(self) -> None: def test_kwarg_only_and_positional_default(self) -> None:
with capture_logs() as logs: with capture_logs() as logs:
@ -96,10 +96,10 @@ $7 = torch._ops.aten.addmv($0, $1, $2, beta=2, alpha=2)''')
self.assertExpectedInline('\n'.join(logs), '''\ self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x') $0 = input('x')
$1 = input('y') $1 = input('y')
$2 = torch._ops.aten.kl_div($0, $1) $2 = torch._ops.aten.kl_div.default($0, $1)
$3 = torch._ops.aten.kl_div($0, $1, 2) $3 = torch._ops.aten.kl_div.default($0, $1, 2)
$4 = torch._ops.aten.kl_div($0, $1, log_target=True) $4 = torch._ops.aten.kl_div.default($0, $1, log_target=True)
$5 = torch._ops.aten.kl_div($0, $1, 2, log_target=True)''') $5 = torch._ops.aten.kl_div.default($0, $1, 2, log_target=True)''')
def test_list_ret(self) -> None: def test_list_ret(self) -> None:
# test all sequence types are permissible returns # test all sequence types are permissible returns
@ -111,7 +111,7 @@ $5 = torch._ops.aten.kl_div($0, $1, 2, log_target=True)''')
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func == torch.ops.aten.split: if func.overloadpacket == torch.ops.aten.split:
with no_dispatch(): with no_dispatch():
return list_type(torch.split(*args)) return list_type(torch.split(*args))
else: else:
@ -152,8 +152,8 @@ $5 = torch._ops.aten.kl_div($0, $1, 2, log_target=True)''')
# would be bad if calling .detach() once emits 3+ detaches). # would be bad if calling .detach() once emits 3+ detaches).
self.assertExpectedInline('\n'.join(logs), '''\ self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x') $0 = input('x')
$1 = torch._ops.aten.detach($0) $1 = torch._ops.aten.detach.default($0)
$2 = torch._ops.aten.detach($1)''') $2 = torch._ops.aten.detach.default($1)''')
def test_metadata_change_not_allowed(self) -> None: def test_metadata_change_not_allowed(self) -> None:
x = LoggingTensor(torch.ones(1)) x = LoggingTensor(torch.ones(1))
@ -264,11 +264,11 @@ $2 = torch._ops.aten.detach($1)''')
self.assertExpectedInline('\n'.join(logs), '''\ self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x') $0 = input('x')
$1 = input('x.grad') $1 = input('x.grad')
$2 = torch._ops.aten.pow($0, 2) $2 = torch._ops.aten.pow.Tensor_Scalar($0, 2)
$3 = input('grad_output') $3 = input('grad_output')
$4 = torch._ops.aten.mul($3, tensor(2)) $4 = torch._ops.aten.mul.Tensor($3, tensor(2))
$5 = torch._ops.aten.mul($4, $0) $5 = torch._ops.aten.mul.Tensor($4, $0)
$6 = torch._ops.aten.add_($1, $5)''') $6 = torch._ops.aten.add_.Tensor($1, $5)''')
def test_subclass_creation(self): def test_subclass_creation(self):
# Make sure these statements runs without error # Make sure these statements runs without error
@ -376,7 +376,7 @@ $6 = torch._ops.aten.add_($1, $5)''')
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func.__name__ == "clone": if func.overloadpacket.__name__ == "clone":
# Return a plain tensor from clone(). # Return a plain tensor from clone().
return args[0].elem.clone() return args[0].elem.clone()
raise RuntimeError("NYI") raise RuntimeError("NYI")
@ -444,7 +444,7 @@ $6 = torch._ops.aten.add_($1, $5)''')
idxs = (MyTensor(torch.tensor(0)),) idxs = (MyTensor(torch.tensor(0)),)
v = torch.randn(1) v = torch.randn(1)
res = x.index_put_(idxs, v) res = x.index_put_(idxs, v)
self.assertEqual(called_funcs, [torch.ops.aten.index_put_]) self.assertEqual(called_funcs, [torch.ops.aten.index_put_.default])
def test_enable_python_mode_error(self) -> None: def test_enable_python_mode_error(self) -> None:
with self.assertRaisesRegex(ValueError, "__torch_dispatch__"): with self.assertRaisesRegex(ValueError, "__torch_dispatch__"):
@ -594,7 +594,7 @@ $6 = torch._ops.aten.add_($1, $5)''')
# It prevents infinite recursion. # It prevents infinite recursion.
with no_dispatch(): with no_dispatch():
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
if func.__name__ == "add": if func.overloadpacket.__name__ == "add":
return None return None
else: else:
return rs return rs
@ -659,7 +659,7 @@ $6 = torch._ops.aten.add_($1, $5)''')
x = torch.randn(2) x = torch.randn(2)
y = torch.randn(2) y = torch.randn(2)
self.assertEqual(SubTensor(x) + SubTensor(y), x + y) self.assertEqual(SubTensor(x) + SubTensor(y), x + y)
self.assertEqual(called, [torch.ops.aten.add]) self.assertEqual(called, [torch.ops.aten.add.Tensor])
def test_dispatch_super_dont_autograd(self): def test_dispatch_super_dont_autograd(self):
called = [] called = []
@ -685,7 +685,7 @@ $6 = torch._ops.aten.add_($1, $5)''')
x = SubTensor(torch.randn(2, requires_grad=True)) x = SubTensor(torch.randn(2, requires_grad=True))
x.neg() x.neg()
self.assertEqual(called, [torch.ops.aten.neg]) self.assertEqual(called, [torch.ops.aten.neg.default])
def test_multiple_ops_subclass(self): def test_multiple_ops_subclass(self):
# This is a Direct Subclass, don't do that! # This is a Direct Subclass, don't do that!

View File

@ -32,13 +32,17 @@ class OpOverload:
self._op = op self._op = op
self._schema = schema self._schema = schema
self._overloadpacket = overloadpacket self._overloadpacket = overloadpacket
self._overloadname = 'default' if schema.overload_name == '' else schema.overload_name
self.__name__ = "{}.{}".format(self._schema.name.split("::")[1], self._overloadname)
self.__module__ = overloadpacket.__module__
op.__module__ = overloadpacket.__module__
# it's a no-op since OpOverload object is immutable and must be unique for a given op overload. # it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
def __deepcopy__(self, memo=None): def __deepcopy__(self, memo=None):
return self return self
def __str__(self): def __repr__(self):
return "OpOverload(op='{}.{}', overload='{}')".format(*self._schema.name.split("::"), self.overload_name) return "<OpOverload(op='{}.{}', overload='{}')>".format(*self._schema.name.split("::"), self._overloadname)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self._op(*args, **kwargs or {}) return self._op(*args, **kwargs or {})
@ -46,17 +50,15 @@ class OpOverload:
def __getattr__(self, key): def __getattr__(self, key):
return getattr(self._op, key) return getattr(self._op, key)
# `my_namespace::my_op` def __hash__(self):
@property return hash(self._op)
def name(self):
return "{}.{}".format(*self._schema.name.split("::")) # `my_namespace.my_op_name.overload_name`
def __str__(self):
return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)
@property @property
def overload_name(self): def overloadpacket(self):
return self._schema.overload_name
@property
def overload_packet(self):
return self._overloadpacket return self._overloadpacket
@property @property
@ -72,24 +74,22 @@ class OpOverloadPacket:
# These attributes are accessible on the object through the properties # These attributes are accessible on the object through the properties
# defined below but are immutable # defined below but are immutable
self._qualified_op_name = qualified_op_name self._qualified_op_name = qualified_op_name
self._op_name = op_name self.__name__ = op_name
self._op = op self._op = op
# it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op. # it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op.
def __deepcopy__(self, memo=None): def __deepcopy__(self, memo=None):
return self return self
def __repr__(self):
return "<OpOverloadPacket(op='{}.{}')>".format(*self._qualified_op_name.split("::"))
def __hash__(self):
return hash(self._op)
def __str__(self): def __str__(self):
return "OpOverloadPacket(op='{}.{}')".format(*self._qualified_op_name.split("::"))
@property
def qualified_op_name(self):
return "{}.{}".format(*self._qualified_op_name.split("::")) return "{}.{}".format(*self._qualified_op_name.split("::"))
@property
def op_name(self):
return self._op_name
@property @property
def op(self): def op(self):
return self._op return self._op

View File

@ -1684,6 +1684,7 @@ void concrete_dispatch_fn(
// Parse the name into namespace and name (no overload_name) // Parse the name into namespace and name (no overload_name)
// TODO: put this into the library // TODO: put this into the library
const auto& qualified_name = op.operator_name().name; const auto& qualified_name = op.operator_name().name;
const auto& overload_name = schema.overload_name();
auto pos = qualified_name.find("::"); auto pos = qualified_name.find("::");
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name); TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
// Make me some null terminated strings // Make me some null terminated strings
@ -1704,6 +1705,12 @@ void concrete_dispatch_fn(
// overload resolution but is more complicated (need to expose separate // overload resolution but is more complicated (need to expose separate
// functions per overload) // functions per overload)
py::handle torch_api_function = py::module::import("torch").attr("ops").attr(ns).attr(func_name); py::handle torch_api_function = py::module::import("torch").attr("ops").attr(ns).attr(func_name);
py::handle torch_api_function_overload;
if (overload_name == "") {
torch_api_function_overload = torch_api_function.attr("default");
} else {
torch_api_function_overload = torch_api_function.attr(overload_name.c_str());
}
std::string module_name_str = "torch.ops." + ns_str; std::string module_name_str = "torch.ops." + ns_str;
// About all the pointers: // About all the pointers:
@ -1795,7 +1802,7 @@ void concrete_dispatch_fn(
args.ptr(), args.ptr(),
kwargs.ptr(), kwargs.ptr(),
func_name, func_name,
torch_api_function.ptr(), torch_api_function_overload.ptr(),
module_name_str.c_str(), module_name_str.c_str(),
"__torch_dispatch__" "__torch_dispatch__"
)); ));
@ -1835,7 +1842,7 @@ c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter
args.ptr(), args.ptr(),
kwargs.ptr(), kwargs.ptr(),
"detach", "detach",
py::module::import("torch").attr("ops").attr("aten").attr("detach").ptr(), py::module::import("torch").attr("ops").attr("aten").attr("detach").attr("default").ptr(),
"torch.ops.aten", "torch.ops.aten",
"__torch_dispatch__" "__torch_dispatch__"
)); ));

View File

@ -1232,9 +1232,10 @@ void initJITBindings(PyObject* module) {
auto operations = getAllOperatorsFor(symbol); auto operations = getAllOperatorsFor(symbol);
for (const auto& op : operations) { for (const auto& op : operations) {
if (op->schema().overload_name() == overload_name) { if (op->schema().overload_name() == overload_name) {
auto func = auto func = py::cpp_function(
py::cpp_function([op](py::args args, py::kwargs kwargs) { [op, symbol](py::args args, py::kwargs kwargs) {
return invokeOperatorFromPython({op}, args, kwargs); return _get_operation_for_overload_or_packet(
{op}, symbol, args, kwargs, true);
}); });
return func; return func;
} }
@ -1265,60 +1266,8 @@ void initJITBindings(PyObject* module) {
auto func = py::cpp_function( auto func = py::cpp_function(
[operations, symbol](py::args args, py::kwargs kwargs) { [operations, symbol](py::args args, py::kwargs kwargs) {
std::vector<py::handle> overloaded_args; return _get_operation_for_overload_or_packet(
size_t total_arg_num = args.size() + kwargs.size(); operations, symbol, args, kwargs, false);
for (const auto i : c10::irange(args.size())) {
is_tensor_and_append_overloaded(
args[i].ptr(), &overloaded_args);
is_tensor_list_and_append_overloaded(
args[i].ptr(),
&overloaded_args,
static_cast<int>(total_arg_num),
false /* throw_error */);
}
// NB: for kwargs, we cannot guarantee the order of appending
// is the same as the argument order in operator's schema.
// This is suboptimal, but should be fine. Later when we have
// better schema matching and argument parsing, we could
// match the operator in `operations` first, then the order will
// be guaranteed.
for (auto item : kwargs) {
is_tensor_and_append_overloaded(
item.second.ptr(), &overloaded_args);
is_tensor_list_and_append_overloaded(
item.second.ptr(),
&overloaded_args,
total_arg_num,
false /* throw_error */);
}
if (overloaded_args.size() > 0) {
std::vector<py::object> overloaded_types;
overloaded_types.reserve(overloaded_args.size());
for (auto& oarg : overloaded_args) {
overloaded_types.push_back(
py::reinterpret_borrow<py::object>(
(PyObject*)Py_TYPE(oarg.ptr())));
}
py::tuple py_types = py::cast(overloaded_types);
py::object ret;
std::string ns = symbol.ns().toUnqualString();
std::string method_name = symbol.toUnqualString();
auto self_func = py::module::import("torch")
.attr("ops")
.attr(ns.c_str())
.attr(method_name.c_str());
std::string module_name("torch.ops");
module_name.append(ns);
return pybind11::reinterpret_steal<py::object>(
handle_torch_function_no_python_arg_parser(
overloaded_args,
args.ptr(),
kwargs.ptr(),
method_name.c_str(),
self_func.ptr(),
module_name.c_str()));
}
return invokeOperatorFromPython(operations, args, kwargs);
}, },
py::name(symbol.toUnqualString()), py::name(symbol.toUnqualString()),
py::doc(docstring.str().c_str())); py::doc(docstring.str().c_str()));

View File

@ -1159,5 +1159,72 @@ inline py::object invokeOperatorFromPython(
return createPyObjectForStack(std::move(stack)); return createPyObjectForStack(std::move(stack));
} }
inline py::object _get_operation_for_overload_or_packet(
const std::vector<std::shared_ptr<Operator>>& operations,
Symbol symbol,
py::args args,
const py::kwargs& kwargs,
bool is_overload) {
std::vector<py::handle> overloaded_args;
size_t total_arg_num = args.size() + kwargs.size();
for (const auto i : c10::irange(args.size())) {
is_tensor_and_append_overloaded(args[i].ptr(), &overloaded_args);
is_tensor_list_and_append_overloaded(
args[i].ptr(),
&overloaded_args,
static_cast<int>(total_arg_num),
false /* throw_error */);
}
// NB: for kwargs, we cannot guarantee the order of appending
// is the same as the argument order in operator's schema.
// This is suboptimal, but should be fine. Later when we have
// better schema matching and argument parsing, we could
// match the operator in `operations` first, then the order will
// be guaranteed.
for (auto item : kwargs) {
is_tensor_and_append_overloaded(item.second.ptr(), &overloaded_args);
is_tensor_list_and_append_overloaded(
item.second.ptr(),
&overloaded_args,
total_arg_num,
false /* throw_error */);
}
if (overloaded_args.size() > 0) {
std::vector<py::object> overloaded_types;
overloaded_types.reserve(overloaded_args.size());
for (auto& oarg : overloaded_args) {
overloaded_types.push_back(
py::reinterpret_borrow<py::object>((PyObject*)Py_TYPE(oarg.ptr())));
}
py::tuple py_types = py::cast(overloaded_types);
py::object ret;
std::string ns = symbol.ns().toUnqualString();
std::string method_name = symbol.toUnqualString();
auto self_func = py::module::import("torch")
.attr("ops")
.attr(ns.c_str())
.attr(method_name.c_str());
if (is_overload) {
auto overload_name = operations[0]->schema().overload_name();
if (overload_name == "") {
self_func = self_func.attr("default");
} else {
self_func.attr(overload_name.c_str());
}
}
std::string module_name("torch.ops");
module_name.append(ns);
return pybind11::reinterpret_steal<py::object>(
handle_torch_function_no_python_arg_parser(
overloaded_args,
args.ptr(),
kwargs.ptr(),
method_name.c_str(),
self_func.ptr(),
module_name.c_str()));
}
return invokeOperatorFromPython(operations, args, kwargs);
}
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -7,7 +7,7 @@ import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING
from torch._jit_internal import boolean_dispatched from torch._jit_internal import boolean_dispatched
from ._compatibility import compatibility from ._compatibility import compatibility
from torch._ops import OpOverloadPacket from torch._ops import OpOverloadPacket, OpOverload
if TYPE_CHECKING: if TYPE_CHECKING:
from .node import Argument from .node import Argument
@ -134,13 +134,12 @@ def get_signature_for_torch_op(op : Callable, return_schemas : bool = False):
return_schemas=True, returns a tuple containing the optional Python signatures return_schemas=True, returns a tuple containing the optional Python signatures
and the optional TorchScript Function signature and the optional TorchScript Function signature
""" """
if isinstance(op, OpOverloadPacket) or isinstance(op, OpOverload):
op = op.op
override = _manual_overrides.get(op) override = _manual_overrides.get(op)
if override: if override:
return (override, None) if return_schemas else None return (override, None) if return_schemas else None
if isinstance(op, OpOverloadPacket):
op = op._op
aten_fn = torch.jit._builtins._find_builtin(op) aten_fn = torch.jit._builtins._find_builtin(op)
if aten_fn is None: if aten_fn is None:

View File

@ -41,7 +41,7 @@ def check_metadata_consistency(wrapper_tensor):
check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor) check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor)
def is_view_fn(func): def is_view_fn(func):
return func.__name__ in { return func.overloadpacket.__name__ in {
'as_strided', 'as_strided',
'detach', 'detach',
'diagonal', 'diagonal',
@ -81,7 +81,7 @@ def is_view_fn(func):
# manually populated from native_functions that have inplace_view: True. # manually populated from native_functions that have inplace_view: True.
# In the future we will probably be able to grab that list directly # In the future we will probably be able to grab that list directly
def is_inplace_view_fn(func): def is_inplace_view_fn(func):
return func.__name__ in { return func.overloadpacket.__name__ in {
'as_strided_', 'as_strided_',
'detach_', 'detach_',
'squeeze_', 'squeeze_',
@ -122,7 +122,7 @@ class CompositeCompliantTensor(torch.Tensor):
def wrap(e): def wrap(e):
return CompositeCompliantTensor(e) if isinstance(e, torch.Tensor) else e return CompositeCompliantTensor(e) if isinstance(e, torch.Tensor) else e
if func.__name__ in ('set_', 'resize_'): if func.overloadpacket.__name__ in ('set_', 'resize_'):
raise RuntimeError( raise RuntimeError(
f"{func.__name__} is not allowed to be called inside of " f"{func.__name__} is not allowed to be called inside of "
f"CompositeImplicitAutograd operators.") f"CompositeImplicitAutograd operators.")