mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Widen ops support to take in IntHOArrayRef vs only std::vec (#165152)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165152 Approved by: https://github.com/mikaylagawarecki ghstack dependencies: #164991
This commit is contained in:
committed by
PyTorch MergeBot
parent
3806e9767b
commit
e4454947e2
@ -311,10 +311,9 @@ void boxed_fill_infinity(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Tensor my_pad(Tensor t) {
|
Tensor my_pad(Tensor t) {
|
||||||
std::vector<int64_t> padding = {1, 2, 2, 1};
|
|
||||||
std::string mode = "constant";
|
std::string mode = "constant";
|
||||||
double value = 0.0;
|
double value = 0.0;
|
||||||
return pad(t, padding, mode, value);
|
return pad(t, {1, 2, 2, 1}, mode, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void boxed_my_pad(
|
void boxed_my_pad(
|
||||||
@ -342,6 +341,9 @@ void boxed_my_narrow(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Tensor my_new_empty_dtype_variant(Tensor t) {
|
Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||||
|
// Still using a std::vector below even though people can just pass in an
|
||||||
|
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
|
||||||
|
// directly.
|
||||||
std::vector<int64_t> sizes = {2, 5};
|
std::vector<int64_t> sizes = {2, 5};
|
||||||
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
|
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
|
||||||
return new_empty(t, sizes, dtype);
|
return new_empty(t, sizes, dtype);
|
||||||
@ -353,9 +355,8 @@ void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, ui
|
|||||||
}
|
}
|
||||||
|
|
||||||
Tensor my_new_zeros_dtype_variant(Tensor t) {
|
Tensor my_new_zeros_dtype_variant(Tensor t) {
|
||||||
std::vector<int64_t> sizes = {2, 5};
|
|
||||||
auto dtype = std::make_optional(at::ScalarType::Float);
|
auto dtype = std::make_optional(at::ScalarType::Float);
|
||||||
return new_zeros(t, sizes, dtype);
|
return new_zeros(t, {2, 5}, dtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||||
@ -429,8 +430,7 @@ void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs)
|
|||||||
}
|
}
|
||||||
|
|
||||||
Tensor my_amax_vec(Tensor t) {
|
Tensor my_amax_vec(Tensor t) {
|
||||||
std::vector<int64_t> v = {0,1};
|
return amax(t, {0,1}, false);
|
||||||
return amax(t, v, false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||||
|
@ -5,10 +5,10 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
|
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
|
||||||
#include <torch/headeronly/core/ScalarType.h>
|
#include <torch/headeronly/core/ScalarType.h>
|
||||||
|
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
|
||||||
|
|
||||||
namespace torch::stable {
|
namespace torch::stable {
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ inline torch::stable::Tensor narrow(
|
|||||||
// only dtype information.
|
// only dtype information.
|
||||||
inline torch::stable::Tensor new_empty(
|
inline torch::stable::Tensor new_empty(
|
||||||
const torch::stable::Tensor& self,
|
const torch::stable::Tensor& self,
|
||||||
std::vector<int64_t> size,
|
torch::headeronly::IntHeaderOnlyArrayRef size,
|
||||||
std::optional<c10::ScalarType> dtype = std::nullopt) {
|
std::optional<c10::ScalarType> dtype = std::nullopt) {
|
||||||
int32_t device_type;
|
int32_t device_type;
|
||||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
|
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
|
||||||
@ -98,7 +98,7 @@ inline torch::stable::Tensor new_empty(
|
|||||||
// only dtype information.
|
// only dtype information.
|
||||||
inline torch::stable::Tensor new_zeros(
|
inline torch::stable::Tensor new_zeros(
|
||||||
const torch::stable::Tensor& self,
|
const torch::stable::Tensor& self,
|
||||||
std::vector<int64_t> size,
|
torch::headeronly::IntHeaderOnlyArrayRef size,
|
||||||
std::optional<c10::ScalarType> dtype = std::nullopt) {
|
std::optional<c10::ScalarType> dtype = std::nullopt) {
|
||||||
int32_t device_type;
|
int32_t device_type;
|
||||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
|
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
|
||||||
@ -134,12 +134,10 @@ inline torch::stable::Tensor new_zeros(
|
|||||||
|
|
||||||
// We expect this to be the stable version of the pad.default op.
|
// We expect this to be the stable version of the pad.default op.
|
||||||
// pad.default takes in a SymInt[] as the pad argument however pad is typed as
|
// pad.default takes in a SymInt[] as the pad argument however pad is typed as
|
||||||
// use std::vector<int64_t> because
|
// torch::headeronly::IntHeaderOnlyArrayRef as SymInt is not yet header-only.
|
||||||
// (1) IntArrayRef is not yet header-only
|
|
||||||
// (2) SymInt is not yet header-only
|
|
||||||
inline torch::stable::Tensor pad(
|
inline torch::stable::Tensor pad(
|
||||||
const torch::stable::Tensor& self,
|
const torch::stable::Tensor& self,
|
||||||
std::vector<int64_t> pad,
|
torch::headeronly::IntHeaderOnlyArrayRef pad,
|
||||||
const std::string& mode = "constant",
|
const std::string& mode = "constant",
|
||||||
double value = 0.0) {
|
double value = 0.0) {
|
||||||
AtenTensorHandle ret0 = nullptr;
|
AtenTensorHandle ret0 = nullptr;
|
||||||
@ -171,11 +169,10 @@ inline torch::stable::Tensor amax(
|
|||||||
// This function is an overload to compute the maximum value along each slice of
|
// This function is an overload to compute the maximum value along each slice of
|
||||||
// `self` reducing over all the dimensions in the vector `dims`. The
|
// `self` reducing over all the dimensions in the vector `dims`. The
|
||||||
// amax.default op takes in a SymInt[] as the dims argument, however dims is
|
// amax.default op takes in a SymInt[] as the dims argument, however dims is
|
||||||
// typed as use std::vector<int64_t> here because (1) IntArrayRef is not yet
|
// typed as use IntHeaderOnlyArrayRef here because SymInt is not yet header-only
|
||||||
// header-only (2) SymInt is not yet header-only
|
|
||||||
inline torch::stable::Tensor amax(
|
inline torch::stable::Tensor amax(
|
||||||
const torch::stable::Tensor& self,
|
const torch::stable::Tensor& self,
|
||||||
std::vector<int64_t> dims,
|
torch::headeronly::IntHeaderOnlyArrayRef dims,
|
||||||
bool keepdim = false) {
|
bool keepdim = false) {
|
||||||
AtenTensorHandle ret = nullptr;
|
AtenTensorHandle ret = nullptr;
|
||||||
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_amax(
|
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_amax(
|
||||||
|
Reference in New Issue
Block a user