Per-overload torch.ops API (#67254)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67254

Fixes https://github.com/pytorch/pytorch/issues/65997

TODO: disallow `default` as an overload name for aten operators.

BC breaking:
`output = torch.ops._test.leaky_relu(self=torch.tensor(-1.0))` now fails with the error `TypeError: __call__() got multiple values for argument 'self'` since we call into `OpOverloadBundle`'s `__call__` method that has `self` bound to it as its first argument.

cc ezyang gchanan

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D33262228

Pulled By: anjali411

fbshipit-source-id: 600dbf511514ea9b41aea3e6b1bc1102dab08909
This commit is contained in:
anjali411
2022-01-05 15:16:17 -08:00
committed by Facebook GitHub Bot
parent f9e1a1c97f
commit 8e6d1738a4
11 changed files with 245 additions and 25 deletions

View File

@ -1104,6 +1104,12 @@ std::shared_ptr<SugaredValue> toSugaredValue(
}
}
auto opoverloadpacket_type = py::module::import("torch").attr("_ops").attr("OpOverloadPacket");
py::bool_ is_overloadpacket = py::isinstance(obj, opoverloadpacket_type);
if (is_overloadpacket) {
obj = py::getattr(obj, "op");
}
bool isRpcAvailable = py::cast<bool>(
py::module::import("torch.distributed.rpc").attr("is_available")());