[JIT] Enable ModuleList non-literal indexing (#53410)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53410

**Summary**
This commit enables indexing into `ModuleList` using a non-literal
index if the LHS of the assignment statement of which the indexing is
the RHS is annotated with an interface type.

This feature already exists for `ModuleDict`, and this commit builds on
top of that implementation. A `prim::ModuleContainerIndex` operator is
emitted for any statement of the form `lhs: InterfaceType =
module_container[idx]`. The same operator has to be used for both
`ModuleDict` and `ModuleList` because serialization does not preserve
the metadata that indicates whether a `Module` is a `ModuleDict` or
`ModuleList`.

**Testing**
This commit extends the existing unit tests for non-literal `ModuleDict`
indexing to test non-literal `ModuleList` indexing.

**Fixes**
This commit fixes #47496.

Test Plan: Imported from OSS

Reviewed By: gmagogsfm

Differential Revision: D26857597

Pulled By: SplitInfinity

fbshipit-source-id: d56678700a264d79aae3de37ad6b08b080175f7c
This commit is contained in:
Meghan Lele
2021-03-09 16:07:39 -08:00
committed by Facebook GitHub Bot
parent 5dca8ff6de
commit 60ed8fb244
9 changed files with 163 additions and 38 deletions

View File

@ -258,14 +258,55 @@ SugaredValuePtr ModuleValue::asTupleValue(const SourceRange& loc, Function& m) {
<< "Only ModuleList or Sequential modules can be used as tuple";
}
bool ModuleValue::areAllSubmodulesSubtypeOf(
const TypePtr& ty,
std::ostream* why_not) const {
const auto& self_type = concreteType_->getJitType()->expect<ClassType>();
for (size_t i = 0; i < self_type->numAttributes(); ++i) {
const auto& attr_type = self_type->getAttribute(i);
if (attr_type->is_module()) {
std::stringstream ss;
if (!attr_type->isSubtypeOfExt(ty, &ss)) {
if (why_not) {
*why_not << "Attribute " << self_type->getAttributeName(i)
<< " is not of annotated type " << ty->annotation_str()
<< ": " << ss.str();
}
return false;
}
}
}
return true;
}
SugaredValuePtr ModuleValue::getitem(
const SourceRange& loc,
Function& m,
Value* idx,
TypePtr type_hint) {
if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) {
return getSugaredDict(loc, m)->getModules()->getitem(
loc, m, idx, type_hint);
if (type_hint) {
// Check that all submodules comply with the type hint.
std::stringstream ss;
if (!areAllSubmodulesSubtypeOf(type_hint, &ss)) {
throw ErrorReport(loc) << ss.str();
}
// Emit a prim::ModuleContainerIndex operator. This is needed because
// it's difficult to construct a list in the graph representing the
// ModuleList and use aten::__getitem__ ops to index into it because
// any call to ModuleList.setitem would invalidate that emitted list.
auto graph = m.graph();
auto* getitem_node = graph->insertNode(
graph->create(prim::ModuleContainerIndex, {self_, idx}));
getitem_node->output(0)->setType(type_hint);
return std::make_shared<SimpleValue>(getitem_node->output(0));
} else {
return getSugaredDict(loc, m)->getModules()->getitem(
loc, m, idx, type_hint);
}
} else if (
concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) {
if (auto ivalue = toIValue(idx)) {
@ -283,28 +324,18 @@ SugaredValuePtr ModuleValue::getitem(
throw ErrorReport(loc) << "Key Error, " << idx_str;
} else if (type_hint) {
// Check that all submodules comply with the type hint.
const auto& self_type = concreteType_->getJitType()->expect<ClassType>();
for (size_t i = 0; i < self_type->numAttributes(); ++i) {
const auto& attr_type = self_type->getAttribute(i);
if (attr_type->is_module()) {
std::stringstream ss;
if (!attr_type->isSubtypeOfExt(type_hint, &ss)) {
auto loc = self_->node()->sourceRange();
throw ErrorReport(loc)
<< "Attribute " << self_type->getAttributeName(i)
<< " is not of annotated type " << type_hint->annotation_str()
<< ": " << ss.str();
}
}
std::stringstream ss;
if (!areAllSubmodulesSubtypeOf(type_hint, &ss)) {
throw ErrorReport(loc) << ss.str();
}
// Emit a prim::ModuleDictIndex operator. This is needed because it's
// difficult to construct a dict in the graph representing the ModuleDict
// and use aten::__getitem__ ops to index into it because any call to
// ModuleDict.setAttr would invalidate that emitted dict.
// Emit a prim::ModuleContainerIndex operator. This is needed because
// it's difficult to construct a dict in the graph representing the
// ModuleDict and use aten::__getitem__ ops to index into it because
// any call to ModuleDict.setAttr would invalidate that emitted dict.
auto graph = m.graph();
auto* getitem_node =
graph->insertNode(graph->create(prim::ModuleDictIndex, {self_, idx}));
auto* getitem_node = graph->insertNode(
graph->create(prim::ModuleContainerIndex, {self_, idx}));
getitem_node->output(0)->setType(type_hint);
return std::make_shared<SimpleValue>(getitem_node->output(0));
}