Add TORCH_BOX helper for STABLE_TORCH_LIBRARY_IMPL

ghstack-source-id: 4a3f08fb5188ab3a8e42c01c99ca4b116508ef59
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167582
This commit is contained in:
Jane Xu
2025-11-11 17:03:36 -08:00
parent b2360baa7a
commit 17b46fce2c
2 changed files with 208 additions and 223 deletions

View File

@ -37,7 +37,7 @@ using torch::stable::Tensor;
Tensor sgd_out_of_place(
const Tensor param,
const Tensor grad,
const float weight_decay,
const double weight_decay,
const double lr,
const bool maximize) {
STD_TORCH_CHECK(param.dim() == 1, "param must be 1D");
@ -56,7 +56,7 @@ Tensor sgd_out_of_place(
reinterpret_cast<float*>(param.data_ptr()),
reinterpret_cast<float*>(grad.data_ptr()),
reinterpret_cast<float*>(out.data_ptr()),
weight_decay,
float(weight_decay),
lr,
maximize,
param.numel()
@ -65,44 +65,29 @@ Tensor sgd_out_of_place(
return out;
}
void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = sgd_out_of_place(
torch::stable::detail::to<Tensor>(stack[0]),
torch::stable::detail::to<Tensor>(stack[1]),
float(torch::stable::detail::to<double>(stack[2])),
torch::stable::detail::to<double>(stack[3]),
torch::stable::detail::to<bool>(stack[4]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY(libtorch_agnostic, m) {
m.def("sgd_out_of_place(Tensor param, Tensor grad, float weight_decay, float lr, bool maximize) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
m.impl("sgd_out_of_place", &boxed_sgd_out_of_place);
m.impl("sgd_out_of_place", TORCH_BOX(&sgd_out_of_place));
}
Tensor identity(Tensor t) {
return t;
}
void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = identity(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("identity(Tensor t) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CUDA, m) {
m.impl("identity", &boxed_identity);
m.impl("identity", TORCH_BOX(&identity));
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
m.impl("identity", &boxed_identity);
m.impl("identity", TORCH_BOX(&identity));
}
Tensor my_abs(Tensor t) {
@ -113,17 +98,12 @@ Tensor my_abs(Tensor t) {
return torch::stable::detail::to<Tensor>(stack[0]);
}
void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor tensor_res = my_abs(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(tensor_res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my_abs(Tensor t) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_abs", &boxed_my_abs);
m.impl("my_abs", TORCH_BOX(&my_abs));
}
Tensor my_ones_like(Tensor t, StableIValue device) {
@ -144,17 +124,12 @@ Tensor my_ones_like(Tensor t, StableIValue device) {
return torch::stable::detail::to<Tensor>(stack[0]);
}
void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = my_ones_like(torch::stable::detail::to<Tensor>(stack[0]), stack[1]);
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my_ones_like(Tensor t, Device d) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_ones_like", &boxed_my_ones_like);
m.impl("my_ones_like", TORCH_BOX(&my_ones_like));
}
std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) {
@ -176,19 +151,12 @@ std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3
torch::stable::detail::to<bool>(stack_is_leaf[0]));
}
void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto tuple = exp_neg_is_leaf(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]), torch::stable::detail::to<Tensor>(stack[2]));
stack[0] = torch::stable::detail::from(std::get<0>(tuple));
stack[1] = torch::stable::detail::from(std::get<1>(tuple));
stack[2] = torch::stable::detail::from(std::get<2>(tuple));
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("exp_neg_is_leaf", &boxed_exp_neg_is_leaf);
m.impl("exp_neg_is_leaf", TORCH_BOX(&exp_neg_is_leaf));
}
Tensor neg_exp(Tensor t) {
@ -199,17 +167,12 @@ Tensor neg_exp(Tensor t) {
return torch::stable::detail::to<Tensor>(stack[0]);
}
void boxed_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = neg_exp(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("neg_exp(Tensor t) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("neg_exp", &boxed_neg_exp);
m.impl("neg_exp", TORCH_BOX(&neg_exp));
}
Tensor divide_neg_exp(Tensor t) {
@ -228,108 +191,53 @@ Tensor divide_neg_exp(Tensor t) {
return torch::stable::detail::to<Tensor>(stack_div[0]);
}
void boxed_divide_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = divide_neg_exp(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("divide_neg_exp(Tensor t) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("divide_neg_exp", &boxed_divide_neg_exp);
m.impl("divide_neg_exp", TORCH_BOX(&divide_neg_exp));
}
bool is_contiguous(Tensor t) {
return t.is_contiguous();
}
void boxed_is_contiguous(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
bool res = is_contiguous(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("is_contiguous(Tensor t) -> bool");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("is_contiguous", &boxed_is_contiguous);
m.impl("is_contiguous", TORCH_BOX(&is_contiguous));
}
Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) {
return transpose(t, dim0, dim1);
}
void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_transpose(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<int64_t>(stack[1]), torch::stable::detail::to<int64_t>(stack[2]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_empty_like(Tensor t) {
return empty_like(t);
}
void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_empty_like(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
bool my_is_cpu(Tensor t) {
return t.is_cpu();
}
void boxed_my_is_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_is_cpu(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor fill_infinity(Tensor t) {
auto value = std::numeric_limits<float>::infinity();
return fill_(t, value);
}
void boxed_fill_infinity(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
auto res = fill_infinity(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_pad(Tensor t) {
std::string mode = "constant";
double value = 0.0;
return pad(t, {1, 2, 2, 1}, mode, value);
}
void boxed_my_pad(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
auto res = my_pad(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_narrow(Tensor t, int64_t dim, int64_t start, int64_t length) {
return narrow(t, dim, start, length);
}
void boxed_my_narrow(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
auto res = my_narrow(
torch::stable::detail::to<Tensor>(stack[0]),
torch::stable::detail::to<int64_t>(stack[1]),
torch::stable::detail::to<int64_t>(stack[2]),
torch::stable::detail::to<int64_t>(stack[3]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_new_empty_dtype_variant(Tensor t) {
// Still using a std::vector below even though people can just pass in an
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
@ -341,40 +249,19 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
return new_empty(t, sizes, dtype);
}
void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_new_empty_dtype_variant(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_new_zeros_dtype_variant(Tensor t) {
auto dtype = std::make_optional(at::ScalarType::Float);
return new_zeros(t, {2, 5}, dtype);
}
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_new_zeros_dtype_variant(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_copy_(Tensor dst, Tensor src, bool non_blocking) {
return copy_(dst, src, non_blocking);
}
void boxed_my_copy_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor tensor_res = my_copy_(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]), torch::stable::detail::to<bool>(stack[2]));
stack[0] = torch::stable::detail::from(tensor_res);
}
Tensor my_clone(Tensor t) {
return clone(t);
}
void boxed_my_clone(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor tensor_res = my_clone(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(tensor_res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor");
m.def("my_empty_like(Tensor t) -> Tensor");
@ -388,57 +275,39 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_transpose", &boxed_my_transpose);
m.impl("my_empty_like", &boxed_empty_like);
m.impl("fill_infinity", &boxed_fill_infinity);
m.impl("my_is_cpu", &boxed_my_is_cpu);
m.impl("my_new_empty_dtype_variant", &boxed_my_new_empty_dtype_variant);
m.impl("my_new_zeros_dtype_variant", &boxed_my_new_zeros_dtype_variant);
m.impl("my_copy_", &boxed_my_copy_);
m.impl("my_clone", &boxed_my_clone);
m.impl("my_transpose", TORCH_BOX(&my_transpose));
m.impl("my_empty_like", TORCH_BOX(&my_empty_like));
m.impl("fill_infinity", TORCH_BOX(&fill_infinity));
m.impl("my_is_cpu", TORCH_BOX(&my_is_cpu));
m.impl("my_new_empty_dtype_variant", TORCH_BOX(&my_new_empty_dtype_variant));
m.impl("my_new_zeros_dtype_variant", TORCH_BOX(&my_new_zeros_dtype_variant));
m.impl("my_copy_", TORCH_BOX(&my_copy_));
m.impl("my_clone", TORCH_BOX(&my_clone));
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeImplicitAutograd, m) {
m.impl("my_pad", &boxed_my_pad);
m.impl("my_narrow", &boxed_my_narrow);
m.impl("my_pad", TORCH_BOX(&my_pad));
m.impl("my_narrow", TORCH_BOX(&my_narrow));
}
Tensor my_zero_(Tensor t) {
return zero_(t);
}
void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_zero_(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_amax(Tensor t) {
return amax(t, 0, false);
}
void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_amax(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
Tensor my_amax_vec(Tensor t) {
return amax(t, {0,1}, false);
}
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_amax_vec(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my_zero_(Tensor(a!) t) -> Tensor(a!)");
m.def("my_amax(Tensor a) -> Tensor");
m.def("my_amax_vec(Tensor a) -> Tensor");
m.def("my_is_cpu(Tensor t) -> bool");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
m.impl("my_zero_", &boxed_my_zero_);
m.def("test_default_constructor(bool undefined) -> bool");
}
bool test_default_constructor(bool defined) {
@ -460,22 +329,12 @@ bool test_default_constructor(bool defined) {
return out.defined();
}
void boxed_test_default_constructor(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
bool res = test_default_constructor(torch::stable::detail::to<bool>(stack[0]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("test_default_constructor(bool undefined) -> bool");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("test_default_constructor", &boxed_test_default_constructor);
m.impl("my_amax", &boxed_my_amax);
m.impl("my_amax_vec", &boxed_my_amax_vec);
m.impl("my_zero_", TORCH_BOX(&my_zero_));
m.impl("my_amax", TORCH_BOX(&my_amax));
m.impl("my_amax_vec", TORCH_BOX(&my_amax_vec));
m.impl("test_default_constructor", TORCH_BOX(&test_default_constructor));
}
std::vector<Tensor> my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
@ -484,23 +343,11 @@ std::vector<Tensor> my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef<Tensor
return torch::stable::detail::to<std::vector<Tensor>>(stack[0]);
}
void boxed_my__foreach_mul(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
// Why is the following NOT torch::stable::detail::to<HeaderOnlyArrayRef<Tensor>>(stack[0])? Because calling `to`
// on a StableIValue means that the result is owning its underlying data now! HeaderOnlyArrayRef
// is not owning, so it cannot safely steward the result of the torch::stable::detail::to<>.
auto res = my__foreach_mul(torch::stable::detail::to<std::vector<Tensor>>(stack[0]), torch::stable::detail::to<std::vector<Tensor>>(stack[1]));
stack[0] = torch::stable::detail::from(res);
}
void my__foreach_mul_(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
aoti_torch_call_dispatcher("aten::_foreach_mul_", "List", stack.data());
}
void boxed_my__foreach_mul_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
my__foreach_mul_(torch::stable::detail::to<std::vector<Tensor>>(stack[0]), torch::stable::detail::to<std::vector<Tensor>>(stack[1]));
}
std::vector<Tensor> make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) {
// This function tests that my__foreach_mul can take in std::initializer_lists
// in addition to std::vectors.
@ -511,11 +358,6 @@ std::vector<Tensor> make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) {
return my__foreach_mul({t1_1, t2_1}, {t1_2, t2_2});
}
void boxed_make_tensor_clones_and_call_foreach(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = make_tensor_clones_and_call_foreach(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my__foreach_mul(Tensor[] self, Tensor[] other) -> Tensor[]");
m.def("my__foreach_mul_(Tensor(a!)[] self, Tensor[] other) -> ()");
@ -523,9 +365,9 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my__foreach_mul", &boxed_my__foreach_mul);
m.impl("my__foreach_mul_", &boxed_my__foreach_mul_);
m.impl("make_tensor_clones_and_call_foreach", &boxed_make_tensor_clones_and_call_foreach);
m.impl("my__foreach_mul", TORCH_BOX(&my__foreach_mul));
m.impl("my__foreach_mul_", TORCH_BOX(&my__foreach_mul_));
m.impl("make_tensor_clones_and_call_foreach", TORCH_BOX(&make_tensor_clones_and_call_foreach));
}
// Test functions for torch::stable::accelerator APIs
@ -546,14 +388,6 @@ int64_t test_device_guard(int64_t device_index) {
return currentDevice;
}
void boxed_test_device_guard(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
int res = test_device_guard(static_cast<int64_t>(torch::stable::detail::to<int64_t>(stack[0])));
stack[0] = torch::stable::detail::from(res);
}
int64_t test_device_guard_set_index() {
using torch::stable::accelerator::DeviceGuard;
@ -565,14 +399,6 @@ int64_t test_device_guard_set_index() {
return currentDevice;
}
void boxed_test_device_guard_set_index(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
int64_t res = test_device_guard_set_index();
stack[0] = torch::stable::detail::from(res);
}
int64_t test_stream(int32_t device_index) {
STD_TORCH_CHECK(
device_index >= std::numeric_limits<int32_t>::min() &&
@ -582,26 +408,10 @@ int64_t test_stream(int32_t device_index) {
return torch::stable::accelerator::getCurrentStream(device_index).id();
}
void boxed_test_stream(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
int64_t res = test_stream(static_cast<int64_t>(torch::stable::detail::to<int64_t>(stack[0])));
stack[0] = torch::stable::detail::from(res);
}
int64_t test_get_current_device_index() {
return torch::stable::accelerator::getCurrentDeviceIndex();
}
void boxed_test_get_current_device_index(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
int64_t res = test_get_current_device_index();
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("test_device_guard(int device_index) -> int");
m.def("test_device_guard_set_index() -> int");
@ -610,10 +420,10 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("test_device_guard", &boxed_test_device_guard);
m.impl("test_device_guard_set_index", &boxed_test_device_guard_set_index);
m.impl("test_stream", &boxed_test_stream);
m.impl("test_get_current_device_index", &boxed_test_get_current_device_index);
m.impl("test_device_guard", TORCH_BOX(&test_device_guard));
m.impl("test_device_guard_set_index", TORCH_BOX(&test_device_guard_set_index));
m.impl("test_stream", TORCH_BOX(&test_stream));
m.impl("test_get_current_device_index", TORCH_BOX(&test_get_current_device_index));
}
#endif // LAE_USE_CUDA

View File

@ -6,6 +6,7 @@
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/c/shim.h>
#include <torch/headeronly/macros/Macros.h>
#include <torch/headeronly/util/Metaprogramming.h>
// Technically, this file doesn't use anything from stableivalue_conversions.h,
// but we need to include it here as the contents of stableivalue_conversions.h
@ -116,8 +117,182 @@ class StableTorchLibraryInit final {
}
};
// type mapper: since to<HeaderOnlyArrayRef<T>> cannot exist,
// we map that to to<std::vector<T>> to preserve ownership semantics.
// note that unbox_type_t is used to convert ParamTypes, so that
// the tuple holding the arguments will have proper ownership too.
template <typename T>
struct UnboxType {
using type = T;
};
template <typename T>
struct UnboxType<torch::headeronly::HeaderOnlyArrayRef<T>> {
using type = std::vector<T>;
};
template <typename T>
using unbox_type_t = typename UnboxType<T>::type;
template <class... T, std::size_t... I>
std::tuple<T...> unbox_to_tuple_impl(
StableIValue* stack,
std::index_sequence<I...>) {
return std::make_tuple(to<T>(stack[I])...);
}
template <class... T>
std::tuple<T...> unbox_to_tuple(StableIValue* stack) {
return unbox_to_tuple_impl<T...>(
stack, std::make_index_sequence<sizeof...(T)>());
}
template <class... T, std::size_t... I>
void box_from_tuple_impl(
StableIValue* stack,
std::tuple<T...> vals,
std::index_sequence<I...>) {
((stack[I] = from<T>(std::get<I>(vals))), ...);
}
template <class... T>
void box_from_tuple(StableIValue* stack, std::tuple<T...> vals) {
box_from_tuple_impl<T...>(
stack, vals, std::make_index_sequence<sizeof...(T)>());
}
template <
typename ReturnType,
typename ParameterTypeList,
typename FuncT,
FuncT* func>
struct boxer_impl {
static_assert(
torch::headeronly::guts::false_t<ReturnType>::value,
"Unsupported function schema for TORCH_BOX.");
};
// Multiple returns
template <
typename... ReturnTypes,
typename... ParameterTypes,
typename FuncT,
FuncT* func>
struct boxer_impl<
std::tuple<ReturnTypes...>,
torch::headeronly::guts::typelist::typelist<ParameterTypes...>,
FuncT,
func> {
static void boxed_fn(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
STD_TORCH_CHECK(
num_args == sizeof...(ParameterTypes),
"Registered schema has ",
num_args,
" args, but the kernel to box has ",
sizeof...(ParameterTypes));
STD_TORCH_CHECK(
num_outputs == sizeof...(ReturnTypes),
"Registered schema has ",
num_outputs,
" outputs, but the kernel to box has ",
sizeof...(ReturnTypes));
std::tuple<unbox_type_t<ParameterTypes>...> args =
unbox_to_tuple<unbox_type_t<ParameterTypes>...>(stack);
auto res = std::apply(func, args);
box_from_tuple<ReturnTypes...>(stack, res);
}
};
// Single return
template <
typename ReturnType,
typename... ParameterTypes,
typename FuncT,
FuncT* func>
struct boxer_impl<
ReturnType,
torch::headeronly::guts::typelist::typelist<ParameterTypes...>,
FuncT,
func> {
static void boxed_fn(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
STD_TORCH_CHECK(
num_args == sizeof...(ParameterTypes),
"Registered schema has ",
num_args,
" args, but the kernel to box has ",
sizeof...(ParameterTypes));
STD_TORCH_CHECK(
num_outputs == 1,
"Registered schema has ",
num_outputs,
" outputs, but the kernel to box has ",
1);
std::tuple<unbox_type_t<ParameterTypes>...> args =
unbox_to_tuple<unbox_type_t<ParameterTypes>...>(stack);
auto res = std::apply(func, args);
stack[0] = from<ReturnType>(res);
}
};
// No/void return
template <typename... ParameterTypes, typename FuncT, FuncT* func>
struct boxer_impl<
void,
torch::headeronly::guts::typelist::typelist<ParameterTypes...>,
FuncT,
func> {
static void boxed_fn(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
STD_TORCH_CHECK(
num_args == sizeof...(ParameterTypes),
"Registered schema has ",
num_args,
" args, but the kernel to box has ",
sizeof...(ParameterTypes));
STD_TORCH_CHECK(
num_outputs == 0,
"Registered schema has ",
num_outputs,
" outputs, but the kernel to box has ",
0);
std::tuple<unbox_type_t<ParameterTypes>...> args =
unbox_to_tuple<unbox_type_t<ParameterTypes>...>(stack);
std::apply(func, args);
}
};
template <typename FuncT, FuncT* func>
struct boxer {
using FunctionTraits =
torch::headeronly::guts::infer_function_traits_t<FuncT>;
static void boxed_fn(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
boxer_impl<
typename FunctionTraits::return_type,
typename FunctionTraits::parameter_types,
FuncT,
func>::boxed_fn(stack, num_args, num_outputs);
}
};
HIDDEN_NAMESPACE_END(torch, stable, detail)
#define TORCH_BOX(func) \
torch::stable::detail::boxer< \
std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \
(func)>::boxed_fn
// macros copied from c10/macros/Macros.h
#ifdef __COUNTER__
#define STABLE_UID __COUNTER__