sup torch script parameterlist

Fixes #61176

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75479
Approved by: https://github.com/davidberard98
This commit is contained in:
jishaomin
2022-04-20 20:53:07 +00:00
committed by PyTorch MergeBot
parent 81722f6630
commit 91e9fcf5b0
7 changed files with 61 additions and 3 deletions

View File

@ -326,6 +326,10 @@ SugaredValuePtr ModuleValue::getitem(
return getSugaredDict(loc, m)->getModules()->getitem(
loc, m, idx, type_hint);
}
} else if (
concreteType_->getIterableModuleKind() == IterableModuleKind::PARAMLIST) {
return getSugaredNamedParameterList(loc, m)->getModules()->getitem(
loc, m, idx, type_hint);
} else if (
concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) {
if (auto ivalue = toIValue(idx)) {
@ -443,6 +447,34 @@ std::shared_ptr<SugaredDict> ModuleValue::getSugaredNamedBufferDict(
std::make_shared<SugaredTupleValue>(values));
}
std::shared_ptr<SugaredDict> ModuleValue::getSugaredNamedParameterList(
const SourceRange& loc,
GraphFunction& m) {
std::vector<std::string> paramNames;
std::vector<SugaredValuePtr> values;
const auto& selfType = concreteType_->getJitType()->expect<ClassType>();
for (size_t i = 0; i < selfType->numAttributes(); ++i) {
if (selfType->is_parameter(i)) {
paramNames.push_back(selfType->getAttributeName(i));
}
}
std::vector<SugaredValuePtr> keys;
for (const auto& name : paramNames) {
auto name_v =
std::make_shared<SimpleValue>(insertConstant(*m.graph(), name));
m.graph()->insertGetAttr(self_, name);
values.push_back(tryGetAttr(loc, m, name));
keys.push_back(name_v);
}
return std::make_shared<SugaredDict>(
std::make_shared<ModuleValue>(self_, concreteType_),
std::make_shared<SugaredTupleValue>(keys),
std::make_shared<SugaredTupleValue>(values));
}
std::shared_ptr<SugaredDict> ModuleValue::getSugaredDict(
const SourceRange& loc,
GraphFunction& m) {