#include #include #include #include #include #include #include #include #include #include #include #if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH)) #include #endif #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch { namespace jit { using c10::AliasInfo; using c10::Argument; using c10::FunctionSchema; using c10::SchemaArgType; using c10::SchemaArgument; using c10::SymNode; using caffe2::serialize::PyTorchStreamReader; using caffe2::serialize::PyTorchStreamWriter; using torch::utils::SchemaInfo; namespace { using autograd::variable_list; bool loadPythonClasses() { // Leaving this code here, because it will likely be useful at some point // PyObject *jit_module = PyImport_ImportModule("torch.jit"); // THPUtils_assert(jit_module, "class loader couldn't access " //"torch.jit module"); // PyObject *jit_dict = PyModule_GetDict(jit_module); return true; } c10::optional toTypeInferredIValueOptional(py::handle input) { // Errors need to be caught here because toTypeInferredIValue errors out // on various object types, but we want it to work with all types. try { return toTypeInferredIValue(input); } catch (const c10::Error& e) { return c10::nullopt; } } } // anonymous namespace #if !defined(USE_ROCM) TORCH_API void runJITCPPTests(); #endif void initJITBindings(PyObject* module) { auto m = py::handle(module).cast(); auto jit = m.def_submodule("_jit"); static py::exception exc(m, "JITException"); py::register_exception_translator([](std::exception_ptr p) { try { if (p) { std::rethrow_exception(p); } } catch (const JITException& e) { // special handling of JITException, to set its python class name and msg py::gil_scoped_acquire acquire; const auto& className = e.getPythonClassName(); const auto& originalMsg = e.getOriginalMsg(); JITException::setCaughtOriginalMsg(originalMsg.value_or("")); JITException::setCaughtPythonClassName(className.value_or("")); exc(e.what()); } }); m.def( "_get_caught_jit_exception_class_name", JITException::getCaughtPythonClassName); m.def( "_get_caught_jit_exception_original_msg", JITException::getCaughtOriginalMsg); py::class_ iodescriptor( m, "IODescriptor"); // NOLINT(bugprone-unused-raii) m.def("_jit_init", loadPythonClasses) .def( "_jit_debug_fuser_num_cached_kernel_specs", torch::jit::fuser::debugNumCachedKernelSpecs) .def("_jit_pass_lower_all_tuples", LowerAllTuples) .def( "_new_symbolic_shape_symbol", []() { return c10::ShapeSymbol::newSymbol().value(); }) .def( "_jit_shape_compute_graph_for_node", [](Node* n) -> c10::optional> { if (!n->maybeSchema()) { return c10::nullopt; } return shapeComputeGraphForSchema(n->schema()); }) .def( "_jit_decomposition_graph_for_node", [](Node* n) -> c10::optional> { if (!n->maybeSchema()) { return c10::nullopt; } return GetDecomposition(n->schema()); }) .def("_jit_pass_run_decompositions", RunDecompositions) // using Node* here instead of Schema because looking up the schema // and passing it in from Python will have a different pointer than the // schema that is globally used for caching .def( "_jit_register_shape_compute_graph_for_node", [](Node* n, std::shared_ptr& graph) { if (n->maybeSchema()) { const FunctionSchema& schema = n->schema(); RegisterShapeComputeGraphForSchema(schema, graph); } else { TORCH_INTERNAL_ASSERT(false, "Expected schema", n); } }) .def( "_jit_register_decomposition_for_schema", [](const FunctionSchema& s, std::shared_ptr& graph) { // because this is invoked by python, the function schema * // becomes different, and we need to find and reuse the // one that is used for caching auto op = findOperatorFor(c10::OperatorName(s.name(), s.overload_name())); RegisterDecomposition(op->schema(), graph); }) .def("_jit_pass_propagate_shapes_on_graph", PropagateShapesOnGraph) .def( "_jit_pass_propagate_shapes_on_graph_and_build_compute", [](std::shared_ptr& graph) { return PropagateShapesAndBuildLargeShapeComputeGraph( graph, *graph->nodes().begin(), *graph->nodes().end()); }) .def( "_jit_pass_propagate_shapes_on_graph_and_build_compute", [](std::shared_ptr& graph, Node* beg) { return PropagateShapesAndBuildLargeShapeComputeGraph( graph, beg, *graph->nodes().end()); }) .def( "_jit_pass_propagate_shapes_on_graph_and_build_compute", PropagateShapesAndBuildLargeShapeComputeGraph) .def("_jit_pass_integer_value_refinement", RefineIntegerValues) .def( "_jit_set_symbolic_shapes_test_mode", &setSymbolicShapeAnalysisTestMode) .def( "_jit_symbolic_shapes_test_mode_enabled", &symbolicShapeAnalysisTestModeEnabled) .def("_jit_pass_autocast", Autocast) .def("_jit_set_autocast_mode", &setAutocastMode) .def("_jit_pass_fuse", FuseGraph) .def( "_jit_pass_replace_old_ops_with_upgraders", [](std::shared_ptr& g) { return ReplaceOldOperatorsWithUpgraders(g); }) .def( "_jit_pass_dce", [](std::shared_ptr& g) { return EliminateDeadCode(g->block()); // overload resolution }) .def( "_jit_pass_dce_allow_deleting_nodes_with_side_effects", [](std::shared_ptr& g) { return EliminateDeadCode( g->block(), true, DCESideEffectPolicy:: ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); // overload // resolution }) .def( "_jit_pass_cse", [](std::shared_ptr& g) { return EliminateCommonSubexpression(g); // overload resolution }) .def( "_jit_pass_fuse_quantized_add_relu", [](std::shared_ptr& g) { return FuseQuantizedAddRelu(g); // overload resolution }) .def( "_jit_pass_insert_observers", [](Module& module, const std::string& method_name, const py::dict& qconfig_dict, bool inplace, int quant_type_int) { auto dict = py::cast>>>(qconfig_dict); auto quant_type = static_cast(quant_type_int); return InsertObservers( module, method_name, dict, inplace, quant_type); }, py::arg("module"), py::arg("method_name"), py::arg("qconfig_dict"), py::arg("inplace"), py::arg("quant_type_int") = 1) .def( "_jit_pass_insert_observer_method_for_ondevice_ptq", [](Module& module, const std::string& method_name, const py::dict& qconfig_dict, bool inplace, int quant_type_int) { auto dict = py::cast>>>(qconfig_dict); auto quant_type = static_cast(quant_type_int); return InsertObserversForOnDevicePTQ( module, method_name, dict, inplace, quant_type); }, py::arg("module"), py::arg("method_name"), py::arg("qconfig_dict"), py::arg("inplace"), py::arg("quant_type_int") = 1) .def( "_jit_pass_insert_quant_dequant", [](Module& module, const std::string& method_name, bool inplace, bool debug, int quant_type_int) { auto quant_type = static_cast(quant_type_int); return InsertQuantDeQuant( module, method_name, inplace, debug, quant_type); }, py::arg("module"), py::arg("method_name"), py::arg("inplace"), py::arg("debug"), py::arg("quant_type_int") = 1) .def( "_jit_pass_insert_quant_dequant_for_ondevice_ptq", [](Module& module, const std::string& method_name, bool inplace, bool debug, int quant_type_int) { auto quant_type = static_cast(quant_type_int); return InsertQuantDeQuantOnDevicePTQ( module, method_name, inplace, debug, quant_type); }, py::arg("module"), py::arg("method_name"), py::arg("inplace"), py::arg("debug"), py::arg("quant_type_int") = 1) .def( "_jit_pass_insert_prepack_unpack", [](std::shared_ptr& g) { return InsertPrepackUnpack(g); }) .def( "_jit_pass_insert_prepack_unpack", [](Module& module) { return InsertPrepackUnpack(module); }) .def( "_jit_pass_quant_fusion", [](std::shared_ptr& g) { return QuantFusion(g); }) .def( "_jit_pass_fold_convbn", [](Module& module) { return FoldConvBatchNorm(module); }) .def( "_jit_pass_dbr_quant_remove_redundant_aliases", [](Module& module) { return DBRQuantRemoveRedundantAliases(module); }) .def( "_freeze_module", [](Module& module, std::vector& preservedAttrs, bool freezeInterfaces, bool preserveParameters) { return freeze_module( module, preservedAttrs, freezeInterfaces, preserveParameters); }, py::arg("module"), py::arg("preservedAttrs") = std::vector(), py::arg("freezeInterfaces") = true, py::arg("preserveParameters") = false) .def("_jit_pass_concat_frozen_linear", &FrozenConcatLinear) .def("_jit_pass_fold_frozen_conv_bn", &FoldFrozenConvBatchnorm) .def("_jit_pass_fold_frozen_conv_add_or_sub", &FoldFrozenConvAddOrSub) .def("_jit_pass_fold_frozen_conv_mul_or_div", &FoldFrozenConvMulOrDiv) .def("_jit_pass_convert_frozen_ops_to_mkldnn", &ConvertFrozenOpsToMKLDNN) .def("_jit_pass_fuse_frozen_conv_add_relu", &FuseFrozenConvAddRelu) .def("_jit_pass_transpose_frozen_linear", &FrozenLinearTranspose) .def("_jit_pass_optimize_frozen_graph", &OptimizeFrozenGraph) .def( "_jit_pass_optimize_for_inference", [](Module& module, std::vector other_methods) { optimize_for_inference(module, other_methods); }, py::arg("module"), py::arg("other_methods") = std::vector()) .def("_jit_pass_fuse_linear", &FuseLinear) .def( "_jit_pass_fuse_add_relu", [](std::shared_ptr& graph) { FuseAddRelu(graph); }) .def("_jit_pass_dedup_module_uses", &DedupModuleUses) .def("_jit_pass_replicate_dequantize", &ReplicateDeQuant) .def( "_jit_pass_swap_functional_linear", [](std::shared_ptr& graph) { SwapFunctionalLinear(graph); }) .def( "_jit_pass_swap_functional_linear", [](Module& module) { SwapFunctionalLinear(module); }) .def( "_jit_pass_quant_finalize", [](Module& module, int quant_type_int, const std::vector& preserved_attrs) { auto quant_type = static_cast(quant_type_int); return Finalize(module, quant_type, preserved_attrs); }, py::arg("module"), py::arg("quant_type_int") = 1, py::arg("preserved_attrs") = std::vector()) .def( "_jit_pass_quant_finalize_for_ondevice_ptq", [](Module& module, int quant_type_int, const std::string& method_name) { auto quant_type = static_cast(quant_type_int); return FinalizeOnDevicePTQ(module, quant_type, method_name); }, py::arg("module"), py::arg("quant_type_int") = 1, py::arg("preserved_attrs") = std::vector()) .def( "_jit_pass_pattern_based_rewrite", [](const Module& m) { return PatternBasedRewrite(m); }) .def( "_jit_pass_custom_pattern_based_rewrite", [](const std::string& pattern, const std::string& fused_node_name, const Module& m) { SubgraphRewriter subgraph_rewriter; subgraph_rewriter.RegisterRewritePattern(pattern, fused_node_name); subgraph_rewriter.runOnModule(m); }) .def( "_jit_pass_custom_pattern_based_rewrite_graph", [](const std::string& pattern, const std::string& fused_node_name, std::shared_ptr g, const std::vector>& value_name_pairs) { SubgraphRewriter subgraph_rewriter; subgraph_rewriter.RegisterRewritePattern( pattern, fused_node_name, value_name_pairs); subgraph_rewriter.runOnGraph(g); }, py::arg("pattern"), py::arg("fused_node_name"), py::arg("g"), py::arg("value_name_pairs") = std::vector>()) .def("_jit_pass_constant_pooling", ConstantPooling) // RemoveInplaceOps is used by CoreML so it must be removed with care. .def("_jit_pass_propagate_dtype", DtypePropagation) .def("_jit_pass_propagate_device", DeviceTypePropagation) .def( "_jit_pass_remove_inplace_ops", [](const std::shared_ptr& g) { return RemoveInplaceOps(g); }) .def( "_jit_pass_create_functional_graphs", [](std::shared_ptr& g) { return CreateFunctionalGraphs(g); }) .def( "_jit_pass_remove_mutation", [](std::shared_ptr& g) { RemoveListMutation(g); return RemoveTensorMutation(g); }) .def( "_jit_pass_functional_to_inplace_activation", [](std::shared_ptr& g) { return FunctionalToInplaceActivation(g); }) .def( "_jit_pass_inplace_to_functional_activation", [](std::shared_ptr& g) { return InplaceToFunctionalActivation(g); }) .def( "_jit_pass_inline_functional_graphs", [](std::shared_ptr& g) { return InlineFunctionalGraphs(g); }) .def( "_jit_pass_peephole", [](const std::shared_ptr& g, bool disable_shape_peepholes) { return PeepholeOptimize(g, disable_shape_peepholes); }, py::arg("graph"), py::arg("disable_shape_peepholes") = false) .def( "_jit_pass_peephole_list_idioms", [](const std::shared_ptr& g, bool refine_list_len) { return PeepholeOptimizeListIdioms(g, refine_list_len); }, py::arg("graph"), py::arg("refine_list_len") = false) .def( "_jit_pass_refine_integer_values", [](std::shared_ptr& g) { return RefineIntegerValues(g); }) .def( "_jit_pass_fuse_addmm", [](std::shared_ptr& g) { return FuseAddMM(g); }) .def( "_jit_pass_canonicalize", [](const std::shared_ptr& g, bool keep_unique_names = true) { return Canonicalize(g, keep_unique_names); }, py::arg("graph"), py::arg("keep_unique_names") = true) .def("_jit_pass_lint", LintGraph) .def( "_jit_pass_complete_shape_analysis", [](const std::shared_ptr& graph, const py::tuple& inputs, bool with_grad) { ArgumentSpecCreator arg_spec_creator(*graph); Stack stack; stack.reserve(inputs.size()); // captures? for (auto& obj : inputs) { stack.push_back(toTypeInferredIValue(obj)); } ArgumentSpec spec = arg_spec_creator.create(with_grad, stack); arg_spec_creator.specializeTypes(*graph, spec); // We only get partial specialization from the arg_spec_creator, but // we want full shape specialization. The alternative would be to // have a "complete type inference" function in ArguemntSpecCreator. auto g_inputs = graph->inputs(); for (const auto i : c10::irange(inputs.size())) { if (stack[i].isTensor()) { g_inputs[i]->setType(stack[i].type()); } } PropagateInputShapes(graph); }) .def( "_jit_interpret_graph", [](std::shared_ptr& graph, const py::tuple& inputs) { Stack stack; stack.reserve(inputs.size()); // captures? for (auto& obj : inputs) { stack.push_back(toTypeInferredIValue(obj)); } auto g_inputs = graph->inputs(); for (const auto i : c10::irange(inputs.size())) { if (stack[i].isTensor()) { g_inputs[i]->setType(stack[i].type()); } } Code code(graph, ""); InterpreterState(code).run(stack); return createPyObjectForStack(std::move(stack)); }, py::doc( "Interpret a JIT graph with given inputs without running any optimization passes on it")) .def( "_jit_trace_graph", [](std::shared_ptr& graph, const py::tuple& inputs) { Stack stack; stack.reserve(inputs.size()); // captures? for (auto& obj : inputs) { stack.push_back(toTypeInferredIValue(obj)); } auto g_inputs = graph->inputs(); for (const auto i : c10::irange(inputs.size())) { if (stack[i].isTensor()) { g_inputs[i]->setType(stack[i].type()); } } return TraceGraph(graph, stack); }) .def( "_jit_trace_module", [](Module& model, const py::tuple& inputs) { auto graph = model.get_method("forward").graph(); Stack stack; stack.reserve(inputs.size() + 1); // captures? push(stack, model._ivalue()); for (auto& obj : inputs) { stack.push_back(toTypeInferredIValue(obj)); } auto traced = TraceGraph(graph, stack); GRAPH_DUMP("Traced Graph", traced); // the easiest way to replace a graph in a module is // to remove all the nodes in the original graph // clone everything from the traced one graph->block()->clear(); graph->block()->cloneFrom(traced->block(), nullptr); GRAPH_DUMP("Copied Graph", graph); }) .def("_jit_pass_remove_expands", RemoveExpands) .def("_jit_pass_erase_number_types", EraseNumberTypes) .def("_jit_pass_inline_fork_wait", InlineForkWait) .def("_jit_pass_inline", Inline) .def( "_jit_pass_lower_graph", [](std::shared_ptr& graph, const Module& self) { return LowerGraph(*graph, self._ivalue()); }) .def("_jit_pass_loop_unrolling", UnrollLoops) .def("_jit_pass_constant_loop_unrolling", UnrollConstantLoops) .def( "_jit_pass_constant_propagation_immutable_types", [](std::shared_ptr& g) { return ConstantPropagationImmutableTypes(g); }) .def( "_jit_pass_constant_propagation", [](std::shared_ptr& g) { return ConstantPropagation(g); }, py::arg("graph")) .def("_jit_pass_erase_shape_information", EraseShapeInformation) .def( "_jit_object_is_non_holding", [](Node& n) { return toIValue(n.output())->toObject()->is_weak_compilation_ref(); }) .def( "_jit_erase_non_input_shape_information", [](std::shared_ptr& g) { std::vector input_types; for (Value* v : g->inputs()) { if (auto tt = v->type()->cast()) { input_types.push_back(tt); } else { input_types.push_back(nullptr); } } EraseShapeInformation(g); for (size_t i = 0; i < input_types.size(); ++i) { if (input_types[i]) { g->inputs().at(i)->setType(input_types[i]); } } }) .def( "_jit_pass_create_autodiff_subgraphs", [](const std::shared_ptr& graph, py::object threshold) { if (threshold.is(py::none())) { CreateAutodiffSubgraphs(graph); } else { CreateAutodiffSubgraphs(graph, py::cast(threshold)); } }, py::arg("graph"), py::arg("threshold") = py::none()) #if defined(BUILDING_TESTS) && !defined(USE_ROCM) .def( "_jit_run_cpp_tests", []() { // We have to release the GIL inside this method, because if we // happen to initialize the autograd engine in these tests, the // newly spawned worker threads will try to initialize their // PyThreadState*, and they need the GIL for this. pybind11::gil_scoped_release _no_gil; return runJITCPPTests(); }) .def("_jit_has_cpp_tests", []() { return true; }) .def("_has_tensorexpr_cpp_tests", []() { return true; }) #else .def("_jit_run_cpp_tests", []() { throw std::exception(); }) .def("_jit_has_cpp_tests", []() { return false; }) .def("_run_tensorexpr_cpp_tests", []() { throw std::exception(); }) .def("_has_tensorexpr_cpp_tests", []() { return false; }) #endif .def( "_jit_flatten", [](py::handle& obj) { auto res = python::flatten(obj); return std::make_pair(res.vars, res.desc); }) .def( "_jit_unflatten", [](const autograd::variable_list& vars, python::IODescriptor& desc) { return py::reinterpret_steal( python::unflatten(vars, desc)); }) .def("_jit_pass_canonicalize_graph_fuser_ops", CanonicalizeOps) .def("_jit_pass_decompose_ops", DecomposeOps) .def("_jit_pass_specialize_autogradzero", specializeAutogradZero) .def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU) .def("_jit_override_can_fuse_on_gpu", &overrideCanFuseOnGPU) .def("_jit_can_fuse_on_cpu", &canFuseOnCPU) .def("_jit_can_fuse_on_gpu", &canFuseOnGPU) .def("_jit_can_fuse_on_cpu_legacy", &canFuseOnCPULegacy) .def("_jit_override_can_fuse_on_cpu_legacy", &overrideCanFuseOnCPULegacy) .def( "_jit_differentiate", [](Graph& g) { // the python binding slightly differs in semantics // it makes a copy of the input Graph, and works on that // jit::differentiate mutates the input Graph auto g_clone = g.copy(); return differentiate(g_clone); }) .def( "_jit_check_alias_annotation", [](const std::shared_ptr& g, const py::tuple& args, const std::string& unqualified_op_name) { auto stack = toTraceableStack(args); checkAliasAnnotation(g, std::move(stack), unqualified_op_name); }) #if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH)) .def("_jit_set_llga_enabled", &RegisterLlgaFuseGraph::setEnabled) .def("_jit_llga_enabled", &RegisterLlgaFuseGraph::isEnabled) #else .def("_jit_set_llga_enabled", [](bool flag) { return false; }) .def("_jit_llga_enabled", []() { return false; }) #endif .def( "_jit_set_tracer_state_warn", [](bool new_warn) { jit::tracer::getTracerStateWarnMode() = new_warn; }) .def( "_jit_get_tracer_state_warn", []() { bool current_tracer_warn = jit::tracer::getTracerStateWarnMode(); return current_tracer_warn; }) .def( "_jit_set_nvfuser_skip_node_kind", // Args: // `op_name`: Symbol of op; // `flip`: flag indicating whether to flip the given op in the // skip list. // Returns: // a bool flag indicating if `op_name` was already in the skip // list. [](const std::string& op_name, bool flip = true) { return fuser::cuda::skipNode(op_name, flip); }) .def("_jit_set_nvfuser_enabled", &fuser::cuda::setEnabled) .def("_jit_nvfuser_can_be_enabled", &fuser::cuda::canBeEnabled) .def( "_jit_set_nvfuser_single_node_mode", [](bool flag) { return fuser::cuda::setSingletonFusion(flag); }) .def( "_jit_nvfuser_single_node_mode", []() { return fuser::cuda::getSingletonFusion(); }) .def( "_jit_set_nvfuser_horizontal_mode", [](bool flag) { return fuser::cuda::setHorizontalFusion(flag); }) .def( "_jit_nvfuser_horizontal_mode", []() { return fuser::cuda::getHorizontalFusion(); }) .def( "_jit_set_nvfuser_guard_mode", [](bool profiling_flag) { bool oldState = fuser::cuda::getCudaFusionGuardMode(); fuser::cuda::getCudaFusionGuardMode() = profiling_flag; return oldState; }) .def("_jit_nvfuser_enabled", &fuser::cuda::isEnabled) .def( "_jit_nvfuser_set_comparison_callback", [](bool run_fallback, py::function fn) { // If set, then the callback will be run after each nvfuser fusion // group is executed. Can be used for testing accuracy. // If run_fallback == True, then a fallback will be run and // unfused_outputs will be nonempty, showing the result if the // fusion didn't take place. Otherwise, unfused_outputs will // be empty auto fn_ptr = std::make_shared(fn); auto callback_lambda = [fn_ptr]( const Stack& fused_outputs, const Stack& unfused_outputs, const std::string& graph_ir) { py::gil_scoped_acquire acquire{}; (*fn_ptr)(fused_outputs, unfused_outputs, graph_ir); }; setCudaFuserComparisonCallback({run_fallback, callback_lambda}); }) .def( "_jit_nvfuser_clear_comparison_callback", []() { setCudaFuserComparisonCallback({false, nullptr}); }) .def( "_jit_set_profiling_mode", [](bool profiling_flag) { bool oldState = getProfilingMode(); getProfilingMode() = profiling_flag; return oldState; }) .def( "_jit_set_profiling_executor", [](bool profiling_flag) { bool oldState = getExecutorMode(); getExecutorMode() = profiling_flag; return oldState; }) .def( "_jit_set_num_profiled_runs", [](size_t num) { size_t old_num = getNumProfiledRuns(); getNumProfiledRuns() = num; return old_num; }) .def( "_jit_get_num_profiled_runs", [] { // pybind can't automatically bind to atomic size_t size_t num_runs = getNumProfiledRuns(); return num_runs; }) .def( "_jit_set_bailout_depth", [](size_t depth) { TORCH_WARN( "Use _jit_set_fusion_strategy, bailout depth is deprecated. Setting to (STATIC, ", depth, ")"); size_t old_depth = getBailoutDepth(); FusionStrategy strat = {{FusionBehavior::STATIC, depth}}; setFusionStrategy(strat); return old_depth; }) .def( "_jit_set_fusion_strategy", [](std::vector> strategy) { FusionStrategy vec_conv; for (const auto& pair : strategy) { if (pair.first == "STATIC") { vec_conv.emplace_back(FusionBehavior::STATIC, pair.second); } else if (pair.first == "DYNAMIC") { vec_conv.emplace_back(FusionBehavior::DYNAMIC, pair.second); } else { TORCH_INTERNAL_ASSERT( false, "FusionBehavior only supported 'STATIC' or 'DYNAMIC', got: ", pair.first); } } auto old_strategy = getFusionStrategy(); auto strat = fmap(old_strategy, [](std::pair behav) { return std::pair( behav.first == FusionBehavior::STATIC ? "STATIC" : "DYNAMIC", behav.second); }); setFusionStrategy(vec_conv); return strat; }) .def( "_jit_set_inline_everything_mode", [](bool enabled) { getInlineEverythingMode() = enabled; }) .def( "_jit_get_inline_everything_mode", []() { return getInlineEverythingMode(); }) .def( "_jit_get_logging_option", []() { return ::torch::jit::get_jit_logging_levels(); }) .def( "_jit_set_logging_option", [](std::string loggingOption) -> void { ::torch::jit::set_jit_logging_levels(loggingOption); }) .def( "_jit_set_logging_stream", [](std::string stream_name) -> void { if (stream_name == "stdout") { ::torch::jit::set_jit_logging_output_stream(std::cout); } else if (stream_name == "stderr") { ::torch::jit::set_jit_logging_output_stream(std::cerr); } else { std::cerr << "ERROR: only `stdout` and `stderr`" << "are supported as output options" << std::endl; } }) .def( "_storage_id", [](const at::Tensor& ten) -> int64_t { return reinterpret_cast( ten.storage().unsafeGetStorageImpl()); }) .def( "_jit_try_infer_type", [](py::object obj) -> InferredType { return tryToInferType(std::move(obj)); }) .def( "_jit_get_te_cuda_pointwise_loop_levels", []() -> int { using namespace torch::jit::tensorexpr; return getTECudaPointwiseLoopLevels(); }) .def( "_jit_set_te_cuda_pointwise_loop_levels", [](int level) { using namespace torch::jit::tensorexpr; return getTECudaPointwiseLoopLevels() = level; }) .def( "_jit_get_te_cuda_pointwise_block_count", []() -> int { using namespace torch::jit::tensorexpr; return getTECudaPointwiseBlockCount(); }) .def( "_jit_set_te_cuda_pointwise_block_count", [](int block_count) { using namespace torch::jit::tensorexpr; return getTECudaPointwiseBlockCount() = block_count; }) .def( "_jit_get_te_cuda_pointwise_block_size", []() -> int { using namespace torch::jit::tensorexpr; return getTECudaPointwiseBlockSize(); }) .def( "_jit_set_te_cuda_pointwise_block_size", [](int block_size) { using namespace torch::jit::tensorexpr; return getTECudaPointwiseBlockSize() = block_size; }) .def("_jit_set_texpr_fuser_enabled", &setTensorExprFuserEnabled) .def("_jit_texpr_fuser_enabled", &tensorExprFuserEnabled) .def("_jit_texpr_fallback_allowed", &tensorexpr::fallbackAllowed) .def("_jit_texpr_set_fallback_allowed", &tensorexpr::setFallbackAllowed) .def("_jit_set_texpr_reductions_enabled", &setTexprReductionsEnabled) .def( "_jit_set_texpr_dynamic_shape_enabled", &setTensorExprDynamicShapeFusionEnabled) .def( "_jit_texpr_dynamic_shape_enabled", &tensorExprDynamicShapeFusionEnabled) .def("_jit_texpr_reductions_enabled", &texprReductionsEnabled) .def( "_jit_set_te_generate_block_code", [](bool gen_block_code) { using namespace torch::jit::tensorexpr; return getTEGenerateBlockCode() = gen_block_code; }) .def( "_jit_get_te_generate_block_code", []() -> bool { using namespace torch::jit::tensorexpr; return getTEGenerateBlockCode(); }) .def( "_jit_get_te_must_use_llvm_cpu", []() -> bool { using namespace torch::jit::tensorexpr; return getTEMustUseLLVMOnCPU(); }) .def( "_jit_set_te_must_use_llvm_cpu", [](bool use_llvm) { using namespace torch::jit::tensorexpr; getTEMustUseLLVMOnCPU() = use_llvm; }) .def( "_jit_cat_wo_conditionals", [](bool optimize_cat) { using namespace torch::jit::tensorexpr; getCatWoConditionals() = optimize_cat; }) .def( "_jit_opt_conditionals", [](bool opt_conds) { using namespace torch::jit::tensorexpr; getOptConditionals() = opt_conds; }) .def( "_llvm_enabled", []() { #ifdef TORCH_ENABLE_LLVM return true; #else return false; #endif }) .def( "_jit_pass_fuse_tensorexprs", [](std::shared_ptr& g) { FuseTensorExprs(g); RemoveTensorTypeSpecializations(g); }) .def( "_jit_fuser_get_fused_kernel_code", [](Graph& g, const std::vector& inps) { return debugGetFusedKernelCode(g, inps); }) .def( "_jit_pass_remove_dropout", [](script::Module& module) { return removeDropout(module); }) .def( "_jit_pass_refine_tuple_types", [](std::shared_ptr& graph) { return RefineTupleTypes(graph); }) .def( "_jit_pass_transform_conv1d_to_conv2d", [](std::shared_ptr& graph) { return transformConv1dToConv2d(graph); }) .def( "_jit_pass_transform_conv1d_to_conv2d", [](script::Module& module) { return transformConv1dToConv2d(module); }) .def( "_jit_pass_insert_prepacked_ops", [](std::shared_ptr& graph) { return insertPrePackedOps(graph); }) .def( "_jit_pass_insert_prepacked_ops", [](script::Module& module) { return insertPrePackedOps(module); }) .def( "_jit_pass_fuse_clamp_w_prepacked_linear_conv", [](script::Module& module) { return fusePrePackedLinearConvWithClamp(module); }) .def( "_jit_pass_fold_prepacking_ops", [](script::Module& module) { return FoldPrePackingOps(module); }) .def( "_jit_pass_optimize_for_mobile", [](script::Module& module, std::set& optimization_blocklist, std::vector& preserved_methods) { return optimizeForMobile( module, optimization_blocklist, preserved_methods); }) .def( "_hack_do_not_use_clone_module_with_class", [](script::Module& module, std::vector& ignored_methods, std::vector& ignored_attributes) { const bool inplace = false; const std::unordered_set ignored_methods_set( ignored_methods.begin(), ignored_methods.end()); const std::unordered_set ignored_attributes_set( ignored_attributes.begin(), ignored_attributes.end()); return module.clone( inplace, ignored_methods_set, ignored_attributes_set); }) .def( "_jit_pass_vulkan_insert_prepacked_ops", [](std::shared_ptr& graph) { return vulkanInsertPrePackedOps(graph); }) .def( "_jit_pass_vulkan_insert_prepacked_ops", [](script::Module& module) { return vulkanInsertPrePackedOps(module); }) .def( "_jit_pass_vulkan_fuse_clamp_w_prepacked_conv", [](script::Module& module) { return vulkanFusePrePackedConvWithClamp(module); }) .def( "_jit_pass_vulkan_fold_prepacking_ops", [](script::Module& module) { return vulkanFoldPrePackingOps(module); }) .def( "_jit_pass_vulkan_optimize_for_mobile", [](script::Module& module, std::set& optimization_blocklist, std::vector& preserved_methods) { return vulkanOptimizeForMobile( module, optimization_blocklist, preserved_methods); }) .def( "_jit_pass_metal_insert_prepacked_ops", [](std::shared_ptr& graph) { return metalInsertPrePackedOps(graph); }) .def( "_jit_pass_metal_insert_prepacked_ops", [](script::Module& module) { return metalInsertPrePackedOps(module); }) .def( "_jit_pass_metal_fuse_clamp_w_prepacked_conv", [](script::Module& module) { return metalFusePrePackedConvWithClamp(module); }) .def( "_jit_pass_metal_fold_prepacking_ops", [](script::Module& module) { return metalFoldPrePackingOps(module); }) .def( "_jit_pass_metal_optimize_for_mobile", [](script::Module& module, std::vector& preserved_methods) { return metalOptimizeForMobile(module, preserved_methods); }) .def( "_jit_pass_filter_non_tensor_arguments", [](std::map params) { std::map retval; for (auto& kv : params) { if (kv.second.isTensor()) { retval[kv.first] = std::move(kv.second).toTensor(); } } return retval; }) .def("_jit_pass_batch_mm", BatchMM) .def("_jit_decay_packed_param_input_types", [](Graph& g) { for (Value* i : g.inputs()) { if (i->type() == getCustomClass( "__torch__.torch.classes.quantized.Conv2dPackedParamsBase") || i->type() == getCustomClass( "__torch__.torch.classes.quantized.Conv3dPackedParamsBase") || i->type() == getCustomClass( "__torch__.torch.classes.quantized.LinearPackedParamsBase")) { // Dummy CompleteTensorType to appease ONNX validator. i->setType(TensorType::create( at::kQInt8, c10::kCPU, std::vector{1}, std::vector{1}, c10::nullopt)); } } }); // NB: This isn't actually used for regular PyTorch symbolic tracing; // XLA is what needs this #define SYMNODE_UNARY(n) .def(#n, [](c10::SymNode a) { return a->n(); }) #define SYMNODE_UNARY2(n2, n) .def(#n2, [](c10::SymNode a) { return a->n(); }) #define SYMNODE_BINARY(n) \ .def(#n, [](c10::SymNode a, c10::SymNode b) { return a->n(b); }) auto symnode_class = py::class_(m, "_SymNode") // These DO NOT install magic methods; the SymInt/SymFloat wrapper in // Python is responsible for this SYMNODE_UNARY(clone) // Named these for consistency with inner python class, but maybe // should change the python side SYMNODE_UNARY2(__bool__, bool_) SYMNODE_UNARY2(__int__, int_) SYMNODE_UNARY2(__sym_int__, sym_int) SYMNODE_UNARY2( __sym_float__, sym_float) SYMNODE_BINARY(add) SYMNODE_BINARY(sub) SYMNODE_BINARY(mul) SYMNODE_BINARY(truediv) SYMNODE_BINARY(pow) SYMNODE_BINARY(floordiv) SYMNODE_BINARY(mod) SYMNODE_BINARY( eq) SYMNODE_BINARY(gt) SYMNODE_BINARY(lt) SYMNODE_BINARY(le) SYMNODE_BINARY(ge) SYMNODE_BINARY(min) SYMNODE_BINARY(max) SYMNODE_UNARY(ceil) SYMNODE_UNARY(floor) SYMNODE_UNARY(neg) // Intentionally don't set file line, as the // Python backtrace matters more here .def( "guard_int", [](c10::SymNode a) { return a->guard_int(nullptr, 0); }) .def( "__str__", [](c10::SymNode a) { return a->str(); }) .def("__repr__", [](c10::SymNode a) { return a->str(); }); // NOLINTNEXTLINE(bugprone-unused-raii) py::class_(m, "CompleteArgumentSpec") .def("__repr__", [](CompleteArgumentSpec& self) { std::ostringstream s; s << self; return s.str(); }); // NOLINTNEXTLINE(bugprone-unused-raii) py::class_(m, "ArgumentSpec"); py::class_(m, "Code") .def( "grad_executor_states", [](Code& c) { std::vector states; for (auto& e : c.grad_executors()) { states.emplace_back(e->getDebugState()); } return states; }) .def( "differentiable_op_executor_states", [](Code& c) { std::vector states; for (auto& e : c.diff_graph_op_executors()) { if (e->isOptimized()) { states.emplace_back(e->getDebugState()); } else { // we leave an empty entry for node that doesn't have an // optimized plan states.emplace_back(); } } return states; }) .def("num_bailouts", [](Code& c) { return c.num_bailouts(); }) .def("request_bailout", [](Code& c, size_t index) { c.request_bailout(index); }); py::class_(m, "ExecutionPlan") .def_property_readonly("graph", [](ExecutionPlan& s) { return s.graph; }) .def_property_readonly("code", [](ExecutionPlan& s) { return s.code; }); py::class_(m, "Gradient") .def_property_readonly("f", [](Gradient& m) { return m.f; }) .def_property_readonly("df", [](Gradient& m) { return m.df; }) .def_property_readonly( "f_real_outputs", [](Gradient& m) { return m.f_real_outputs; }) .def_property_readonly( "df_input_vjps", [](Gradient& m) { return m.df_input_vjps; }) .def_property_readonly( "df_input_captured_inputs", [](Gradient& m) { return m.df_input_captured_inputs; }) .def_property_readonly( "df_input_captured_outputs", [](Gradient& m) { return m.df_input_captured_outputs; }) .def_property_readonly( "df_output_vjps", [](Gradient& m) { return m.df_output_vjps; }); py::class_(m, "GraphExecutorState") .def_property_readonly( "graph", [](GraphExecutorState& s) { return s.graph; }) .def_property_readonly( "execution_plans", [](GraphExecutorState& s) { return s.execution_plans; }) .def_property_readonly( "fallback", [](GraphExecutorState& s) { return s.fallback; }); py::class_(m, "PyTorchFileWriter") .def(py::init()) .def(py::init([](const py::object& buffer) { auto writer_func = [=](const void* data, size_t size) { // Writting an empty file is a noop if (size == 0) { return size; } auto memory_view = py::memoryview::from_memory( reinterpret_cast(data), size); buffer.attr("write")(std::move(memory_view)); return size; }; return std::make_unique(std::move(writer_func)); })) .def(py::init&>()) .def( "write_record", [](PyTorchStreamWriter& self, const std::string& name, const char* data, size_t size) { return self.writeRecord(name, data, size); }) .def("write_end_of_file", &PyTorchStreamWriter::writeEndOfFile) .def("set_min_version", &PyTorchStreamWriter::setMinVersion) .def( "write_record", [](PyTorchStreamWriter& self, const std::string& name, uintptr_t data, size_t size) { return self.writeRecord( name, reinterpret_cast(data), size); }) .def("archive_name", &PyTorchStreamWriter::archiveName) .def( "get_all_written_records", &PyTorchStreamWriter::getAllWrittenRecords); py::enum_(m, "MobileOptimizerType") .value("CONV_BN_FUSION", MobileOptimizerType::CONV_BN_FUSION) .value( "INSERT_FOLD_PREPACK_OPS", MobileOptimizerType::INSERT_FOLD_PREPACK_OPS) .value("REMOVE_DROPOUT", MobileOptimizerType::REMOVE_DROPOUT) .value("FUSE_ADD_RELU", MobileOptimizerType::FUSE_ADD_RELU) .value( "HOIST_CONV_PACKED_PARAMS", MobileOptimizerType::HOIST_CONV_PACKED_PARAMS) .export_values(); // This allows PyTorchStreamReader to read from a Python buffer. It requires // that the buffer implement `seek()`, `tell()`, and `read()`. class BufferAdapter : public caffe2::serialize::ReadAdapterInterface { public: BufferAdapter(const py::object& buffer) : buffer_(buffer) { // Jump to the end of the buffer to get its size auto current = buffer.attr("tell")(); start_offset_ = py::cast(current); buffer.attr("seek")(current, py::module::import("os").attr("SEEK_END")); size_ = py::cast(buffer.attr("tell")()) - start_offset_; buffer.attr("seek")(current); // If we can read directly into a buffer, do that instead of an extra copy use_readinto_ = py::hasattr(buffer, "readinto"); } size_t size() const override { return size_; } THPObjectPtr getMemview(void* buf, size_t n) const { THPObjectPtr memview(PyMemoryView_FromMemory( reinterpret_cast(buf), n, PyBUF_WRITE)); if (!memview) { throw python_error(); } return memview; } size_t read(uint64_t pos, void* buf, size_t n, const char* what) const override { // Seek to desired position (NB: this has to be a Py_ssize_t or Python // throws a weird error) Py_ssize_t absolute_pos = start_offset_ + pos; buffer_.attr("seek")(absolute_pos); if (use_readinto_) { auto memview = getMemview(buf, n); auto res = PyObject_CallMethod(buffer_.ptr(), "readinto", "O", memview.get()); if (res) { int64_t i = static_cast(PyLong_AsLongLong(res)); if (i > 0) { return i; } } } // Read bytes into `buf` from the buffer std::string bytes = py::cast(buffer_.attr("read")(n)); std::copy( bytes.data(), bytes.data() + bytes.size(), reinterpret_cast(buf)); return bytes.size(); } py::object buffer_; size_t size_; size_t start_offset_; bool use_readinto_; }; py::class_>( m, "PyTorchFileReader") .def(py::init()) .def(py::init([](const py::object& buffer) { auto adapter = std::make_unique(buffer); return std::make_shared(std::move(adapter)); })) .def( "get_record", [](PyTorchStreamReader& self, const std::string& key) { at::DataPtr data; size_t size = 0; std::tie(data, size) = self.getRecord(key); return py::bytes(reinterpret_cast(data.get()), size); }) .def( "has_record", [](PyTorchStreamReader& self, const std::string& key) { return self.hasRecord(key); }) .def( "get_storage_from_record", [](PyTorchStreamReader& self, const std::string& key, size_t numel, py::object data_type_obj) { at::DataPtr data(std::get<0>(self.getRecord(key))); auto scalar_type = reinterpret_cast(data_type_obj.ptr())->scalar_type; c10::Storage storage( c10::Storage::use_byte_size_t(), numel * elementSize(scalar_type), std::move(data), /*allocator=*/nullptr, /*resizable=*/false); auto ptr = c10::make_intrusive( std::move(storage), at::DispatchKeySet(), at::CPU(scalar_type).typeMeta()); return at::Tensor(std::move(ptr)); }) .def("get_all_records", [](PyTorchStreamReader& self) { return self.getAllRecords(); }); // Used by torch.Package to coordinate deserialization of storages across // ScriptModules and eager modules py::class_< DeserializationStorageContext, std::shared_ptr>( m, "DeserializationStorageContext") .def(py::init<>()) .def( "get_storage", [](DeserializationStorageContext& self, const std::string& name, py::object data_type_obj) { c10::Storage storage = self.getStorage(name); auto scalar_type = reinterpret_cast(data_type_obj.ptr())->scalar_type; auto ptr = c10::make_intrusive( std::move(storage), at::DispatchKeySet(), at::CPU(scalar_type).typeMeta()); return at::Tensor(std::move(ptr)); }) .def( "add_storage", [](DeserializationStorageContext& self, const std::string& name, const at::Tensor& tensor) { return self.addStorage(name, tensor.storage()); }) .def("has_storage", &DeserializationStorageContext::hasStorage); m.def( "_get_schema", [](const std::string& op_name, const std::string& overload_name) { try { auto symbol = Symbol::fromQualString(op_name); auto operations = getAllOperatorsFor(symbol); for (const auto& op : operations) { if (op->schema().overload_name() == overload_name) { return op->schema(); } } throw std::runtime_error("Found no matching schema"); } catch (const c10::Error& e) { auto msg = torch::get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace(); throw std::runtime_error(msg); } }); m.def( "_get_operation_overload", [](const std::string& op_name, const std::string& overload_name) { try { auto symbol = Symbol::fromQualString(op_name); auto operations = getAllOperatorsFor(symbol); bool allow_numbers_as_tensors = symbol.is_prims() || symbol.is_nvprims() || (symbol.is_aten() && torch::should_allow_numbers_as_tensors(symbol.toUnqualString())); for (const auto& op : operations) { if (op->schema().overload_name() == overload_name) { auto func = py::cpp_function([op, symbol, allow_numbers_as_tensors]( py::args args, py::kwargs kwargs) { ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors); return _get_operation_for_overload_or_packet( {op}, symbol, args, kwargs, /*is_overload*/ true); }); auto func_dk = py::cpp_function( [op, symbol, allow_numbers_as_tensors]( c10::DispatchKey dk_, py::args args, py::kwargs kwargs) { c10::optional dk = c10::make_optional(dk_); ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors); return _get_operation_for_overload_or_packet( {op}, symbol, args, kwargs, /*is_overload*/ true, dk); }); return py::make_tuple( func, func_dk, py::cast(op->getTags().vec())); } } throw std::runtime_error("Found no matching operator overload"); } catch (const c10::Error& e) { auto msg = torch::get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace(); throw std::runtime_error(msg); } }); m.def( "_jit_get_operation", [](const std::string& op_name) { try { auto symbol = Symbol::fromQualString(op_name); auto operations = getAllOperatorsFor(symbol); TORCH_CHECK(!operations.empty(), "No such operator ", op_name); std::ostringstream docstring; docstring << "Automatically bound operator '" << op_name << "' with schema(s):\n"; for (const auto& op : operations) { docstring << " " << op->schema() << "\n"; } py::list overload_names; for (const auto& op : operations) { overload_names.append(py::str(op->schema().overload_name())); } bool allow_numbers_as_tensors = symbol.is_prims() || symbol.is_nvprims() || (symbol.is_aten() && torch::should_allow_numbers_as_tensors(symbol.toUnqualString())); auto func = py::cpp_function( [operations, symbol, allow_numbers_as_tensors]( py::args args, py::kwargs kwargs) { ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors); return _get_operation_for_overload_or_packet( operations, symbol, args, kwargs, false); }, py::name(symbol.toUnqualString()), py::doc(docstring.str().c_str())); return py::make_tuple(func, overload_names); } catch (const c10::Error& e) { auto msg = torch::get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace(); throw std::runtime_error(msg); } }, py::arg("qualified_name")); m.def( "parse_ir", [](const std::string& input, bool parse_tensor_constants) { auto graph = std::make_shared(); parseIR(input, &*graph, parse_tensor_constants); return graph; }, py::arg("input"), py::arg("parse_tensor_constants") = false); m.def("parse_schema", parseSchema); m.def("unify_type_list", [](const std::vector& types) { std::ostringstream s; auto type = unifyTypeList(types, s); if (!type) { throw std::runtime_error(s.str()); } return type.value(); }); py::enum_(m, "_SchemaArgType") .value("input", SchemaArgType::input) .value("output", SchemaArgType::output); py::class_(m, "_SchemaArgument") .def(py::init()) .def_readwrite("type", &SchemaArgument::type) .def_readwrite("index", &SchemaArgument::index); py::class_(m, "_SchemaInfo") .def(py::init()) .def("is_mutable", [](SchemaInfo& self) { return self.is_mutable(); }) .def( "is_mutable", [](SchemaInfo& self, const SchemaArgument& argument) { return self.is_mutable(argument); }) .def( "has_argument", [](SchemaInfo& self, const std::string& name) { return self.has_argument(name); }) .def( "is_mutable", [](SchemaInfo& self, const std::string& name) { return self.is_mutable(name); }) .def( "may_alias", [](SchemaInfo& self, const SchemaArgument& lhs, const SchemaArgument& rhs) { return self.may_alias(lhs, rhs); }) .def( "may_contain_alias", [](SchemaInfo& self, const SchemaArgument& lhs, const SchemaArgument& rhs) { return self.may_contain_alias(lhs, rhs); }) .def( "add_argument_value", [](SchemaInfo& self, const std::string& name, const py::object& value) { c10::optional i_value = toTypeInferredIValueOptional(value); if (i_value) { // For normalization purposes there is an inconsistency within // torch.fx that turns all arguments named "self" into "input". // Thus this check ensures that those arguments are checked // correctly. if (name == "input" && !self.hasInputArgumentNamed("input")) { self.addArgumentValue("self", *i_value); } else { self.addArgumentValue(name, *i_value); } } }) .def("add_argument_values", [](SchemaInfo& self, const py::dict& values) { std::unordered_map value_map; for (const auto& key_pair : values) { IValue key = toTypeInferredIValue(key_pair.first); TORCH_INTERNAL_ASSERT( key.isString(), "Add argument value keys types should be strings."); c10::optional value = toTypeInferredIValueOptional(key_pair.second); if (value) { // For normalization purposes there is an inconsistency within // torch.fx that // turns all arguments named "self" into "input". Thus this check // ensures that those arguments are checked correctly. if (key.toStringRef() == "input" && !self.hasInputArgumentNamed("input")) { self.addArgumentValue("self", *value); } else { value_map[key.toStringRef()] = *value; } } } self.addArgumentValues(value_map); }); py::class_(m, "FunctionSchema") .def_property_readonly( "name", [](FunctionSchema& self) { return self.name(); }) .def_property_readonly( "overload_name", [](FunctionSchema& self) { return self.overload_name(); }) .def_property_readonly( "arguments", [](FunctionSchema& self) { return self.arguments(); }) .def_property_readonly( "returns", [](FunctionSchema& self) { return self.returns(); }) .def( "is_backward_compatible_with", [](const FunctionSchema& self, const FunctionSchema& old_schema) { return self.isBackwardCompatibleWith(old_schema); }) .def( "check_forward_compatible_with", [](const FunctionSchema& self, const FunctionSchema& old_schema) { std::ostringstream out; auto result = self.isForwardCompatibleWith(old_schema, out); return std::make_pair(result, out.str()); }) .def( "__eq__", [](const FunctionSchema& self, const FunctionSchema& other) { return self == other; }) .def( "__str__", [](FunctionSchema& self) { std::stringstream ss; ss << self; return ss.str(); }) .def_property_readonly( "is_mutable", [](FunctionSchema& self) { return self.is_mutable(); }); py::class_(m, "Argument") .def_property_readonly("name", [](Argument& self) { return self.name(); }) .def_property_readonly("type", [](Argument& self) { return self.type(); }) .def_property_readonly( "N", [](Argument& self) -> py::object { return (self.N()) ? py::cast(*self.N()) : py::none(); }) .def_property_readonly( "default_value", [](Argument& self) -> py::object { if (!self.default_value()) { return py::none(); } IValue v = *self.default_value(); return toPyObject(std::move(v)); }) .def( "has_default_value", [](Argument& self) -> py::bool_ { return self.default_value().has_value(); }) .def_property_readonly( "alias_info", [](Argument& self) { return self.alias_info(); }) .def_property_readonly( "is_out", [](Argument& self) { return self.is_out(); }) .def_property_readonly("kwarg_only", [](Argument& self) -> bool { return self.kwarg_only(); }); py::class_(m, "_AliasInfo") .def_property_readonly( "is_write", [](AliasInfo& self) { return self.isWrite(); }) .def_property_readonly( "before_set", [](AliasInfo& self) { std::set before_set_python; for (const auto& set : self.beforeSets()) { before_set_python.insert(py::str(set.toUnqualString())); } return before_set_python; }) .def_property_readonly("after_set", [](AliasInfo& self) { std::set after_set_python; for (const auto& set : self.afterSets()) { after_set_python.insert(py::str(set.toUnqualString())); } return after_set_python; }); m.def("_jit_get_all_schemas", []() { const std::vector>& operations = getAllOperators(); return fmap(operations, [](const std::shared_ptr& op) { return op->schema(); }); }); m.def("_jit_get_custom_class_schemas", customClassSchemasForBCCheck); m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) { auto symbol = Symbol::fromQualString(qualified_name); const auto& operations = getAllOperatorsFor(symbol); return fmap(operations, [](const std::shared_ptr& op) { return op->schema(); }); }); m.def("_is_tracing", []() { return jit::tracer::isTracing(); }); py::class_>( m, "Future") .def(py::init([](std::vector devices = {}) { return std::make_shared( c10::make_intrusive( PyObjectType::get(), std::move(devices))); })) .def( "done", // Intentionally not releasing GIL &PythonFutureWrapper::done) .def( "value", &PythonFutureWrapper::value, py::call_guard()) .def( "wait", &PythonFutureWrapper::wait, py::call_guard()) .def( "then", &PythonFutureWrapper::then, py::call_guard()) .def( "add_done_callback", &PythonFutureWrapper::add_done_callback, py::call_guard()) .def( "set_result", // Intentionally not releasing GIL &PythonFutureWrapper::markCompleted) .def( "_set_unwrap_func", // Intentionally not releasing GIL as this just does an assign [](PythonFutureWrapper& self, py::function unwrapFunc) { auto functionGuard = std::make_shared( std::move(unwrapFunc)); std::function pf = [functionGuard(std::move(functionGuard))]( const py::object& inp) { return functionGuard->func_(inp); }; self.unwrap_func = std::move(pf); }) .def( py::pickle( /* __getstate__ */ [](const PythonFutureWrapper& /* unused */) { TORCH_CHECK(false, "Can not pickle torch.futures.Future"); // Note that this return has no meaning since we always // throw, it's only here to satisfy Pybind API's // requirement. return py::make_tuple(); }, /* __setstate__ */ [](const py::tuple& /* unused */) { // NOLINT TORCH_CHECK(false, "Can not unpickle torch.futures.Future"); // Note that this return has no meaning since we always // throw, it's only here to satisfy PyBind's API // requirement. return nullptr; }), py::call_guard()); m.def("_is_alias_of", [](const py::object& self, const py::object& other) { c10::optional self_value = toTypeInferredIValueOptional(self); c10::optional other_value = toTypeInferredIValueOptional(other); // Only return true if we are certain that self and other are aliasing. if (!self_value || !other_value) { return false; } return self_value->isAliasOf(*other_value); }); m.def("_overlaps", [](const py::object& self, const py::object& other) { c10::optional self_value = toTypeInferredIValueOptional(self); c10::optional other_value = toTypeInferredIValueOptional(other); // Only return true if we are certain that self and other are overlapping. if (!self_value || !other_value) { return false; } return self_value->overlaps(*other_value); }); m.def("fork", [](const py::args& args, const py::kwargs& kwargs) { AT_ASSERT(args.size() >= 1); py::function f = py::cast(args[0]); py::tuple args_tup(args.size() - 1); for (const auto i : c10::irange(1, args.size())) { args_tup[i - 1] = args[i]; } if (jit::tracer::isTracing()) { auto graph = jit::tracer::getTracingState()->graph; auto fork_node = graph->insertNode(graph->create(prim::TracedFork, 1)); auto body_block = fork_node->addBlock(); Value* node_output = nullptr; py::object py_func_output; // Insert new trace ops into the fork op's sub-block WithInsertPoint guard(body_block); IValue output_ivalue; { tracer::WithNestedTracingFrame env_guard; // Run the user-supplied function py_func_output = f(*args_tup, **kwargs); // Convert the output of the user-supplied function to IValue. The type // information of this IValue is used both to record the correct type in // the trace. output_ivalue = toTypeInferredIValue(py_func_output); Value* out_val = jit::tracer::getValueTrace(output_ivalue); body_block->registerOutput(out_val); node_output = fork_node->output()->setType(FutureType::create(out_val->type())); } auto retval = c10::make_intrusive(output_ivalue.type()); // Record the ivalue in the tracer jit::tracer::setValueTrace(retval, node_output); // stuff the ivalue output in the Future retval->markCompleted(output_ivalue); return std::make_shared(retval); } else { auto result = toTypeInferredIValue(f(*args_tup, **kwargs)); auto retval = c10::make_intrusive(result.type()); retval->markCompleted(std::move(result)); return std::make_shared(retval); } }); m.def("wait", [](const std::shared_ptr& fut) { TORCH_CHECK(fut, "Future can't be None"); return fut->wait(); }); m.def( "_collect_all", [](const std::vector>& futures) -> std::shared_ptr { auto typePtr = futures.empty() || futures[0] == nullptr ? AnyType::get() : futures[0]->fut->elementType(); c10::List> asList( c10::FutureType::create(typePtr)); asList.reserve(futures.size()); for (const auto& f : futures) { TORCH_CHECK(f, "Future can't be None"); asList.push_back(f->fut); } return std::make_shared( c10::collectAll(asList), /* unwrap_func */ [futures](const py::object& /*unused*/) { // Throw errors when calling wait() on the returned Future if // any of the original futures would throw. // NB: PythonFutureWrapper takes an unwrap_func which serves as a // callback to evalute the value in the Future. RPC uses this // unwrap_func to check whether the returned py::object is a // RemoteException object, and re-throw the exception if it is. // By extracting the c10::ivalue::Future from PythonFutureWrapper // the unwrap_func on the original PythonFutureWrapper objects are // discarded, and hence it will return the RemoteException as an // object instead of re-throwing it. for (auto& fut : futures) { fut->wait(); } }); }, py::call_guard()); m.def("_jit_assert_is_instance", [](py::object obj, const TypePtr& type) { toIValue(std::move(obj), type); }); #if defined(C10_SUPPORTS_FATAL_SIGNAL_HANDLERS) m.def("_set_print_stack_traces_on_fatal_signal", [](bool print) { c10::FatalSignalHandler::getInstance().setPrintStackTracesOnFatalSignal( print); }); #endif // defined(C10_SUPPORTS_SIGNAL_HANDLER) initPythonCustomClassBindings(module); initPythonIRBindings(module); tracer::initPythonTracerBindings(module); initTreeViewBindings(module); initJitScriptBindings(module); initJitBackendBindings(module); initStaticModuleBindings(module); initTensorExprBindings(module); initNvFuserPythonBindings(module); setPrintHandler([](const std::string& str) { py::gil_scoped_acquire acquire; try { auto _stdout = py::module::import("sys").attr("stdout"); _stdout.attr("write")(str); } catch (py::error_already_set& e) { throw std::runtime_error(e.what()); } }); // On exit we need to reset the print handler to default one, // because otherwise prim::Print() instruction won't work for JIT modules. auto atexit = py::module_::import("atexit"); atexit.attr("register")( py::cpp_function([]() { setPrintHandler(getDefaultPrintHandler()); })); } } // namespace jit } // namespace torch