[JIT] support parameterlist iteration

Followup to https://github.com/pytorch/pytorch/pull/75479.

This adds support for iterating through parameterlists

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76140

Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
David Berard
2022-04-20 22:44:29 -07:00
committed by PyTorch MergeBot
parent 272890998e
commit 82421b0fb8
2 changed files with 21 additions and 2 deletions

View File

@ -831,10 +831,14 @@ SugaredValuePtr ModuleValue::iter(const SourceRange& loc, GraphFunction& m) {
<< "Only constant Sequential, ModueList, or ModuleDict can be used as an iterable";
}
auto module_dict = getSugaredDict(loc, m);
if (iterableModuleKind == IterableModuleKind::DICT) {
auto module_dict = getSugaredDict(loc, m);
return module_dict->keys_;
} else if (iterableModuleKind == IterableModuleKind::LIST) {
auto module_dict = getSugaredDict(loc, m);
return module_dict->modules_;
} else if (iterableModuleKind == IterableModuleKind::PARAMLIST) {
auto module_dict = getSugaredNamedParameterList(loc, m);
return module_dict->modules_;
} else {
TORCH_INTERNAL_ASSERT(false);