mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
sup torch script parameterlist
Fixes #61176 Pull Request resolved: https://github.com/pytorch/pytorch/pull/75479 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
81722f6630
commit
91e9fcf5b0
@ -326,6 +326,10 @@ SugaredValuePtr ModuleValue::getitem(
|
||||
return getSugaredDict(loc, m)->getModules()->getitem(
|
||||
loc, m, idx, type_hint);
|
||||
}
|
||||
} else if (
|
||||
concreteType_->getIterableModuleKind() == IterableModuleKind::PARAMLIST) {
|
||||
return getSugaredNamedParameterList(loc, m)->getModules()->getitem(
|
||||
loc, m, idx, type_hint);
|
||||
} else if (
|
||||
concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) {
|
||||
if (auto ivalue = toIValue(idx)) {
|
||||
@ -443,6 +447,34 @@ std::shared_ptr<SugaredDict> ModuleValue::getSugaredNamedBufferDict(
|
||||
std::make_shared<SugaredTupleValue>(values));
|
||||
}
|
||||
|
||||
std::shared_ptr<SugaredDict> ModuleValue::getSugaredNamedParameterList(
|
||||
const SourceRange& loc,
|
||||
GraphFunction& m) {
|
||||
std::vector<std::string> paramNames;
|
||||
std::vector<SugaredValuePtr> values;
|
||||
|
||||
const auto& selfType = concreteType_->getJitType()->expect<ClassType>();
|
||||
for (size_t i = 0; i < selfType->numAttributes(); ++i) {
|
||||
if (selfType->is_parameter(i)) {
|
||||
paramNames.push_back(selfType->getAttributeName(i));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<SugaredValuePtr> keys;
|
||||
for (const auto& name : paramNames) {
|
||||
auto name_v =
|
||||
std::make_shared<SimpleValue>(insertConstant(*m.graph(), name));
|
||||
m.graph()->insertGetAttr(self_, name);
|
||||
values.push_back(tryGetAttr(loc, m, name));
|
||||
keys.push_back(name_v);
|
||||
}
|
||||
|
||||
return std::make_shared<SugaredDict>(
|
||||
std::make_shared<ModuleValue>(self_, concreteType_),
|
||||
std::make_shared<SugaredTupleValue>(keys),
|
||||
std::make_shared<SugaredTupleValue>(values));
|
||||
}
|
||||
|
||||
std::shared_ptr<SugaredDict> ModuleValue::getSugaredDict(
|
||||
const SourceRange& loc,
|
||||
GraphFunction& m) {
|
||||
|
Reference in New Issue
Block a user