mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Per-overload torch.ops API (#67254)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67254 Fixes https://github.com/pytorch/pytorch/issues/65997 TODO: disallow `default` as an overload name for aten operators. BC breaking: `output = torch.ops._test.leaky_relu(self=torch.tensor(-1.0))` now fails with the error `TypeError: __call__() got multiple values for argument 'self'` since we call into `OpOverloadBundle`'s `__call__` method that has `self` bound to it as its first argument. cc ezyang gchanan Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D33262228 Pulled By: anjali411 fbshipit-source-id: 600dbf511514ea9b41aea3e6b1bc1102dab08909
This commit is contained in:
committed by
Facebook GitHub Bot
parent
f9e1a1c97f
commit
8e6d1738a4
@ -1104,6 +1104,12 @@ std::shared_ptr<SugaredValue> toSugaredValue(
|
||||
}
|
||||
}
|
||||
|
||||
auto opoverloadpacket_type = py::module::import("torch").attr("_ops").attr("OpOverloadPacket");
|
||||
py::bool_ is_overloadpacket = py::isinstance(obj, opoverloadpacket_type);
|
||||
if (is_overloadpacket) {
|
||||
obj = py::getattr(obj, "op");
|
||||
}
|
||||
|
||||
bool isRpcAvailable = py::cast<bool>(
|
||||
py::module::import("torch.distributed.rpc").attr("is_available")());
|
||||
|
||||
|
Reference in New Issue
Block a user