From 275a7c5dbb2cab6410f5af0ef038f1280370928c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sun, 9 Mar 2025 20:44:55 +0000 Subject: [PATCH] Revert "Add a stable TORCH_LIBRARY to C shim (#148124)" This reverts commit 327e07ac1dc3351bb5f0ad436760b83590c400aa. Reverted https://github.com/pytorch/pytorch/pull/148124 on behalf of https://github.com/malfet due to Sorry for reverting your PR, but somehow it caused test failures in newly introduced tests, see https://hud.pytorch.org/hud/pytorch/pytorch/main/1?per_page=50&name_filter=pull%20%2F%20linux-focal-cuda12.6-py3.10-gcc11-sm89%20%2F%20test%20(default%2C%201&mergeLF=true ([comment](https://github.com/pytorch/pytorch/pull/148124#issuecomment-2709057833)) --- build_variables.bzl | 1 - docs/cpp/source/Doxyfile | 1 - setup.py | 1 - .../libtorch_agnostic/__init__.py | 21 --- .../csrc/libtorch_agnostic_kernel.cpp | 127 --------------- .../libtorch_agnostic/ops.py | 38 ----- .../libtorch_agnostic_extension/setup.py | 67 -------- .../test/test_libtorch_agnostic.py | 77 --------- test/run_test.py | 5 +- test/test_cpp_extensions_aot.py | 65 -------- tools/amd_build/build_amd.py | 1 - torch/csrc/inductor/aoti_torch/c/shim.h | 49 ------ .../csrc/inductor/aoti_torch/shim_common.cpp | 138 ---------------- torch/csrc/stable/library.h | 149 ------------------ 14 files changed, 1 insertion(+), 739 deletions(-) delete mode 100644 test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/__init__.py delete mode 100644 test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/libtorch_agnostic_kernel.cpp delete mode 100644 test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py delete mode 100644 test/cpp_extensions/libtorch_agnostic_extension/setup.py delete mode 100644 test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py delete mode 100644 torch/csrc/stable/library.h diff --git a/build_variables.bzl b/build_variables.bzl index 3f72fd70a7ee..5de4bd214775 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -48,7 +48,6 @@ jit_core_headers = [ "torch/csrc/jit/frontend/schema_type_parser.h", "torch/csrc/jit/frontend/error_report.h", "torch/csrc/jit/frontend/tree.h", - "torch/csrc/stable/library.h", "torch/custom_class.h", "torch/custom_class_detail.h", "torch/library.h", diff --git a/docs/cpp/source/Doxyfile b/docs/cpp/source/Doxyfile index 01cd27663372..8df07dbbb976 100644 --- a/docs/cpp/source/Doxyfile +++ b/docs/cpp/source/Doxyfile @@ -67,7 +67,6 @@ INPUT = ../../../aten/src/ATen/ATen.h \ ../../../torch/csrc/jit/runtime/custom_operator.h \ ../../../torch/csrc/jit/serialization/import.h \ ../../../torch/csrc/jit/api/module.h \ - ../../../torch/csrc/stable/library.h \ ../../../torch/library.h \ ../../../torch/custom_class.h # Don't include .cpp files! diff --git a/setup.py b/setup.py index 61ee9363fc26..b9ff12313e9a 100644 --- a/setup.py +++ b/setup.py @@ -1274,7 +1274,6 @@ def main(): "include/c10/xpu/impl/*.h", "include/torch/*.h", "include/torch/csrc/*.h", - "include/torch/csrc/stable/*.h", "include/torch/csrc/api/include/torch/*.h", "include/torch/csrc/api/include/torch/data/*.h", "include/torch/csrc/api/include/torch/data/dataloader/*.h", diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/__init__.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/__init__.py deleted file mode 100644 index 7fa8732335cf..000000000000 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -import ctypes -from pathlib import Path - -import torch - - -so_files = list(Path(__file__).parent.glob("_C*.so")) -assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}" - -# use ctypes.CDLL instead of load_library to be able to test the unload logic -# below code is reduced from the load_library code -with torch._ops.dl_open_guard(): - loaded_lib = ctypes.CDLL(so_files[0]) - -from . import ops - - -__all__ = [ - "loaded_lib", - "ops", -] diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/libtorch_agnostic_kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/libtorch_agnostic_kernel.cpp deleted file mode 100644 index 3321def1f621..000000000000 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/libtorch_agnostic_kernel.cpp +++ /dev/null @@ -1,127 +0,0 @@ -#include -#include -#include - -using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; - -void inline sgd_math( - float* param_ptr, - float* grad_ptr, - float* out_ptr, - const float weight_decay, - const double lr, - const bool maximize, - int64_t size -){ - int64_t d = 0; - for (; d < size; d++) { - float grad_val = grad_ptr[d]; - if (maximize) grad_val = -grad_val; - if (weight_decay != 0.0){ - grad_val += param_ptr[d] * weight_decay; - } - out_ptr[d] = param_ptr[d] - grad_val * float(lr); - } -} - - -RAIIATH sgd_out_of_place( - const RAIIATH param, - const RAIIATH grad, - const float weight_decay, - const double lr, - const bool maximize) { - - int64_t param_dim; - aoti_torch_get_dim(param.get(), ¶m_dim); - - int64_t *param_sizes; - int64_t *param_strides; - aoti_torch_get_sizes(param.get(), ¶m_sizes); - aoti_torch_get_strides(param.get(), ¶m_strides); - - int32_t param_dtype; - aoti_torch_get_dtype(param.get(), ¶m_dtype); - - int32_t param_device_type; - int32_t param_device_index; - aoti_torch_get_device_type(param.get(), ¶m_device_type); - aoti_torch_get_device_index(param.get(), ¶m_device_index); - - AtenTensorHandle out; - aoti_torch_empty_strided(param_dim, param_sizes, param_strides, param_dtype, param_device_type, param_device_index, &out); - - void* param_ptr; - aoti_torch_get_data_ptr(param.get(), ¶m_ptr); - void* grad_ptr; - aoti_torch_get_data_ptr(grad.get(), &grad_ptr); - void* out_ptr; - aoti_torch_get_data_ptr(out, &out_ptr); - - auto param_fp_ptr = reinterpret_cast(param_ptr); - auto grad_fp_ptr = reinterpret_cast(grad_ptr); - auto out_fp_ptr = reinterpret_cast(out_ptr); - - int64_t param_numel; - aoti_torch_get_numel(param.get(), ¶m_numel); - - sgd_math( - param_fp_ptr, - grad_fp_ptr, - out_fp_ptr, - weight_decay, - lr, - maximize, - param_numel - ); - - return RAIIATH(out); -} - - -void boxed_sgd_out_of_place(StableIValue* stack, int64_t num_args, int64_t num_outputs) { - RAIIATH param(to(stack[0])); - RAIIATH grad(to(stack[1])); - auto weight_decay = to(stack[2]); - auto lr = to(stack[3]); - auto maximize = to(stack[4]); - - RAIIATH raiiath_res = sgd_out_of_place( - std::move(param), - std::move(grad), - float(weight_decay), - lr, - maximize); - - stack[num_args] = from(raiiath_res.release()); -} - -STABLE_TORCH_LIBRARY(libtorch_agnostic, m) { - m.def("sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor"); -} - -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { - m.impl("sgd_out_of_place", &boxed_sgd_out_of_place); -} - -RAIIATH identity(RAIIATH t) { - return std::move(t); -} - -void boxed_identity(StableIValue* stack, int64_t num_args, int64_t num_outputs) { - RAIIATH t(to(stack[0])); - RAIIATH raiiath_res = identity(std::move(t)); - stack[num_args] = from(raiiath_res.release()); -} - -STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { - m.def("identity(Tensor t) -> Tensor"); -} - -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CUDA, m) { - m.impl("identity", &boxed_identity); -} - -STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { - m.impl("identity", &boxed_identity); -} diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py deleted file mode 100644 index 2a76d0f4b17c..000000000000 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ /dev/null @@ -1,38 +0,0 @@ -import torch -from torch import Tensor - - -def sgd_out_of_place(param, grad, weight_decay, lr, maximize) -> Tensor: - """ - Computes a single step of SGD on a single parameter Tensor with grad. - - Assumes: - - param and grad are the same shape and are 1D. - - param and grad are float and on CPU - - Args: - param: a 1D tensor of floats - grad: a 1D tensor of floats - weight_decay: a python double between 0 and 1 - lr: a python double - - Returns: - a 1D float Tensor the same shape as param - - """ - return torch.ops.libtorch_agnostic.sgd_out_of_place.default( - param, grad, weight_decay, lr, maximize - ) - - -def identity(t) -> Tensor: - """ - Returns the input tensor - - Args: - t: any Tensor - - Returns: - a Tensor, the same as input. - """ - return torch.ops.libtorch_agnostic.identity.default(t) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/setup.py b/test/cpp_extensions/libtorch_agnostic_extension/setup.py deleted file mode 100644 index 5cd18f5579f9..000000000000 --- a/test/cpp_extensions/libtorch_agnostic_extension/setup.py +++ /dev/null @@ -1,67 +0,0 @@ -import distutils.command.clean -import shutil -from pathlib import Path - -from setuptools import find_packages, setup - -from torch.utils.cpp_extension import BuildExtension, CppExtension - - -ROOT_DIR = Path(__file__).parent -CSRC_DIR = ROOT_DIR / "libtorch_agnostic" / "csrc" - - -class clean(distutils.command.clean.clean): - def run(self): - # Run default behavior first - distutils.command.clean.clean.run(self) - - # Remove extension - for path in (ROOT_DIR / "libtorch_agnostic").glob("**/*.so"): - path.unlink() - # Remove build and dist and egg-info directories - dirs = [ - ROOT_DIR / "build", - ROOT_DIR / "dist", - ROOT_DIR / "libtorch_agnostic.egg-info", - ] - for path in dirs: - if path.exists(): - shutil.rmtree(str(path), ignore_errors=True) - - -def get_extension(): - extra_compile_args = { - "cxx": ["-fdiagnostics-color=always"], - } - - sources = list(CSRC_DIR.glob("**/*.cpp")) - - return [ - CppExtension( - "libtorch_agnostic._C", - sources=sorted(str(s) for s in sources), - py_limited_api=True, - extra_compile_args=extra_compile_args, - extra_link_args=[], - ) - ] - - -setup( - name="libtorch_agnostic", - version="0.0", - author="PyTorch Core Team", - description="Example of libtorch agnostic extension", - packages=find_packages(exclude=("test",)), - package_data={"libtorch_agnostic": ["*.dll", "*.dylib", "*.so"]}, - install_requires=[ - "torch", - ], - ext_modules=get_extension(), - cmdclass={ - "build_ext": BuildExtension.with_options(no_python_abi_suffix=True), - "clean": clean, - }, - options={"bdist_wheel": {"py_limited_api": "cp39"}}, -) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py deleted file mode 100644 index a33099175f1e..000000000000 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ /dev/null @@ -1,77 +0,0 @@ -# Owner(s): ["module: cpp"] - -import libtorch_agnostic # noqa: F401 - -import torch -from torch.testing._internal.common_device_type import ( - instantiate_device_type_tests, - onlyCPU, - onlyCUDA, -) -from torch.testing._internal.common_utils import run_tests, TestCase - - -class TestLibtorchAgnostic(TestCase): - @onlyCPU - def test_slow_sgd(self, device): - param = torch.rand(5, device=device) - grad = torch.rand_like(param) - weight_decay = 0.01 - lr = 0.001 - maximize = False - - new_param = libtorch_agnostic.ops.sgd_out_of_place( - param, grad, weight_decay, lr, maximize - ) - torch._fused_sgd_( - (param,), - (grad,), - (), - weight_decay=weight_decay, - momentum=0.0, - lr=lr, - dampening=0.0, - nesterov=False, - maximize=maximize, - is_first_step=False, - ) - self.assertEqual(new_param, param) - - @onlyCUDA - def test_identity_does_not_hog_memory(self, device): - def _run_identity(prior_mem): - t = torch.rand(32, 32, device=device) - self.assertGreater(torch.cuda.memory_allocated(device), prior_mem) - identi_t = libtorch_agnostic.ops.identity(t) - assert identi_t is t - - init_mem = torch.cuda.memory_allocated(device) - - for _ in range(3): - _run_identity(init_mem) - curr_mem = torch.cuda.memory_allocated(device) - self.assertEqual(curr_mem, init_mem) - - @onlyCUDA - def test_z_delete_torch_lib(self, device): - # Why the z + CUDA? THIS TEST MUST BE RUN LAST - # We are testing that unloading the library properly deletes the registrations, so running this test - # earlier will cause all other tests in this file to fail - lib = libtorch_agnostic.loaded_lib - - # code for unloading a library inspired from - # https://stackoverflow.com/questions/19547084/can-i-explicitly-close-a-ctypes-cdll - lib_handle = lib._handle - lib.dlclose(lib_handle) - - t = torch.tensor([-2.0, 0.5]) - with self.assertRaises(RuntimeError): - libtorch_agnostic.ops.identity( - t - ) # errors as identity shouldn't be registered anymore - - -instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) - -if __name__ == "__main__": - run_tests() diff --git a/test/run_test.py b/test/run_test.py index c75503682a19..ef4c7bada1bd 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1074,12 +1074,9 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja): if return_code != 0: return return_code if sys.platform != "win32": - exts_to_build = [ - (install_cmd, "no_python_abi_suffix_test"), - ] + exts_to_build = [(install_cmd, "no_python_abi_suffix_test")] if TEST_CUDA: exts_to_build.append((wheel_cmd, "python_agnostic_extension")) - exts_to_build.append((install_cmd, "libtorch_agnostic_extension")) for cmd, extension_dir in exts_to_build: return_code = shell( cmd, diff --git a/test/test_cpp_extensions_aot.py b/test/test_cpp_extensions_aot.py index e938ee28f1d3..2f4a2b9d12dc 100644 --- a/test/test_cpp_extensions_aot.py +++ b/test/test_cpp_extensions_aot.py @@ -227,71 +227,6 @@ class TestCppExtensionAOT(common.TestCase): if return_code != 0: return return_code - @unittest.skipIf(not TEST_CUDA, "some aspects of this test require CUDA") - def test_libtorch_agnostic(self): - import libtorch_agnostic - - # (1) first test that SGD CPU kernel works - param = torch.rand(5, device="cpu") - grad = torch.rand_like(param) - weight_decay = 0.01 - lr = 0.001 - maximize = False - - new_param = libtorch_agnostic.ops.sgd_out_of_place( - param, grad, weight_decay, lr, maximize - ) - torch._fused_sgd_( - (param,), - (grad,), - (), - weight_decay=weight_decay, - momentum=0.0, - lr=lr, - dampening=0.0, - nesterov=False, - maximize=maximize, - is_first_step=False, - ) - self.assertEqual(new_param, param) - - # (2) then test that we don't hog unnecessary memory - def _run_identity(prior_mem, device): - t = torch.rand(32, 32, device=device) - self.assertGreater(torch.cuda.memory_allocated(device), prior_mem) - identi_t = libtorch_agnostic.ops.identity(t) - assert identi_t is t - - device = torch.cuda.current_device() - init_mem = torch.cuda.memory_allocated(device) - - for _ in range(3): - _run_identity(init_mem, device) - curr_mem = torch.cuda.memory_allocated(device) - self.assertEqual(curr_mem, init_mem) - - # (3) last, test that unloading the torch library will deregister previous ops - lib = libtorch_agnostic.loaded_lib - - # code for unloading a library inspired from - # https://stackoverflow.com/questions/19547084/can-i-explicitly-close-a-ctypes-cdll - lib_handle = lib._handle - lib.dlclose(lib_handle) - - t = torch.tensor([-2.0, 0.5]) - with self.assertRaises(RuntimeError): - libtorch_agnostic.ops.identity(t) - - # finally, clean up the folder - cmd = [sys.executable, "setup.py", "clean"] - return_code = shell( - cmd, - cwd=os.path.join("cpp_extensions", "libtorch_agnostic_extension"), - env=os.environ.copy(), - ) - if return_code != 0: - return return_code - @torch.testing._internal.common_utils.markDynamoStrictTest class TestPybindTypeCasters(common.TestCase): diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index 5a9aaf0aa6f7..17463e84aa08 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -108,7 +108,6 @@ includes = [ "aten/src/THC/CMakeLists.txt", "torch/*", "tools/autograd/templates/python_variable_methods.cpp", - "torch/csrc/stable/*", ] includes = [os.path.join(proj_dir, include) for include in includes] diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 179d437455f3..47a2290100a3 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -623,55 +623,6 @@ AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle( const char* launch_prefix, const char* kernel_name); -// helpers for converting between StableIValue and actual IValues -using StableIValue = uint64_t; - -class TorchLibraryOpaque; -using TorchLibraryHandle = TorchLibraryOpaque*; - -// stable corollary to torch::Library constructor with Kind::IMPL -// will create a new torch::Library object on the heap -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_init_impl( - const char* ns, - const char* k, - const char* file, - uint32_t line, - TorchLibraryHandle* ret_new_torch_lib); - -// stable corollary to torch::Library constructor with Kind::DEF -// will create a new torch::Library object on the heap -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_init_def( - const char* ns, - const char* file, - uint32_t line, - TorchLibraryHandle* ret_new_torch_lib); - -// stable corollary to torch::Library constructor with Kind::FRAGMENT -// will create a new torch::Library object on the heap -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_init_fragment( - const char* ns, - const char* file, - uint32_t line, - TorchLibraryHandle* ret_new_torch_lib); - -// stable corollary to torch::Library method m.impl() -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl( - TorchLibraryHandle self, - const char* name, - void (*fn)(StableIValue*, int64_t, int64_t)); - -// stable corollary to torch::Library method m.def() -AOTI_TORCH_EXPORT AOTITorchError -aoti_torch_library_def(TorchLibraryHandle self, const char* name); - -// the above stable constructors for torch::Library add Library objects -// to the heap. if you are calling those functions directly, please use -// this function to free the Library's memory. The more user friendly -// alternative is to use StableLibrary, which will free its handle upon -// destruction -AOTI_TORCH_EXPORT AOTITorchError -aoti_torch_delete_library_object(TorchLibraryHandle tlh); - #ifdef USE_CUDA struct CUDAGuardOpaque; diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index ac04970a3ac3..68cf75883eaf 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -1,12 +1,10 @@ #include #include -#include #include #include #include #include #include -#include #include #include #include @@ -15,14 +13,10 @@ #include #include #include -#include -#include #include #include -#include #include #include -#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -1306,135 +1300,3 @@ AOTITorchError aoti_torch_cpu__weight_int4pack_mm_cpu_tensor( *ret0 = new_tensor_handle(std::move(tmp_result)); }); } - -class StableIValueBoxedKernel : public c10::OperatorKernel { - public: - StableIValueBoxedKernel(void (*fn)(StableIValue*, int64_t, int64_t)) - : fn_(fn) {} - - void operator()( - const c10::OperatorHandle& op, - c10::DispatchKeySet keyset, - torch::jit::Stack* stack) { - const auto& schema = op.schema(); - const auto num_returns = schema.returns().size(); - const auto num_arguments = schema.arguments().size(); - - std::vector ministack(num_arguments + num_returns); - - for (const auto idx : c10::irange(num_arguments)) { - const c10::IValue& arg = torch::jit::pop(stack); - const auto ministack_idx = num_arguments - idx - 1; - if (arg.isInt()) { - ministack[ministack_idx] = from(arg.toInt()); - } else if (arg.isDouble()) { - ministack[ministack_idx] = from(arg.toDouble()); - } else if (arg.isBool()) { - ministack[ministack_idx] = from(arg.toBool()); - } else if (arg.isNone()) { - ministack[ministack_idx] = from(nullptr); - } else if (arg.isTensor()) { - AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle( - std::move(const_cast(arg.toTensor()))); - ministack[ministack_idx] = from(ath); - } else { - TORCH_CHECK(false, "Other types of IValues not yet handled!"); - } - } - - // boxed function is going to take a stack of StableIValues, cast them to - // our schema values, and run the function and modify the StableIValue stack - fn_(ministack.data(), num_arguments, num_returns); - - // read the output from the end of the stack and wrap that back into - // IValue from StableIValue - for (size_t idx = 0; idx < num_returns; idx++) { - const c10::TypePtr& ret_type = schema.returns()[idx].type(); - if (*ret_type == *c10::getTypePtr()) { - auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle( - to(ministack[num_arguments + idx])); - at::Tensor out = *torch::aot_inductor::tensor_handle_to_tensor_pointer( - ret_raiiath.get()); - torch::jit::push(stack, c10::IValue(out)); - } else { - TORCH_CHECK(false, "Only Tensor return types are currently supported!"); - } - } - } - - private: - void (*fn_)(StableIValue*, int64_t, int64_t); -}; - -AOTITorchError aoti_torch_library_init_impl( - const char* ns, - const char* k, - const char* file, - uint32_t line, - TorchLibraryHandle* ret_new_torch_lib) { - AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ - *ret_new_torch_lib = - reinterpret_cast(new torch::Library( - torch::Library::Kind::IMPL, - std::string(ns), - c10::parseDispatchKey(std::string(k)), - file, - line)); - }); -} - -AOTITorchError aoti_torch_library_init_def( - const char* ns, - const char* file, - uint32_t line, - TorchLibraryHandle* ret_new_torch_lib) { - AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ - *ret_new_torch_lib = - reinterpret_cast(new torch::Library( - torch::Library::Kind::DEF, - std::string(ns), - std::nullopt, - file, - line)); - }); -} - -AOTITorchError aoti_torch_library_init_fragment( - const char* ns, - const char* file, - uint32_t line, - TorchLibraryHandle* ret_new_torch_lib) { - AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ - *ret_new_torch_lib = - reinterpret_cast(new torch::Library( - torch::Library::Kind::FRAGMENT, - std::string(ns), - std::nullopt, - file, - line)); - }); -} - -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl( - TorchLibraryHandle self, - const char* name, - void (*fn)(StableIValue*, int64_t, int64_t)) { - AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ - reinterpret_cast(self)->impl( - name, - torch::CppFunction::makeFromBoxedFunctor( - std::make_unique(fn))); - }); -} - -AOTI_TORCH_EXPORT AOTITorchError -aoti_torch_library_def(TorchLibraryHandle self, const char* name) { - AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE( - { reinterpret_cast(self)->def(name); }); -} - -AOTI_TORCH_EXPORT AOTITorchError -aoti_torch_delete_library_object(TorchLibraryHandle tlh) { - AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE( - { delete reinterpret_cast(tlh); }); -} diff --git a/torch/csrc/stable/library.h b/torch/csrc/stable/library.h deleted file mode 100644 index 4e145816e26a..000000000000 --- a/torch/csrc/stable/library.h +++ /dev/null @@ -1,149 +0,0 @@ -// this file can only have stable stuff! Akin to shim.h -// but unlike shim.h, this file can contain header-only C++ -// code for better UX. - -#include - -template -StableIValue from(T val) { - static_assert( - sizeof(T) <= sizeof(StableIValue), - "StableLibrary stack does not support parameter types larger than 64 bits."); - return *reinterpret_cast(&val); -} - -template -T to(StableIValue val) { - return *reinterpret_cast(&val); -} -// end to helpers for converting between StableIValue and actual IValues - -class TORCH_API StableLibrary final { - private: - TorchLibraryHandle lib_; - - public: - enum class Kind { - DEF, - IMPL, - FRAGMENT, - }; - - // constructor - /// \private - /// - /// Use STABLE_TORCH_LIBRARY or STABLE_TORCH_LIBRARY_IMPL() instead of using - /// these constructors directly - StableLibrary( - Kind kind, - const char* ns, - const char* k, - const char* file, - uint32_t line) { - if (kind == Kind::IMPL) { - aoti_torch_library_init_impl(ns, k, file, line, &lib_); - } else if (kind == Kind::DEF) { - aoti_torch_library_init_def(ns, file, line, &lib_); - } else { // kind == FRAGMENT - aoti_torch_library_init_fragment(ns, file, line, &lib_); - } - } - - // do not permit copy - StableLibrary(const StableLibrary&) = delete; - StableLibrary& operator=(const StableLibrary&) = delete; - - // do not permit move - StableLibrary(StableLibrary&& other) = delete; - StableLibrary& operator=(StableLibrary&& other) = delete; - - ~StableLibrary() { - aoti_torch_delete_library_object(lib_); - } - - StableLibrary& impl( - const char* name, - void (*fn)(StableIValue*, int64_t, int64_t)) { - aoti_torch_library_impl(lib_, name, fn); - return *this; - } - - StableLibrary& def(const char* name) { - aoti_torch_library_def(lib_, name); - return *this; - } -}; - -class TORCH_API StableTorchLibraryInit final { - private: - using InitFn = void(StableLibrary&); - StableLibrary lib_; - - public: - StableTorchLibraryInit( - StableLibrary::Kind kind, - InitFn* fn, - const char* ns, - const char* k, - const char* file, - uint32_t line) - : lib_(kind, ns, k, file, line) { - fn(lib_); - } -}; - -// macros copied from c10/macros/Macros.h -#ifdef __COUNTER__ -#define STABLE_UID __COUNTER__ -#else -#define STABLE_UID __LINE__ -#endif - -#define STABLE_CONCATENATE_IMPL(s1, s2) s1##s2 -#define STABLE_CONCATENATE(s1, s2) STABLE_CONCATENATE_IMPL(s1, s2) -// end of macros copied from c10/macros/Macros.h - -#define STABLE_TORCH_LIBRARY_IMPL(ns, k, m) \ - _STABLE_TORCH_LIBRARY_IMPL(ns, k, m, STABLE_UID) - -#define _STABLE_TORCH_LIBRARY_IMPL(ns, k, m, uid) \ - static void STABLE_CONCATENATE( \ - STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(StableLibrary&); \ - static const StableTorchLibraryInit STABLE_CONCATENATE( \ - STABLE_TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)( \ - StableLibrary::Kind::IMPL, \ - &STABLE_CONCATENATE(STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid), \ - #ns, \ - #k, \ - __FILE__, \ - __LINE__); \ - void STABLE_CONCATENATE( \ - STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(StableLibrary & m) - -#define STABLE_TORCH_LIBRARY(ns, m) \ - static void STABLE_TORCH_LIBRARY_init_##ns(StableLibrary&); \ - static const StableTorchLibraryInit STABLE_TORCH_LIBRARY_static_init_##ns( \ - StableLibrary::Kind::DEF, \ - &STABLE_TORCH_LIBRARY_init_##ns, \ - #ns, \ - nullptr, \ - __FILE__, \ - __LINE__); \ - void STABLE_TORCH_LIBRARY_init_##ns(StableLibrary& m) - -#define STABLE_TORCH_LIBRARY_FRAGMENT(ns, m) \ - _STABLE_TORCH_LIBRARY_FRAGMENT(ns, m, STABLE_UID) - -#define _STABLE_TORCH_LIBRARY_FRAGMENT(ns, m, uid) \ - static void STABLE_CONCATENATE( \ - STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(StableLibrary&); \ - static const StableTorchLibraryInit STABLE_CONCATENATE( \ - STABLE_TORCH_LIBRARY_FRAGMENT_static_init_##ns##_, uid)( \ - StableLibrary::Kind::FRAGMENT, \ - &STABLE_CONCATENATE(STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid), \ - #ns, \ - nullptr, \ - __FILE__, \ - __LINE__); \ - void STABLE_CONCATENATE( \ - STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(StableLibrary & m)