Add a stable TORCH_LIBRARY to C shim (#148124)

This PR adds two main parts:
- shim.h stable C APIs into torch::Library APIs
- a higher level API in torch/csrc/stable/library.h that calls into this shim.h + otherwise is self contained

Goal: custom kernel writers should be able to call the apis in the directories above in order to register their library in a way that allows their custom extension to run with a different libtorch version than it was built with.

Subplots resolved:

- Do we want a whole separate StableLibrary or do we want to freeze torch::Library and add `m.stable_impl(cstring, void (*fn)(void **, int64_t, int64_t)` into it
    - Yes, we want a separate StableLibrary. We cannot freeze Library and it is NOT header only.
- Should I use unint64_t as the common denominator instead of void* to support 32bit architectures better?
    -  Yes, and done
- Should I add a stable `def` and `fragment` when those can be done in python?
    - I think we do want these --- and now they're done
- Where should library_stable_impl.cpp live? -- no longer relevant
- I need some solid test cases to make sure everything's going ok. I've intentionally thrown in a bunch of random dtypes into the signature, but I still haven't tested returning multiple things, returning nothing, complex dtypes, etc.
    - Have since tested all the torch library endpoints. the others can be tested in a followup to separate components that need to be in shim.h vs can be added later

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148124
Approved by: https://github.com/albanD, https://github.com/zou3519
This commit is contained in:
Jane Xu
2025-03-08 21:55:51 -08:00
committed by PyTorch MergeBot
parent 685fb37713
commit 327e07ac1d
14 changed files with 739 additions and 1 deletions

View File

@ -48,6 +48,7 @@ jit_core_headers = [
"torch/csrc/jit/frontend/schema_type_parser.h",
"torch/csrc/jit/frontend/error_report.h",
"torch/csrc/jit/frontend/tree.h",
"torch/csrc/stable/library.h",
"torch/custom_class.h",
"torch/custom_class_detail.h",
"torch/library.h",

View File

@ -67,6 +67,7 @@ INPUT = ../../../aten/src/ATen/ATen.h \
../../../torch/csrc/jit/runtime/custom_operator.h \
../../../torch/csrc/jit/serialization/import.h \
../../../torch/csrc/jit/api/module.h \
../../../torch/csrc/stable/library.h \
../../../torch/library.h \
../../../torch/custom_class.h
# Don't include .cpp files!

View File

@ -1274,6 +1274,7 @@ def main():
"include/c10/xpu/impl/*.h",
"include/torch/*.h",
"include/torch/csrc/*.h",
"include/torch/csrc/stable/*.h",
"include/torch/csrc/api/include/torch/*.h",
"include/torch/csrc/api/include/torch/data/*.h",
"include/torch/csrc/api/include/torch/data/dataloader/*.h",

View File

@ -0,0 +1,21 @@
import ctypes
from pathlib import Path
import torch
so_files = list(Path(__file__).parent.glob("_C*.so"))
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
# use ctypes.CDLL instead of load_library to be able to test the unload logic
# below code is reduced from the load_library code
with torch._ops.dl_open_guard():
loaded_lib = ctypes.CDLL(so_files[0])
from . import ops
__all__ = [
"loaded_lib",
"ops",
]

View File

@ -0,0 +1,127 @@
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_runtime/utils.h>
#include <torch/csrc/stable/library.h>
using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle;
void inline sgd_math(
float* param_ptr,
float* grad_ptr,
float* out_ptr,
const float weight_decay,
const double lr,
const bool maximize,
int64_t size
){
int64_t d = 0;
for (; d < size; d++) {
float grad_val = grad_ptr[d];
if (maximize) grad_val = -grad_val;
if (weight_decay != 0.0){
grad_val += param_ptr[d] * weight_decay;
}
out_ptr[d] = param_ptr[d] - grad_val * float(lr);
}
}
RAIIATH sgd_out_of_place(
const RAIIATH param,
const RAIIATH grad,
const float weight_decay,
const double lr,
const bool maximize) {
int64_t param_dim;
aoti_torch_get_dim(param.get(), &param_dim);
int64_t *param_sizes;
int64_t *param_strides;
aoti_torch_get_sizes(param.get(), &param_sizes);
aoti_torch_get_strides(param.get(), &param_strides);
int32_t param_dtype;
aoti_torch_get_dtype(param.get(), &param_dtype);
int32_t param_device_type;
int32_t param_device_index;
aoti_torch_get_device_type(param.get(), &param_device_type);
aoti_torch_get_device_index(param.get(), &param_device_index);
AtenTensorHandle out;
aoti_torch_empty_strided(param_dim, param_sizes, param_strides, param_dtype, param_device_type, param_device_index, &out);
void* param_ptr;
aoti_torch_get_data_ptr(param.get(), &param_ptr);
void* grad_ptr;
aoti_torch_get_data_ptr(grad.get(), &grad_ptr);
void* out_ptr;
aoti_torch_get_data_ptr(out, &out_ptr);
auto param_fp_ptr = reinterpret_cast<float*>(param_ptr);
auto grad_fp_ptr = reinterpret_cast<float*>(grad_ptr);
auto out_fp_ptr = reinterpret_cast<float*>(out_ptr);
int64_t param_numel;
aoti_torch_get_numel(param.get(), &param_numel);
sgd_math(
param_fp_ptr,
grad_fp_ptr,
out_fp_ptr,
weight_decay,
lr,
maximize,
param_numel
);
return RAIIATH(out);
}
void boxed_sgd_out_of_place(StableIValue* stack, int64_t num_args, int64_t num_outputs) {
RAIIATH param(to<AtenTensorHandle>(stack[0]));
RAIIATH grad(to<AtenTensorHandle>(stack[1]));
auto weight_decay = to<double>(stack[2]);
auto lr = to<double>(stack[3]);
auto maximize = to<bool>(stack[4]);
RAIIATH raiiath_res = sgd_out_of_place(
std::move(param),
std::move(grad),
float(weight_decay),
lr,
maximize);
stack[num_args] = from(raiiath_res.release());
}
STABLE_TORCH_LIBRARY(libtorch_agnostic, m) {
m.def("sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
m.impl("sgd_out_of_place", &boxed_sgd_out_of_place);
}
RAIIATH identity(RAIIATH t) {
return std::move(t);
}
void boxed_identity(StableIValue* stack, int64_t num_args, int64_t num_outputs) {
RAIIATH t(to<AtenTensorHandle>(stack[0]));
RAIIATH raiiath_res = identity(std::move(t));
stack[num_args] = from(raiiath_res.release());
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("identity(Tensor t) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CUDA, m) {
m.impl("identity", &boxed_identity);
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
m.impl("identity", &boxed_identity);
}

View File

@ -0,0 +1,38 @@
import torch
from torch import Tensor
def sgd_out_of_place(param, grad, weight_decay, lr, maximize) -> Tensor:
"""
Computes a single step of SGD on a single parameter Tensor with grad.
Assumes:
- param and grad are the same shape and are 1D.
- param and grad are float and on CPU
Args:
param: a 1D tensor of floats
grad: a 1D tensor of floats
weight_decay: a python double between 0 and 1
lr: a python double
Returns:
a 1D float Tensor the same shape as param
"""
return torch.ops.libtorch_agnostic.sgd_out_of_place.default(
param, grad, weight_decay, lr, maximize
)
def identity(t) -> Tensor:
"""
Returns the input tensor
Args:
t: any Tensor
Returns:
a Tensor, the same as input.
"""
return torch.ops.libtorch_agnostic.identity.default(t)

View File

@ -0,0 +1,67 @@
import distutils.command.clean
import shutil
from pathlib import Path
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
ROOT_DIR = Path(__file__).parent
CSRC_DIR = ROOT_DIR / "libtorch_agnostic" / "csrc"
class clean(distutils.command.clean.clean):
def run(self):
# Run default behavior first
distutils.command.clean.clean.run(self)
# Remove extension
for path in (ROOT_DIR / "libtorch_agnostic").glob("**/*.so"):
path.unlink()
# Remove build and dist and egg-info directories
dirs = [
ROOT_DIR / "build",
ROOT_DIR / "dist",
ROOT_DIR / "libtorch_agnostic.egg-info",
]
for path in dirs:
if path.exists():
shutil.rmtree(str(path), ignore_errors=True)
def get_extension():
extra_compile_args = {
"cxx": ["-fdiagnostics-color=always"],
}
sources = list(CSRC_DIR.glob("**/*.cpp"))
return [
CppExtension(
"libtorch_agnostic._C",
sources=sorted(str(s) for s in sources),
py_limited_api=True,
extra_compile_args=extra_compile_args,
extra_link_args=[],
)
]
setup(
name="libtorch_agnostic",
version="0.0",
author="PyTorch Core Team",
description="Example of libtorch agnostic extension",
packages=find_packages(exclude=("test",)),
package_data={"libtorch_agnostic": ["*.dll", "*.dylib", "*.so"]},
install_requires=[
"torch",
],
ext_modules=get_extension(),
cmdclass={
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
"clean": clean,
},
options={"bdist_wheel": {"py_limited_api": "cp39"}},
)

View File

@ -0,0 +1,77 @@
# Owner(s): ["module: cpp"]
import libtorch_agnostic # noqa: F401
import torch
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCPU,
onlyCUDA,
)
from torch.testing._internal.common_utils import run_tests, TestCase
class TestLibtorchAgnostic(TestCase):
@onlyCPU
def test_slow_sgd(self, device):
param = torch.rand(5, device=device)
grad = torch.rand_like(param)
weight_decay = 0.01
lr = 0.001
maximize = False
new_param = libtorch_agnostic.ops.sgd_out_of_place(
param, grad, weight_decay, lr, maximize
)
torch._fused_sgd_(
(param,),
(grad,),
(),
weight_decay=weight_decay,
momentum=0.0,
lr=lr,
dampening=0.0,
nesterov=False,
maximize=maximize,
is_first_step=False,
)
self.assertEqual(new_param, param)
@onlyCUDA
def test_identity_does_not_hog_memory(self, device):
def _run_identity(prior_mem):
t = torch.rand(32, 32, device=device)
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
identi_t = libtorch_agnostic.ops.identity(t)
assert identi_t is t
init_mem = torch.cuda.memory_allocated(device)
for _ in range(3):
_run_identity(init_mem)
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
@onlyCUDA
def test_z_delete_torch_lib(self, device):
# Why the z + CUDA? THIS TEST MUST BE RUN LAST
# We are testing that unloading the library properly deletes the registrations, so running this test
# earlier will cause all other tests in this file to fail
lib = libtorch_agnostic.loaded_lib
# code for unloading a library inspired from
# https://stackoverflow.com/questions/19547084/can-i-explicitly-close-a-ctypes-cdll
lib_handle = lib._handle
lib.dlclose(lib_handle)
t = torch.tensor([-2.0, 0.5])
with self.assertRaises(RuntimeError):
libtorch_agnostic.ops.identity(
t
) # errors as identity shouldn't be registered anymore
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
if __name__ == "__main__":
run_tests()

View File

@ -1074,9 +1074,12 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja):
if return_code != 0:
return return_code
if sys.platform != "win32":
exts_to_build = [(install_cmd, "no_python_abi_suffix_test")]
exts_to_build = [
(install_cmd, "no_python_abi_suffix_test"),
]
if TEST_CUDA:
exts_to_build.append((wheel_cmd, "python_agnostic_extension"))
exts_to_build.append((install_cmd, "libtorch_agnostic_extension"))
for cmd, extension_dir in exts_to_build:
return_code = shell(
cmd,

View File

@ -227,6 +227,71 @@ class TestCppExtensionAOT(common.TestCase):
if return_code != 0:
return return_code
@unittest.skipIf(not TEST_CUDA, "some aspects of this test require CUDA")
def test_libtorch_agnostic(self):
import libtorch_agnostic
# (1) first test that SGD CPU kernel works
param = torch.rand(5, device="cpu")
grad = torch.rand_like(param)
weight_decay = 0.01
lr = 0.001
maximize = False
new_param = libtorch_agnostic.ops.sgd_out_of_place(
param, grad, weight_decay, lr, maximize
)
torch._fused_sgd_(
(param,),
(grad,),
(),
weight_decay=weight_decay,
momentum=0.0,
lr=lr,
dampening=0.0,
nesterov=False,
maximize=maximize,
is_first_step=False,
)
self.assertEqual(new_param, param)
# (2) then test that we don't hog unnecessary memory
def _run_identity(prior_mem, device):
t = torch.rand(32, 32, device=device)
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
identi_t = libtorch_agnostic.ops.identity(t)
assert identi_t is t
device = torch.cuda.current_device()
init_mem = torch.cuda.memory_allocated(device)
for _ in range(3):
_run_identity(init_mem, device)
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
# (3) last, test that unloading the torch library will deregister previous ops
lib = libtorch_agnostic.loaded_lib
# code for unloading a library inspired from
# https://stackoverflow.com/questions/19547084/can-i-explicitly-close-a-ctypes-cdll
lib_handle = lib._handle
lib.dlclose(lib_handle)
t = torch.tensor([-2.0, 0.5])
with self.assertRaises(RuntimeError):
libtorch_agnostic.ops.identity(t)
# finally, clean up the folder
cmd = [sys.executable, "setup.py", "clean"]
return_code = shell(
cmd,
cwd=os.path.join("cpp_extensions", "libtorch_agnostic_extension"),
env=os.environ.copy(),
)
if return_code != 0:
return return_code
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestPybindTypeCasters(common.TestCase):

View File

@ -108,6 +108,7 @@ includes = [
"aten/src/THC/CMakeLists.txt",
"torch/*",
"tools/autograd/templates/python_variable_methods.cpp",
"torch/csrc/stable/*",
]
includes = [os.path.join(proj_dir, include) for include in includes]

View File

@ -623,6 +623,55 @@ AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle(
const char* launch_prefix,
const char* kernel_name);
// helpers for converting between StableIValue and actual IValues
using StableIValue = uint64_t;
class TorchLibraryOpaque;
using TorchLibraryHandle = TorchLibraryOpaque*;
// stable corollary to torch::Library constructor with Kind::IMPL
// will create a new torch::Library object on the heap
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_init_impl(
const char* ns,
const char* k,
const char* file,
uint32_t line,
TorchLibraryHandle* ret_new_torch_lib);
// stable corollary to torch::Library constructor with Kind::DEF
// will create a new torch::Library object on the heap
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_init_def(
const char* ns,
const char* file,
uint32_t line,
TorchLibraryHandle* ret_new_torch_lib);
// stable corollary to torch::Library constructor with Kind::FRAGMENT
// will create a new torch::Library object on the heap
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_init_fragment(
const char* ns,
const char* file,
uint32_t line,
TorchLibraryHandle* ret_new_torch_lib);
// stable corollary to torch::Library method m.impl()
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
TorchLibraryHandle self,
const char* name,
void (*fn)(StableIValue*, int64_t, int64_t));
// stable corollary to torch::Library method m.def()
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_library_def(TorchLibraryHandle self, const char* name);
// the above stable constructors for torch::Library add Library objects
// to the heap. if you are calling those functions directly, please use
// this function to free the Library's memory. The more user friendly
// alternative is to use StableLibrary, which will free its handle upon
// destruction
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_delete_library_object(TorchLibraryHandle tlh);
#ifdef USE_CUDA
struct CUDAGuardOpaque;

View File

@ -1,10 +1,12 @@
#include <ATen/native/quantized/cpu/qlinear.h>
#include <c10/core/DeviceType.h>
#include <c10/core/DispatchKey.h>
#include <c10/core/GradMode.h>
#include <c10/core/Layout.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#include <torch/csrc/inductor/aoti_runtime/utils.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_torch/mkldnn_tensor.h>
#include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h>
@ -13,10 +15,14 @@
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <torch/csrc/inductor/inductor_ops.h>
#include <torch/csrc/jit/serialization/pickle.h>
#include <torch/csrc/stable/library.h>
#include <torch/library.h>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <iostream>
#include <vector>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -1300,3 +1306,135 @@ AOTITorchError aoti_torch_cpu__weight_int4pack_mm_cpu_tensor(
*ret0 = new_tensor_handle(std::move(tmp_result));
});
}
class StableIValueBoxedKernel : public c10::OperatorKernel {
public:
StableIValueBoxedKernel(void (*fn)(StableIValue*, int64_t, int64_t))
: fn_(fn) {}
void operator()(
const c10::OperatorHandle& op,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack) {
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
const auto num_arguments = schema.arguments().size();
std::vector<StableIValue> ministack(num_arguments + num_returns);
for (const auto idx : c10::irange(num_arguments)) {
const c10::IValue& arg = torch::jit::pop(stack);
const auto ministack_idx = num_arguments - idx - 1;
if (arg.isInt()) {
ministack[ministack_idx] = from(arg.toInt());
} else if (arg.isDouble()) {
ministack[ministack_idx] = from(arg.toDouble());
} else if (arg.isBool()) {
ministack[ministack_idx] = from(arg.toBool());
} else if (arg.isNone()) {
ministack[ministack_idx] = from(nullptr);
} else if (arg.isTensor()) {
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
std::move(const_cast<at::Tensor&>(arg.toTensor())));
ministack[ministack_idx] = from(ath);
} else {
TORCH_CHECK(false, "Other types of IValues not yet handled!");
}
}
// boxed function is going to take a stack of StableIValues, cast them to
// our schema values, and run the function and modify the StableIValue stack
fn_(ministack.data(), num_arguments, num_returns);
// read the output from the end of the stack and wrap that back into
// IValue from StableIValue
for (size_t idx = 0; idx < num_returns; idx++) {
const c10::TypePtr& ret_type = schema.returns()[idx].type();
if (*ret_type == *c10::getTypePtr<at::Tensor>()) {
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
to<AtenTensorHandle>(ministack[num_arguments + idx]));
at::Tensor out = *torch::aot_inductor::tensor_handle_to_tensor_pointer(
ret_raiiath.get());
torch::jit::push(stack, c10::IValue(out));
} else {
TORCH_CHECK(false, "Only Tensor return types are currently supported!");
}
}
}
private:
void (*fn_)(StableIValue*, int64_t, int64_t);
};
AOTITorchError aoti_torch_library_init_impl(
const char* ns,
const char* k,
const char* file,
uint32_t line,
TorchLibraryHandle* ret_new_torch_lib) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
*ret_new_torch_lib =
reinterpret_cast<TorchLibraryOpaque*>(new torch::Library(
torch::Library::Kind::IMPL,
std::string(ns),
c10::parseDispatchKey(std::string(k)),
file,
line));
});
}
AOTITorchError aoti_torch_library_init_def(
const char* ns,
const char* file,
uint32_t line,
TorchLibraryHandle* ret_new_torch_lib) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
*ret_new_torch_lib =
reinterpret_cast<TorchLibraryOpaque*>(new torch::Library(
torch::Library::Kind::DEF,
std::string(ns),
std::nullopt,
file,
line));
});
}
AOTITorchError aoti_torch_library_init_fragment(
const char* ns,
const char* file,
uint32_t line,
TorchLibraryHandle* ret_new_torch_lib) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
*ret_new_torch_lib =
reinterpret_cast<TorchLibraryOpaque*>(new torch::Library(
torch::Library::Kind::FRAGMENT,
std::string(ns),
std::nullopt,
file,
line));
});
}
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
TorchLibraryHandle self,
const char* name,
void (*fn)(StableIValue*, int64_t, int64_t)) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
reinterpret_cast<torch::Library*>(self)->impl(
name,
torch::CppFunction::makeFromBoxedFunctor(
std::make_unique<StableIValueBoxedKernel>(fn)));
});
}
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_library_def(TorchLibraryHandle self, const char* name) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ reinterpret_cast<torch::Library*>(self)->def(name); });
}
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_delete_library_object(TorchLibraryHandle tlh) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ delete reinterpret_cast<torch::Library*>(tlh); });
}

149
torch/csrc/stable/library.h Normal file
View File

@ -0,0 +1,149 @@
// this file can only have stable stuff! Akin to shim.h
// but unlike shim.h, this file can contain header-only C++
// code for better UX.
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
template <typename T>
StableIValue from(T val) {
static_assert(
sizeof(T) <= sizeof(StableIValue),
"StableLibrary stack does not support parameter types larger than 64 bits.");
return *reinterpret_cast<StableIValue*>(&val);
}
template <typename T>
T to(StableIValue val) {
return *reinterpret_cast<T*>(&val);
}
// end to helpers for converting between StableIValue and actual IValues
class TORCH_API StableLibrary final {
private:
TorchLibraryHandle lib_;
public:
enum class Kind {
DEF,
IMPL,
FRAGMENT,
};
// constructor
/// \private
///
/// Use STABLE_TORCH_LIBRARY or STABLE_TORCH_LIBRARY_IMPL() instead of using
/// these constructors directly
StableLibrary(
Kind kind,
const char* ns,
const char* k,
const char* file,
uint32_t line) {
if (kind == Kind::IMPL) {
aoti_torch_library_init_impl(ns, k, file, line, &lib_);
} else if (kind == Kind::DEF) {
aoti_torch_library_init_def(ns, file, line, &lib_);
} else { // kind == FRAGMENT
aoti_torch_library_init_fragment(ns, file, line, &lib_);
}
}
// do not permit copy
StableLibrary(const StableLibrary&) = delete;
StableLibrary& operator=(const StableLibrary&) = delete;
// do not permit move
StableLibrary(StableLibrary&& other) = delete;
StableLibrary& operator=(StableLibrary&& other) = delete;
~StableLibrary() {
aoti_torch_delete_library_object(lib_);
}
StableLibrary& impl(
const char* name,
void (*fn)(StableIValue*, int64_t, int64_t)) {
aoti_torch_library_impl(lib_, name, fn);
return *this;
}
StableLibrary& def(const char* name) {
aoti_torch_library_def(lib_, name);
return *this;
}
};
class TORCH_API StableTorchLibraryInit final {
private:
using InitFn = void(StableLibrary&);
StableLibrary lib_;
public:
StableTorchLibraryInit(
StableLibrary::Kind kind,
InitFn* fn,
const char* ns,
const char* k,
const char* file,
uint32_t line)
: lib_(kind, ns, k, file, line) {
fn(lib_);
}
};
// macros copied from c10/macros/Macros.h
#ifdef __COUNTER__
#define STABLE_UID __COUNTER__
#else
#define STABLE_UID __LINE__
#endif
#define STABLE_CONCATENATE_IMPL(s1, s2) s1##s2
#define STABLE_CONCATENATE(s1, s2) STABLE_CONCATENATE_IMPL(s1, s2)
// end of macros copied from c10/macros/Macros.h
#define STABLE_TORCH_LIBRARY_IMPL(ns, k, m) \
_STABLE_TORCH_LIBRARY_IMPL(ns, k, m, STABLE_UID)
#define _STABLE_TORCH_LIBRARY_IMPL(ns, k, m, uid) \
static void STABLE_CONCATENATE( \
STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(StableLibrary&); \
static const StableTorchLibraryInit STABLE_CONCATENATE( \
STABLE_TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)( \
StableLibrary::Kind::IMPL, \
&STABLE_CONCATENATE(STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid), \
#ns, \
#k, \
__FILE__, \
__LINE__); \
void STABLE_CONCATENATE( \
STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(StableLibrary & m)
#define STABLE_TORCH_LIBRARY(ns, m) \
static void STABLE_TORCH_LIBRARY_init_##ns(StableLibrary&); \
static const StableTorchLibraryInit STABLE_TORCH_LIBRARY_static_init_##ns( \
StableLibrary::Kind::DEF, \
&STABLE_TORCH_LIBRARY_init_##ns, \
#ns, \
nullptr, \
__FILE__, \
__LINE__); \
void STABLE_TORCH_LIBRARY_init_##ns(StableLibrary& m)
#define STABLE_TORCH_LIBRARY_FRAGMENT(ns, m) \
_STABLE_TORCH_LIBRARY_FRAGMENT(ns, m, STABLE_UID)
#define _STABLE_TORCH_LIBRARY_FRAGMENT(ns, m, uid) \
static void STABLE_CONCATENATE( \
STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(StableLibrary&); \
static const StableTorchLibraryInit STABLE_CONCATENATE( \
STABLE_TORCH_LIBRARY_FRAGMENT_static_init_##ns##_, uid)( \
StableLibrary::Kind::FRAGMENT, \
&STABLE_CONCATENATE(STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid), \
#ns, \
nullptr, \
__FILE__, \
__LINE__); \
void STABLE_CONCATENATE( \
STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(StableLibrary & m)