Use schema as source of truth + support ones_like/empty_like (#149052)

This change does 2 important things:
(a) Instead of relying on IValue type as source of truth, we use the schema as the source of truth, which is important as IValue types are overloaded and can ambiguously convert incorrectly. For example, a MemoryFormat will look like an int + get converted to an int64_t vs a MemoryFormat!

(b) This PR expands support for many more types to encompass way more schemas, e.g., Optional, Device, dtype, etc. The main win from this PR is the ability for aoti_torch_call_dispatcher to call TensorFactory ops like ones_like/empty_like!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149052
Approved by: https://github.com/albanD
This commit is contained in:
Jane Xu
2025-03-17 16:33:38 -07:00
committed by PyTorch MergeBot
parent ebabd0efdd
commit 988827cdfb
7 changed files with 343 additions and 90 deletions

View File

@ -0,0 +1,34 @@
# LibTorch Stable ABI
This note will eventually contain more details on how to use the APIs in torch/csrc/stable. For the moment, it contains a table of internal representations:
1. type in custom extension: type used within the end user custom library.
2. StableIValue representation: a stable conversion of the type to liaison between the user model vs libtorch.so in an ABI-stable manner.
3. type in libtorch: type used within libtorch.so (or any code binary locked with libtorch).
4. Schema Type: type as described by the schema, which we hail as the source of truth for both ATen ops in native_functions.yaml and for user defined custom operators registered to the dispatcher via TORCH_LIBRARY or torch.library.
| type in custom extension | StableIValue representation | type in libtorch | Schema Type |
| -------- | ------- | ------- | ------- |
| std::optional\<S> | \*reinterpret_cast\<(StableIValue\*)\*>, pointer to a StableIValue recursively defined | std::optional\<T> | Type? |
| std::nullopt | \*reinterpret_cast\<nullptr_t\*> | IValue() | None |
| RAIIATH | \*reinterpret_cast\<uint64_t\*> of AtenTensorHandle | at::Tensor | Tensor |
| int32_t | \*reinterpret_cast\<uint64_t\*> | at::ScalarType | ScalarType |
| int32_t | \*reinterpret_cast\<uint64_t\*> | at::Layout | Layout |
| int32_t | \*reinterpret_cast\<uint64_t\*> | at::MemoryFormat | MemoryFormat |
| bool | \*reinterpret_cast\<uint64_t\*> | bool | bool |
| int64_t | \*reinterpret_cast\<uint64_t\*> | int64_t | int |
| double | \*reinterpret_cast\<uint64_t\*> | double | float |
| ? | ? | c10::Device | Device |
| ? | ? | c10::Stream | Stream |
| ? | ? | c10::complex<double> | complex |
| ? | ? | at::Scalar | Scalar |
| ? | ? | std::string/const char*/ivalue::ConstantString | str |
| ? | ? | at::Storage | Storage |
| ? | ? | at::Generator | Generator |
| ? | ? | c10::List\<T> | Type[] |
| ? | ? | ivalue::Tuple\<T> | (Type, ...) |
| ? | ? | c10::SymInt | SymInt |
| ? | ? | c10::SymFloat | SymFloat |
| ? | ? | c10::SymBool | SymBool |
| ? | ? | at::QScheme | QScheme |
Our confidently supported types are the ones in the table that have completed rows. For a limited set of use cases, we also implicitly support any literal type that is representable within 64 bits as StableIValues, as the default reinterpret_cast will succeed. You can work with StableIValue abstractions in your custom kernel for types such as c10::Device even if there is no standard defined representation of device in custom extensions. For example, a custom operator can take as argument a StableIValue device and directly pass it through to an aten operator with aoti_torch_call_dispatcher.

View File

@ -2,6 +2,8 @@
#include <torch/csrc/inductor/aoti_runtime/utils.h>
#include <torch/csrc/stable/library.h>
#include <optional>
using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle;
void inline sgd_math(
@ -147,3 +149,39 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_abs", &boxed_my_abs);
}
RAIIATH my_ones_like(RAIIATH t, StableIValue device) {
const auto num_args = 6;
StableIValue stack[num_args];
int32_t t_dtype;
aoti_torch_get_dtype(t.get(), &t_dtype);
auto mf = aoti_torch_memory_format_contiguous_format();
stack[0] = from(t.release());
stack[1] = from(std::optional(t_dtype)); // dtype
stack[2] = from(std::nullopt); // layout
stack[3] = from(std::optional(device)); // device
stack[4] = from(std::optional(false)); // pin_memory
stack[5] = from(std::optional(mf)); // memory_format
aoti_torch_call_dispatcher("aten::ones_like", "", stack);
return RAIIATH(to<AtenTensorHandle>(stack[0]));
}
void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
RAIIATH t(to<AtenTensorHandle>(stack[0]));
StableIValue device = stack[1];
RAIIATH raiiath_res = my_ones_like(std::move(t), device);
stack[0] = from(raiiath_res.release());
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my_ones_like(Tensor t, Device d) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_ones_like", &boxed_my_ones_like);
}

View File

@ -49,3 +49,18 @@ def my_abs(t) -> Tensor:
a Tensor
"""
return torch.ops.libtorch_agnostic.my_abs.default(t)
def my_ones_like(tensor, device) -> Tensor:
"""
Returns a new Tensor like the input tensor, but with all ones
Args:
tensor: any Tensor
device: a device string
Returns:
a ones Tensor with the same dtype and shape and other attributes
like the input tensor
"""
return torch.ops.libtorch_agnostic.my_ones_like.default(tensor, device)

View File

@ -53,7 +53,7 @@ class TestLibtorchAgnostic(TestCase):
self.assertEqual(curr_mem, init_mem)
def test_my_abs(self, device):
t = torch.rand(32, 16, device=device)
t = torch.rand(32, 16, device=device) - 0.5
cpu_t = libtorch_agnostic.ops.my_abs(t)
self.assertEqual(cpu_t, torch.abs(t))
@ -69,6 +69,23 @@ class TestLibtorchAgnostic(TestCase):
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
def test_my_ones_like(self, device):
t = torch.rand(3, 1, device=device) - 0.5
cpu_t = libtorch_agnostic.ops.my_ones_like(t, "cpu")
self.assertEqual(cpu_t, torch.ones_like(t, device="cpu"))
def _make_cuda_tensors(prior_mem):
cuda_t = libtorch_agnostic.ops.my_ones_like(t, device)
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
self.assertEqual(cuda_t, torch.ones_like(t, device=device))
if t.is_cuda:
init_mem = torch.cuda.memory_allocated(device)
for _ in range(3):
_make_cuda_tensors(init_mem)
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
@onlyCUDA
def test_z_delete_torch_lib(self, device):
# Why the z + CUDA? THIS TEST MUST BE RUN LAST

View File

@ -270,22 +270,34 @@ class TestCppExtensionAOT(common.TestCase):
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
# (3) test calling our dispatcher on ones_like
t = torch.rand(32, 16, device=device)
cpu_t = libtorch_agnostic.ops.my_abs(t)
self.assertEqual(cpu_t, torch.abs(t))
# (3a) test calling our dispatcher on easy API like abs
t = torch.rand(32, 16, device=device) - 0.5
def _make_cuda_tensors(prior_mem):
cuda_t = libtorch_agnostic.ops.my_abs(t)
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
self.assertEqual(cuda_t, torch.abs(t))
if t.is_cuda:
init_mem = torch.cuda.memory_allocated(device)
for _ in range(3):
_make_cuda_tensors(init_mem)
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
init_mem = torch.cuda.memory_allocated(device)
for _ in range(3):
_make_cuda_tensors(init_mem)
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
# (3b) and on factory API like ones_like
cpu_t = libtorch_agnostic.ops.my_ones_like(t, "cpu")
self.assertEqual(cpu_t, torch.ones_like(t, device="cpu"))
def _make_cuda_tensors(prior_mem):
cuda_t = libtorch_agnostic.ops.my_ones_like(t, t.device)
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
self.assertEqual(cuda_t, torch.ones_like(t, device=t.device))
init_mem = torch.cuda.memory_allocated(device)
for _ in range(3):
_make_cuda_tensors(init_mem)
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
@torch.testing._internal.common_utils.markDynamoStrictTest

View File

@ -1298,6 +1298,128 @@ AOTITorchError aoti_torch_zero_(AtenTensorHandle tensor) {
});
}
static StableIValue from_ivalue(
const c10::TypePtr& type,
const c10::IValue& ivalue) {
switch (type->kind()) {
case c10::TypeKind::TensorType: {
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
std::move(const_cast<at::Tensor&>(ivalue.toTensor())));
return from(ath);
}
case c10::TypeKind::IntType: {
return from(ivalue.toInt());
}
case c10::TypeKind::FloatType: {
return from(ivalue.toDouble());
}
case c10::TypeKind::BoolType: {
return from(ivalue.toBool());
}
case c10::TypeKind::ScalarTypeType: {
return from(ivalue.toScalarType());
}
case c10::TypeKind::DeviceObjType: {
return from(ivalue.toDevice());
}
case c10::TypeKind::LayoutType: {
return from(ivalue.toLayout());
}
case c10::TypeKind::MemoryFormatType: {
return from(ivalue.toMemoryFormat());
}
case c10::TypeKind::OptionalType: {
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
// ideally, if we had the C++ type corresponding to inner_type, which we
// will denote as inner_type::t (does not actually exist), we would be
// able to follow the patterned semantic of every other case here in one
// line:
//
// return from<std::optional<inner_type::t>>(ivalue.toInnerTypeT()));
//
// BUT we do NOT have that type inner_type::t readily available, so we
// will manually unwrap and recursively call. This implementation MUST
// be kept in sync with from<std::optional<T>> function in
// torch/csrc/stable/library.h
if (ivalue.isNone()) {
return from(std::nullopt);
}
StableIValue* sivp = new StableIValue(from_ivalue(inner_type, ivalue));
return from(sivp);
}
default: {
TORCH_CHECK(
false,
"Not yet supported conversion from IValue to StableIValue for schema type: ",
type->str());
}
}
}
static c10::IValue to_ivalue(
const c10::TypePtr& type,
const StableIValue stable_ivalue) {
switch (type->kind()) {
case c10::TypeKind::TensorType: {
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
to<AtenTensorHandle>(stable_ivalue));
at::Tensor arg = *torch::aot_inductor::tensor_handle_to_tensor_pointer(
ret_raiiath.get());
return (c10::IValue(arg));
}
case c10::TypeKind::IntType: {
return c10::IValue(to<int64_t>(stable_ivalue));
}
case c10::TypeKind::FloatType: {
return c10::IValue(to<double>(stable_ivalue));
}
case c10::TypeKind::BoolType: {
return c10::IValue(to<bool>(stable_ivalue));
}
case c10::TypeKind::ScalarTypeType: {
return c10::IValue(to<c10::ScalarType>(stable_ivalue));
}
case c10::TypeKind::DeviceObjType: {
return c10::IValue(to<c10::Device>(stable_ivalue));
}
case c10::TypeKind::LayoutType: {
return c10::IValue(to<c10::Layout>(stable_ivalue));
}
case c10::TypeKind::MemoryFormatType: {
return c10::IValue(to<c10::MemoryFormat>(stable_ivalue));
}
case c10::TypeKind::OptionalType: {
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
// ideally, if we had the C++ type corresponding to inner_type, which we
// will denote as inner_type::t (does not actually exist), we would be
// able to follow the patterned semantic of every other case here in one
// line:
//
// return c10::IValue(to<std::optional<inner_type::t>>(stable_ivalue));
//
// BUT we do NOT have that type inner_type::t readily available, so we
// will manually unwrap and recursively call. This implementation MUST
// be kept in sync with the to<T> function in
// torch/csrc/stable/library.h
if (stable_ivalue == from(std::nullopt)) {
return c10::IValue();
}
auto sivp = to<StableIValue*>(stable_ivalue);
auto ival = to_ivalue(inner_type, *sivp);
delete sivp;
return ival;
}
default: {
TORCH_CHECK(
false,
"Not yet supported conversion from StableIValue to IValue for schema type: ",
type->str());
}
}
}
class StableIValueBoxedKernel : public c10::OperatorKernel {
public:
StableIValueBoxedKernel(void (*fn)(StableIValue*, uint64_t, uint64_t))
@ -1314,23 +1436,9 @@ class StableIValueBoxedKernel : public c10::OperatorKernel {
std::vector<StableIValue> ministack(std::max(num_arguments, num_returns));
for (const auto idx : c10::irange(num_arguments)) {
const c10::IValue& arg = torch::jit::pop(stack);
const auto ministack_idx = num_arguments - idx - 1;
if (arg.isInt()) {
ministack[ministack_idx] = from(arg.toInt());
} else if (arg.isDouble()) {
ministack[ministack_idx] = from(arg.toDouble());
} else if (arg.isBool()) {
ministack[ministack_idx] = from(arg.toBool());
} else if (arg.isNone()) {
ministack[ministack_idx] = from(nullptr);
} else if (arg.isTensor()) {
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
std::move(const_cast<at::Tensor&>(arg.toTensor())));
ministack[ministack_idx] = from(ath);
} else {
TORCH_CHECK(false, "Other types of IValues not yet handled!");
}
const c10::TypePtr& arg_type = schema.arguments()[ministack_idx].type();
ministack[ministack_idx] = from_ivalue(arg_type, torch::jit::pop(stack));
}
// boxed function is going to take a stack of StableIValues, cast them to
@ -1341,15 +1449,7 @@ class StableIValueBoxedKernel : public c10::OperatorKernel {
// IValue from StableIValue
for (size_t idx = 0; idx < num_returns; idx++) {
const c10::TypePtr& ret_type = schema.returns()[idx].type();
if (*ret_type == *c10::getTypePtr<at::Tensor>()) {
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
to<AtenTensorHandle>(ministack[idx]));
at::Tensor out = *torch::aot_inductor::tensor_handle_to_tensor_pointer(
ret_raiiath.get());
torch::jit::push(stack, c10::IValue(out));
} else {
TORCH_CHECK(false, "Only Tensor return types are currently supported!");
}
torch::jit::push(stack, to_ivalue(ret_type, ministack[idx]));
}
}
@ -1430,42 +1530,6 @@ aoti_torch_delete_library_object(TorchLibraryHandle tlh) {
{ delete reinterpret_cast<torch::Library*>(tlh); });
}
static c10::IValue to_ivalue(
const c10::TypePtr& arg_type,
const StableIValue stable_ivalue) {
switch (arg_type->kind()) {
case c10::TypeKind::TensorType: {
// stable_ivalue must be an ATH
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
to<AtenTensorHandle>(stable_ivalue));
at::Tensor arg = *torch::aot_inductor::tensor_handle_to_tensor_pointer(
ret_raiiath.get());
return (c10::IValue(arg));
}
case c10::TypeKind::IntType: {
return c10::IValue(to<int64_t>(stable_ivalue));
}
case c10::TypeKind::FloatType: {
return c10::IValue(to<double>(stable_ivalue));
}
case c10::TypeKind::BoolType: {
return c10::IValue(to<bool>(stable_ivalue));
}
case c10::TypeKind::ScalarTypeType: {
return c10::IValue(to<c10::ScalarType>(stable_ivalue));
}
case c10::TypeKind::LayoutType: {
return c10::IValue(to<c10::Layout>(stable_ivalue));
}
case c10::TypeKind::MemoryFormatType: {
return c10::IValue(to<c10::MemoryFormat>(stable_ivalue));
}
default: {
TORCH_CHECK(false, "Not yet supported argument type: ", arg_type->str());
}
}
}
AOTITorchError aoti_torch_call_dispatcher(
const char* opName,
const char* overloadName,
@ -1493,23 +1557,9 @@ AOTITorchError aoti_torch_call_dispatcher(
// there should then be num_returns IValues on the stack, which
// we will convert to StableIValue and repopulate user input stack
for (const auto idx : c10::irange(num_returns)) {
const c10::IValue& ret = torch::jit::pop(ivalue_stack);
const auto stack_idx = num_returns - idx - 1;
if (ret.isInt()) {
stack[stack_idx] = from(ret.toInt());
} else if (ret.isDouble()) {
stack[stack_idx] = from(ret.toDouble());
} else if (ret.isBool()) {
stack[stack_idx] = from(ret.toBool());
} else if (ret.isNone()) {
stack[stack_idx] = from(nullptr);
} else if (ret.isTensor()) {
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
std::move(const_cast<at::Tensor&>(ret.toTensor())));
stack[stack_idx] = from(ath);
} else {
TORCH_CHECK(false, "Other types of IValue returns not yet handled!");
}
const c10::TypePtr& ret_type = schema.returns()[idx].type();
stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack));
}
});
}

View File

@ -4,12 +4,23 @@
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <optional>
// use anonymous namespace to avoid collisions between differing
// versions of this file that may be included by different sources
namespace {
// helpers for converting between StableIValue and actual IValues
template <typename T>
namespace detail {
// utility functions to detect optional
template <typename V>
struct is_optional : std::false_type {};
template <typename V>
struct is_optional<std::optional<V>> : std::true_type {};
} // namespace detail
template <
typename T,
std::enable_if_t<!detail::is_optional<T>::value, bool> = true>
StableIValue from(T val) {
static_assert(
sizeof(T) <= sizeof(StableIValue),
@ -17,10 +28,86 @@ StableIValue from(T val) {
return *reinterpret_cast<StableIValue*>(&val);
}
// Specialization for std::nullopt_t
template <>
StableIValue from(std::nullopt_t val) {
return from(nullptr);
}
// Specialization for std::optional
// [Handling std::optional]
// When the schema is represented by an optional type, say int?, then we
// expect the custom extension representation to be a std::optional<int>
// (critically NOT int!). In order for all parameters to be stably parsed and
// handled by our dispatcher, we liaison custom extension parameters through
// boxed kernels, meaning that every value will make its way to be an IValue:
//
// custom extension value --(from)-> StableIValue --(to_ivalue)-> IValue
//
// When the custom extension value is a literal that can be trivially
// casted to StableIValue, e.g., an int, a float, a pointer, this route is
// ...trivial. The below specialization is for a case when the custom
// extension value would NOT fit within a StableIValue: a std::optional.
//
// If the std::optional has no value, it is treated as std::nullopt,
// whose StableIValue representation is from(nullptr). Otherwise, we:
// 1. unwrap the std::optional<T>
// 2. recursively convert its value of type T to a StableIValue
// 3. allocate heap space for said StableIValue
// 4. convert the resulting StableIValue* into a StableIValue
//
// note that this allocates heap memory! which we expect to be cleaned
// up in the to_ivalue() function defined in shim_common.cpp. We
// purposefully hide this implementation detail from the user so that
// all the user needs to know is:
//
// The schema requests an optional (T?) so I must call `from` on a
// std::optional<T> or a std::nullopt.
template <typename T>
StableIValue from(std::optional<T> val) {
if (!val.has_value()) {
return from(std::nullopt);
}
StableIValue* heap_val = new StableIValue(from(val.value()));
return from(heap_val);
}
template <
typename T,
std::enable_if_t<!detail::is_optional<T>::value, bool> = true>
T to(StableIValue val) {
return *reinterpret_cast<T*>(&val);
}
template <
typename T,
std::enable_if_t<std::is_same_v<T, std::nullopt_t>, bool> = true>
T to(StableIValue val) {
// val should be equivalent to from(nullptr)
return std::nullopt;
}
// Specialization for std::optional, see [Handling std::optional] above
// as the semantic is the same but in reverse direction as we go from
// IValue --(from_ivalue)-> StableIValue --(to<T>)-> T in custom extension
template <
typename T,
std::enable_if_t<detail::is_optional<T>::value, bool> = true>
T to(StableIValue val) {
using V = typename T::value_type;
auto sivp = to<StableIValue*>(val);
// sivp is either nullptr or a pointer to a StableIValue
if (sivp == nullptr) {
return {};
}
auto inner_val = to<V>(*sivp);
// free the memory associated with StableIValue* sivp
delete sivp;
return std::make_optional(inner_val);
}
// end to helpers for converting between StableIValue and actual IValues
class StableLibrary final {