mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[JIT] support parameterlist iteration
Followup to https://github.com/pytorch/pytorch/pull/75479. This adds support for iterating through parameterlists Pull Request resolved: https://github.com/pytorch/pytorch/pull/76140 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
272890998e
commit
82421b0fb8
@ -831,10 +831,14 @@ SugaredValuePtr ModuleValue::iter(const SourceRange& loc, GraphFunction& m) {
|
||||
<< "Only constant Sequential, ModueList, or ModuleDict can be used as an iterable";
|
||||
}
|
||||
|
||||
auto module_dict = getSugaredDict(loc, m);
|
||||
if (iterableModuleKind == IterableModuleKind::DICT) {
|
||||
auto module_dict = getSugaredDict(loc, m);
|
||||
return module_dict->keys_;
|
||||
} else if (iterableModuleKind == IterableModuleKind::LIST) {
|
||||
auto module_dict = getSugaredDict(loc, m);
|
||||
return module_dict->modules_;
|
||||
} else if (iterableModuleKind == IterableModuleKind::PARAMLIST) {
|
||||
auto module_dict = getSugaredNamedParameterList(loc, m);
|
||||
return module_dict->modules_;
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(false);
|
||||
|
Reference in New Issue
Block a user