[quant][graphmode] Different rule for add/add_/mul/mul_ (#38667)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38667

Test Plan: Imported from OSS

Differential Revision: D21633555

fbshipit-source-id: 03b0298e83bf4dbda41b048c0edc7bb92cd4e1df
This commit is contained in:
Jerry Zhang
2020-05-20 19:42:02 -07:00
committed by Facebook GitHub Bot
parent 57d6e19d6f
commit a8d8fc5532
9 changed files with 819 additions and 367 deletions

View File

@ -462,13 +462,15 @@ std::shared_ptr<SugaredValue> ModuleValue::tryGetAttr(
pybind11::cast<pybind11::none>(Py_None));
if (py::isinstance<py::function>(unboundMethod)) {
bool isStaticFn =
py::cast<bool>(py::module::import("torch._jit_internal")
.attr("is_static_fn")(concreteType_->getPyClass(), field.c_str()));
bool isStaticFn = py::cast<bool>(
py::module::import("torch._jit_internal")
.attr("is_static_fn")(concreteType_->getPyClass(), field.c_str()));
if (isStaticFn) {
// Functions within the module annotated with @staticmethod do not need binding.
// Functions within the module annotated with @staticmethod do not need
// binding.
py::object staticFn = py::module::import("torch._jit_internal")
.attr("get_static_fn")(concreteType_->getPyClass(), field.c_str());
.attr("get_static_fn")(
concreteType_->getPyClass(), field.c_str());
return toSugaredValue(staticFn, m, loc);
}
// For Python methods that we're trying to call directly, we need to bind
@ -747,7 +749,7 @@ std::shared_ptr<SugaredValue> toSugaredValue(
} else if (
// RPC module is only avaialble when build flag "USE_DISTRIBUTED" is on.
obj.ptr() ==
py::module::import("torch.distributed.rpc").attr("rpc_async").ptr()) {
py::module::import("torch.distributed.rpc").attr("rpc_async").ptr()) {
return SpecialFormValue::create(prim::rpc_async);
#endif
} else if (auto callee = as_module(obj)) {