mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add a stable TORCH_LIBRARY to C shim (#148124)
This PR adds two main parts: - shim.h stable C APIs into torch::Library APIs - a higher level API in torch/csrc/stable/library.h that calls into this shim.h + otherwise is self contained Goal: custom kernel writers should be able to call the apis in the directories above in order to register their library in a way that allows their custom extension to run with a different libtorch version than it was built with. Subplots resolved: - Do we want a whole separate StableLibrary or do we want to freeze torch::Library and add `m.stable_impl(cstring, void (*fn)(void **, int64_t, int64_t)` into it - Yes, we want a separate StableLibrary. We cannot freeze Library and it is NOT header only. - Should I use unint64_t as the common denominator instead of void* to support 32bit architectures better? - Yes, and done - Should I add a stable `def` and `fragment` when those can be done in python? - I think we do want these --- and now they're done - Where should library_stable_impl.cpp live? -- no longer relevant - I need some solid test cases to make sure everything's going ok. I've intentionally thrown in a bunch of random dtypes into the signature, but I still haven't tested returning multiple things, returning nothing, complex dtypes, etc. - Have since tested all the torch library endpoints. the others can be tested in a followup to separate components that need to be in shim.h vs can be added later Pull Request resolved: https://github.com/pytorch/pytorch/pull/148124 Approved by: https://github.com/albanD, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
685fb37713
commit
327e07ac1d
@ -48,6 +48,7 @@ jit_core_headers = [
|
||||
"torch/csrc/jit/frontend/schema_type_parser.h",
|
||||
"torch/csrc/jit/frontend/error_report.h",
|
||||
"torch/csrc/jit/frontend/tree.h",
|
||||
"torch/csrc/stable/library.h",
|
||||
"torch/custom_class.h",
|
||||
"torch/custom_class_detail.h",
|
||||
"torch/library.h",
|
||||
|
@ -67,6 +67,7 @@ INPUT = ../../../aten/src/ATen/ATen.h \
|
||||
../../../torch/csrc/jit/runtime/custom_operator.h \
|
||||
../../../torch/csrc/jit/serialization/import.h \
|
||||
../../../torch/csrc/jit/api/module.h \
|
||||
../../../torch/csrc/stable/library.h \
|
||||
../../../torch/library.h \
|
||||
../../../torch/custom_class.h
|
||||
# Don't include .cpp files!
|
||||
|
1
setup.py
1
setup.py
@ -1274,6 +1274,7 @@ def main():
|
||||
"include/c10/xpu/impl/*.h",
|
||||
"include/torch/*.h",
|
||||
"include/torch/csrc/*.h",
|
||||
"include/torch/csrc/stable/*.h",
|
||||
"include/torch/csrc/api/include/torch/*.h",
|
||||
"include/torch/csrc/api/include/torch/data/*.h",
|
||||
"include/torch/csrc/api/include/torch/data/dataloader/*.h",
|
||||
|
@ -0,0 +1,21 @@
|
||||
import ctypes
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
so_files = list(Path(__file__).parent.glob("_C*.so"))
|
||||
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
|
||||
|
||||
# use ctypes.CDLL instead of load_library to be able to test the unload logic
|
||||
# below code is reduced from the load_library code
|
||||
with torch._ops.dl_open_guard():
|
||||
loaded_lib = ctypes.CDLL(so_files[0])
|
||||
|
||||
from . import ops
|
||||
|
||||
|
||||
__all__ = [
|
||||
"loaded_lib",
|
||||
"ops",
|
||||
]
|
@ -0,0 +1,127 @@
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/inductor/aoti_runtime/utils.h>
|
||||
#include <torch/csrc/stable/library.h>
|
||||
|
||||
using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle;
|
||||
|
||||
void inline sgd_math(
|
||||
float* param_ptr,
|
||||
float* grad_ptr,
|
||||
float* out_ptr,
|
||||
const float weight_decay,
|
||||
const double lr,
|
||||
const bool maximize,
|
||||
int64_t size
|
||||
){
|
||||
int64_t d = 0;
|
||||
for (; d < size; d++) {
|
||||
float grad_val = grad_ptr[d];
|
||||
if (maximize) grad_val = -grad_val;
|
||||
if (weight_decay != 0.0){
|
||||
grad_val += param_ptr[d] * weight_decay;
|
||||
}
|
||||
out_ptr[d] = param_ptr[d] - grad_val * float(lr);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
RAIIATH sgd_out_of_place(
|
||||
const RAIIATH param,
|
||||
const RAIIATH grad,
|
||||
const float weight_decay,
|
||||
const double lr,
|
||||
const bool maximize) {
|
||||
|
||||
int64_t param_dim;
|
||||
aoti_torch_get_dim(param.get(), ¶m_dim);
|
||||
|
||||
int64_t *param_sizes;
|
||||
int64_t *param_strides;
|
||||
aoti_torch_get_sizes(param.get(), ¶m_sizes);
|
||||
aoti_torch_get_strides(param.get(), ¶m_strides);
|
||||
|
||||
int32_t param_dtype;
|
||||
aoti_torch_get_dtype(param.get(), ¶m_dtype);
|
||||
|
||||
int32_t param_device_type;
|
||||
int32_t param_device_index;
|
||||
aoti_torch_get_device_type(param.get(), ¶m_device_type);
|
||||
aoti_torch_get_device_index(param.get(), ¶m_device_index);
|
||||
|
||||
AtenTensorHandle out;
|
||||
aoti_torch_empty_strided(param_dim, param_sizes, param_strides, param_dtype, param_device_type, param_device_index, &out);
|
||||
|
||||
void* param_ptr;
|
||||
aoti_torch_get_data_ptr(param.get(), ¶m_ptr);
|
||||
void* grad_ptr;
|
||||
aoti_torch_get_data_ptr(grad.get(), &grad_ptr);
|
||||
void* out_ptr;
|
||||
aoti_torch_get_data_ptr(out, &out_ptr);
|
||||
|
||||
auto param_fp_ptr = reinterpret_cast<float*>(param_ptr);
|
||||
auto grad_fp_ptr = reinterpret_cast<float*>(grad_ptr);
|
||||
auto out_fp_ptr = reinterpret_cast<float*>(out_ptr);
|
||||
|
||||
int64_t param_numel;
|
||||
aoti_torch_get_numel(param.get(), ¶m_numel);
|
||||
|
||||
sgd_math(
|
||||
param_fp_ptr,
|
||||
grad_fp_ptr,
|
||||
out_fp_ptr,
|
||||
weight_decay,
|
||||
lr,
|
||||
maximize,
|
||||
param_numel
|
||||
);
|
||||
|
||||
return RAIIATH(out);
|
||||
}
|
||||
|
||||
|
||||
void boxed_sgd_out_of_place(StableIValue* stack, int64_t num_args, int64_t num_outputs) {
|
||||
RAIIATH param(to<AtenTensorHandle>(stack[0]));
|
||||
RAIIATH grad(to<AtenTensorHandle>(stack[1]));
|
||||
auto weight_decay = to<double>(stack[2]);
|
||||
auto lr = to<double>(stack[3]);
|
||||
auto maximize = to<bool>(stack[4]);
|
||||
|
||||
RAIIATH raiiath_res = sgd_out_of_place(
|
||||
std::move(param),
|
||||
std::move(grad),
|
||||
float(weight_decay),
|
||||
lr,
|
||||
maximize);
|
||||
|
||||
stack[num_args] = from(raiiath_res.release());
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY(libtorch_agnostic, m) {
|
||||
m.def("sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
|
||||
m.impl("sgd_out_of_place", &boxed_sgd_out_of_place);
|
||||
}
|
||||
|
||||
RAIIATH identity(RAIIATH t) {
|
||||
return std::move(t);
|
||||
}
|
||||
|
||||
void boxed_identity(StableIValue* stack, int64_t num_args, int64_t num_outputs) {
|
||||
RAIIATH t(to<AtenTensorHandle>(stack[0]));
|
||||
RAIIATH raiiath_res = identity(std::move(t));
|
||||
stack[num_args] = from(raiiath_res.release());
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("identity(Tensor t) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CUDA, m) {
|
||||
m.impl("identity", &boxed_identity);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
|
||||
m.impl("identity", &boxed_identity);
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def sgd_out_of_place(param, grad, weight_decay, lr, maximize) -> Tensor:
|
||||
"""
|
||||
Computes a single step of SGD on a single parameter Tensor with grad.
|
||||
|
||||
Assumes:
|
||||
- param and grad are the same shape and are 1D.
|
||||
- param and grad are float and on CPU
|
||||
|
||||
Args:
|
||||
param: a 1D tensor of floats
|
||||
grad: a 1D tensor of floats
|
||||
weight_decay: a python double between 0 and 1
|
||||
lr: a python double
|
||||
|
||||
Returns:
|
||||
a 1D float Tensor the same shape as param
|
||||
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.sgd_out_of_place.default(
|
||||
param, grad, weight_decay, lr, maximize
|
||||
)
|
||||
|
||||
|
||||
def identity(t) -> Tensor:
|
||||
"""
|
||||
Returns the input tensor
|
||||
|
||||
Args:
|
||||
t: any Tensor
|
||||
|
||||
Returns:
|
||||
a Tensor, the same as input.
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.identity.default(t)
|
67
test/cpp_extensions/libtorch_agnostic_extension/setup.py
Normal file
67
test/cpp_extensions/libtorch_agnostic_extension/setup.py
Normal file
@ -0,0 +1,67 @@
|
||||
import distutils.command.clean
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).parent
|
||||
CSRC_DIR = ROOT_DIR / "libtorch_agnostic" / "csrc"
|
||||
|
||||
|
||||
class clean(distutils.command.clean.clean):
|
||||
def run(self):
|
||||
# Run default behavior first
|
||||
distutils.command.clean.clean.run(self)
|
||||
|
||||
# Remove extension
|
||||
for path in (ROOT_DIR / "libtorch_agnostic").glob("**/*.so"):
|
||||
path.unlink()
|
||||
# Remove build and dist and egg-info directories
|
||||
dirs = [
|
||||
ROOT_DIR / "build",
|
||||
ROOT_DIR / "dist",
|
||||
ROOT_DIR / "libtorch_agnostic.egg-info",
|
||||
]
|
||||
for path in dirs:
|
||||
if path.exists():
|
||||
shutil.rmtree(str(path), ignore_errors=True)
|
||||
|
||||
|
||||
def get_extension():
|
||||
extra_compile_args = {
|
||||
"cxx": ["-fdiagnostics-color=always"],
|
||||
}
|
||||
|
||||
sources = list(CSRC_DIR.glob("**/*.cpp"))
|
||||
|
||||
return [
|
||||
CppExtension(
|
||||
"libtorch_agnostic._C",
|
||||
sources=sorted(str(s) for s in sources),
|
||||
py_limited_api=True,
|
||||
extra_compile_args=extra_compile_args,
|
||||
extra_link_args=[],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
setup(
|
||||
name="libtorch_agnostic",
|
||||
version="0.0",
|
||||
author="PyTorch Core Team",
|
||||
description="Example of libtorch agnostic extension",
|
||||
packages=find_packages(exclude=("test",)),
|
||||
package_data={"libtorch_agnostic": ["*.dll", "*.dylib", "*.so"]},
|
||||
install_requires=[
|
||||
"torch",
|
||||
],
|
||||
ext_modules=get_extension(),
|
||||
cmdclass={
|
||||
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
|
||||
"clean": clean,
|
||||
},
|
||||
options={"bdist_wheel": {"py_limited_api": "cp39"}},
|
||||
)
|
@ -0,0 +1,77 @@
|
||||
# Owner(s): ["module: cpp"]
|
||||
|
||||
import libtorch_agnostic # noqa: F401
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
onlyCPU,
|
||||
onlyCUDA,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class TestLibtorchAgnostic(TestCase):
|
||||
@onlyCPU
|
||||
def test_slow_sgd(self, device):
|
||||
param = torch.rand(5, device=device)
|
||||
grad = torch.rand_like(param)
|
||||
weight_decay = 0.01
|
||||
lr = 0.001
|
||||
maximize = False
|
||||
|
||||
new_param = libtorch_agnostic.ops.sgd_out_of_place(
|
||||
param, grad, weight_decay, lr, maximize
|
||||
)
|
||||
torch._fused_sgd_(
|
||||
(param,),
|
||||
(grad,),
|
||||
(),
|
||||
weight_decay=weight_decay,
|
||||
momentum=0.0,
|
||||
lr=lr,
|
||||
dampening=0.0,
|
||||
nesterov=False,
|
||||
maximize=maximize,
|
||||
is_first_step=False,
|
||||
)
|
||||
self.assertEqual(new_param, param)
|
||||
|
||||
@onlyCUDA
|
||||
def test_identity_does_not_hog_memory(self, device):
|
||||
def _run_identity(prior_mem):
|
||||
t = torch.rand(32, 32, device=device)
|
||||
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
|
||||
identi_t = libtorch_agnostic.ops.identity(t)
|
||||
assert identi_t is t
|
||||
|
||||
init_mem = torch.cuda.memory_allocated(device)
|
||||
|
||||
for _ in range(3):
|
||||
_run_identity(init_mem)
|
||||
curr_mem = torch.cuda.memory_allocated(device)
|
||||
self.assertEqual(curr_mem, init_mem)
|
||||
|
||||
@onlyCUDA
|
||||
def test_z_delete_torch_lib(self, device):
|
||||
# Why the z + CUDA? THIS TEST MUST BE RUN LAST
|
||||
# We are testing that unloading the library properly deletes the registrations, so running this test
|
||||
# earlier will cause all other tests in this file to fail
|
||||
lib = libtorch_agnostic.loaded_lib
|
||||
|
||||
# code for unloading a library inspired from
|
||||
# https://stackoverflow.com/questions/19547084/can-i-explicitly-close-a-ctypes-cdll
|
||||
lib_handle = lib._handle
|
||||
lib.dlclose(lib_handle)
|
||||
|
||||
t = torch.tensor([-2.0, 0.5])
|
||||
with self.assertRaises(RuntimeError):
|
||||
libtorch_agnostic.ops.identity(
|
||||
t
|
||||
) # errors as identity shouldn't be registered anymore
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -1074,9 +1074,12 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja):
|
||||
if return_code != 0:
|
||||
return return_code
|
||||
if sys.platform != "win32":
|
||||
exts_to_build = [(install_cmd, "no_python_abi_suffix_test")]
|
||||
exts_to_build = [
|
||||
(install_cmd, "no_python_abi_suffix_test"),
|
||||
]
|
||||
if TEST_CUDA:
|
||||
exts_to_build.append((wheel_cmd, "python_agnostic_extension"))
|
||||
exts_to_build.append((install_cmd, "libtorch_agnostic_extension"))
|
||||
for cmd, extension_dir in exts_to_build:
|
||||
return_code = shell(
|
||||
cmd,
|
||||
|
@ -227,6 +227,71 @@ class TestCppExtensionAOT(common.TestCase):
|
||||
if return_code != 0:
|
||||
return return_code
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "some aspects of this test require CUDA")
|
||||
def test_libtorch_agnostic(self):
|
||||
import libtorch_agnostic
|
||||
|
||||
# (1) first test that SGD CPU kernel works
|
||||
param = torch.rand(5, device="cpu")
|
||||
grad = torch.rand_like(param)
|
||||
weight_decay = 0.01
|
||||
lr = 0.001
|
||||
maximize = False
|
||||
|
||||
new_param = libtorch_agnostic.ops.sgd_out_of_place(
|
||||
param, grad, weight_decay, lr, maximize
|
||||
)
|
||||
torch._fused_sgd_(
|
||||
(param,),
|
||||
(grad,),
|
||||
(),
|
||||
weight_decay=weight_decay,
|
||||
momentum=0.0,
|
||||
lr=lr,
|
||||
dampening=0.0,
|
||||
nesterov=False,
|
||||
maximize=maximize,
|
||||
is_first_step=False,
|
||||
)
|
||||
self.assertEqual(new_param, param)
|
||||
|
||||
# (2) then test that we don't hog unnecessary memory
|
||||
def _run_identity(prior_mem, device):
|
||||
t = torch.rand(32, 32, device=device)
|
||||
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
|
||||
identi_t = libtorch_agnostic.ops.identity(t)
|
||||
assert identi_t is t
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
init_mem = torch.cuda.memory_allocated(device)
|
||||
|
||||
for _ in range(3):
|
||||
_run_identity(init_mem, device)
|
||||
curr_mem = torch.cuda.memory_allocated(device)
|
||||
self.assertEqual(curr_mem, init_mem)
|
||||
|
||||
# (3) last, test that unloading the torch library will deregister previous ops
|
||||
lib = libtorch_agnostic.loaded_lib
|
||||
|
||||
# code for unloading a library inspired from
|
||||
# https://stackoverflow.com/questions/19547084/can-i-explicitly-close-a-ctypes-cdll
|
||||
lib_handle = lib._handle
|
||||
lib.dlclose(lib_handle)
|
||||
|
||||
t = torch.tensor([-2.0, 0.5])
|
||||
with self.assertRaises(RuntimeError):
|
||||
libtorch_agnostic.ops.identity(t)
|
||||
|
||||
# finally, clean up the folder
|
||||
cmd = [sys.executable, "setup.py", "clean"]
|
||||
return_code = shell(
|
||||
cmd,
|
||||
cwd=os.path.join("cpp_extensions", "libtorch_agnostic_extension"),
|
||||
env=os.environ.copy(),
|
||||
)
|
||||
if return_code != 0:
|
||||
return return_code
|
||||
|
||||
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
class TestPybindTypeCasters(common.TestCase):
|
||||
|
@ -108,6 +108,7 @@ includes = [
|
||||
"aten/src/THC/CMakeLists.txt",
|
||||
"torch/*",
|
||||
"tools/autograd/templates/python_variable_methods.cpp",
|
||||
"torch/csrc/stable/*",
|
||||
]
|
||||
|
||||
includes = [os.path.join(proj_dir, include) for include in includes]
|
||||
|
@ -623,6 +623,55 @@ AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle(
|
||||
const char* launch_prefix,
|
||||
const char* kernel_name);
|
||||
|
||||
// helpers for converting between StableIValue and actual IValues
|
||||
using StableIValue = uint64_t;
|
||||
|
||||
class TorchLibraryOpaque;
|
||||
using TorchLibraryHandle = TorchLibraryOpaque*;
|
||||
|
||||
// stable corollary to torch::Library constructor with Kind::IMPL
|
||||
// will create a new torch::Library object on the heap
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_init_impl(
|
||||
const char* ns,
|
||||
const char* k,
|
||||
const char* file,
|
||||
uint32_t line,
|
||||
TorchLibraryHandle* ret_new_torch_lib);
|
||||
|
||||
// stable corollary to torch::Library constructor with Kind::DEF
|
||||
// will create a new torch::Library object on the heap
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_init_def(
|
||||
const char* ns,
|
||||
const char* file,
|
||||
uint32_t line,
|
||||
TorchLibraryHandle* ret_new_torch_lib);
|
||||
|
||||
// stable corollary to torch::Library constructor with Kind::FRAGMENT
|
||||
// will create a new torch::Library object on the heap
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_init_fragment(
|
||||
const char* ns,
|
||||
const char* file,
|
||||
uint32_t line,
|
||||
TorchLibraryHandle* ret_new_torch_lib);
|
||||
|
||||
// stable corollary to torch::Library method m.impl()
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
|
||||
TorchLibraryHandle self,
|
||||
const char* name,
|
||||
void (*fn)(StableIValue*, int64_t, int64_t));
|
||||
|
||||
// stable corollary to torch::Library method m.def()
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_library_def(TorchLibraryHandle self, const char* name);
|
||||
|
||||
// the above stable constructors for torch::Library add Library objects
|
||||
// to the heap. if you are calling those functions directly, please use
|
||||
// this function to free the Library's memory. The more user friendly
|
||||
// alternative is to use StableLibrary, which will free its handle upon
|
||||
// destruction
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_delete_library_object(TorchLibraryHandle tlh);
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
||||
struct CUDAGuardOpaque;
|
||||
|
@ -1,10 +1,12 @@
|
||||
#include <ATen/native/quantized/cpu/qlinear.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/DispatchKey.h>
|
||||
#include <c10/core/GradMode.h>
|
||||
#include <c10/core/Layout.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/inductor/aoti_runtime/utils.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/mkldnn_tensor.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h>
|
||||
@ -13,10 +15,14 @@
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
#include <torch/csrc/inductor/inductor_ops.h>
|
||||
#include <torch/csrc/jit/serialization/pickle.h>
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/library.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -1300,3 +1306,135 @@ AOTITorchError aoti_torch_cpu__weight_int4pack_mm_cpu_tensor(
|
||||
*ret0 = new_tensor_handle(std::move(tmp_result));
|
||||
});
|
||||
}
|
||||
|
||||
class StableIValueBoxedKernel : public c10::OperatorKernel {
|
||||
public:
|
||||
StableIValueBoxedKernel(void (*fn)(StableIValue*, int64_t, int64_t))
|
||||
: fn_(fn) {}
|
||||
|
||||
void operator()(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet keyset,
|
||||
torch::jit::Stack* stack) {
|
||||
const auto& schema = op.schema();
|
||||
const auto num_returns = schema.returns().size();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
|
||||
std::vector<StableIValue> ministack(num_arguments + num_returns);
|
||||
|
||||
for (const auto idx : c10::irange(num_arguments)) {
|
||||
const c10::IValue& arg = torch::jit::pop(stack);
|
||||
const auto ministack_idx = num_arguments - idx - 1;
|
||||
if (arg.isInt()) {
|
||||
ministack[ministack_idx] = from(arg.toInt());
|
||||
} else if (arg.isDouble()) {
|
||||
ministack[ministack_idx] = from(arg.toDouble());
|
||||
} else if (arg.isBool()) {
|
||||
ministack[ministack_idx] = from(arg.toBool());
|
||||
} else if (arg.isNone()) {
|
||||
ministack[ministack_idx] = from(nullptr);
|
||||
} else if (arg.isTensor()) {
|
||||
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
|
||||
std::move(const_cast<at::Tensor&>(arg.toTensor())));
|
||||
ministack[ministack_idx] = from(ath);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Other types of IValues not yet handled!");
|
||||
}
|
||||
}
|
||||
|
||||
// boxed function is going to take a stack of StableIValues, cast them to
|
||||
// our schema values, and run the function and modify the StableIValue stack
|
||||
fn_(ministack.data(), num_arguments, num_returns);
|
||||
|
||||
// read the output from the end of the stack and wrap that back into
|
||||
// IValue from StableIValue
|
||||
for (size_t idx = 0; idx < num_returns; idx++) {
|
||||
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
||||
if (*ret_type == *c10::getTypePtr<at::Tensor>()) {
|
||||
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
|
||||
to<AtenTensorHandle>(ministack[num_arguments + idx]));
|
||||
at::Tensor out = *torch::aot_inductor::tensor_handle_to_tensor_pointer(
|
||||
ret_raiiath.get());
|
||||
torch::jit::push(stack, c10::IValue(out));
|
||||
} else {
|
||||
TORCH_CHECK(false, "Only Tensor return types are currently supported!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void (*fn_)(StableIValue*, int64_t, int64_t);
|
||||
};
|
||||
|
||||
AOTITorchError aoti_torch_library_init_impl(
|
||||
const char* ns,
|
||||
const char* k,
|
||||
const char* file,
|
||||
uint32_t line,
|
||||
TorchLibraryHandle* ret_new_torch_lib) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
*ret_new_torch_lib =
|
||||
reinterpret_cast<TorchLibraryOpaque*>(new torch::Library(
|
||||
torch::Library::Kind::IMPL,
|
||||
std::string(ns),
|
||||
c10::parseDispatchKey(std::string(k)),
|
||||
file,
|
||||
line));
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_library_init_def(
|
||||
const char* ns,
|
||||
const char* file,
|
||||
uint32_t line,
|
||||
TorchLibraryHandle* ret_new_torch_lib) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
*ret_new_torch_lib =
|
||||
reinterpret_cast<TorchLibraryOpaque*>(new torch::Library(
|
||||
torch::Library::Kind::DEF,
|
||||
std::string(ns),
|
||||
std::nullopt,
|
||||
file,
|
||||
line));
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_library_init_fragment(
|
||||
const char* ns,
|
||||
const char* file,
|
||||
uint32_t line,
|
||||
TorchLibraryHandle* ret_new_torch_lib) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
*ret_new_torch_lib =
|
||||
reinterpret_cast<TorchLibraryOpaque*>(new torch::Library(
|
||||
torch::Library::Kind::FRAGMENT,
|
||||
std::string(ns),
|
||||
std::nullopt,
|
||||
file,
|
||||
line));
|
||||
});
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
|
||||
TorchLibraryHandle self,
|
||||
const char* name,
|
||||
void (*fn)(StableIValue*, int64_t, int64_t)) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
reinterpret_cast<torch::Library*>(self)->impl(
|
||||
name,
|
||||
torch::CppFunction::makeFromBoxedFunctor(
|
||||
std::make_unique<StableIValueBoxedKernel>(fn)));
|
||||
});
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_library_def(TorchLibraryHandle self, const char* name) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
|
||||
{ reinterpret_cast<torch::Library*>(self)->def(name); });
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_delete_library_object(TorchLibraryHandle tlh) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
|
||||
{ delete reinterpret_cast<torch::Library*>(tlh); });
|
||||
}
|
||||
|
149
torch/csrc/stable/library.h
Normal file
149
torch/csrc/stable/library.h
Normal file
@ -0,0 +1,149 @@
|
||||
// this file can only have stable stuff! Akin to shim.h
|
||||
// but unlike shim.h, this file can contain header-only C++
|
||||
// code for better UX.
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
|
||||
template <typename T>
|
||||
StableIValue from(T val) {
|
||||
static_assert(
|
||||
sizeof(T) <= sizeof(StableIValue),
|
||||
"StableLibrary stack does not support parameter types larger than 64 bits.");
|
||||
return *reinterpret_cast<StableIValue*>(&val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T to(StableIValue val) {
|
||||
return *reinterpret_cast<T*>(&val);
|
||||
}
|
||||
// end to helpers for converting between StableIValue and actual IValues
|
||||
|
||||
class TORCH_API StableLibrary final {
|
||||
private:
|
||||
TorchLibraryHandle lib_;
|
||||
|
||||
public:
|
||||
enum class Kind {
|
||||
DEF,
|
||||
IMPL,
|
||||
FRAGMENT,
|
||||
};
|
||||
|
||||
// constructor
|
||||
/// \private
|
||||
///
|
||||
/// Use STABLE_TORCH_LIBRARY or STABLE_TORCH_LIBRARY_IMPL() instead of using
|
||||
/// these constructors directly
|
||||
StableLibrary(
|
||||
Kind kind,
|
||||
const char* ns,
|
||||
const char* k,
|
||||
const char* file,
|
||||
uint32_t line) {
|
||||
if (kind == Kind::IMPL) {
|
||||
aoti_torch_library_init_impl(ns, k, file, line, &lib_);
|
||||
} else if (kind == Kind::DEF) {
|
||||
aoti_torch_library_init_def(ns, file, line, &lib_);
|
||||
} else { // kind == FRAGMENT
|
||||
aoti_torch_library_init_fragment(ns, file, line, &lib_);
|
||||
}
|
||||
}
|
||||
|
||||
// do not permit copy
|
||||
StableLibrary(const StableLibrary&) = delete;
|
||||
StableLibrary& operator=(const StableLibrary&) = delete;
|
||||
|
||||
// do not permit move
|
||||
StableLibrary(StableLibrary&& other) = delete;
|
||||
StableLibrary& operator=(StableLibrary&& other) = delete;
|
||||
|
||||
~StableLibrary() {
|
||||
aoti_torch_delete_library_object(lib_);
|
||||
}
|
||||
|
||||
StableLibrary& impl(
|
||||
const char* name,
|
||||
void (*fn)(StableIValue*, int64_t, int64_t)) {
|
||||
aoti_torch_library_impl(lib_, name, fn);
|
||||
return *this;
|
||||
}
|
||||
|
||||
StableLibrary& def(const char* name) {
|
||||
aoti_torch_library_def(lib_, name);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
class TORCH_API StableTorchLibraryInit final {
|
||||
private:
|
||||
using InitFn = void(StableLibrary&);
|
||||
StableLibrary lib_;
|
||||
|
||||
public:
|
||||
StableTorchLibraryInit(
|
||||
StableLibrary::Kind kind,
|
||||
InitFn* fn,
|
||||
const char* ns,
|
||||
const char* k,
|
||||
const char* file,
|
||||
uint32_t line)
|
||||
: lib_(kind, ns, k, file, line) {
|
||||
fn(lib_);
|
||||
}
|
||||
};
|
||||
|
||||
// macros copied from c10/macros/Macros.h
|
||||
#ifdef __COUNTER__
|
||||
#define STABLE_UID __COUNTER__
|
||||
#else
|
||||
#define STABLE_UID __LINE__
|
||||
#endif
|
||||
|
||||
#define STABLE_CONCATENATE_IMPL(s1, s2) s1##s2
|
||||
#define STABLE_CONCATENATE(s1, s2) STABLE_CONCATENATE_IMPL(s1, s2)
|
||||
// end of macros copied from c10/macros/Macros.h
|
||||
|
||||
#define STABLE_TORCH_LIBRARY_IMPL(ns, k, m) \
|
||||
_STABLE_TORCH_LIBRARY_IMPL(ns, k, m, STABLE_UID)
|
||||
|
||||
#define _STABLE_TORCH_LIBRARY_IMPL(ns, k, m, uid) \
|
||||
static void STABLE_CONCATENATE( \
|
||||
STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(StableLibrary&); \
|
||||
static const StableTorchLibraryInit STABLE_CONCATENATE( \
|
||||
STABLE_TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)( \
|
||||
StableLibrary::Kind::IMPL, \
|
||||
&STABLE_CONCATENATE(STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid), \
|
||||
#ns, \
|
||||
#k, \
|
||||
__FILE__, \
|
||||
__LINE__); \
|
||||
void STABLE_CONCATENATE( \
|
||||
STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(StableLibrary & m)
|
||||
|
||||
#define STABLE_TORCH_LIBRARY(ns, m) \
|
||||
static void STABLE_TORCH_LIBRARY_init_##ns(StableLibrary&); \
|
||||
static const StableTorchLibraryInit STABLE_TORCH_LIBRARY_static_init_##ns( \
|
||||
StableLibrary::Kind::DEF, \
|
||||
&STABLE_TORCH_LIBRARY_init_##ns, \
|
||||
#ns, \
|
||||
nullptr, \
|
||||
__FILE__, \
|
||||
__LINE__); \
|
||||
void STABLE_TORCH_LIBRARY_init_##ns(StableLibrary& m)
|
||||
|
||||
#define STABLE_TORCH_LIBRARY_FRAGMENT(ns, m) \
|
||||
_STABLE_TORCH_LIBRARY_FRAGMENT(ns, m, STABLE_UID)
|
||||
|
||||
#define _STABLE_TORCH_LIBRARY_FRAGMENT(ns, m, uid) \
|
||||
static void STABLE_CONCATENATE( \
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(StableLibrary&); \
|
||||
static const StableTorchLibraryInit STABLE_CONCATENATE( \
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT_static_init_##ns##_, uid)( \
|
||||
StableLibrary::Kind::FRAGMENT, \
|
||||
&STABLE_CONCATENATE(STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid), \
|
||||
#ns, \
|
||||
nullptr, \
|
||||
__FILE__, \
|
||||
__LINE__); \
|
||||
void STABLE_CONCATENATE( \
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(StableLibrary & m)
|
Reference in New Issue
Block a user