Add new_empty (with dtype argument only) to torch::stable (#159508)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159508
Approved by: https://github.com/janeyx99
ghstack dependencies: #160557
This commit is contained in:
Mikayla Gawarecki
2025-08-19 13:54:31 -07:00
committed by PyTorch MergeBot
parent 543896fcf3
commit 78a8e6a671
9 changed files with 111 additions and 1 deletions

View File

@ -4,6 +4,7 @@
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/util/Exception.h>
#include <torch/headeronly/core/ScalarType.h>
#ifdef LAE_USE_CUDA
#include <cuda_runtime.h>
@ -340,12 +341,24 @@ void boxed_my_narrow(
stack[0] = from(res);
}
Tensor my_new_empty_dtype_variant(Tensor t) {
std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(at::ScalarType::BFloat16);
return new_empty(t, sizes, dtype);
}
void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_new_empty_dtype_variant(to<Tensor>(stack[0]));
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor");
m.def("my_empty_like(Tensor t) -> Tensor");
m.def("fill_infinity(Tensor(a!) t) -> Tensor(a!)");
m.def("my_pad(Tensor t) -> Tensor");
m.def("my_narrow(Tensor t, int dim, int start, int length) -> Tensor");
m.def("my_new_empty_dtype_variant(Tensor t) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
@ -353,6 +366,7 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_empty_like", &boxed_empty_like);
m.impl("fill_infinity", &boxed_fill_infinity);
m.impl("my_is_cpu", &boxed_my_is_cpu);
m.impl("my_new_empty_dtype_variant", &boxed_my_new_empty_dtype_variant);
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeImplicitAutograd, m) {

View File

@ -283,3 +283,15 @@ def test_get_current_device_index() -> int:
Returns: Current device index as an integer
"""
return torch.ops.libtorch_agnostic.test_get_current_device_index.default()
def my_new_empty_dtype_variant(t) -> Tensor:
"""
Returns a new empty tensor with shape [2, 5] and dtype bfloat16
Args:
t: Input tensor used as a reference for device and other properties
Returns: New empty tensor with shape [2, 5] and dtype bfloat16
"""
return torch.ops.libtorch_agnostic.my_new_empty_dtype_variant.default(t)

View File

@ -190,7 +190,7 @@ if not IS_WINDOWS:
deterministic = torch.are_deterministic_algorithms_enabled()
try:
# set use_deterministic_algorithms to fill unintialized memory
# set use_deterministic_algorithms to fill uninitialized memory
torch.use_deterministic_algorithms(True)
t = torch.rand(2, 7, device=device)
@ -322,6 +322,21 @@ if not IS_WINDOWS:
finally:
torch.cuda.set_device(prev_device)
def test_my_new_empty_dtype_variant(self, device):
import libtorch_agnostic
deterministic = torch.are_deterministic_algorithms_enabled()
try:
# set use_deterministic_algorithms to fill uninitialized memory
torch.use_deterministic_algorithms(True)
t = torch.randn(3, 4, device=device)
out = libtorch_agnostic.ops.my_new_empty_dtype_variant(t)
ref_out = t.new_empty((2, 5), dtype=torch.bfloat16)
self.assertEqual(out, ref_out, exact_device=True)
finally:
torch.use_deterministic_algorithms(deterministic)
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
if __name__ == "__main__":

View File

@ -220,6 +220,9 @@ aoti_torch_get_device_type(AtenTensorHandle tensor, int32_t* ret_device_type);
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_get_device_index(AtenTensorHandle tensor, int32_t* ret_device_index);
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_get_layout(AtenTensorHandle tensor, int32_t* ret_layout);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_storage_offset(
AtenTensorHandle tensor,
int64_t* ret_storage_offset);

View File

@ -17,6 +17,7 @@ extern "C" {
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_amax(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int32_t keepdim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_fill__Scalar(AtenTensorHandle self, double value);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_new_empty(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0);
#ifdef __cplusplus

View File

@ -389,6 +389,15 @@ AOTITorchError aoti_torch_get_device_index(
});
}
AOTITorchError aoti_torch_get_layout(
AtenTensorHandle tensor,
int32_t* ret_layout) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
*ret_layout = static_cast<int32_t>(t->layout());
});
}
AOTITorchError aoti_torch_get_storage_offset(
AtenTensorHandle tensor,
int64_t* ret_storage_offset) {

View File

@ -8,6 +8,7 @@
#include <vector>
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
#include <torch/headeronly/core/ScalarType.h>
using torch::stable::Tensor;
@ -51,6 +52,44 @@ inline Tensor narrow(Tensor& self, int64_t dim, int64_t start, int64_t length) {
return Tensor(ret0);
}
// We expect this to be a stable version of the new_empty op that takes in
// only dtype information.
inline Tensor new_empty(
const Tensor& self,
std::vector<int64_t> size,
std::optional<c10::ScalarType> dtype = std::nullopt) {
int32_t device_type;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
int32_t device_index;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_device_index(self.get(), &device_index));
int32_t target_dtype;
if (dtype.has_value()) {
target_dtype = to<int32_t>(from(dtype.value()));
} else {
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &target_dtype));
}
int32_t layout;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout));
AtenTensorHandle ret0;
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_new_empty(
self.get(),
size.data(),
static_cast<int64_t>(size.size()),
&target_dtype,
&layout,
&device_type,
device_index,
nullptr, // pin_memory (nullptr for default)
&ret0));
return Tensor(ret0);
}
// We expect this to be the stable version of the pad.default op.
// pad.default takes in a SymInt[] as the pad argument however pad is typed as
// use std::vector<int64_t> because

View File

@ -186,4 +186,5 @@ aten_shimified_ops: dict[str, dict[str, list[str]]] = {
"aten.pad.default": {},
"aten.narrow.default": {},
"aten.amax.default": {},
"aten.new_empty.default": {},
}

View File

@ -24,6 +24,7 @@ from torchgen.model import (
OperatorName,
OptionalType,
Type,
Variant,
)
from torchgen.utils import FileManager, mapMaybe
@ -396,7 +397,22 @@ def gen_static_dispatch_backend_call(
) -> str:
sig = DispatcherSignature.from_schema(f.func)
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
if backend_index is None:
# Check if this is a symint function and if the function only has method variants
if sig.symint and f.func.has_symint():
has_function_variant = Variant.function in f.variants
if not has_function_variant:
# Functions with both function and method variants can use the at::{*}_symint version
# (e.g., narrow -> at::narrow_symint), BUT
# Method-only functions with symint parameters should use at::symint:: namespace
# Remove the _symint suffix since at::symint:: namespace uses the base name
# (e.g., new_empty -> at::symint::new_empty<c10::SymInt>)
base_name = cpp_sig.name()
base_name = base_name.removesuffix("_symint") # Remove "_symint" suffix
return f"at::symint::{base_name}<c10::SymInt>"
return f"at::{cpp_sig.name()}"
else:
return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"