#include #include #include #include #include #include #include #include namespace torch { namespace jit { struct OpsValue : public SugaredValue { OpsValue(size_t version) : version_(version) {} std::string kind() const override { return "ops"; } std::shared_ptr attr( const SourceRange& loc, GraphFunction& m, const std::string& field) override { return std::make_shared(field, version_); } size_t version_; }; // Represents nested namespaces, like `foo.bar.Baz`. // Right now these namespaces can only contain other namespaces or NamedTypes struct TORCH_API ClassNamespaceValue : public SugaredValue { /** * @param name The fully qualified path, which can resolve either to a * namespace or a NamedType * @param si The source importer that searches for and loads * classes/functions. */ explicit ClassNamespaceValue( c10::QualifiedName name, std::shared_ptr si) : basename_(std::move(name)), si_(std::move(si)) {} std::shared_ptr attr( const SourceRange& loc, GraphFunction& m, const std::string& name) override; std::string kind() const override { return "Class Namespace"; } private: c10::QualifiedName basename_; std::shared_ptr si_; }; // This value maps attributes CONSTANTS.c0 CONSTANTS.c1 to entries // in the 'constants' vector. This table is will be stored in a container format // and given to the import_method when restoring the code. struct ConstantTableValue : public SugaredValue { explicit ConstantTableValue(const std::vector* constants) : constants_(constants) {} std::string kind() const override { return "CONSTANTS"; } // select an attribute on it, e.g. `this.field` std::shared_ptr attr( const SourceRange& loc, GraphFunction& m, const std::string& field) override { const char* field_s = field.c_str(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) char* end; int64_t offset = strtoll(field_s + 1, &end, 10); if (field.size() < 2 || *end != 0) throw ErrorReport(loc) << "invalid constant specifier: " << field; if (offset < 0 || size_t(offset) >= constants_->size()) { throw ErrorReport(loc) << "constant index " << offset << " is out of bounds (constant table has " << constants_->size() << " entries)"; } auto ivalue = constants_->at(offset); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) Value* value; // see [Constant Object Weak CompilationUnit Reference] if (ivalue.isObject() && !ivalue.toObject()->is_weak_compilation_ref()) { auto obj = ivalue.toObject(); if (!non_holding_object_cache.count(obj)) { non_holding_object_cache[obj] = obj->copy_to_weak_compilation_ref(); } value = m.graph()->insertConstant(non_holding_object_cache[obj], loc); } else { value = m.graph()->insertConstant(constants_->at(offset), loc); } // specializing tensor type on compilation messes up typing relations value->setType(unshapedType(value->type())); return std::make_shared(value); } private: std::unordered_map< c10::intrusive_ptr, c10::intrusive_ptr> non_holding_object_cache; const std::vector* constants_; }; SourceImporterImpl::SourceImporterImpl( std::shared_ptr cu, const std::vector* constant_table, SourceLoader source_loader, size_t version) : cu_(std::move(cu)), source_loader_(std::move(source_loader)) { env_ = { {"torch", std::make_shared("aten", version)}, {"ops", std::make_shared(version)}, // Constants present in the model. Used to resolve "CONSTANTS.n" to the // actual value {"CONSTANTS", std::make_shared(constant_table)}, {"fork", SpecialFormValue::create(prim::fork)}, {"annotate", SpecialFormValue::create(prim::annotate)}, {"unchecked_cast", SpecialFormValue::create(prim::unchecked_cast)}, {"uninitialized", SpecialFormValue::create(prim::Uninitialized)}, }; } TypePtr SourceImporterImpl::findNamedType(const QualifiedName& name) { if (auto custom_class = getCustomClass(name.qualifiedName())) { return custom_class; } parseSourceIfNeeded(name.prefix()); auto it = to_be_defined_.find(name); if (it != to_be_defined_.end() && it->second->kind() == TK_CLASS_DEF) { ClassDef cd(std::move(it->second)); to_be_defined_.erase(it); importNamedType(name.prefix(), cd); } return cu_->get_type(name); } Function* SourceImporterImpl::findFunction(const QualifiedName& name) { parseSourceIfNeeded(name.prefix()); auto it = to_be_defined_.find(name); if (it != to_be_defined_.end() && it->second->kind() == TK_DEF) { Def d(it->second); to_be_defined_.erase(it); importFunction(name.prefix(), d); } return cu_->find_function(name); } void SourceImporterImpl::parseSourceIfNeeded(const std::string& qualifier) { // qualifier may be blank, for instance checking if __torch__ is a class. if (qualifier == "" || loaded_sources_.count(qualifier)) { return; } loaded_sources_.insert(qualifier); std::shared_ptr src = source_loader_(qualifier); // The importer, when looking for classes/functions doesn't know if 'foo' // contains definitions or if it is a prefix of 'foo.bar', we only figure it // out by testing if `foo.py` exists in the source loader. If it doesn't // then there is nothing to load here if (!src) { return; } Parser p(src); parsePossibleVersionNumber(p.lexer()); auto& L = p.lexer(); while (L.cur().kind != TK_EOF) { parseImports(L); auto tk = L.cur(); auto kind = tk.kind; switch (kind) { case TK_CLASS_DEF: { auto parsed_treeref = ClassDef(p.parseClass()); to_be_defined_[QualifiedName(qualifier, parsed_treeref.name().name())] = parsed_treeref; } break; case TK_DEF: { auto parsed_treeref = Def(p.parseFunction(/*is_method=*/false)); to_be_defined_[QualifiedName(qualifier, parsed_treeref.name().name())] = parsed_treeref; } break; default: throw ErrorReport(L.cur().range) << "Unexpected token in code import: " << kindToString(kind); } } } void SourceImporterImpl::LEGACY_import_methods( const Module& mod, const std::shared_ptr& src) { auto self = SimpleSelf(mod.type()); c10::QualifiedName prefix = *mod.type()->name(); Parser p(src); parsePossibleVersionNumber(p.lexer()); parseImports(p.lexer()); std::vector definitions; std::vector resolvers; while (p.lexer().cur().kind != TK_EOF) { auto def = Def(p.parseFunction(/*is_method=*/true)); definitions.emplace_back(def); resolvers.emplace_back(shared_from_this()); } cu_->define( prefix, /*properties=*/{}, /*propResolvers=*/{}, definitions, resolvers, &self); } std::shared_ptr SourceImporterImpl::resolveValue( const std::string& name, GraphFunction& m, const SourceRange& loc) { auto it = env_.find(name); if (it != env_.end()) { return it->second; } auto graph = m.graph(); if (name == "inf") { return std::make_shared( graph->insertConstant(std::numeric_limits::infinity(), loc)); } if (name == "nan") { return std::make_shared( graph->insertConstant(std::numeric_limits::quiet_NaN(), loc)); } if (name == "infj") { return std::make_shared(graph->insertConstant( c10::complex(0, std::numeric_limits::infinity()), loc)); } if (name == "nanj") { return std::make_shared(graph->insertConstant( c10::complex(0, std::numeric_limits::quiet_NaN()), loc)); } if (name == "__torch__") { return std::make_shared( c10::QualifiedName(name), shared_from_this()); } return nullptr; } TypePtr SourceImporterImpl::resolveType( const std::string& name, const SourceRange& loc) { return findNamedType(QualifiedName(name)); } void SourceImporterImpl::importFunction( const std::string& qualifier, const Def& def) { std::vector definitions{def}; std::vector resolvers{shared_from_this()}; cu_->define( qualifier, /*properties=*/{}, /*propResolvers=*/{}, definitions, resolvers, nullptr); } void SourceImporterImpl::importNamedType( const std::string& qualifier, const ClassDef& class_def) { const auto qualified_name = QualifiedName(QualifiedName(qualifier), class_def.name().name()); if (!class_def.superclass().present()) { return importClass(qualified_name, class_def, /*is_module=*/false); } const auto& superclass_name = Var(class_def.superclass().get()).name().name(); if (superclass_name == "Module") { importClass(qualified_name, class_def, /*is_module=*/true); } else if (superclass_name == "NamedTuple") { // NamedTuples have special rules (since they are TupleTypes and not // ClassTypes) return importNamedTuple(qualified_name, class_def); } else if (superclass_name == "Interface") { cu_->define_interface( qualified_name, class_def, shared_from_this(), /*is_module=*/false); } else if (superclass_name == "ModuleInterface") { cu_->define_interface( qualified_name, class_def, shared_from_this(), /*is_module=*/true); } else if (superclass_name == "Enum") { importEnum(qualified_name, class_def); } else { throw ErrorReport(class_def.range()) << "Torchscript does not support class inheritance."; } } c10::optional SourceImporterImpl:: attributeAssignmentSpecialHandlingHack( const QualifiedName& qualified_classname, const Assign& assign) { struct AttrTypeReplacementDescr { std::string attr_name; std::string expected_type; std::string replacement_type; }; // module demangled qualname -> ReplacementDescr static std::unordered_map replacements{ {"__torch__.torch.nn.quantized.modules.linear.LinearPackedParams", {"_packed_params", "Tensor", "__torch__.torch.classes.quantized.LinearPackedParamsBase"}}, {"__torch__.torch.nn.quantized.modules.linear.Linear", {"_packed_params", "Tensor", "__torch__.torch.classes.quantized.LinearPackedParamsBase"}}, {"__torch__.torch.nn.quantized.dynamic.modules.linear.Linear", {"_packed_params", "Tensor", "__torch__.torch.classes.quantized.LinearPackedParamsBase"}}, {"__torch__.torch.nn.quantized.modules.conv.Conv2d", {"_packed_params", "Tensor", "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"}}, {"__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d", {"_packed_params", "Tensor", "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"}}, {"__torch__.torch.nn.quantized.modules.conv.Conv3d", {"_packed_params", "Tensor", "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"}}, {"__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU3d", {"_packed_params", "Tensor", "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"}}}; // @lint-ignore-every CLANGTIDY facebook-hte-StdRegexIsAwful static std::regex mangle_re("\\.___torch_mangle_\\d+"); auto demangled_classname = std::regex_replace(qualified_classname.qualifiedName(), mangle_re, ""); if (replacements.count(demangled_classname)) { auto lhs = Var(assign.lhs()); if (!assign.type().present() || assign.type().get().kind() != TK_VAR) { return c10::nullopt; } auto type = Var(assign.type().get()); auto& attr_name = replacements.at(demangled_classname).attr_name; auto& expected_type = replacements.at(demangled_classname).expected_type; auto& replacement_type = replacements.at(demangled_classname).replacement_type; if (lhs.name().name() == attr_name && type.name().name() == expected_type) { Parser p(std::make_shared(replacement_type)); auto typename_expr = p.parseExp(); auto maybe_typename = Maybe::create(typename_expr.range(), typename_expr); return Assign::create( assign.range(), assign.lhs_list(), assign.rhs(), maybe_typename); } } return c10::nullopt; } void SourceImporterImpl::importClass( const QualifiedName& qualified_classname, const ClassDef& class_def, bool is_module) { // BC for TorchBind classes // // Previously we would serialize TorchBind classes as actual // classes with methods that delegate to things in the // torch.ops.* namespace. We've switched away from this and // now just rely on those classes being present in the binary // and emit code for them based on the ClassType in memory. // // TODO: remove this once we no longer have old TorchBind code // in production models { static QualifiedName torch_classes_qualname("__torch__.torch.classes"); if (torch_classes_qualname.isPrefixOf(qualified_classname)) { return; } } auto class_type = ClassType::create( c10::QualifiedName(qualified_classname), cu_, is_module); std::vector methods; std::vector method_resolvers; std::map pre_hook_def_map; std::map hook_def_map; std::map pre_hook_resolver_map; std::map hook_resolver_map; std::vector attributes; std::vector constants; // Module-specific: which attrs are parameters? std::unordered_set parameter_names; std::unordered_set buffer_names; std::unordered_set pre_hook_names; std::unordered_set hook_names; // used to keep track of original ordering of hooks and prehooks // in case any are called more than once std::vector pre_hooks_order; std::vector hooks_order; // Process statements, splitting things into attribute and method // definitions. for (const auto& statement : class_def.body()) { switch (statement.kind()) { case TK_ASSIGN: { const auto assign = Assign(statement); switch (assign.lhs().kind()) { case TK_VAR: { const auto name = Var(assign.lhs()).name().name(); if (name == "__parameters__") { // Populate the module parameter list. This is a field that // looks like: // __parameters__ = ["foo", "bar", "baz"] // which tells us which attributes are module parameters. TORCH_INTERNAL_ASSERT( is_module, "Assignments in class body only " "supported on modules right now"); const auto param_list = ListLiteral(assign.rhs().get()).inputs(); for (const auto& param : param_list) { parameter_names.insert(StringLiteral(param).text()); } } else if (name == "__annotations__") { // This is to initialize the annotations dict, just ignore. continue; } else if (name == "__buffers__") { TORCH_INTERNAL_ASSERT( is_module, "Buffers only exist on modules at the moment"); const auto buffer_list = ListLiteral(assign.rhs().get()).inputs(); for (const auto& buffer : buffer_list) { buffer_names.insert(StringLiteral(buffer).text()); } } else if (name == "__forward_pre_hooks__") { TORCH_INTERNAL_ASSERT( is_module, "Forward pre hooks only exist on modules at the moment"); const auto pre_hook_list = ListLiteral(assign.rhs().get()).inputs(); for (const auto& pre_hook : pre_hook_list) { std::string pre_hook_name = StringLiteral(pre_hook).text(); pre_hook_names.insert(pre_hook_name); pre_hooks_order.emplace_back(pre_hook_name); } } else if (name == "__forward_hooks__") { TORCH_INTERNAL_ASSERT( is_module, "Forward hooks only exist on modules at the moment"); const auto hook_list = ListLiteral(assign.rhs().get()).inputs(); for (const auto& hook : hook_list) { std::string hook_name = StringLiteral(hook).text(); hook_names.insert(hook_name); hooks_order.emplace_back(hook_name); } } else { if (auto fixed_up = attributeAssignmentSpecialHandlingHack( qualified_classname, assign)) { attributes.push_back(std::move(*fixed_up)); } else if (assign.rhs().present()) { // This is a constant assignment, of the form: // foo : Final[int] = 3 constants.push_back(assign); } else { // This is a regular attribute assignment, of the form: // foo : Tensor attributes.push_back(assign); } } } break; case TK_SUBSCRIPT: { // This is a special attribute assignment where the attribute // is not a valid python, identifier. Looks like: // __annotations__["0"] = Tensor const auto lhs = Subscript(assign.lhs()); TORCH_INTERNAL_ASSERT( Var(lhs.value()).name().name() == "__annotations__"); TORCH_INTERNAL_ASSERT(lhs.subscript_exprs().size() == 1); attributes.push_back(assign); } break; default: { TORCH_INTERNAL_ASSERT( false, "Unexpected statement kind in module metadata: ", kindToString(statement.kind())); } } } break; case TK_DEF: { Def def = Def(statement); const auto def_name = def.name().name(); if (pre_hook_names.find(def_name) != pre_hook_names.end()) { pre_hook_def_map.emplace(def_name, def); pre_hook_resolver_map.emplace(def_name, shared_from_this()); } else if (hook_names.find(def_name) != hook_names.end()) { hook_def_map.emplace(def_name, def); hook_resolver_map.emplace(def_name, shared_from_this()); } else { methods.emplace_back(def); method_resolvers.push_back(shared_from_this()); } } break; default: { TORCH_INTERNAL_ASSERT( false, "Unexpected statement kind in class body: ", kindToString(statement.kind())); } } } // Populate class attributes ScriptTypeParser type_parser(shared_from_this()); for (const auto& assign : attributes) { switch (assign.lhs().kind()) { case TK_VAR: { const auto name = Var(assign.lhs()).name().name(); TORCH_INTERNAL_ASSERT(name != "__parameters__"); const auto type = type_parser.parseTypeFromExpr(assign.type().get()); const bool is_parameter = parameter_names.count(name); const bool is_buffer = buffer_names.count(name); class_type->addAttribute(name, type, is_parameter, is_buffer); } break; case TK_SUBSCRIPT: { const auto name = StringLiteral(Subscript(assign.lhs()).subscript_exprs()[0]).text(); const auto type = type_parser.parseTypeFromExpr(assign.rhs().get()); const bool is_parameter = parameter_names.count(name); const bool is_buffer = buffer_names.count(name); class_type->addAttribute(name, type, is_parameter, is_buffer); } } } // Populate class constants for (const auto& assign : constants) { auto const_val = type_parser.parseClassConstant(assign); const auto name = Var(assign.lhs()).name().name(); class_type->addConstant(name, const_val); } // build pre hook and hook def/resolver pairs // pairs are dedupped in ir_emitter.cpp's CompilationUnit::define_hooks() // ordering here is call order for hooks std::vector hooks; std::vector hook_resolvers; for (const std::string& hook_name : hooks_order) { hooks.emplace_back(hook_def_map.find(hook_name)->second); hook_resolvers.push_back(hook_resolver_map.find(hook_name)->second); } std::vector pre_hooks; std::vector pre_hook_resolvers; for (const std::string& pre_hook_name : pre_hooks_order) { pre_hooks.emplace_back(pre_hook_def_map.find(pre_hook_name)->second); pre_hook_resolvers.push_back( pre_hook_resolver_map.find(pre_hook_name)->second); } cu_->register_type(class_type); const auto self = SimpleSelf(class_type); cu_->define( qualified_classname, /*properties=*/{}, /*propResolvers=*/{}, methods, method_resolvers, &self); cu_->define_hooks( qualified_classname, hooks, hook_resolvers, pre_hooks, pre_hook_resolvers, &self); } void SourceImporterImpl::importEnum( const QualifiedName& qualified_name, const ClassDef& enum_def) { std::vector names_values; TypePtr value_type = nullptr; auto set_or_check_type = [&value_type]( const TypePtr& t, const SourceRange& loc) { if (!value_type) { value_type = t; } else if (value_type != t) { throw ErrorReport(loc) << "Enum class with varying value types are not supported."; } }; for (const auto& statement : enum_def.body()) { if (statement.kind() != TK_ASSIGN) { throw ErrorReport(statement.range()) << "Unexpected statement in Enum class body: " "only enum attribute definitions are currently supported."; } const auto assign = Assign(statement); const auto name = Var(assign.lhs()).name().name(); IValue ivalue; auto rhs = assign.rhs().get(); switch (rhs.kind()) { case TK_STRINGLITERAL: ivalue = IValue(StringLiteral(rhs).text()); set_or_check_type(StringType::get(), statement.range()); break; case TK_CONST: { auto numeric_const = Const(rhs); if (numeric_const.isFloatingPoint()) { ivalue = IValue(numeric_const.asFloatingPoint()); set_or_check_type(FloatType::get(), statement.range()); } else if (numeric_const.isIntegral()) { ivalue = IValue(numeric_const.asIntegral()); set_or_check_type(IntType::get(), statement.range()); } break; } default: throw ErrorReport(rhs.range()) << "Unsupported enum value type: " << rhs.kind() << ". Only Integers, Floats and Strings are supported."; } names_values.emplace_back(std::make_pair(name, ivalue)); } if (!value_type) { throw ErrorReport(enum_def.range()) << "No enum values defined for " << qualified_name.qualifiedName(); } auto enum_type = EnumType::create( qualified_name, std::move(value_type), std::move(names_values), cu_); cu_->register_type(enum_type); } void SourceImporterImpl::importNamedTuple( const QualifiedName& qualified_name, const ClassDef& named_tuple_def) { ScriptTypeParser type_parser(shared_from_this()); std::vector field_names; std::vector field_types; std::vector field_defaults; for (const auto& statement : named_tuple_def.body()) { if (statement.kind() != TK_ASSIGN) { throw ErrorReport(statement.range()) << "Unexpected statement in NamedTuple body: " "only attribute annotations are currently supported."; } const auto assign = Assign(statement); auto name = Var(Assign(statement).lhs()).name().name(); c10::optional default_val; if (assign.rhs().present()) { std::vector parsed = type_parser.evaluateDefaults( assign.rhs().range(), {assign.rhs().get()}, {assign.type().get()}); TORCH_INTERNAL_ASSERT(parsed.size() == 1); default_val = parsed[0]; } auto type = type_parser.parseTypeFromExpr(assign.type().get()); field_names.emplace_back(std::move(name)); field_types.emplace_back(std::move(type)); if (default_val) { field_defaults.emplace_back(std::move(*default_val)); } } auto tt = TupleType::createNamed( qualified_name, field_names, field_types, field_defaults); cu_->register_type(tt); } void SourceImporterImpl::parsePossibleVersionNumber(Lexer& L) { // Older versions of serialization produced an op_version_set string // per-file We now just use a single version which is handled by // PyTorchStreamReader. We used to check if op_version_set was _newer_ for // forward compatibility reasons but now that it doesn't exist there can't // be a newer one, so we just discard this. if (L.cur().kind == TK_IDENT && L.cur().text() == "op_version_set") { auto range = L.cur().range; L.next(); L.expect('='); std::string version_text = L.expect(TK_NUMBER).text(); L.expect(TK_NEWLINE); } } // older versions of serialization required import statements, // and defined classes file-at-a-time in import order. // The problem is that in Python // it is possible to construct cyclic dependencies between files even // when there are none between individual classes. New versions of loading // just compile class-at-a-time, so we no longer need to follow the import // order. Future serialization may stop producing the import code. void SourceImporterImpl::parseImports(Lexer& L) { while (L.nextIf(TK_IMPORT)) { std::ostringstream s; while (L.cur().kind != TK_NEWLINE) { s << L.cur().text(); L.next(); } L.expect(TK_NEWLINE); } } std::shared_ptr ClassNamespaceValue::attr( const SourceRange& loc, GraphFunction& m, const std::string& name) { auto fullName = c10::QualifiedName(basename_, name); // Could be a ClassType or NamedTuple constructor if (auto serializable_type = si_->findNamedType(fullName)) { if (auto classType = serializable_type->cast()) { return std::make_shared(classType); } else if (auto tupleType = serializable_type->cast()) { return std::make_shared(tupleType); } else if (auto enumType = serializable_type->cast()) { return std::make_shared(enumType); } } // Or it could be a free function if (auto fn = si_->findFunction(fullName)) { return std::make_shared(fn); } // If it's none of those things, assume it's another namespace return std::make_shared(std::move(fullName), si_); } SourceImporter::SourceImporter( // The compilation unit that will own the imported source std::shared_ptr cu, const std::vector* constant_table, SourceLoader loader, size_t version) : pImpl(std::make_shared( std::move(cu), constant_table, std::move(loader), version)) {} TypePtr SourceImporter::loadType(const QualifiedName& name) const { ScriptTypeParser type_parser(pImpl); TypePtr t = type_parser.parseType(name.qualifiedName()); return t; } void SourceImporter::LEGACY_import_methods( const Module& mod, const std::shared_ptr& src) { pImpl->LEGACY_import_methods(mod, src); } SourceImporter::~SourceImporter() = default; } // namespace jit } // namespace torch