mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:30:26 +08:00
Implement JIT Enum type serialization and deserialization (#43460)
Summary: [Re-review tips: nothing changed other than a type in python_ir.cpp to fix a windows build failure] Adds code printing for enum type Enhance enum type to include all contained enum names and values Adds code parsing for enum type in deserialization Enabled serialization/deserialization test in most TestCases. (With a few dangling issues to be addressed in later PRs to avoid this PR grows too large) Pull Request resolved: https://github.com/pytorch/pytorch/pull/43460 Reviewed By: albanD Differential Revision: D23284929 Pulled By: gmagogsfm fbshipit-source-id: e3e81d6106f18b7337ac3ff5cd1eeaff854904f3
This commit is contained in:
committed by
Facebook GitHub Bot
parent
0fa99d50bc
commit
35a36c1280
@ -672,6 +672,8 @@ struct PythonPrintImpl {
|
||||
}
|
||||
} else if (const auto interfaceType = type->cast<InterfaceType>()) {
|
||||
registerDependency(interfaceType);
|
||||
} else if (const auto enumType = type->cast<EnumType>()) {
|
||||
registerDependency(enumType);
|
||||
}
|
||||
for (const auto& containedType : type->containedTypes()) {
|
||||
registerClassDependencies(containedType);
|
||||
@ -1413,6 +1415,22 @@ struct PythonPrintImpl {
|
||||
body_ << " pass\n";
|
||||
}
|
||||
}
|
||||
} else if (auto enumType = type->cast<EnumType>()) {
|
||||
body_ << "class " << enumType->qualifiedClassName().name() << "(Enum):\n";
|
||||
|
||||
std::string value_wrapper = "";
|
||||
if (enumType->getValueType() == StringType::get()) {
|
||||
value_wrapper = "\"";
|
||||
}
|
||||
|
||||
{
|
||||
auto guard = WithIndented();
|
||||
for (const auto& name_value : enumType->enumNamesValues()) {
|
||||
indent();
|
||||
body_ << name_value.first << " = " << value_wrapper
|
||||
<< name_value.second << value_wrapper << "\n";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(false, "Unhandled NamedType");
|
||||
}
|
||||
|
Reference in New Issue
Block a user