mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Port amax to stable ABI (#160214)
To enable porting torchaudio to the stable ABI, we need the `amax` operation to be accessible. This PR ports the op and provides tests that it behaves correctly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160214 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
committed by
PyTorch MergeBot
parent
1fbe230b0d
commit
0a5ab612dd
@ -371,10 +371,31 @@ void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
Tensor my_amax(Tensor t) {
|
||||
return amax(t, 0, false);
|
||||
}
|
||||
|
||||
void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_amax(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
Tensor my_amax_vec(Tensor t) {
|
||||
std::vector<int64_t> v = {0,1};
|
||||
return amax(t, v, false);
|
||||
}
|
||||
|
||||
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_amax_vec(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my_zero_(Tensor(a!) t) -> Tensor(a!)");
|
||||
m.def("my_amax(Tensor a) -> Tensor");
|
||||
m.def("my_amax_vec(Tensor a) -> Tensor");
|
||||
m.def("my_is_cpu(Tensor t) -> bool");
|
||||
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
|
||||
@ -414,6 +435,8 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("test_default_constructor", &boxed_test_default_constructor);
|
||||
m.impl("my_amax", &boxed_my_amax);
|
||||
m.impl("my_amax_vec", &boxed_my_amax_vec);
|
||||
}
|
||||
|
||||
// Test functions for torch::stable::accelerator APIs
|
||||
|
@ -167,6 +167,30 @@ def my_zero_(t) -> Tensor:
|
||||
return torch.ops.libtorch_agnostic.my_zero_.default(t)
|
||||
|
||||
|
||||
def my_amax(t) -> Tensor:
|
||||
"""
|
||||
Returns t.amax()
|
||||
|
||||
Args:
|
||||
t: Tensor
|
||||
|
||||
Returns: amax(t)
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_amax.default(t)
|
||||
|
||||
|
||||
def my_amax_vec(t) -> Tensor:
|
||||
"""
|
||||
Returns t.amax()
|
||||
|
||||
Args:
|
||||
t: Tensor
|
||||
|
||||
Returns: amax(t)
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_amax_vec.default(t)
|
||||
|
||||
|
||||
def fill_infinity(t) -> Tensor:
|
||||
"""
|
||||
Fills the tensor with inf.
|
||||
|
@ -209,6 +209,20 @@ if not IS_WINDOWS:
|
||||
self.assertEqual(id(out), id(t))
|
||||
self.assertEqual(out, torch.zeros_like(t))
|
||||
|
||||
def test_my_amax(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t = torch.rand(2, 7, device=device)
|
||||
out = libtorch_agnostic.ops.my_amax(t)
|
||||
self.assertEqual(out, torch.amax(t, 0))
|
||||
|
||||
def test_my_amax_vec(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t = torch.rand(2, 7, 5, device=device)
|
||||
out = libtorch_agnostic.ops.my_amax_vec(t)
|
||||
self.assertEqual(out, torch.amax(t, (0, 1)))
|
||||
|
||||
def test_my_is_cpu(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_amax(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int32_t keepdim, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_fill__Scalar(AtenTensorHandle self, double value);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0);
|
||||
|
@ -68,6 +68,41 @@ inline Tensor pad(
|
||||
return Tensor(ret0);
|
||||
}
|
||||
|
||||
// We expect the following two functions to be stable versions of the
|
||||
// amax.default op with identical semantics to the existing amax.default op. If
|
||||
// `keepdim` is true, the result will have the same number of dimensions as
|
||||
// `self`, with the specified dimension having size 1. Otherwise, the result
|
||||
// will have one fewer dimension than `self`, with the specified dimension
|
||||
// removed.
|
||||
|
||||
// This function is an overload to compute the maximum value along each slice of
|
||||
// `self` along a single dimension `dim`.
|
||||
inline Tensor amax(Tensor& self, int64_t dim, bool keepdim = false) {
|
||||
AtenTensorHandle ret = nullptr;
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_aten_amax(self.get(), &dim, 1, keepdim, &ret));
|
||||
return Tensor(ret);
|
||||
}
|
||||
|
||||
// This function is an overload to compute the maximum value along each slice of
|
||||
// `self` reducing over all the dimensions in the vector `dims`. The
|
||||
// amax.default op takes in a SymInt[] as the dims argument, however dims is
|
||||
// typed as use std::vector<int64_t> here because (1) IntArrayRef is not yet
|
||||
// header-only (2) SymInt is not yet header-only
|
||||
inline Tensor amax(
|
||||
Tensor& self,
|
||||
std::vector<int64_t> dims,
|
||||
bool keepdim = false) {
|
||||
AtenTensorHandle ret = nullptr;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_amax(
|
||||
self.get(),
|
||||
dims.data(),
|
||||
static_cast<int64_t>(dims.size()),
|
||||
keepdim,
|
||||
&ret));
|
||||
return Tensor(ret);
|
||||
}
|
||||
|
||||
// We expect this to be the stable version of the transpose op with identical
|
||||
// semantics to the existing transpose.int op.
|
||||
inline Tensor transpose(const Tensor& self, int64_t dim0, int64_t dim1) {
|
||||
|
@ -185,4 +185,5 @@ aten_shimified_ops: dict[str, dict[str, list[str]]] = {
|
||||
"aten.fill_.Scalar": {},
|
||||
"aten.pad.default": {},
|
||||
"aten.narrow.default": {},
|
||||
"aten.amax.default": {},
|
||||
}
|
||||
|
Reference in New Issue
Block a user