_get_operation_overload: dont raise exception when overload does not exist (#131554)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131554
Approved by: https://github.com/ezyang, https://github.com/zou3519
ghstack dependencies: #131403, #131482, #131665
This commit is contained in:
Brian Hirsh
2024-07-25 18:01:44 -07:00
committed by PyTorch MergeBot
parent eba2ffd278
commit 5612408735
2 changed files with 12 additions and 5 deletions

View File

@ -1665,7 +1665,8 @@ void initJITBindings(PyObject* module) {
m.def(
"_get_operation_overload",
[](const std::string& op_name, const std::string& overload_name) {
[](const std::string& op_name,
const std::string& overload_name) -> std::optional<py::tuple> {
try {
auto symbol = Symbol::fromQualString(op_name);
auto operations = getAllOperatorsFor(symbol);
@ -1688,11 +1689,11 @@ void initJITBindings(PyObject* module) {
return _get_operation_for_overload_or_packet(
{op}, symbol, args, kwargs, /*is_overload*/ true, dk);
});
return py::make_tuple(
func, func_dk, py::cast(op->getTags().vec()));
return std::make_optional(
py::make_tuple(func, func_dk, py::cast(op->getTags().vec())));
}
}
throw std::runtime_error("Found no matching operator overload");
return std::nullopt;
} catch (const c10::Error& e) {
auto msg = torch::get_cpp_stacktraces_enabled()
? e.what()