mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 23:04:52 +08:00
[JIT] Enable ModuleList non-literal indexing (#53410)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53410 **Summary** This commit enables indexing into `ModuleList` using a non-literal index if the LHS of the assignment statement of which the indexing is the RHS is annotated with an interface type. This feature already exists for `ModuleDict`, and this commit builds on top of that implementation. A `prim::ModuleContainerIndex` operator is emitted for any statement of the form `lhs: InterfaceType = module_container[idx]`. The same operator has to be used for both `ModuleDict` and `ModuleList` because serialization does not preserve the metadata that indicates whether a `Module` is a `ModuleDict` or `ModuleList`. **Testing** This commit extends the existing unit tests for non-literal `ModuleDict` indexing to test non-literal `ModuleList` indexing. **Fixes** This commit fixes #47496. Test Plan: Imported from OSS Reviewed By: gmagogsfm Differential Revision: D26857597 Pulled By: SplitInfinity fbshipit-source-id: d56678700a264d79aae3de37ad6b08b080175f7c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5dca8ff6de
commit
60ed8fb244
@ -258,14 +258,55 @@ SugaredValuePtr ModuleValue::asTupleValue(const SourceRange& loc, Function& m) {
|
||||
<< "Only ModuleList or Sequential modules can be used as tuple";
|
||||
}
|
||||
|
||||
bool ModuleValue::areAllSubmodulesSubtypeOf(
|
||||
const TypePtr& ty,
|
||||
std::ostream* why_not) const {
|
||||
const auto& self_type = concreteType_->getJitType()->expect<ClassType>();
|
||||
for (size_t i = 0; i < self_type->numAttributes(); ++i) {
|
||||
const auto& attr_type = self_type->getAttribute(i);
|
||||
if (attr_type->is_module()) {
|
||||
std::stringstream ss;
|
||||
if (!attr_type->isSubtypeOfExt(ty, &ss)) {
|
||||
if (why_not) {
|
||||
*why_not << "Attribute " << self_type->getAttributeName(i)
|
||||
<< " is not of annotated type " << ty->annotation_str()
|
||||
<< ": " << ss.str();
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
SugaredValuePtr ModuleValue::getitem(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
Value* idx,
|
||||
TypePtr type_hint) {
|
||||
if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) {
|
||||
return getSugaredDict(loc, m)->getModules()->getitem(
|
||||
loc, m, idx, type_hint);
|
||||
if (type_hint) {
|
||||
// Check that all submodules comply with the type hint.
|
||||
std::stringstream ss;
|
||||
if (!areAllSubmodulesSubtypeOf(type_hint, &ss)) {
|
||||
throw ErrorReport(loc) << ss.str();
|
||||
}
|
||||
|
||||
// Emit a prim::ModuleContainerIndex operator. This is needed because
|
||||
// it's difficult to construct a list in the graph representing the
|
||||
// ModuleList and use aten::__getitem__ ops to index into it because
|
||||
// any call to ModuleList.setitem would invalidate that emitted list.
|
||||
auto graph = m.graph();
|
||||
auto* getitem_node = graph->insertNode(
|
||||
graph->create(prim::ModuleContainerIndex, {self_, idx}));
|
||||
getitem_node->output(0)->setType(type_hint);
|
||||
return std::make_shared<SimpleValue>(getitem_node->output(0));
|
||||
} else {
|
||||
return getSugaredDict(loc, m)->getModules()->getitem(
|
||||
loc, m, idx, type_hint);
|
||||
}
|
||||
} else if (
|
||||
concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) {
|
||||
if (auto ivalue = toIValue(idx)) {
|
||||
@ -283,28 +324,18 @@ SugaredValuePtr ModuleValue::getitem(
|
||||
throw ErrorReport(loc) << "Key Error, " << idx_str;
|
||||
} else if (type_hint) {
|
||||
// Check that all submodules comply with the type hint.
|
||||
const auto& self_type = concreteType_->getJitType()->expect<ClassType>();
|
||||
for (size_t i = 0; i < self_type->numAttributes(); ++i) {
|
||||
const auto& attr_type = self_type->getAttribute(i);
|
||||
if (attr_type->is_module()) {
|
||||
std::stringstream ss;
|
||||
if (!attr_type->isSubtypeOfExt(type_hint, &ss)) {
|
||||
auto loc = self_->node()->sourceRange();
|
||||
throw ErrorReport(loc)
|
||||
<< "Attribute " << self_type->getAttributeName(i)
|
||||
<< " is not of annotated type " << type_hint->annotation_str()
|
||||
<< ": " << ss.str();
|
||||
}
|
||||
}
|
||||
std::stringstream ss;
|
||||
if (!areAllSubmodulesSubtypeOf(type_hint, &ss)) {
|
||||
throw ErrorReport(loc) << ss.str();
|
||||
}
|
||||
|
||||
// Emit a prim::ModuleDictIndex operator. This is needed because it's
|
||||
// difficult to construct a dict in the graph representing the ModuleDict
|
||||
// and use aten::__getitem__ ops to index into it because any call to
|
||||
// ModuleDict.setAttr would invalidate that emitted dict.
|
||||
// Emit a prim::ModuleContainerIndex operator. This is needed because
|
||||
// it's difficult to construct a dict in the graph representing the
|
||||
// ModuleDict and use aten::__getitem__ ops to index into it because
|
||||
// any call to ModuleDict.setAttr would invalidate that emitted dict.
|
||||
auto graph = m.graph();
|
||||
auto* getitem_node =
|
||||
graph->insertNode(graph->create(prim::ModuleDictIndex, {self_, idx}));
|
||||
auto* getitem_node = graph->insertNode(
|
||||
graph->create(prim::ModuleContainerIndex, {self_, idx}));
|
||||
getitem_node->output(0)->setType(type_hint);
|
||||
return std::make_shared<SimpleValue>(getitem_node->output(0));
|
||||
}
|
||||
|
Reference in New Issue
Block a user