mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[jit] fix segfault in attribute lookup on loaded ScriptModules (#43284)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43284 The IR emitter looks for attributes on modules like: 1. Check the JIT type for the attribute 2. Check the originating Python class, in order to fulfill requests for, e.g. static methods or ignored methods. In the case where you do: ``` inner_module = torch.jit.load("inner.pt") wrapped = Wrapper(inner_module) # wrap the loaded ScriptModule in an nn.Module torch.jit.script(wrapped) ``` The IR emitter may check for attributes on `inner_module`. There is no originating Python class for `inner_module`, since it was directly compiled from the serialized format. Due to a bug in the code, we don't guard for this case an a segfault results if the wrapper asks for an undefined attribute. The lookup in this case looks like: 1. Check the JIT type for the attribute (not there!) 2. Check the originating Python class (this is a nullptr! segfault!) This PR guards this case and properly just raises an attribute missing compiler error instead of segfaulting. Test Plan: Imported from OSS Reviewed By: bertmaher Differential Revision: D23224337 Pulled By: suo fbshipit-source-id: 0cf3060c427f2253286f76f646765ec37b9c4c49
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e64879e180
commit
74f18476a2
@ -499,21 +499,26 @@ std::shared_ptr<SugaredValue> ModuleValue::tryGetAttr(
|
||||
// 5. Check if it's an attribute of the original Python class that this
|
||||
// ScriptModule was derived from. The only class attributes we handle are
|
||||
// methods.
|
||||
const auto maybePyClass = concreteType_->getPyClass();
|
||||
if (!maybePyClass) {
|
||||
// ConcreteType doesn't always have an originating Python class, e.g. if it
|
||||
// was derived from a serialized ScriptModule. In this case, we've exhausted
|
||||
// our options for attr lookup.
|
||||
return nullptr;
|
||||
}
|
||||
py::object unboundMethod = py::getattr(
|
||||
concreteType_->getPyClass(),
|
||||
field.c_str(),
|
||||
pybind11::cast<pybind11::none>(Py_None));
|
||||
*maybePyClass, field.c_str(), pybind11::cast<pybind11::none>(Py_None));
|
||||
|
||||
if (py::isinstance<py::function>(unboundMethod)) {
|
||||
bool isStaticFn = py::cast<bool>(
|
||||
py::module::import("torch._jit_internal")
|
||||
.attr("is_static_fn")(concreteType_->getPyClass(), field.c_str()));
|
||||
bool isStaticFn =
|
||||
py::cast<bool>(py::module::import("torch._jit_internal")
|
||||
.attr("is_static_fn")(*maybePyClass, field.c_str()));
|
||||
if (isStaticFn) {
|
||||
// Functions within the module annotated with @staticmethod do not need
|
||||
// binding.
|
||||
py::object staticFn = py::module::import("torch._jit_internal")
|
||||
.attr("get_static_fn")(
|
||||
concreteType_->getPyClass(), field.c_str());
|
||||
py::object staticFn =
|
||||
py::module::import("torch._jit_internal")
|
||||
.attr("get_static_fn")(*maybePyClass, field.c_str());
|
||||
return toSugaredValue(staticFn, m, loc);
|
||||
}
|
||||
// For Python methods that we're trying to call directly, we need to bind
|
||||
|
Reference in New Issue
Block a user