// in memory description of all ATen Ops similar to Caffe2 schema // once C10 exists this can be removed, or stubbed out, but we need // it now to implement correct semantic checking for script #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::jit { struct Node; using ::c10::Argument; using ::c10::FunctionSchema; using ::c10::Symbol; using OperationCreator = Operation (*)(const Node*); namespace { const std::array kJitOnlyOperatorTags = { at::Tag::pt2_compliant_tag}; } /* * Note: JIT relies on Operator instances having static lifetime, because * it for example stores a non-owning FunctionSchema* pointer in the Node class, * which points to the function schema stored in the Operator instance. * Also, jit::Operator is meant to store more operator related information like * symbolic derivatives, which also requires them to have static lifetime * so that changes to symbolic derivatives are remembered. * * Currently, the JIT operator library contains a jit::Operator instance * with a wrapper for each c10 operator. The c10 operator library registers * those wrappers using listeners in register_c10_ops.cpp. * TODO Instead of doing it this way, we should only have pure-jit ops in * the jit library but have the JIT operator lookup look into the c10 library * too. */ // An Operator is a thin wrapper around either a pure JIT operator (e.g. prim // ops) or a c10 operator, allowing some common operations and abstracting away // the concrete operator nature. struct TORCH_API Operator { private: struct C10Operator final { c10::OperatorHandle handle_; Operation op_; }; struct UnparsedFunctionSchema final { std::string schema_string_; mutable std::optional alias_analysis_; }; struct JitOnlyOperator final { // The only valid transition for schema_ is from right->left, i.e. // when the schema gets parsed. mutable std::variant schema_; std::variant op_; }; public: Operator(c10::OperatorHandle opHandle, Operation operation) : op_(C10Operator{std::move(opHandle), std::move(operation)}) {} Operator( std::string schema, Operation op, c10::AliasAnalysisKind alias_analysis) : op_(JitOnlyOperator{ UnparsedFunctionSchema{std::move(schema), alias_analysis}, Operation(std::move(op))}) {} Operator( std::string name, std::string overload_name, std::vector arguments, std::vector returns, Operation op, c10::AliasAnalysisKind alias_analysis) : op_(JitOnlyOperator{ FunctionSchema(varArgSchemaWithName( std::move(name), std::move(overload_name), std::move(arguments), std::move(returns), alias_analysis)), std::move(op)}) {} Operator( std::string schema, OperationCreator op_creator, c10::AliasAnalysisKind alias_analysis) : op_(JitOnlyOperator{ UnparsedFunctionSchema{std::move(schema), alias_analysis}, op_creator}) {} // Helper constructor to register `op` to run // run for _every_ IR Node where n.kind() == name, regardless of arguments. // This is accomplished by marking the schema varargs and having no required // arguments. Operator( Symbol name, OperationCreator op_creator, c10::AliasAnalysisKind alias_analysis) : op_(JitOnlyOperator{ FunctionSchema(varArgSchemaWithName(name, alias_analysis)), op_creator}) {} Operation getOperation(const Node* node = nullptr) const { return std::visit( c10::overloaded( [](const C10Operator& op) { return op.op_; }, [node](const JitOnlyOperator& op) { return std::visit( c10::overloaded( [](const Operation& op) { return op; }, [node](const OperationCreator& op_creator) { return op_creator(node); }), op.op_); }), op_); } Operation getOperationForDispatchKey(c10::DispatchKey dk) const { // TODO: some sort of caching mechanism? return std::visit( c10::overloaded( [dk](const C10Operator& op) { return Operation([op, dk](Stack& stack) { op.handle_.callBoxedForDispatchKey(dk, stack); }); }, [](const JitOnlyOperator& op) { TORCH_CHECK( false, "calling a JIT operator for dispatch key is not supported"); return Operation(nullptr); }), op_); } const FunctionSchema& schema() const { return std::visit( c10::overloaded( [](const C10Operator& op) -> const FunctionSchema& { return op.handle_.schema(); }, [](const JitOnlyOperator& op) -> const FunctionSchema& { // we lazily parse schema initialized from strings so that // we do less work during static operator registration if (op.schema_.index() == 1) { auto& unmaterializedSchema = std::get(op.schema_); FunctionSchema schema = parseSchema(unmaterializedSchema.schema_string_); if (unmaterializedSchema.alias_analysis_.has_value()) { // TODO What if it gets set later? schema.setAliasAnalysis( *unmaterializedSchema.alias_analysis_); } op.schema_ = std::move(schema); } return std::get(op.schema_); }), op_); } c10::ArrayRef getTags() const { return std::visit( c10::overloaded( [](const C10Operator& op) { return op.handle_.getTags(); }, [](const JitOnlyOperator& op) { // JitOnlyOperators don't have an c10::OperatorHandle or a way to // specify tags. We're grandfathering them all into // pt2_compliant_tag, but for anything else, please just stop // using JitOnlyOperator. return c10::ArrayRef(kJitOnlyOperatorTags); }), op_); } bool isC10Op() const { return op_.index() == 0; } c10::AliasAnalysisKind aliasAnalysisKind() const { const FunctionSchema& schemaRef = schema(); c10::AliasAnalysisKind alias_analysis = schemaRef.aliasAnalysis(); TORCH_CHECK( alias_analysis == AliasAnalysisKind::FROM_SCHEMA || !schemaRef.hasAnyAliasInfo(), "In operator registration: Tried to register operator ", schemaRef, " with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA."); return alias_analysis; } bool hasOperation() const { return std::visit( c10::overloaded( [](const C10Operator&) { return true; }, [](const JitOnlyOperator& op) { return op.op_.index() == 0; }), op_); } private: static FunctionSchema varArgSchemaWithName( Symbol name, AliasAnalysisKind alias_analysis) { auto result = FunctionSchema( name, "", {}, {}, /*is_vararg*/ true, /*is_varret*/ true); result.setAliasAnalysis(alias_analysis); return result; } static FunctionSchema varArgSchemaWithName( std::string name, std::string overload_name, std::vector arguments, std::vector returns, AliasAnalysisKind alias_analysis) { auto result = FunctionSchema( std::move(name), std::move(overload_name), std::move(arguments), std::move(returns), /*is_vararg*/ false, /*is_varret*/ false); result.setAliasAnalysis(alias_analysis); return result; } std::variant op_; }; TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema); TORCH_API const std::vector> getAllOperators(); TORCH_API const std::vector>& getAllOperatorsFor( Symbol name); // Returns operators in the order which OpOverloadPacket resolves them. TORCH_API std::vector> getAllSortedOperatorsFor( Symbol name); // given a operator with an overload name, find the specific operator related to // it, may return nullptr if no operator exists. TORCH_API std::shared_ptr findOperatorFor( const c10::OperatorName& full_name); TORCH_API std::vector findSimilarOperators(Symbol input_op); TORCH_API void registerOperator(Operator&& op); TORCH_API void deregisterOperator(const FunctionSchema& schema); // XXX: this function is meant to be used with string literals only! TORCH_API std::shared_ptr getOperatorForLiteral( const char* signature); // Ensure the thing that registers c10 ops is defined. // Otherwise, our registry will not have c10 ops. You can run into this // scenario if you're querying registered ops during static init. // // This fn is defined in register_c10_ops.cpp TORCH_API void ensure_c10_registerer_defined(); // Used to assert that unschematized operators have an analysis method written TORCH_API bool aliasAnalysisHasSpecialCaseFor(c10::Symbol sym); // A factory function to generate an optional operator. It has two // instantiations depending on the template bool arg value. The arg can be a // compile-time function for the selective op registration based on schema // string. template std::optional OperatorGenerator( const char* schema_str, Func&& op, AliasAnalysisKind alias_analysis) { return std::optional(Operator( std::string(schema_str), std::forward(op), alias_analysis)); } template std::optional OperatorGenerator( torch::detail::SelectiveStr schema_str, Func&& op, AliasAnalysisKind alias_analysis) { return OperatorGenerator( static_cast(schema_str), std::forward(op), alias_analysis); } template std::optional OperatorGenerator( torch::detail::SelectiveStr schema_str, Func&& op, AliasAnalysisKind alias_analysis) { return std::nullopt; } template std::optional OperatorGenerator( const std::string name, const std::string overload_name, const std::vector arguments, const std::vector returns, Func&& op, AliasAnalysisKind alias_analysis) { return std::optional(Operator( name, overload_name, arguments, returns, std::forward(op), alias_analysis)); } } // namespace torch::jit