[jit] Speed up saving in case of many classes (#44589)

Summary:
There's an annoying O(N^2) in module export logic that makes saving some of the models (if they have many classes) take eternity.

I'm not super familiar with this code to properly untangle the deps and make it a pure hash lookup. So I just added a side lookup table for raw pointers. It's still quadratic, but it's O(num_classes^2) instead of O(num_classes * num_references) which already gives huge savings.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44589

Test Plan:
Tested with one of the offending models - just loading a saving a Torchscript file:

```
Before:
load 1.9239683151245117
save 165.74712467193604

After:
load 1.9409027099609375
save 1.4711427688598633
```

Reviewed By: suo

Differential Revision: D23675278

Pulled By: dzhulgakov

fbshipit-source-id: 8f3fa7730941085ea20d9255b49a149ac1bf64fe
This commit is contained in:
Dmytro Dzhulgakov
2020-09-15 13:08:39 -07:00
committed by Facebook GitHub Bot
parent 285ba0d068
commit 2f4c31ce3a
5 changed files with 58 additions and 36 deletions

View File

@ -77,7 +77,7 @@ bool is_enabled(const char* cfname, JitLoggingLevels level) {
std::string log_function(const std::shared_ptr<torch::jit::Graph>& graph) {
torch::jit::GraphFunction func("source_dump", graph, nullptr);
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PrintDepsTable deps;
PythonPrint pp(constants, deps);
pp.printFunction(func);
return pp.str();

View File

@ -1014,7 +1014,7 @@ void initJitScriptBindings(PyObject* module) {
"code",
[](Module& self) {
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PrintDepsTable deps;
PythonPrint pp(constants, deps);
pp.printNamedType(self.type());
return pp.str();
@ -1023,7 +1023,7 @@ void initJitScriptBindings(PyObject* module) {
"code_with_constants",
[](Module& self) {
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PrintDepsTable deps;
PythonPrint pp(constants, deps);
pp.printNamedType(self.type());
std::map<std::string, at::IValue> consts;
@ -1177,7 +1177,7 @@ void initJitScriptBindings(PyObject* module) {
"code",
[](const StrongFunctionPtr& self) {
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PrintDepsTable deps;
PythonPrint pp(constants, deps);
pp.printFunction(*self.function_);
@ -1222,14 +1222,14 @@ void initJitScriptBindings(PyObject* module) {
"code",
[](Method& self) {
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PrintDepsTable deps;
PythonPrint pp(constants, deps);
pp.printMethod(self.function());
return pp.str();
})
.def_property_readonly("code_with_constants", [](Method& self) {
std::vector<at::IValue> constants;
std::vector<c10::NamedTypePtr> deps;
PrintDepsTable deps;
PythonPrint pp(constants, deps);
pp.printMethod(self.function());
std::map<std::string, at::IValue> consts;

View File

@ -369,7 +369,7 @@ class ScriptModuleSerializer {
}
void writeCode(const at::NamedTypePtr& root_type) {
class_deps_.push_back(root_type);
class_deps_.add(root_type);
for (size_t i = 0; i < class_deps_.size(); ++i) {
// note: convertNameType may extend class_deps_, so re-checking
// .size() is necessary
@ -459,7 +459,7 @@ class ScriptModuleSerializer {
caffe2::serialize::PyTorchStreamWriter writer_;
std::vector<at::IValue> constant_table_;
std::unordered_set<c10::NamedTypePtr> converted_types_;
std::vector<c10::NamedTypePtr> class_deps_;
PrintDepsTable class_deps_;
TypeNameUniquer type_name_uniquer_;
// qualifier, e.g. '__torch__.Bar' -> PythonPrint for the file that will be

View File

@ -87,6 +87,27 @@ const static std::unordered_set<std::string> reserved_names = {
"unchecked_cast",
};
// Helper to avoid duplicating class types
void PrintDepsTable::add(const c10::NamedTypePtr& type) {
// Despite doing the linear search below, we don't want to do
// wasteful work and only try to insert each instance once.
if (!non_unique_.insert(type).second) {
return;
}
// Need to do actual equality comparison, not a pointer equality. This is
// because for some types (e.g. FunctionType), we may have multiple
// TypePtr's that represent the same underlying thing.
// TODO: this should be really swapped for something more efficient
auto it = std::find_if(
table_.cbegin(), table_.cend(), [&](const c10::NamedTypePtr& dep) {
return *dep == *type;
});
if (it == table_.cend()) {
table_.push_back(type);
}
}
struct PythonPrintImpl {
using SourceRangeStack = std::vector<SourceRange>;
SourceRangeStack source_range_stack_ = {SourceRange()};
@ -169,21 +190,6 @@ struct PythonPrintImpl {
const SourceRangeStack* srs_;
};
// Helper to avoid duplicating class types
void registerDependency(const c10::NamedTypePtr& type) {
// Need to do actual equality comparison, not a pointer equality. This is
// because for some types (e.g. FunctionType), we may have multiple
// TypePtr's that represent the same underlying thing.
auto it = std::find_if(
deps_table_.cbegin(),
deps_table_.cend(),
[&](const c10::NamedTypePtr& dep) { return *dep == *type; });
if (it == deps_table_.cend()) {
deps_table_.push_back(type);
}
}
// scanValue, scanNode, scanBlock:
// decide if it is safe to omit the output of a temporary variable,
// and inline the expression into its use
@ -667,15 +673,15 @@ struct PythonPrintImpl {
// Recursively check contained types for any class dependencies
void registerClassDependencies(const TypePtr& type) {
if (const auto classType = type->cast<ClassType>()) {
registerDependency(classType);
deps_table_.add(classType);
} else if (const auto tupleType = type->cast<TupleType>()) {
if (tupleType->name()) {
registerDependency(tupleType);
deps_table_.add(tupleType);
}
} else if (const auto interfaceType = type->cast<InterfaceType>()) {
registerDependency(interfaceType);
deps_table_.add(interfaceType);
} else if (const auto enumType = type->cast<EnumType>()) {
registerDependency(enumType);
deps_table_.add(enumType);
}
for (const auto& containedType : type->containedTypes()) {
registerClassDependencies(containedType);
@ -925,7 +931,7 @@ struct PythonPrintImpl {
if (node->outputs().size() == 1 &&
node->output()->type()->kind() == TypeKind::FunctionType) {
auto fn = node->output()->type()->expect<FunctionType>();
registerDependency(fn);
deps_table_.add(fn);
stmt << fn->annotation_str(type_printer_);
} else if (!node->mustBeNone()) {
IValue v = toIValue(node->output()).value();
@ -1055,13 +1061,13 @@ struct PythonPrintImpl {
stmt << ")";
if (auto selfClass = self->type()->cast<ClassType>()) {
registerDependency(selfClass);
deps_table_.add(selfClass);
const Function& method = selfClass->getMethod(node->s(attr::name));
TORCH_INTERNAL_ASSERT(
method.qualname() ==
QualifiedName(selfClass->name()->qualifiedName(), methodName));
} else if (auto selfInterface = self->type()->cast<InterfaceType>()) {
registerDependency(selfInterface);
deps_table_.add(selfInterface);
} else {
TORCH_INTERNAL_ASSERT(
false, "method call to unhandled type in serialization");
@ -1260,13 +1266,13 @@ struct PythonPrintImpl {
PythonPrintImpl(
std::vector<at::IValue>& constant_table,
std::vector<c10::NamedTypePtr>& deps_table,
PrintDepsTable& deps_table,
c10::TypePrinter type_printer,
bool enforce_importable)
: body_(&source_range_stack_),
constant_table_(constant_table),
deps_table_(deps_table),
type_printer_(type_printer),
type_printer_(std::move(type_printer)),
enforce_importable_(enforce_importable) {}
void printClass(const ClassTypePtr& classType) {
@ -1461,7 +1467,7 @@ struct PythonPrintImpl {
// Any NamedTypes (classes, functions, NamedTuples) used are written to this
// table.
std::vector<c10::NamedTypePtr>& deps_table_;
PrintDepsTable& deps_table_;
// A function that, given a named type, returns us the correct string to print
// for it.
@ -1477,13 +1483,13 @@ struct PythonPrintImpl {
PythonPrint::PythonPrint(
std::vector<at::IValue>& constant_table,
std::vector<c10::NamedTypePtr>& deps_table,
PrintDepsTable& deps_table,
c10::TypePrinter type_printer,
bool enforce_importable)
: pImpl(std::make_shared<PythonPrintImpl>(
constant_table,
deps_table,
type_printer,
std::move(type_printer),
enforce_importable)) {}
void PythonPrint::printNamedType(const c10::NamedTypePtr& type) {

View File

@ -11,10 +11,26 @@ struct Method;
struct Module;
struct PythonPrintImpl;
struct PrintDepsTable {
void add(const c10::NamedTypePtr& type);
size_t size() const {
return table_.size();
}
const c10::NamedTypePtr& operator[](size_t index) const {
return table_[index];
}
private:
std::vector<c10::NamedTypePtr> table_;
std::unordered_set<c10::NamedTypePtr> non_unique_;
};
struct TORCH_API PythonPrint {
PythonPrint(
std::vector<IValue>& constant_table,
std::vector<c10::NamedTypePtr>& deps_table,
PrintDepsTable& deps_table,
c10::TypePrinter type_printer = nullptr,
bool enforce_importable = false);