Widen ops support to take in IntHOArrayRef vs only std::vec (#165152)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165152
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #164991
This commit is contained in:
Jane Xu
2025-10-16 13:12:44 -07:00
committed by PyTorch MergeBot
parent 3806e9767b
commit e4454947e2
2 changed files with 13 additions and 16 deletions

View File

@ -311,10 +311,9 @@ void boxed_fill_infinity(
}
Tensor my_pad(Tensor t) {
std::vector<int64_t> padding = {1, 2, 2, 1};
std::string mode = "constant";
double value = 0.0;
return pad(t, padding, mode, value);
return pad(t, {1, 2, 2, 1}, mode, value);
}
void boxed_my_pad(
@ -342,6 +341,9 @@ void boxed_my_narrow(
}
Tensor my_new_empty_dtype_variant(Tensor t) {
// Still using a std::vector below even though people can just pass in an
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
// directly.
std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
return new_empty(t, sizes, dtype);
@ -353,9 +355,8 @@ void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, ui
}
Tensor my_new_zeros_dtype_variant(Tensor t) {
std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(at::ScalarType::Float);
return new_zeros(t, sizes, dtype);
return new_zeros(t, {2, 5}, dtype);
}
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
@ -429,8 +430,7 @@ void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs)
}
Tensor my_amax_vec(Tensor t) {
std::vector<int64_t> v = {0,1};
return amax(t, v, false);
return amax(t, {0,1}, false);
}
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {

View File

@ -5,10 +5,10 @@
#include <cstdint>
#include <optional>
#include <string>
#include <vector>
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
namespace torch::stable {
@ -60,7 +60,7 @@ inline torch::stable::Tensor narrow(
// only dtype information.
inline torch::stable::Tensor new_empty(
const torch::stable::Tensor& self,
std::vector<int64_t> size,
torch::headeronly::IntHeaderOnlyArrayRef size,
std::optional<c10::ScalarType> dtype = std::nullopt) {
int32_t device_type;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
@ -98,7 +98,7 @@ inline torch::stable::Tensor new_empty(
// only dtype information.
inline torch::stable::Tensor new_zeros(
const torch::stable::Tensor& self,
std::vector<int64_t> size,
torch::headeronly::IntHeaderOnlyArrayRef size,
std::optional<c10::ScalarType> dtype = std::nullopt) {
int32_t device_type;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
@ -134,12 +134,10 @@ inline torch::stable::Tensor new_zeros(
// We expect this to be the stable version of the pad.default op.
// pad.default takes in a SymInt[] as the pad argument however pad is typed as
// use std::vector<int64_t> because
// (1) IntArrayRef is not yet header-only
// (2) SymInt is not yet header-only
// torch::headeronly::IntHeaderOnlyArrayRef as SymInt is not yet header-only.
inline torch::stable::Tensor pad(
const torch::stable::Tensor& self,
std::vector<int64_t> pad,
torch::headeronly::IntHeaderOnlyArrayRef pad,
const std::string& mode = "constant",
double value = 0.0) {
AtenTensorHandle ret0 = nullptr;
@ -171,11 +169,10 @@ inline torch::stable::Tensor amax(
// This function is an overload to compute the maximum value along each slice of
// `self` reducing over all the dimensions in the vector `dims`. The
// amax.default op takes in a SymInt[] as the dims argument, however dims is
// typed as use std::vector<int64_t> here because (1) IntArrayRef is not yet
// header-only (2) SymInt is not yet header-only
// typed as use IntHeaderOnlyArrayRef here because SymInt is not yet header-only
inline torch::stable::Tensor amax(
const torch::stable::Tensor& self,
std::vector<int64_t> dims,
torch::headeronly::IntHeaderOnlyArrayRef dims,
bool keepdim = false) {
AtenTensorHandle ret = nullptr;
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_amax(