Widen ops support to take in IntHOArrayRef vs only std::vec (#165152)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165152
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #164991
This commit is contained in:
Jane Xu
2025-10-16 13:12:44 -07:00
committed by PyTorch MergeBot
parent 3806e9767b
commit e4454947e2
2 changed files with 13 additions and 16 deletions

View File

@ -311,10 +311,9 @@ void boxed_fill_infinity(
} }
Tensor my_pad(Tensor t) { Tensor my_pad(Tensor t) {
std::vector<int64_t> padding = {1, 2, 2, 1};
std::string mode = "constant"; std::string mode = "constant";
double value = 0.0; double value = 0.0;
return pad(t, padding, mode, value); return pad(t, {1, 2, 2, 1}, mode, value);
} }
void boxed_my_pad( void boxed_my_pad(
@ -342,6 +341,9 @@ void boxed_my_narrow(
} }
Tensor my_new_empty_dtype_variant(Tensor t) { Tensor my_new_empty_dtype_variant(Tensor t) {
// Still using a std::vector below even though people can just pass in an
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
// directly.
std::vector<int64_t> sizes = {2, 5}; std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16); auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
return new_empty(t, sizes, dtype); return new_empty(t, sizes, dtype);
@ -353,9 +355,8 @@ void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, ui
} }
Tensor my_new_zeros_dtype_variant(Tensor t) { Tensor my_new_zeros_dtype_variant(Tensor t) {
std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(at::ScalarType::Float); auto dtype = std::make_optional(at::ScalarType::Float);
return new_zeros(t, sizes, dtype); return new_zeros(t, {2, 5}, dtype);
} }
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
@ -429,8 +430,7 @@ void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs)
} }
Tensor my_amax_vec(Tensor t) { Tensor my_amax_vec(Tensor t) {
std::vector<int64_t> v = {0,1}; return amax(t, {0,1}, false);
return amax(t, v, false);
} }
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {

View File

@ -5,10 +5,10 @@
#include <cstdint> #include <cstdint>
#include <optional> #include <optional>
#include <string> #include <string>
#include <vector>
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h> #include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
#include <torch/headeronly/core/ScalarType.h> #include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
namespace torch::stable { namespace torch::stable {
@ -60,7 +60,7 @@ inline torch::stable::Tensor narrow(
// only dtype information. // only dtype information.
inline torch::stable::Tensor new_empty( inline torch::stable::Tensor new_empty(
const torch::stable::Tensor& self, const torch::stable::Tensor& self,
std::vector<int64_t> size, torch::headeronly::IntHeaderOnlyArrayRef size,
std::optional<c10::ScalarType> dtype = std::nullopt) { std::optional<c10::ScalarType> dtype = std::nullopt) {
int32_t device_type; int32_t device_type;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type)); TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
@ -98,7 +98,7 @@ inline torch::stable::Tensor new_empty(
// only dtype information. // only dtype information.
inline torch::stable::Tensor new_zeros( inline torch::stable::Tensor new_zeros(
const torch::stable::Tensor& self, const torch::stable::Tensor& self,
std::vector<int64_t> size, torch::headeronly::IntHeaderOnlyArrayRef size,
std::optional<c10::ScalarType> dtype = std::nullopt) { std::optional<c10::ScalarType> dtype = std::nullopt) {
int32_t device_type; int32_t device_type;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type)); TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
@ -134,12 +134,10 @@ inline torch::stable::Tensor new_zeros(
// We expect this to be the stable version of the pad.default op. // We expect this to be the stable version of the pad.default op.
// pad.default takes in a SymInt[] as the pad argument however pad is typed as // pad.default takes in a SymInt[] as the pad argument however pad is typed as
// use std::vector<int64_t> because // torch::headeronly::IntHeaderOnlyArrayRef as SymInt is not yet header-only.
// (1) IntArrayRef is not yet header-only
// (2) SymInt is not yet header-only
inline torch::stable::Tensor pad( inline torch::stable::Tensor pad(
const torch::stable::Tensor& self, const torch::stable::Tensor& self,
std::vector<int64_t> pad, torch::headeronly::IntHeaderOnlyArrayRef pad,
const std::string& mode = "constant", const std::string& mode = "constant",
double value = 0.0) { double value = 0.0) {
AtenTensorHandle ret0 = nullptr; AtenTensorHandle ret0 = nullptr;
@ -171,11 +169,10 @@ inline torch::stable::Tensor amax(
// This function is an overload to compute the maximum value along each slice of // This function is an overload to compute the maximum value along each slice of
// `self` reducing over all the dimensions in the vector `dims`. The // `self` reducing over all the dimensions in the vector `dims`. The
// amax.default op takes in a SymInt[] as the dims argument, however dims is // amax.default op takes in a SymInt[] as the dims argument, however dims is
// typed as use std::vector<int64_t> here because (1) IntArrayRef is not yet // typed as use IntHeaderOnlyArrayRef here because SymInt is not yet header-only
// header-only (2) SymInt is not yet header-only
inline torch::stable::Tensor amax( inline torch::stable::Tensor amax(
const torch::stable::Tensor& self, const torch::stable::Tensor& self,
std::vector<int64_t> dims, torch::headeronly::IntHeaderOnlyArrayRef dims,
bool keepdim = false) { bool keepdim = false) {
AtenTensorHandle ret = nullptr; AtenTensorHandle ret = nullptr;
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_amax( TORCH_ERROR_CODE_CHECK(aoti_torch_aten_amax(