mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[jit] Speed up saving in case of many classes (#44589)
Summary: There's an annoying O(N^2) in module export logic that makes saving some of the models (if they have many classes) take eternity. I'm not super familiar with this code to properly untangle the deps and make it a pure hash lookup. So I just added a side lookup table for raw pointers. It's still quadratic, but it's O(num_classes^2) instead of O(num_classes * num_references) which already gives huge savings. Pull Request resolved: https://github.com/pytorch/pytorch/pull/44589 Test Plan: Tested with one of the offending models - just loading a saving a Torchscript file: ``` Before: load 1.9239683151245117 save 165.74712467193604 After: load 1.9409027099609375 save 1.4711427688598633 ``` Reviewed By: suo Differential Revision: D23675278 Pulled By: dzhulgakov fbshipit-source-id: 8f3fa7730941085ea20d9255b49a149ac1bf64fe
This commit is contained in:
committed by
Facebook GitHub Bot
parent
285ba0d068
commit
2f4c31ce3a
@ -77,7 +77,7 @@ bool is_enabled(const char* cfname, JitLoggingLevels level) {
|
||||
std::string log_function(const std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
torch::jit::GraphFunction func("source_dump", graph, nullptr);
|
||||
std::vector<at::IValue> constants;
|
||||
std::vector<c10::NamedTypePtr> deps;
|
||||
PrintDepsTable deps;
|
||||
PythonPrint pp(constants, deps);
|
||||
pp.printFunction(func);
|
||||
return pp.str();
|
||||
|
@ -1014,7 +1014,7 @@ void initJitScriptBindings(PyObject* module) {
|
||||
"code",
|
||||
[](Module& self) {
|
||||
std::vector<at::IValue> constants;
|
||||
std::vector<c10::NamedTypePtr> deps;
|
||||
PrintDepsTable deps;
|
||||
PythonPrint pp(constants, deps);
|
||||
pp.printNamedType(self.type());
|
||||
return pp.str();
|
||||
@ -1023,7 +1023,7 @@ void initJitScriptBindings(PyObject* module) {
|
||||
"code_with_constants",
|
||||
[](Module& self) {
|
||||
std::vector<at::IValue> constants;
|
||||
std::vector<c10::NamedTypePtr> deps;
|
||||
PrintDepsTable deps;
|
||||
PythonPrint pp(constants, deps);
|
||||
pp.printNamedType(self.type());
|
||||
std::map<std::string, at::IValue> consts;
|
||||
@ -1177,7 +1177,7 @@ void initJitScriptBindings(PyObject* module) {
|
||||
"code",
|
||||
[](const StrongFunctionPtr& self) {
|
||||
std::vector<at::IValue> constants;
|
||||
std::vector<c10::NamedTypePtr> deps;
|
||||
PrintDepsTable deps;
|
||||
|
||||
PythonPrint pp(constants, deps);
|
||||
pp.printFunction(*self.function_);
|
||||
@ -1222,14 +1222,14 @@ void initJitScriptBindings(PyObject* module) {
|
||||
"code",
|
||||
[](Method& self) {
|
||||
std::vector<at::IValue> constants;
|
||||
std::vector<c10::NamedTypePtr> deps;
|
||||
PrintDepsTable deps;
|
||||
PythonPrint pp(constants, deps);
|
||||
pp.printMethod(self.function());
|
||||
return pp.str();
|
||||
})
|
||||
.def_property_readonly("code_with_constants", [](Method& self) {
|
||||
std::vector<at::IValue> constants;
|
||||
std::vector<c10::NamedTypePtr> deps;
|
||||
PrintDepsTable deps;
|
||||
PythonPrint pp(constants, deps);
|
||||
pp.printMethod(self.function());
|
||||
std::map<std::string, at::IValue> consts;
|
||||
|
@ -369,7 +369,7 @@ class ScriptModuleSerializer {
|
||||
}
|
||||
|
||||
void writeCode(const at::NamedTypePtr& root_type) {
|
||||
class_deps_.push_back(root_type);
|
||||
class_deps_.add(root_type);
|
||||
for (size_t i = 0; i < class_deps_.size(); ++i) {
|
||||
// note: convertNameType may extend class_deps_, so re-checking
|
||||
// .size() is necessary
|
||||
@ -459,7 +459,7 @@ class ScriptModuleSerializer {
|
||||
caffe2::serialize::PyTorchStreamWriter writer_;
|
||||
std::vector<at::IValue> constant_table_;
|
||||
std::unordered_set<c10::NamedTypePtr> converted_types_;
|
||||
std::vector<c10::NamedTypePtr> class_deps_;
|
||||
PrintDepsTable class_deps_;
|
||||
TypeNameUniquer type_name_uniquer_;
|
||||
|
||||
// qualifier, e.g. '__torch__.Bar' -> PythonPrint for the file that will be
|
||||
|
@ -87,6 +87,27 @@ const static std::unordered_set<std::string> reserved_names = {
|
||||
"unchecked_cast",
|
||||
};
|
||||
|
||||
// Helper to avoid duplicating class types
|
||||
void PrintDepsTable::add(const c10::NamedTypePtr& type) {
|
||||
// Despite doing the linear search below, we don't want to do
|
||||
// wasteful work and only try to insert each instance once.
|
||||
if (!non_unique_.insert(type).second) {
|
||||
return;
|
||||
}
|
||||
// Need to do actual equality comparison, not a pointer equality. This is
|
||||
// because for some types (e.g. FunctionType), we may have multiple
|
||||
// TypePtr's that represent the same underlying thing.
|
||||
// TODO: this should be really swapped for something more efficient
|
||||
auto it = std::find_if(
|
||||
table_.cbegin(), table_.cend(), [&](const c10::NamedTypePtr& dep) {
|
||||
return *dep == *type;
|
||||
});
|
||||
|
||||
if (it == table_.cend()) {
|
||||
table_.push_back(type);
|
||||
}
|
||||
}
|
||||
|
||||
struct PythonPrintImpl {
|
||||
using SourceRangeStack = std::vector<SourceRange>;
|
||||
SourceRangeStack source_range_stack_ = {SourceRange()};
|
||||
@ -169,21 +190,6 @@ struct PythonPrintImpl {
|
||||
const SourceRangeStack* srs_;
|
||||
};
|
||||
|
||||
// Helper to avoid duplicating class types
|
||||
void registerDependency(const c10::NamedTypePtr& type) {
|
||||
// Need to do actual equality comparison, not a pointer equality. This is
|
||||
// because for some types (e.g. FunctionType), we may have multiple
|
||||
// TypePtr's that represent the same underlying thing.
|
||||
auto it = std::find_if(
|
||||
deps_table_.cbegin(),
|
||||
deps_table_.cend(),
|
||||
[&](const c10::NamedTypePtr& dep) { return *dep == *type; });
|
||||
|
||||
if (it == deps_table_.cend()) {
|
||||
deps_table_.push_back(type);
|
||||
}
|
||||
}
|
||||
|
||||
// scanValue, scanNode, scanBlock:
|
||||
// decide if it is safe to omit the output of a temporary variable,
|
||||
// and inline the expression into its use
|
||||
@ -667,15 +673,15 @@ struct PythonPrintImpl {
|
||||
// Recursively check contained types for any class dependencies
|
||||
void registerClassDependencies(const TypePtr& type) {
|
||||
if (const auto classType = type->cast<ClassType>()) {
|
||||
registerDependency(classType);
|
||||
deps_table_.add(classType);
|
||||
} else if (const auto tupleType = type->cast<TupleType>()) {
|
||||
if (tupleType->name()) {
|
||||
registerDependency(tupleType);
|
||||
deps_table_.add(tupleType);
|
||||
}
|
||||
} else if (const auto interfaceType = type->cast<InterfaceType>()) {
|
||||
registerDependency(interfaceType);
|
||||
deps_table_.add(interfaceType);
|
||||
} else if (const auto enumType = type->cast<EnumType>()) {
|
||||
registerDependency(enumType);
|
||||
deps_table_.add(enumType);
|
||||
}
|
||||
for (const auto& containedType : type->containedTypes()) {
|
||||
registerClassDependencies(containedType);
|
||||
@ -925,7 +931,7 @@ struct PythonPrintImpl {
|
||||
if (node->outputs().size() == 1 &&
|
||||
node->output()->type()->kind() == TypeKind::FunctionType) {
|
||||
auto fn = node->output()->type()->expect<FunctionType>();
|
||||
registerDependency(fn);
|
||||
deps_table_.add(fn);
|
||||
stmt << fn->annotation_str(type_printer_);
|
||||
} else if (!node->mustBeNone()) {
|
||||
IValue v = toIValue(node->output()).value();
|
||||
@ -1055,13 +1061,13 @@ struct PythonPrintImpl {
|
||||
stmt << ")";
|
||||
|
||||
if (auto selfClass = self->type()->cast<ClassType>()) {
|
||||
registerDependency(selfClass);
|
||||
deps_table_.add(selfClass);
|
||||
const Function& method = selfClass->getMethod(node->s(attr::name));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
method.qualname() ==
|
||||
QualifiedName(selfClass->name()->qualifiedName(), methodName));
|
||||
} else if (auto selfInterface = self->type()->cast<InterfaceType>()) {
|
||||
registerDependency(selfInterface);
|
||||
deps_table_.add(selfInterface);
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "method call to unhandled type in serialization");
|
||||
@ -1260,13 +1266,13 @@ struct PythonPrintImpl {
|
||||
|
||||
PythonPrintImpl(
|
||||
std::vector<at::IValue>& constant_table,
|
||||
std::vector<c10::NamedTypePtr>& deps_table,
|
||||
PrintDepsTable& deps_table,
|
||||
c10::TypePrinter type_printer,
|
||||
bool enforce_importable)
|
||||
: body_(&source_range_stack_),
|
||||
constant_table_(constant_table),
|
||||
deps_table_(deps_table),
|
||||
type_printer_(type_printer),
|
||||
type_printer_(std::move(type_printer)),
|
||||
enforce_importable_(enforce_importable) {}
|
||||
|
||||
void printClass(const ClassTypePtr& classType) {
|
||||
@ -1461,7 +1467,7 @@ struct PythonPrintImpl {
|
||||
|
||||
// Any NamedTypes (classes, functions, NamedTuples) used are written to this
|
||||
// table.
|
||||
std::vector<c10::NamedTypePtr>& deps_table_;
|
||||
PrintDepsTable& deps_table_;
|
||||
|
||||
// A function that, given a named type, returns us the correct string to print
|
||||
// for it.
|
||||
@ -1477,13 +1483,13 @@ struct PythonPrintImpl {
|
||||
|
||||
PythonPrint::PythonPrint(
|
||||
std::vector<at::IValue>& constant_table,
|
||||
std::vector<c10::NamedTypePtr>& deps_table,
|
||||
PrintDepsTable& deps_table,
|
||||
c10::TypePrinter type_printer,
|
||||
bool enforce_importable)
|
||||
: pImpl(std::make_shared<PythonPrintImpl>(
|
||||
constant_table,
|
||||
deps_table,
|
||||
type_printer,
|
||||
std::move(type_printer),
|
||||
enforce_importable)) {}
|
||||
|
||||
void PythonPrint::printNamedType(const c10::NamedTypePtr& type) {
|
||||
|
@ -11,10 +11,26 @@ struct Method;
|
||||
struct Module;
|
||||
struct PythonPrintImpl;
|
||||
|
||||
struct PrintDepsTable {
|
||||
void add(const c10::NamedTypePtr& type);
|
||||
|
||||
size_t size() const {
|
||||
return table_.size();
|
||||
}
|
||||
|
||||
const c10::NamedTypePtr& operator[](size_t index) const {
|
||||
return table_[index];
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<c10::NamedTypePtr> table_;
|
||||
std::unordered_set<c10::NamedTypePtr> non_unique_;
|
||||
};
|
||||
|
||||
struct TORCH_API PythonPrint {
|
||||
PythonPrint(
|
||||
std::vector<IValue>& constant_table,
|
||||
std::vector<c10::NamedTypePtr>& deps_table,
|
||||
PrintDepsTable& deps_table,
|
||||
c10::TypePrinter type_printer = nullptr,
|
||||
bool enforce_importable = false);
|
||||
|
||||
|
Reference in New Issue
Block a user