mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
fix __len__, __contains__, getitem inherited from interface class derived from nn container (closes #40603) (#40789)
Summary: Define static script implementation of __len__ and __contains__ on any subclass derived from a type such as ModuleList, Sequential, or ModuleDict. Implement getitem for classes derived from ModuleDict. Pull Request resolved: https://github.com/pytorch/pytorch/pull/40789 Reviewed By: eellison Differential Revision: D22325159 Pulled By: wconstab fbshipit-source-id: fc1562c29640fe800e13b5a1dd48e595c2c7239b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
8223858cc1
commit
8ecd4f36aa
@ -235,6 +235,24 @@ SugaredValuePtr ModuleValue::getitem(
|
||||
Value* idx) {
|
||||
if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) {
|
||||
return getSugaredDict(loc, m)->getModules()->getitem(loc, m, idx);
|
||||
} else if (
|
||||
concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) {
|
||||
if (auto ivalue = toIValue(idx)) {
|
||||
auto sd = getSugaredDict(loc, m);
|
||||
auto idx_str = ivalue->toStringRef();
|
||||
auto keys_iter = sd->keys_;
|
||||
auto module_values_iter = sd->modules_;
|
||||
for (size_t i = 0; i < keys_iter->tup_.size(); ++i) {
|
||||
auto key = keys_iter->tup_.at(i);
|
||||
auto key_str = toIValue(key->asValue(loc, m))->toStringRef();
|
||||
if (key_str == idx_str) {
|
||||
return module_values_iter->tup_.at(i);
|
||||
}
|
||||
}
|
||||
throw ErrorReport(loc) << "Key Error, " << idx_str;
|
||||
}
|
||||
throw ErrorReport(loc)
|
||||
<< "Unable to extract string literal index. ModuleDict indexing is only supported with string literals.";
|
||||
}
|
||||
throw ErrorReport(loc)
|
||||
<< "Only ModuleList, Sequential, and ModuleDict modules are subscriptable";
|
||||
|
Reference in New Issue
Block a user