mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use schema as source of truth + support ones_like/empty_like (#149052)
This change does 2 important things: (a) Instead of relying on IValue type as source of truth, we use the schema as the source of truth, which is important as IValue types are overloaded and can ambiguously convert incorrectly. For example, a MemoryFormat will look like an int + get converted to an int64_t vs a MemoryFormat! (b) This PR expands support for many more types to encompass way more schemas, e.g., Optional, Device, dtype, etc. The main win from this PR is the ability for aoti_torch_call_dispatcher to call TensorFactory ops like ones_like/empty_like! Pull Request resolved: https://github.com/pytorch/pytorch/pull/149052 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
ebabd0efdd
commit
988827cdfb
34
docs/source/notes/libtorch_stable_abi.md
Normal file
34
docs/source/notes/libtorch_stable_abi.md
Normal file
@ -0,0 +1,34 @@
|
||||
# LibTorch Stable ABI
|
||||
|
||||
This note will eventually contain more details on how to use the APIs in torch/csrc/stable. For the moment, it contains a table of internal representations:
|
||||
1. type in custom extension: type used within the end user custom library.
|
||||
2. StableIValue representation: a stable conversion of the type to liaison between the user model vs libtorch.so in an ABI-stable manner.
|
||||
3. type in libtorch: type used within libtorch.so (or any code binary locked with libtorch).
|
||||
4. Schema Type: type as described by the schema, which we hail as the source of truth for both ATen ops in native_functions.yaml and for user defined custom operators registered to the dispatcher via TORCH_LIBRARY or torch.library.
|
||||
|
||||
| type in custom extension | StableIValue representation | type in libtorch | Schema Type |
|
||||
| -------- | ------- | ------- | ------- |
|
||||
| std::optional\<S> | \*reinterpret_cast\<(StableIValue\*)\*>, pointer to a StableIValue recursively defined | std::optional\<T> | Type? |
|
||||
| std::nullopt | \*reinterpret_cast\<nullptr_t\*> | IValue() | None |
|
||||
| RAIIATH | \*reinterpret_cast\<uint64_t\*> of AtenTensorHandle | at::Tensor | Tensor |
|
||||
| int32_t | \*reinterpret_cast\<uint64_t\*> | at::ScalarType | ScalarType |
|
||||
| int32_t | \*reinterpret_cast\<uint64_t\*> | at::Layout | Layout |
|
||||
| int32_t | \*reinterpret_cast\<uint64_t\*> | at::MemoryFormat | MemoryFormat |
|
||||
| bool | \*reinterpret_cast\<uint64_t\*> | bool | bool |
|
||||
| int64_t | \*reinterpret_cast\<uint64_t\*> | int64_t | int |
|
||||
| double | \*reinterpret_cast\<uint64_t\*> | double | float |
|
||||
| ? | ? | c10::Device | Device |
|
||||
| ? | ? | c10::Stream | Stream |
|
||||
| ? | ? | c10::complex<double> | complex |
|
||||
| ? | ? | at::Scalar | Scalar |
|
||||
| ? | ? | std::string/const char*/ivalue::ConstantString | str |
|
||||
| ? | ? | at::Storage | Storage |
|
||||
| ? | ? | at::Generator | Generator |
|
||||
| ? | ? | c10::List\<T> | Type[] |
|
||||
| ? | ? | ivalue::Tuple\<T> | (Type, ...) |
|
||||
| ? | ? | c10::SymInt | SymInt |
|
||||
| ? | ? | c10::SymFloat | SymFloat |
|
||||
| ? | ? | c10::SymBool | SymBool |
|
||||
| ? | ? | at::QScheme | QScheme |
|
||||
|
||||
Our confidently supported types are the ones in the table that have completed rows. For a limited set of use cases, we also implicitly support any literal type that is representable within 64 bits as StableIValues, as the default reinterpret_cast will succeed. You can work with StableIValue abstractions in your custom kernel for types such as c10::Device even if there is no standard defined representation of device in custom extensions. For example, a custom operator can take as argument a StableIValue device and directly pass it through to an aten operator with aoti_torch_call_dispatcher.
|
@ -2,6 +2,8 @@
|
||||
#include <torch/csrc/inductor/aoti_runtime/utils.h>
|
||||
#include <torch/csrc/stable/library.h>
|
||||
|
||||
#include <optional>
|
||||
|
||||
using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle;
|
||||
|
||||
void inline sgd_math(
|
||||
@ -147,3 +149,39 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_abs", &boxed_my_abs);
|
||||
}
|
||||
|
||||
RAIIATH my_ones_like(RAIIATH t, StableIValue device) {
|
||||
const auto num_args = 6;
|
||||
StableIValue stack[num_args];
|
||||
|
||||
int32_t t_dtype;
|
||||
aoti_torch_get_dtype(t.get(), &t_dtype);
|
||||
auto mf = aoti_torch_memory_format_contiguous_format();
|
||||
|
||||
stack[0] = from(t.release());
|
||||
stack[1] = from(std::optional(t_dtype)); // dtype
|
||||
stack[2] = from(std::nullopt); // layout
|
||||
stack[3] = from(std::optional(device)); // device
|
||||
stack[4] = from(std::optional(false)); // pin_memory
|
||||
stack[5] = from(std::optional(mf)); // memory_format
|
||||
|
||||
aoti_torch_call_dispatcher("aten::ones_like", "", stack);
|
||||
|
||||
return RAIIATH(to<AtenTensorHandle>(stack[0]));
|
||||
}
|
||||
|
||||
void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
RAIIATH t(to<AtenTensorHandle>(stack[0]));
|
||||
StableIValue device = stack[1];
|
||||
|
||||
RAIIATH raiiath_res = my_ones_like(std::move(t), device);
|
||||
stack[0] = from(raiiath_res.release());
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my_ones_like(Tensor t, Device d) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_ones_like", &boxed_my_ones_like);
|
||||
}
|
||||
|
@ -49,3 +49,18 @@ def my_abs(t) -> Tensor:
|
||||
a Tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_abs.default(t)
|
||||
|
||||
|
||||
def my_ones_like(tensor, device) -> Tensor:
|
||||
"""
|
||||
Returns a new Tensor like the input tensor, but with all ones
|
||||
|
||||
Args:
|
||||
tensor: any Tensor
|
||||
device: a device string
|
||||
|
||||
Returns:
|
||||
a ones Tensor with the same dtype and shape and other attributes
|
||||
like the input tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_ones_like.default(tensor, device)
|
||||
|
@ -53,7 +53,7 @@ class TestLibtorchAgnostic(TestCase):
|
||||
self.assertEqual(curr_mem, init_mem)
|
||||
|
||||
def test_my_abs(self, device):
|
||||
t = torch.rand(32, 16, device=device)
|
||||
t = torch.rand(32, 16, device=device) - 0.5
|
||||
cpu_t = libtorch_agnostic.ops.my_abs(t)
|
||||
self.assertEqual(cpu_t, torch.abs(t))
|
||||
|
||||
@ -69,6 +69,23 @@ class TestLibtorchAgnostic(TestCase):
|
||||
curr_mem = torch.cuda.memory_allocated(device)
|
||||
self.assertEqual(curr_mem, init_mem)
|
||||
|
||||
def test_my_ones_like(self, device):
|
||||
t = torch.rand(3, 1, device=device) - 0.5
|
||||
cpu_t = libtorch_agnostic.ops.my_ones_like(t, "cpu")
|
||||
self.assertEqual(cpu_t, torch.ones_like(t, device="cpu"))
|
||||
|
||||
def _make_cuda_tensors(prior_mem):
|
||||
cuda_t = libtorch_agnostic.ops.my_ones_like(t, device)
|
||||
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
|
||||
self.assertEqual(cuda_t, torch.ones_like(t, device=device))
|
||||
|
||||
if t.is_cuda:
|
||||
init_mem = torch.cuda.memory_allocated(device)
|
||||
for _ in range(3):
|
||||
_make_cuda_tensors(init_mem)
|
||||
curr_mem = torch.cuda.memory_allocated(device)
|
||||
self.assertEqual(curr_mem, init_mem)
|
||||
|
||||
@onlyCUDA
|
||||
def test_z_delete_torch_lib(self, device):
|
||||
# Why the z + CUDA? THIS TEST MUST BE RUN LAST
|
||||
|
@ -270,17 +270,29 @@ class TestCppExtensionAOT(common.TestCase):
|
||||
curr_mem = torch.cuda.memory_allocated(device)
|
||||
self.assertEqual(curr_mem, init_mem)
|
||||
|
||||
# (3) test calling our dispatcher on ones_like
|
||||
t = torch.rand(32, 16, device=device)
|
||||
cpu_t = libtorch_agnostic.ops.my_abs(t)
|
||||
self.assertEqual(cpu_t, torch.abs(t))
|
||||
# (3a) test calling our dispatcher on easy API like abs
|
||||
t = torch.rand(32, 16, device=device) - 0.5
|
||||
|
||||
def _make_cuda_tensors(prior_mem):
|
||||
cuda_t = libtorch_agnostic.ops.my_abs(t)
|
||||
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
|
||||
self.assertEqual(cuda_t, torch.abs(t))
|
||||
|
||||
if t.is_cuda:
|
||||
init_mem = torch.cuda.memory_allocated(device)
|
||||
for _ in range(3):
|
||||
_make_cuda_tensors(init_mem)
|
||||
curr_mem = torch.cuda.memory_allocated(device)
|
||||
self.assertEqual(curr_mem, init_mem)
|
||||
|
||||
# (3b) and on factory API like ones_like
|
||||
cpu_t = libtorch_agnostic.ops.my_ones_like(t, "cpu")
|
||||
self.assertEqual(cpu_t, torch.ones_like(t, device="cpu"))
|
||||
|
||||
def _make_cuda_tensors(prior_mem):
|
||||
cuda_t = libtorch_agnostic.ops.my_ones_like(t, t.device)
|
||||
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
|
||||
self.assertEqual(cuda_t, torch.ones_like(t, device=t.device))
|
||||
|
||||
init_mem = torch.cuda.memory_allocated(device)
|
||||
for _ in range(3):
|
||||
_make_cuda_tensors(init_mem)
|
||||
|
@ -1298,6 +1298,128 @@ AOTITorchError aoti_torch_zero_(AtenTensorHandle tensor) {
|
||||
});
|
||||
}
|
||||
|
||||
static StableIValue from_ivalue(
|
||||
const c10::TypePtr& type,
|
||||
const c10::IValue& ivalue) {
|
||||
switch (type->kind()) {
|
||||
case c10::TypeKind::TensorType: {
|
||||
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
|
||||
std::move(const_cast<at::Tensor&>(ivalue.toTensor())));
|
||||
return from(ath);
|
||||
}
|
||||
case c10::TypeKind::IntType: {
|
||||
return from(ivalue.toInt());
|
||||
}
|
||||
case c10::TypeKind::FloatType: {
|
||||
return from(ivalue.toDouble());
|
||||
}
|
||||
case c10::TypeKind::BoolType: {
|
||||
return from(ivalue.toBool());
|
||||
}
|
||||
case c10::TypeKind::ScalarTypeType: {
|
||||
return from(ivalue.toScalarType());
|
||||
}
|
||||
case c10::TypeKind::DeviceObjType: {
|
||||
return from(ivalue.toDevice());
|
||||
}
|
||||
case c10::TypeKind::LayoutType: {
|
||||
return from(ivalue.toLayout());
|
||||
}
|
||||
case c10::TypeKind::MemoryFormatType: {
|
||||
return from(ivalue.toMemoryFormat());
|
||||
}
|
||||
case c10::TypeKind::OptionalType: {
|
||||
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
|
||||
|
||||
// ideally, if we had the C++ type corresponding to inner_type, which we
|
||||
// will denote as inner_type::t (does not actually exist), we would be
|
||||
// able to follow the patterned semantic of every other case here in one
|
||||
// line:
|
||||
//
|
||||
// return from<std::optional<inner_type::t>>(ivalue.toInnerTypeT()));
|
||||
//
|
||||
// BUT we do NOT have that type inner_type::t readily available, so we
|
||||
// will manually unwrap and recursively call. This implementation MUST
|
||||
// be kept in sync with from<std::optional<T>> function in
|
||||
// torch/csrc/stable/library.h
|
||||
if (ivalue.isNone()) {
|
||||
return from(std::nullopt);
|
||||
}
|
||||
StableIValue* sivp = new StableIValue(from_ivalue(inner_type, ivalue));
|
||||
return from(sivp);
|
||||
}
|
||||
default: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported conversion from IValue to StableIValue for schema type: ",
|
||||
type->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static c10::IValue to_ivalue(
|
||||
const c10::TypePtr& type,
|
||||
const StableIValue stable_ivalue) {
|
||||
switch (type->kind()) {
|
||||
case c10::TypeKind::TensorType: {
|
||||
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
|
||||
to<AtenTensorHandle>(stable_ivalue));
|
||||
at::Tensor arg = *torch::aot_inductor::tensor_handle_to_tensor_pointer(
|
||||
ret_raiiath.get());
|
||||
return (c10::IValue(arg));
|
||||
}
|
||||
case c10::TypeKind::IntType: {
|
||||
return c10::IValue(to<int64_t>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::FloatType: {
|
||||
return c10::IValue(to<double>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::BoolType: {
|
||||
return c10::IValue(to<bool>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::ScalarTypeType: {
|
||||
return c10::IValue(to<c10::ScalarType>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::DeviceObjType: {
|
||||
return c10::IValue(to<c10::Device>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::LayoutType: {
|
||||
return c10::IValue(to<c10::Layout>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::MemoryFormatType: {
|
||||
return c10::IValue(to<c10::MemoryFormat>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::OptionalType: {
|
||||
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
|
||||
|
||||
// ideally, if we had the C++ type corresponding to inner_type, which we
|
||||
// will denote as inner_type::t (does not actually exist), we would be
|
||||
// able to follow the patterned semantic of every other case here in one
|
||||
// line:
|
||||
//
|
||||
// return c10::IValue(to<std::optional<inner_type::t>>(stable_ivalue));
|
||||
//
|
||||
// BUT we do NOT have that type inner_type::t readily available, so we
|
||||
// will manually unwrap and recursively call. This implementation MUST
|
||||
// be kept in sync with the to<T> function in
|
||||
// torch/csrc/stable/library.h
|
||||
if (stable_ivalue == from(std::nullopt)) {
|
||||
return c10::IValue();
|
||||
}
|
||||
auto sivp = to<StableIValue*>(stable_ivalue);
|
||||
auto ival = to_ivalue(inner_type, *sivp);
|
||||
delete sivp;
|
||||
return ival;
|
||||
}
|
||||
default: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported conversion from StableIValue to IValue for schema type: ",
|
||||
type->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class StableIValueBoxedKernel : public c10::OperatorKernel {
|
||||
public:
|
||||
StableIValueBoxedKernel(void (*fn)(StableIValue*, uint64_t, uint64_t))
|
||||
@ -1314,23 +1436,9 @@ class StableIValueBoxedKernel : public c10::OperatorKernel {
|
||||
std::vector<StableIValue> ministack(std::max(num_arguments, num_returns));
|
||||
|
||||
for (const auto idx : c10::irange(num_arguments)) {
|
||||
const c10::IValue& arg = torch::jit::pop(stack);
|
||||
const auto ministack_idx = num_arguments - idx - 1;
|
||||
if (arg.isInt()) {
|
||||
ministack[ministack_idx] = from(arg.toInt());
|
||||
} else if (arg.isDouble()) {
|
||||
ministack[ministack_idx] = from(arg.toDouble());
|
||||
} else if (arg.isBool()) {
|
||||
ministack[ministack_idx] = from(arg.toBool());
|
||||
} else if (arg.isNone()) {
|
||||
ministack[ministack_idx] = from(nullptr);
|
||||
} else if (arg.isTensor()) {
|
||||
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
|
||||
std::move(const_cast<at::Tensor&>(arg.toTensor())));
|
||||
ministack[ministack_idx] = from(ath);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Other types of IValues not yet handled!");
|
||||
}
|
||||
const c10::TypePtr& arg_type = schema.arguments()[ministack_idx].type();
|
||||
ministack[ministack_idx] = from_ivalue(arg_type, torch::jit::pop(stack));
|
||||
}
|
||||
|
||||
// boxed function is going to take a stack of StableIValues, cast them to
|
||||
@ -1341,15 +1449,7 @@ class StableIValueBoxedKernel : public c10::OperatorKernel {
|
||||
// IValue from StableIValue
|
||||
for (size_t idx = 0; idx < num_returns; idx++) {
|
||||
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
||||
if (*ret_type == *c10::getTypePtr<at::Tensor>()) {
|
||||
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
|
||||
to<AtenTensorHandle>(ministack[idx]));
|
||||
at::Tensor out = *torch::aot_inductor::tensor_handle_to_tensor_pointer(
|
||||
ret_raiiath.get());
|
||||
torch::jit::push(stack, c10::IValue(out));
|
||||
} else {
|
||||
TORCH_CHECK(false, "Only Tensor return types are currently supported!");
|
||||
}
|
||||
torch::jit::push(stack, to_ivalue(ret_type, ministack[idx]));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1430,42 +1530,6 @@ aoti_torch_delete_library_object(TorchLibraryHandle tlh) {
|
||||
{ delete reinterpret_cast<torch::Library*>(tlh); });
|
||||
}
|
||||
|
||||
static c10::IValue to_ivalue(
|
||||
const c10::TypePtr& arg_type,
|
||||
const StableIValue stable_ivalue) {
|
||||
switch (arg_type->kind()) {
|
||||
case c10::TypeKind::TensorType: {
|
||||
// stable_ivalue must be an ATH
|
||||
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
|
||||
to<AtenTensorHandle>(stable_ivalue));
|
||||
at::Tensor arg = *torch::aot_inductor::tensor_handle_to_tensor_pointer(
|
||||
ret_raiiath.get());
|
||||
return (c10::IValue(arg));
|
||||
}
|
||||
case c10::TypeKind::IntType: {
|
||||
return c10::IValue(to<int64_t>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::FloatType: {
|
||||
return c10::IValue(to<double>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::BoolType: {
|
||||
return c10::IValue(to<bool>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::ScalarTypeType: {
|
||||
return c10::IValue(to<c10::ScalarType>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::LayoutType: {
|
||||
return c10::IValue(to<c10::Layout>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::MemoryFormatType: {
|
||||
return c10::IValue(to<c10::MemoryFormat>(stable_ivalue));
|
||||
}
|
||||
default: {
|
||||
TORCH_CHECK(false, "Not yet supported argument type: ", arg_type->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_call_dispatcher(
|
||||
const char* opName,
|
||||
const char* overloadName,
|
||||
@ -1493,23 +1557,9 @@ AOTITorchError aoti_torch_call_dispatcher(
|
||||
// there should then be num_returns IValues on the stack, which
|
||||
// we will convert to StableIValue and repopulate user input stack
|
||||
for (const auto idx : c10::irange(num_returns)) {
|
||||
const c10::IValue& ret = torch::jit::pop(ivalue_stack);
|
||||
const auto stack_idx = num_returns - idx - 1;
|
||||
if (ret.isInt()) {
|
||||
stack[stack_idx] = from(ret.toInt());
|
||||
} else if (ret.isDouble()) {
|
||||
stack[stack_idx] = from(ret.toDouble());
|
||||
} else if (ret.isBool()) {
|
||||
stack[stack_idx] = from(ret.toBool());
|
||||
} else if (ret.isNone()) {
|
||||
stack[stack_idx] = from(nullptr);
|
||||
} else if (ret.isTensor()) {
|
||||
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
|
||||
std::move(const_cast<at::Tensor&>(ret.toTensor())));
|
||||
stack[stack_idx] = from(ath);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Other types of IValue returns not yet handled!");
|
||||
}
|
||||
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
||||
stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -4,12 +4,23 @@
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
|
||||
#include <optional>
|
||||
|
||||
// use anonymous namespace to avoid collisions between differing
|
||||
// versions of this file that may be included by different sources
|
||||
namespace {
|
||||
|
||||
// helpers for converting between StableIValue and actual IValues
|
||||
template <typename T>
|
||||
namespace detail {
|
||||
// utility functions to detect optional
|
||||
template <typename V>
|
||||
struct is_optional : std::false_type {};
|
||||
template <typename V>
|
||||
struct is_optional<std::optional<V>> : std::true_type {};
|
||||
} // namespace detail
|
||||
|
||||
template <
|
||||
typename T,
|
||||
std::enable_if_t<!detail::is_optional<T>::value, bool> = true>
|
||||
StableIValue from(T val) {
|
||||
static_assert(
|
||||
sizeof(T) <= sizeof(StableIValue),
|
||||
@ -17,10 +28,86 @@ StableIValue from(T val) {
|
||||
return *reinterpret_cast<StableIValue*>(&val);
|
||||
}
|
||||
|
||||
// Specialization for std::nullopt_t
|
||||
template <>
|
||||
StableIValue from(std::nullopt_t val) {
|
||||
return from(nullptr);
|
||||
}
|
||||
|
||||
// Specialization for std::optional
|
||||
// [Handling std::optional]
|
||||
// When the schema is represented by an optional type, say int?, then we
|
||||
// expect the custom extension representation to be a std::optional<int>
|
||||
// (critically NOT int!). In order for all parameters to be stably parsed and
|
||||
// handled by our dispatcher, we liaison custom extension parameters through
|
||||
// boxed kernels, meaning that every value will make its way to be an IValue:
|
||||
//
|
||||
// custom extension value --(from)-> StableIValue --(to_ivalue)-> IValue
|
||||
//
|
||||
// When the custom extension value is a literal that can be trivially
|
||||
// casted to StableIValue, e.g., an int, a float, a pointer, this route is
|
||||
// ...trivial. The below specialization is for a case when the custom
|
||||
// extension value would NOT fit within a StableIValue: a std::optional.
|
||||
//
|
||||
// If the std::optional has no value, it is treated as std::nullopt,
|
||||
// whose StableIValue representation is from(nullptr). Otherwise, we:
|
||||
// 1. unwrap the std::optional<T>
|
||||
// 2. recursively convert its value of type T to a StableIValue
|
||||
// 3. allocate heap space for said StableIValue
|
||||
// 4. convert the resulting StableIValue* into a StableIValue
|
||||
//
|
||||
// note that this allocates heap memory! which we expect to be cleaned
|
||||
// up in the to_ivalue() function defined in shim_common.cpp. We
|
||||
// purposefully hide this implementation detail from the user so that
|
||||
// all the user needs to know is:
|
||||
//
|
||||
// The schema requests an optional (T?) so I must call `from` on a
|
||||
// std::optional<T> or a std::nullopt.
|
||||
template <typename T>
|
||||
StableIValue from(std::optional<T> val) {
|
||||
if (!val.has_value()) {
|
||||
return from(std::nullopt);
|
||||
}
|
||||
StableIValue* heap_val = new StableIValue(from(val.value()));
|
||||
return from(heap_val);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
std::enable_if_t<!detail::is_optional<T>::value, bool> = true>
|
||||
T to(StableIValue val) {
|
||||
return *reinterpret_cast<T*>(&val);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
std::enable_if_t<std::is_same_v<T, std::nullopt_t>, bool> = true>
|
||||
T to(StableIValue val) {
|
||||
// val should be equivalent to from(nullptr)
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Specialization for std::optional, see [Handling std::optional] above
|
||||
// as the semantic is the same but in reverse direction as we go from
|
||||
// IValue --(from_ivalue)-> StableIValue --(to<T>)-> T in custom extension
|
||||
template <
|
||||
typename T,
|
||||
std::enable_if_t<detail::is_optional<T>::value, bool> = true>
|
||||
T to(StableIValue val) {
|
||||
using V = typename T::value_type;
|
||||
auto sivp = to<StableIValue*>(val);
|
||||
|
||||
// sivp is either nullptr or a pointer to a StableIValue
|
||||
if (sivp == nullptr) {
|
||||
return {};
|
||||
}
|
||||
auto inner_val = to<V>(*sivp);
|
||||
|
||||
// free the memory associated with StableIValue* sivp
|
||||
delete sivp;
|
||||
|
||||
return std::make_optional(inner_val);
|
||||
}
|
||||
// end to helpers for converting between StableIValue and actual IValues
|
||||
|
||||
class StableLibrary final {
|
||||
|
Reference in New Issue
Block a user