fix __len__, __contains__, getitem inherited from interface class derived from nn container (closes #40603) (#40789)

Summary:
Define static script implementation of __len__ and __contains__ on any subclass derived from a type such as ModuleList, Sequential, or ModuleDict.  Implement getitem for classes derived from ModuleDict.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40789

Reviewed By: eellison

Differential Revision: D22325159

Pulled By: wconstab

fbshipit-source-id: fc1562c29640fe800e13b5a1dd48e595c2c7239b
This commit is contained in:
Will Constable
2020-07-04 15:43:52 -07:00
committed by Facebook GitHub Bot
parent 8223858cc1
commit 8ecd4f36aa
4 changed files with 439 additions and 210 deletions

View File

@ -235,6 +235,24 @@ SugaredValuePtr ModuleValue::getitem(
Value* idx) {
if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) {
return getSugaredDict(loc, m)->getModules()->getitem(loc, m, idx);
} else if (
concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) {
if (auto ivalue = toIValue(idx)) {
auto sd = getSugaredDict(loc, m);
auto idx_str = ivalue->toStringRef();
auto keys_iter = sd->keys_;
auto module_values_iter = sd->modules_;
for (size_t i = 0; i < keys_iter->tup_.size(); ++i) {
auto key = keys_iter->tup_.at(i);
auto key_str = toIValue(key->asValue(loc, m))->toStringRef();
if (key_str == idx_str) {
return module_values_iter->tup_.at(i);
}
}
throw ErrorReport(loc) << "Key Error, " << idx_str;
}
throw ErrorReport(loc)
<< "Unable to extract string literal index. ModuleDict indexing is only supported with string literals.";
}
throw ErrorReport(loc)
<< "Only ModuleList, Sequential, and ModuleDict modules are subscriptable";