mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add a stable TORCH_LIBRARY to C shim (#148124)
This PR adds two main parts: - shim.h stable C APIs into torch::Library APIs - a higher level API in torch/csrc/stable/library.h that calls into this shim.h + otherwise is self contained Goal: custom kernel writers should be able to call the apis in the directories above in order to register their library in a way that allows their custom extension to run with a different libtorch version than it was built with. Subplots resolved: - Do we want a whole separate StableLibrary or do we want to freeze torch::Library and add `m.stable_impl(cstring, void (*fn)(void **, int64_t, int64_t)` into it - Yes, we want a separate StableLibrary. We cannot freeze Library and it is NOT header only. - Should I use unint64_t as the common denominator instead of void* to support 32bit architectures better? - Yes, and done - Should I add a stable `def` and `fragment` when those can be done in python? - I think we do want these --- and now they're done - Where should library_stable_impl.cpp live? -- no longer relevant - I need some solid test cases to make sure everything's going ok. I've intentionally thrown in a bunch of random dtypes into the signature, but I still haven't tested returning multiple things, returning nothing, complex dtypes, etc. - Have since tested all the torch library endpoints. the others can be tested in a followup to separate components that need to be in shim.h vs can be added later Pull Request resolved: https://github.com/pytorch/pytorch/pull/148124 Approved by: https://github.com/albanD, https://github.com/zou3519, https://github.com/atalman
This commit is contained in:
committed by
PyTorch MergeBot
parent
4d10da731b
commit
971606befa
1
.gitignore
vendored
1
.gitignore
vendored
@ -64,6 +64,7 @@ test/generated_type_hints_smoketest.py
|
||||
test/htmlcov
|
||||
test/cpp_extensions/install/
|
||||
test/cpp_extensions/open_registration_extension/install
|
||||
test/cpp_extensions/libtorch_agnostic_extension/install
|
||||
test/kernel.errors.txt
|
||||
third_party/build/
|
||||
third_party/nccl/
|
||||
|
@ -48,6 +48,7 @@ jit_core_headers = [
|
||||
"torch/csrc/jit/frontend/schema_type_parser.h",
|
||||
"torch/csrc/jit/frontend/error_report.h",
|
||||
"torch/csrc/jit/frontend/tree.h",
|
||||
"torch/csrc/stable/library.h",
|
||||
"torch/custom_class.h",
|
||||
"torch/custom_class_detail.h",
|
||||
"torch/library.h",
|
||||
|
@ -67,6 +67,7 @@ INPUT = ../../../aten/src/ATen/ATen.h \
|
||||
../../../torch/csrc/jit/runtime/custom_operator.h \
|
||||
../../../torch/csrc/jit/serialization/import.h \
|
||||
../../../torch/csrc/jit/api/module.h \
|
||||
../../../torch/csrc/stable/library.h \
|
||||
../../../torch/library.h \
|
||||
../../../torch/custom_class.h
|
||||
# Don't include .cpp files!
|
||||
|
1
setup.py
1
setup.py
@ -1274,6 +1274,7 @@ def main():
|
||||
"include/c10/xpu/impl/*.h",
|
||||
"include/torch/*.h",
|
||||
"include/torch/csrc/*.h",
|
||||
"include/torch/csrc/stable/*.h",
|
||||
"include/torch/csrc/api/include/torch/*.h",
|
||||
"include/torch/csrc/api/include/torch/data/*.h",
|
||||
"include/torch/csrc/api/include/torch/data/dataloader/*.h",
|
||||
|
@ -0,0 +1,21 @@
|
||||
import ctypes
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
so_files = list(Path(__file__).parent.glob("_C*.so"))
|
||||
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
|
||||
|
||||
# use ctypes.CDLL instead of load_library to be able to test the unload logic
|
||||
# below code is reduced from the load_library code
|
||||
with torch._ops.dl_open_guard():
|
||||
loaded_lib = ctypes.CDLL(so_files[0])
|
||||
|
||||
from . import ops
|
||||
|
||||
|
||||
__all__ = [
|
||||
"loaded_lib",
|
||||
"ops",
|
||||
]
|
@ -0,0 +1,127 @@
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/inductor/aoti_runtime/utils.h>
|
||||
#include <torch/csrc/stable/library.h>
|
||||
|
||||
using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle;
|
||||
|
||||
void inline sgd_math(
|
||||
float* param_ptr,
|
||||
float* grad_ptr,
|
||||
float* out_ptr,
|
||||
const float weight_decay,
|
||||
const double lr,
|
||||
const bool maximize,
|
||||
int64_t size
|
||||
){
|
||||
int64_t d = 0;
|
||||
for (; d < size; d++) {
|
||||
float grad_val = grad_ptr[d];
|
||||
if (maximize) grad_val = -grad_val;
|
||||
if (weight_decay != 0.0){
|
||||
grad_val += param_ptr[d] * weight_decay;
|
||||
}
|
||||
out_ptr[d] = param_ptr[d] - grad_val * float(lr);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
RAIIATH sgd_out_of_place(
|
||||
const RAIIATH param,
|
||||
const RAIIATH grad,
|
||||
const float weight_decay,
|
||||
const double lr,
|
||||
const bool maximize) {
|
||||
|
||||
int64_t param_dim;
|
||||
aoti_torch_get_dim(param.get(), ¶m_dim);
|
||||
|
||||
int64_t *param_sizes;
|
||||
int64_t *param_strides;
|
||||
aoti_torch_get_sizes(param.get(), ¶m_sizes);
|
||||
aoti_torch_get_strides(param.get(), ¶m_strides);
|
||||
|
||||
int32_t param_dtype;
|
||||
aoti_torch_get_dtype(param.get(), ¶m_dtype);
|
||||
|
||||
int32_t param_device_type;
|
||||
int32_t param_device_index;
|
||||
aoti_torch_get_device_type(param.get(), ¶m_device_type);
|
||||
aoti_torch_get_device_index(param.get(), ¶m_device_index);
|
||||
|
||||
AtenTensorHandle out;
|
||||
aoti_torch_empty_strided(param_dim, param_sizes, param_strides, param_dtype, param_device_type, param_device_index, &out);
|
||||
|
||||
void* param_ptr;
|
||||
aoti_torch_get_data_ptr(param.get(), ¶m_ptr);
|
||||
void* grad_ptr;
|
||||
aoti_torch_get_data_ptr(grad.get(), &grad_ptr);
|
||||
void* out_ptr;
|
||||
aoti_torch_get_data_ptr(out, &out_ptr);
|
||||
|
||||
auto param_fp_ptr = reinterpret_cast<float*>(param_ptr);
|
||||
auto grad_fp_ptr = reinterpret_cast<float*>(grad_ptr);
|
||||
auto out_fp_ptr = reinterpret_cast<float*>(out_ptr);
|
||||
|
||||
int64_t param_numel;
|
||||
aoti_torch_get_numel(param.get(), ¶m_numel);
|
||||
|
||||
sgd_math(
|
||||
param_fp_ptr,
|
||||
grad_fp_ptr,
|
||||
out_fp_ptr,
|
||||
weight_decay,
|
||||
lr,
|
||||
maximize,
|
||||
param_numel
|
||||
);
|
||||
|
||||
return RAIIATH(out);
|
||||
}
|
||||
|
||||
|
||||
void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
RAIIATH param(to<AtenTensorHandle>(stack[0]));
|
||||
RAIIATH grad(to<AtenTensorHandle>(stack[1]));
|
||||
auto weight_decay = to<double>(stack[2]);
|
||||
auto lr = to<double>(stack[3]);
|
||||
auto maximize = to<bool>(stack[4]);
|
||||
|
||||
RAIIATH raiiath_res = sgd_out_of_place(
|
||||
std::move(param),
|
||||
std::move(grad),
|
||||
float(weight_decay),
|
||||
lr,
|
||||
maximize);
|
||||
|
||||
stack[0] = from(raiiath_res.release());
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY(libtorch_agnostic, m) {
|
||||
m.def("sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
|
||||
m.impl("sgd_out_of_place", &boxed_sgd_out_of_place);
|
||||
}
|
||||
|
||||
RAIIATH identity(RAIIATH t) {
|
||||
return std::move(t);
|
||||
}
|
||||
|
||||
void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
RAIIATH t(to<AtenTensorHandle>(stack[0]));
|
||||
RAIIATH raiiath_res = identity(std::move(t));
|
||||
stack[0] = from(raiiath_res.release());
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("identity(Tensor t) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CUDA, m) {
|
||||
m.impl("identity", &boxed_identity);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
|
||||
m.impl("identity", &boxed_identity);
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def sgd_out_of_place(param, grad, weight_decay, lr, maximize) -> Tensor:
|
||||
"""
|
||||
Computes a single step of SGD on a single parameter Tensor with grad.
|
||||
|
||||
Assumes:
|
||||
- param and grad are the same shape and are 1D.
|
||||
- param and grad are float and on CPU
|
||||
|
||||
Args:
|
||||
param: a 1D tensor of floats
|
||||
grad: a 1D tensor of floats
|
||||
weight_decay: a python double between 0 and 1
|
||||
lr: a python double
|
||||
|
||||
Returns:
|
||||
a 1D float Tensor the same shape as param
|
||||
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.sgd_out_of_place.default(
|
||||
param, grad, weight_decay, lr, maximize
|
||||
)
|
||||
|
||||
|
||||
def identity(t) -> Tensor:
|
||||
"""
|
||||
Returns the input tensor
|
||||
|
||||
Args:
|
||||
t: any Tensor
|
||||
|
||||
Returns:
|
||||
a Tensor, the same as input.
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.identity.default(t)
|
67
test/cpp_extensions/libtorch_agnostic_extension/setup.py
Normal file
67
test/cpp_extensions/libtorch_agnostic_extension/setup.py
Normal file
@ -0,0 +1,67 @@
|
||||
import distutils.command.clean
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).parent
|
||||
CSRC_DIR = ROOT_DIR / "libtorch_agnostic" / "csrc"
|
||||
|
||||
|
||||
class clean(distutils.command.clean.clean):
|
||||
def run(self):
|
||||
# Run default behavior first
|
||||
distutils.command.clean.clean.run(self)
|
||||
|
||||
# Remove extension
|
||||
for path in (ROOT_DIR / "libtorch_agnostic").glob("**/*.so"):
|
||||
path.unlink()
|
||||
# Remove build and dist and egg-info directories
|
||||
dirs = [
|
||||
ROOT_DIR / "build",
|
||||
ROOT_DIR / "dist",
|
||||
ROOT_DIR / "libtorch_agnostic.egg-info",
|
||||
]
|
||||
for path in dirs:
|
||||
if path.exists():
|
||||
shutil.rmtree(str(path), ignore_errors=True)
|
||||
|
||||
|
||||
def get_extension():
|
||||
extra_compile_args = {
|
||||
"cxx": ["-fdiagnostics-color=always"],
|
||||
}
|
||||
|
||||
sources = list(CSRC_DIR.glob("**/*.cpp"))
|
||||
|
||||
return [
|
||||
CppExtension(
|
||||
"libtorch_agnostic._C",
|
||||
sources=sorted(str(s) for s in sources),
|
||||
py_limited_api=True,
|
||||
extra_compile_args=extra_compile_args,
|
||||
extra_link_args=[],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
setup(
|
||||
name="libtorch_agnostic",
|
||||
version="0.0",
|
||||
author="PyTorch Core Team",
|
||||
description="Example of libtorch agnostic extension",
|
||||
packages=find_packages(exclude=("test",)),
|
||||
package_data={"libtorch_agnostic": ["*.dll", "*.dylib", "*.so"]},
|
||||
install_requires=[
|
||||
"torch",
|
||||
],
|
||||
ext_modules=get_extension(),
|
||||
cmdclass={
|
||||
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
|
||||
"clean": clean,
|
||||
},
|
||||
options={"bdist_wheel": {"py_limited_api": "cp39"}},
|
||||
)
|
@ -0,0 +1,77 @@
|
||||
# Owner(s): ["module: cpp"]
|
||||
|
||||
import libtorch_agnostic # noqa: F401
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
onlyCPU,
|
||||
onlyCUDA,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class TestLibtorchAgnostic(TestCase):
|
||||
@onlyCPU
|
||||
def test_slow_sgd(self, device):
|
||||
param = torch.rand(5, device=device)
|
||||
grad = torch.rand_like(param)
|
||||
weight_decay = 0.01
|
||||
lr = 0.001
|
||||
maximize = False
|
||||
|
||||
new_param = libtorch_agnostic.ops.sgd_out_of_place(
|
||||
param, grad, weight_decay, lr, maximize
|
||||
)
|
||||
torch._fused_sgd_(
|
||||
(param,),
|
||||
(grad,),
|
||||
(),
|
||||
weight_decay=weight_decay,
|
||||
momentum=0.0,
|
||||
lr=lr,
|
||||
dampening=0.0,
|
||||
nesterov=False,
|
||||
maximize=maximize,
|
||||
is_first_step=False,
|
||||
)
|
||||
self.assertEqual(new_param, param)
|
||||
|
||||
@onlyCUDA
|
||||
def test_identity_does_not_hog_memory(self, device):
|
||||
def _run_identity(prior_mem):
|
||||
t = torch.rand(32, 32, device=device)
|
||||
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
|
||||
identi_t = libtorch_agnostic.ops.identity(t)
|
||||
assert identi_t is t
|
||||
|
||||
init_mem = torch.cuda.memory_allocated(device)
|
||||
|
||||
for _ in range(3):
|
||||
_run_identity(init_mem)
|
||||
curr_mem = torch.cuda.memory_allocated(device)
|
||||
self.assertEqual(curr_mem, init_mem)
|
||||
|
||||
@onlyCUDA
|
||||
def test_z_delete_torch_lib(self, device):
|
||||
# Why the z + CUDA? THIS TEST MUST BE RUN LAST
|
||||
# We are testing that unloading the library properly deletes the registrations, so running this test
|
||||
# earlier will cause all other tests in this file to fail
|
||||
lib = libtorch_agnostic.loaded_lib
|
||||
|
||||
# code for unloading a library inspired from
|
||||
# https://stackoverflow.com/questions/19547084/can-i-explicitly-close-a-ctypes-cdll
|
||||
lib_handle = lib._handle
|
||||
lib.dlclose(lib_handle)
|
||||
|
||||
t = torch.tensor([-2.0, 0.5])
|
||||
with self.assertRaises(RuntimeError):
|
||||
libtorch_agnostic.ops.identity(
|
||||
t
|
||||
) # errors as identity shouldn't be registered anymore
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -931,10 +931,10 @@ def install_cpp_extensions(cpp_extensions_test_dir, env=os.environ):
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def extend_python_path(install_directory):
|
||||
def extend_python_path(install_directories):
|
||||
python_path = os.environ.get("PYTHONPATH", "")
|
||||
try:
|
||||
os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path])
|
||||
os.environ["PYTHONPATH"] = os.pathsep.join(install_directories + [python_path])
|
||||
yield
|
||||
finally:
|
||||
os.environ["PYTHONPATH"] = python_path
|
||||
@ -1074,9 +1074,12 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja):
|
||||
if return_code != 0:
|
||||
return return_code
|
||||
if sys.platform != "win32":
|
||||
exts_to_build = [(install_cmd, "no_python_abi_suffix_test")]
|
||||
exts_to_build = [
|
||||
(install_cmd, "no_python_abi_suffix_test"),
|
||||
]
|
||||
if TEST_CUDA:
|
||||
exts_to_build.append((wheel_cmd, "python_agnostic_extension"))
|
||||
exts_to_build.append((install_cmd, "libtorch_agnostic_extension"))
|
||||
for cmd, extension_dir in exts_to_build:
|
||||
return_code = shell(
|
||||
cmd,
|
||||
@ -1094,17 +1097,24 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja):
|
||||
test_directory + "/test_cpp_extensions_aot.py",
|
||||
test_directory + "/" + test_module + ".py",
|
||||
)
|
||||
|
||||
try:
|
||||
cpp_extensions = os.path.join(test_directory, "cpp_extensions")
|
||||
install_directory = ""
|
||||
install_directories = []
|
||||
# install directory is the one that is named site-packages
|
||||
for root, directories, _ in os.walk(os.path.join(cpp_extensions, "install")):
|
||||
for directory in directories:
|
||||
if "-packages" in directory:
|
||||
install_directory = os.path.join(root, directory)
|
||||
install_directories.append(os.path.join(root, directory))
|
||||
|
||||
assert install_directory, "install_directory must not be empty"
|
||||
with extend_python_path(install_directory):
|
||||
for root, directories, _ in os.walk(
|
||||
os.path.join(cpp_extensions, "libtorch_agnostic_extension", "install")
|
||||
):
|
||||
for directory in directories:
|
||||
if "-packages" in directory:
|
||||
install_directories.append(os.path.join(root, directory))
|
||||
|
||||
with extend_python_path(install_directories):
|
||||
return run_test(ShardedTest(test_module, 1, 1), test_directory, options)
|
||||
finally:
|
||||
if os.path.exists(test_directory + "/" + test_module + ".py"):
|
||||
@ -1136,7 +1146,7 @@ def _test_autoload(test_directory, options, enable=True):
|
||||
|
||||
try:
|
||||
os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = str(int(enable))
|
||||
with extend_python_path(install_directory):
|
||||
with extend_python_path([install_directory]):
|
||||
cmd = [sys.executable, "test_autoload.py"]
|
||||
return_code = shell(cmd, cwd=test_directory, env=os.environ)
|
||||
return return_code
|
||||
@ -1152,7 +1162,7 @@ def run_test_with_openreg(test_module, test_directory, options):
|
||||
if return_code != 0:
|
||||
return return_code
|
||||
|
||||
with extend_python_path(install_dir):
|
||||
with extend_python_path([install_dir]):
|
||||
return run_test(test_module, test_directory, options)
|
||||
|
||||
|
||||
|
@ -227,6 +227,49 @@ class TestCppExtensionAOT(common.TestCase):
|
||||
if return_code != 0:
|
||||
return return_code
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "some aspects of this test require CUDA")
|
||||
def test_libtorch_agnostic(self):
|
||||
import libtorch_agnostic
|
||||
|
||||
# (1) first test that SGD CPU kernel works
|
||||
param = torch.rand(5, device="cpu")
|
||||
grad = torch.rand_like(param)
|
||||
weight_decay = 0.01
|
||||
lr = 0.001
|
||||
maximize = False
|
||||
|
||||
new_param = libtorch_agnostic.ops.sgd_out_of_place(
|
||||
param, grad, weight_decay, lr, maximize
|
||||
)
|
||||
torch._fused_sgd_(
|
||||
(param,),
|
||||
(grad,),
|
||||
(),
|
||||
weight_decay=weight_decay,
|
||||
momentum=0.0,
|
||||
lr=lr,
|
||||
dampening=0.0,
|
||||
nesterov=False,
|
||||
maximize=maximize,
|
||||
is_first_step=False,
|
||||
)
|
||||
self.assertEqual(new_param, param)
|
||||
|
||||
# (2) then test that we don't hog unnecessary memory
|
||||
def _run_identity(prior_mem, device):
|
||||
t = torch.rand(32, 32, device=device)
|
||||
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
|
||||
identi_t = libtorch_agnostic.ops.identity(t)
|
||||
assert identi_t is t
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
init_mem = torch.cuda.memory_allocated(device)
|
||||
|
||||
for _ in range(3):
|
||||
_run_identity(init_mem, device)
|
||||
curr_mem = torch.cuda.memory_allocated(device)
|
||||
self.assertEqual(curr_mem, init_mem)
|
||||
|
||||
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
class TestPybindTypeCasters(common.TestCase):
|
||||
|
@ -108,6 +108,7 @@ includes = [
|
||||
"aten/src/THC/CMakeLists.txt",
|
||||
"torch/*",
|
||||
"tools/autograd/templates/python_variable_methods.cpp",
|
||||
"torch/csrc/stable/*",
|
||||
]
|
||||
|
||||
includes = [os.path.join(proj_dir, include) for include in includes]
|
||||
|
@ -626,6 +626,57 @@ AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle(
|
||||
const char* launch_prefix,
|
||||
const char* kernel_name);
|
||||
|
||||
// helpers for converting between StableIValue and actual IValues
|
||||
using StableIValue = uint64_t;
|
||||
|
||||
class TorchLibraryOpaque;
|
||||
using TorchLibraryHandle = TorchLibraryOpaque*;
|
||||
|
||||
// stable corollary to torch::Library constructor with Kind::IMPL
|
||||
// will create a new torch::Library object on the heap
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_init_impl(
|
||||
const char* ns,
|
||||
const char* k,
|
||||
const char* file,
|
||||
uint32_t line,
|
||||
TorchLibraryHandle* ret_new_torch_lib);
|
||||
|
||||
// stable corollary to torch::Library constructor with Kind::DEF
|
||||
// will create a new torch::Library object on the heap
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_init_def(
|
||||
const char* ns,
|
||||
const char* file,
|
||||
uint32_t line,
|
||||
TorchLibraryHandle* ret_new_torch_lib);
|
||||
|
||||
// stable corollary to torch::Library constructor with Kind::FRAGMENT
|
||||
// will create a new torch::Library object on the heap
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_init_fragment(
|
||||
const char* ns,
|
||||
const char* file,
|
||||
uint32_t line,
|
||||
TorchLibraryHandle* ret_new_torch_lib);
|
||||
|
||||
// stable corollary to torch::Library method m.impl(), should be
|
||||
// called from StableLibrary
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
|
||||
TorchLibraryHandle self,
|
||||
const char* name,
|
||||
void (*fn)(StableIValue*, uint64_t, uint64_t));
|
||||
|
||||
// stable corollary to torch::Library method m.def(), should be
|
||||
// called from StableLibrary
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_library_def(TorchLibraryHandle self, const char* schema);
|
||||
|
||||
// the above stable constructors for torch::Library add Library objects
|
||||
// to the heap. if you are calling those functions directly, please use
|
||||
// this function to free the Library's memory. The more user friendly
|
||||
// alternative is to use StableLibrary, which will free its handle upon
|
||||
// destruction
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_delete_library_object(TorchLibraryHandle tlh);
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
||||
struct CUDAGuardOpaque;
|
||||
|
@ -1,10 +1,12 @@
|
||||
#include <ATen/native/quantized/cpu/qlinear.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/DispatchKey.h>
|
||||
#include <c10/core/GradMode.h>
|
||||
#include <c10/core/Layout.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/inductor/aoti_runtime/utils.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/mkldnn_tensor.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/oss_proxy_executor.h>
|
||||
@ -13,10 +15,14 @@
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
#include <torch/csrc/inductor/inductor_ops.h>
|
||||
#include <torch/csrc/jit/serialization/pickle.h>
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/library.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -1291,3 +1297,135 @@ AOTITorchError aoti_torch_zero_(AtenTensorHandle tensor) {
|
||||
t->zero_();
|
||||
});
|
||||
}
|
||||
|
||||
class StableIValueBoxedKernel : public c10::OperatorKernel {
|
||||
public:
|
||||
StableIValueBoxedKernel(void (*fn)(StableIValue*, uint64_t, uint64_t))
|
||||
: fn_(fn) {}
|
||||
|
||||
void operator()(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet keyset,
|
||||
torch::jit::Stack* stack) {
|
||||
const auto& schema = op.schema();
|
||||
const auto num_returns = schema.returns().size();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
|
||||
std::vector<StableIValue> ministack(std::max(num_arguments, num_returns));
|
||||
|
||||
for (const auto idx : c10::irange(num_arguments)) {
|
||||
const c10::IValue& arg = torch::jit::pop(stack);
|
||||
const auto ministack_idx = num_arguments - idx - 1;
|
||||
if (arg.isInt()) {
|
||||
ministack[ministack_idx] = from(arg.toInt());
|
||||
} else if (arg.isDouble()) {
|
||||
ministack[ministack_idx] = from(arg.toDouble());
|
||||
} else if (arg.isBool()) {
|
||||
ministack[ministack_idx] = from(arg.toBool());
|
||||
} else if (arg.isNone()) {
|
||||
ministack[ministack_idx] = from(nullptr);
|
||||
} else if (arg.isTensor()) {
|
||||
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
|
||||
std::move(const_cast<at::Tensor&>(arg.toTensor())));
|
||||
ministack[ministack_idx] = from(ath);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Other types of IValues not yet handled!");
|
||||
}
|
||||
}
|
||||
|
||||
// boxed function is going to take a stack of StableIValues, cast them to
|
||||
// our schema values, and run the function and modify the StableIValue stack
|
||||
fn_(ministack.data(), num_arguments, num_returns);
|
||||
|
||||
// read the output from the end of the stack and wrap that back into
|
||||
// IValue from StableIValue
|
||||
for (size_t idx = 0; idx < num_returns; idx++) {
|
||||
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
||||
if (*ret_type == *c10::getTypePtr<at::Tensor>()) {
|
||||
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
|
||||
to<AtenTensorHandle>(ministack[idx]));
|
||||
at::Tensor out = *torch::aot_inductor::tensor_handle_to_tensor_pointer(
|
||||
ret_raiiath.get());
|
||||
torch::jit::push(stack, c10::IValue(out));
|
||||
} else {
|
||||
TORCH_CHECK(false, "Only Tensor return types are currently supported!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void (*fn_)(StableIValue*, uint64_t, uint64_t);
|
||||
};
|
||||
|
||||
AOTITorchError aoti_torch_library_init_impl(
|
||||
const char* ns,
|
||||
const char* k,
|
||||
const char* file,
|
||||
uint32_t line,
|
||||
TorchLibraryHandle* ret_new_torch_lib) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
*ret_new_torch_lib =
|
||||
reinterpret_cast<TorchLibraryOpaque*>(new torch::Library(
|
||||
torch::Library::Kind::IMPL,
|
||||
std::string(ns),
|
||||
c10::parseDispatchKey(std::string(k)),
|
||||
file,
|
||||
line));
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_library_init_def(
|
||||
const char* ns,
|
||||
const char* file,
|
||||
uint32_t line,
|
||||
TorchLibraryHandle* ret_new_torch_lib) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
*ret_new_torch_lib =
|
||||
reinterpret_cast<TorchLibraryOpaque*>(new torch::Library(
|
||||
torch::Library::Kind::DEF,
|
||||
std::string(ns),
|
||||
std::nullopt,
|
||||
file,
|
||||
line));
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_library_init_fragment(
|
||||
const char* ns,
|
||||
const char* file,
|
||||
uint32_t line,
|
||||
TorchLibraryHandle* ret_new_torch_lib) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
*ret_new_torch_lib =
|
||||
reinterpret_cast<TorchLibraryOpaque*>(new torch::Library(
|
||||
torch::Library::Kind::FRAGMENT,
|
||||
std::string(ns),
|
||||
std::nullopt,
|
||||
file,
|
||||
line));
|
||||
});
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
|
||||
TorchLibraryHandle self,
|
||||
const char* name,
|
||||
void (*fn)(StableIValue*, uint64_t, uint64_t)) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
reinterpret_cast<torch::Library*>(self)->impl(
|
||||
name,
|
||||
torch::CppFunction::makeFromBoxedFunctor(
|
||||
std::make_unique<StableIValueBoxedKernel>(fn)));
|
||||
});
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_library_def(TorchLibraryHandle self, const char* name) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
|
||||
{ reinterpret_cast<torch::Library*>(self)->def(name); });
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_delete_library_object(TorchLibraryHandle tlh) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
|
||||
{ delete reinterpret_cast<torch::Library*>(tlh); });
|
||||
}
|
||||
|
179
torch/csrc/stable/library.h
Normal file
179
torch/csrc/stable/library.h
Normal file
@ -0,0 +1,179 @@
|
||||
// this file can only have stable stuff! Akin to shim.h
|
||||
// but unlike shim.h, this file can contain header-only C++
|
||||
// code for better UX.
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
|
||||
// use anonymous namespace to avoid collisions between differing
|
||||
// versions of this file that may be included by different sources
|
||||
namespace {
|
||||
|
||||
// helpers for converting between StableIValue and actual IValues
|
||||
template <typename T>
|
||||
StableIValue from(T val) {
|
||||
static_assert(
|
||||
sizeof(T) <= sizeof(StableIValue),
|
||||
"StableLibrary stack does not support parameter types larger than 64 bits.");
|
||||
return *reinterpret_cast<StableIValue*>(&val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T to(StableIValue val) {
|
||||
return *reinterpret_cast<T*>(&val);
|
||||
}
|
||||
// end to helpers for converting between StableIValue and actual IValues
|
||||
|
||||
class StableLibrary final {
|
||||
private:
|
||||
TorchLibraryHandle lib_;
|
||||
|
||||
public:
|
||||
enum class Kind {
|
||||
DEF,
|
||||
IMPL,
|
||||
FRAGMENT,
|
||||
};
|
||||
|
||||
// constructor
|
||||
/// \private
|
||||
///
|
||||
/// Use STABLE_TORCH_LIBRARY or STABLE_TORCH_LIBRARY_IMPL() instead of using
|
||||
/// these constructors directly
|
||||
StableLibrary(
|
||||
Kind kind,
|
||||
const char* ns,
|
||||
const char* k,
|
||||
const char* file,
|
||||
uint32_t line) {
|
||||
if (kind == Kind::IMPL) {
|
||||
aoti_torch_library_init_impl(ns, k, file, line, &lib_);
|
||||
} else if (kind == Kind::DEF) {
|
||||
aoti_torch_library_init_def(ns, file, line, &lib_);
|
||||
} else { // kind == FRAGMENT
|
||||
aoti_torch_library_init_fragment(ns, file, line, &lib_);
|
||||
}
|
||||
}
|
||||
|
||||
// do not permit copy
|
||||
StableLibrary(const StableLibrary&) = delete;
|
||||
StableLibrary& operator=(const StableLibrary&) = delete;
|
||||
|
||||
// do not permit move
|
||||
StableLibrary(StableLibrary&& other) = delete;
|
||||
StableLibrary& operator=(StableLibrary&& other) = delete;
|
||||
|
||||
~StableLibrary() {
|
||||
aoti_torch_delete_library_object(lib_);
|
||||
}
|
||||
|
||||
// corresponds to a limited, stable version of torch::library::impl()
|
||||
// Inputs:
|
||||
// name: the name of the function to implement
|
||||
// fn: a boxed function with schema
|
||||
// (StableIValue* stack, uint64_t num_inputs, uint64_t num_outputs) ->
|
||||
// void
|
||||
// fn should follow the calling convention of our boxed kernels that convert
|
||||
// to IValues. fn will be called with a StableIValue* array of length
|
||||
// max(num_inputs, num_outputs), where the first num_inputs entries are
|
||||
// populated with inputs. fn is responsible for stealing the memory of the
|
||||
// inputs, in effect "popping" them off the stack, and then populating the
|
||||
// stack with StableIValue outputs. Concretely, fn should:
|
||||
// 1. read StableIValue inputs from the given stack
|
||||
// 2. convert the inputs to the proper types
|
||||
// 3. call the function corresponding to name with the inputs
|
||||
// 4. convert the outputs to StableIValues
|
||||
// 5. populate the now empty stack with StableIValue outputs
|
||||
// If the operation corresponding to name takes in 4 inputs and returns 2
|
||||
// outputs, fn should expect stack to contain 4 StableIValues:
|
||||
// [stable_arg1, stable_arg2, stable_arg3, stable_arg4]
|
||||
// to end, fn should fill the stack with 2 StableIValues representing outputs:
|
||||
// [stable_ret1, stable_ret2, -, -]
|
||||
StableLibrary& impl(
|
||||
const char* name,
|
||||
void (*fn)(StableIValue*, uint64_t, uint64_t)) {
|
||||
aoti_torch_library_impl(lib_, name, fn);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// corresponds to a limited, stable version of torch::library::def()
|
||||
StableLibrary& def(const char* schema) {
|
||||
aoti_torch_library_def(lib_, schema);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
class StableTorchLibraryInit final {
|
||||
private:
|
||||
using InitFn = void(StableLibrary&);
|
||||
StableLibrary lib_;
|
||||
|
||||
public:
|
||||
StableTorchLibraryInit(
|
||||
StableLibrary::Kind kind,
|
||||
InitFn* fn,
|
||||
const char* ns,
|
||||
const char* k,
|
||||
const char* file,
|
||||
uint32_t line)
|
||||
: lib_(kind, ns, k, file, line) {
|
||||
fn(lib_);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// macros copied from c10/macros/Macros.h
|
||||
#ifdef __COUNTER__
|
||||
#define STABLE_UID __COUNTER__
|
||||
#else
|
||||
#define STABLE_UID __LINE__
|
||||
#endif
|
||||
|
||||
#define STABLE_CONCATENATE_IMPL(s1, s2) s1##s2
|
||||
#define STABLE_CONCATENATE(s1, s2) STABLE_CONCATENATE_IMPL(s1, s2)
|
||||
// end of macros copied from c10/macros/Macros.h
|
||||
|
||||
#define STABLE_TORCH_LIBRARY_IMPL(ns, k, m) \
|
||||
_STABLE_TORCH_LIBRARY_IMPL(ns, k, m, STABLE_UID)
|
||||
|
||||
#define _STABLE_TORCH_LIBRARY_IMPL(ns, k, m, uid) \
|
||||
static void STABLE_CONCATENATE( \
|
||||
STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(StableLibrary&); \
|
||||
static const StableTorchLibraryInit STABLE_CONCATENATE( \
|
||||
STABLE_TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)( \
|
||||
StableLibrary::Kind::IMPL, \
|
||||
&STABLE_CONCATENATE(STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid), \
|
||||
#ns, \
|
||||
#k, \
|
||||
__FILE__, \
|
||||
__LINE__); \
|
||||
void STABLE_CONCATENATE( \
|
||||
STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(StableLibrary & m)
|
||||
|
||||
#define STABLE_TORCH_LIBRARY(ns, m) \
|
||||
static void STABLE_TORCH_LIBRARY_init_##ns(StableLibrary&); \
|
||||
static const StableTorchLibraryInit STABLE_TORCH_LIBRARY_static_init_##ns( \
|
||||
StableLibrary::Kind::DEF, \
|
||||
&STABLE_TORCH_LIBRARY_init_##ns, \
|
||||
#ns, \
|
||||
nullptr, \
|
||||
__FILE__, \
|
||||
__LINE__); \
|
||||
void STABLE_TORCH_LIBRARY_init_##ns(StableLibrary& m)
|
||||
|
||||
#define STABLE_TORCH_LIBRARY_FRAGMENT(ns, m) \
|
||||
_STABLE_TORCH_LIBRARY_FRAGMENT(ns, m, STABLE_UID)
|
||||
|
||||
#define _STABLE_TORCH_LIBRARY_FRAGMENT(ns, m, uid) \
|
||||
static void STABLE_CONCATENATE( \
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(StableLibrary&); \
|
||||
static const StableTorchLibraryInit STABLE_CONCATENATE( \
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT_static_init_##ns##_, uid)( \
|
||||
StableLibrary::Kind::FRAGMENT, \
|
||||
&STABLE_CONCATENATE(STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid), \
|
||||
#ns, \
|
||||
nullptr, \
|
||||
__FILE__, \
|
||||
__LINE__); \
|
||||
void STABLE_CONCATENATE( \
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(StableLibrary & m)
|
Reference in New Issue
Block a user