Add new_zeros dtype variant to the shim and as a stable op (#161597)

In case we want this before 2.9
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161597
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Jane Xu
2025-08-28 13:57:24 +00:00
committed by PyTorch MergeBot
parent 05d0f11dbd
commit 63632fc7ee
6 changed files with 74 additions and 1 deletions

View File

@ -343,7 +343,7 @@ void boxed_my_narrow(
Tensor my_new_empty_dtype_variant(Tensor t) {
std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(at::ScalarType::BFloat16);
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
return new_empty(t, sizes, dtype);
}
@ -352,6 +352,17 @@ void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, ui
stack[0] = from(res);
}
Tensor my_new_zeros_dtype_variant(Tensor t) {
std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(at::ScalarType::Float);
return new_zeros(t, sizes, dtype);
}
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_new_zeros_dtype_variant(to<Tensor>(stack[0]));
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor");
m.def("my_empty_like(Tensor t) -> Tensor");
@ -359,6 +370,7 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my_pad(Tensor t) -> Tensor");
m.def("my_narrow(Tensor t, int dim, int start, int length) -> Tensor");
m.def("my_new_empty_dtype_variant(Tensor t) -> Tensor");
m.def("my_new_zeros_dtype_variant(Tensor t) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
@ -367,6 +379,7 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("fill_infinity", &boxed_fill_infinity);
m.impl("my_is_cpu", &boxed_my_is_cpu);
m.impl("my_new_empty_dtype_variant", &boxed_my_new_empty_dtype_variant);
m.impl("my_new_zeros_dtype_variant", &boxed_my_new_zeros_dtype_variant);
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeImplicitAutograd, m) {

View File

@ -295,3 +295,15 @@ def my_new_empty_dtype_variant(t) -> Tensor:
Returns: New empty tensor with shape [2, 5] and dtype bfloat16
"""
return torch.ops.libtorch_agnostic.my_new_empty_dtype_variant.default(t)
def my_new_zeros_dtype_variant(t) -> Tensor:
"""
Returns a new tensor filled with 0s with shape [2, 5] and dtype Float
Args:
t: Input tensor used as a reference for device and other properties
Returns: New zeros tensor
"""
return torch.ops.libtorch_agnostic.my_new_zeros_dtype_variant.default(t)

View File

@ -337,6 +337,14 @@ if not IS_WINDOWS:
finally:
torch.use_deterministic_algorithms(deterministic)
def test_my_new_zeros_dtype_variant(self, device):
import libtorch_agnostic
t = torch.randn(3, 4, device=device)
out = libtorch_agnostic.ops.my_new_zeros_dtype_variant(t)
ref_out = t.new_zeros((2, 5), dtype=torch.float)
self.assertEqual(out, ref_out, exact_device=True)
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
if __name__ == "__main__":

View File

@ -18,6 +18,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_amax(AtenTensorHandle self, con
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_fill__Scalar(AtenTensorHandle self, double value);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_new_empty(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_new_zeros(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0);
#ifdef __cplusplus

View File

@ -90,6 +90,44 @@ inline Tensor new_empty(
return Tensor(ret0);
}
// We expect this to be a stable version of the new_zeros op that takes in
// only dtype information.
inline Tensor new_zeros(
const Tensor& self,
std::vector<int64_t> size,
std::optional<c10::ScalarType> dtype = std::nullopt) {
int32_t device_type;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
int32_t device_index;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_device_index(self.get(), &device_index));
int32_t target_dtype;
if (dtype.has_value()) {
target_dtype = to<int32_t>(from(dtype.value()));
} else {
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &target_dtype));
}
int32_t layout;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout));
AtenTensorHandle ath;
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_new_zeros(
self.get(),
size.data(),
static_cast<int64_t>(size.size()),
&target_dtype,
&layout,
&device_type,
device_index,
nullptr, // pin_memory (nullptr for default)
&ath));
return Tensor(ath);
}
// We expect this to be the stable version of the pad.default op.
// pad.default takes in a SymInt[] as the pad argument however pad is typed as
// use std::vector<int64_t> because

View File

@ -187,4 +187,5 @@ aten_shimified_ops: dict[str, dict[str, list[str]]] = {
"aten.narrow.default": {},
"aten.amax.default": {},
"aten.new_empty.default": {},
"aten.new_zeros.default": {},
}