#include #include #include #include #include #include #include #include namespace torch::utils { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) static std::array thp_qscheme_array; void initializeQSchemes() { auto torch_module = THPObjectPtr(PyImport_ImportModule("torch")); if (!torch_module) { throw python_error(); } for (const auto i : c10::irange(at::COMPILE_TIME_NUM_QSCHEMES)) { auto qscheme = static_cast(i); PyObject* qscheme_obj = THPQScheme_New(qscheme, toString(qscheme)); thp_qscheme_array[static_cast(qscheme)] = qscheme_obj; Py_INCREF(qscheme_obj); if (PyModule_AddObject( torch_module, toString(qscheme).c_str(), qscheme_obj) != 0) { throw python_error(); } } } PyObject* getTHPQScheme(at::QScheme qscheme) { auto qscheme_ = thp_qscheme_array[static_cast(qscheme)]; if (!qscheme_) { throw std::invalid_argument("unsupported QScheme"); } return qscheme_; } } // namespace torch::utils