Add standardOps match more input type in ORT (#53813) (#56172)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56172

Enable the standardOps include **Add\Sub\Mul\Div\Gemm\Pow\Mod**  with low precision input in ORT

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D27866136

Pulled By: SplitInfinity

fbshipit-source-id: f2cf5649fffefd68c0cc7b6dce94198751636727
This commit is contained in:
BowenBao
2021-04-21 17:54:49 -07:00
committed by Facebook GitHub Bot
parent 43ad172c54
commit 818ce1d0d2
7 changed files with 232 additions and 10 deletions

View File

@ -207,7 +207,17 @@ void initJITBindings(PyObject* module) {
return paramsDict;
},
pybind11::return_value_policy::move)
.def("_jit_pass_onnx_scalar_type_analysis", ScalarTypeAnalysisForONNX)
.def(
"_jit_pass_onnx_scalar_type_analysis",
[](std::shared_ptr<Graph>& graph,
bool lowprecision_cast,
int opset_version) {
return ScalarTypeAnalysisForONNX(
graph, lowprecision_cast, opset_version);
},
py::arg("graph"),
py::arg("lowprecision_cast") = true,
py::arg("opset_version"))
.def(
"_jit_pass_onnx_remove_inplace_ops_for_onnx", RemoveInplaceOpsForONNX)
.def(