Port amax to stable ABI (#160214)

To enable porting torchaudio to the stable ABI, we need the `amax` operation to be accessible. This PR ports the op and provides tests that it behaves correctly.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160214
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Sam Anklesaria
2025-08-19 17:24:53 +00:00
committed by PyTorch MergeBot
parent 1fbe230b0d
commit 0a5ab612dd
6 changed files with 99 additions and 1 deletions

View File

@ -371,10 +371,31 @@ void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs
stack[0] = from(res);
}
Tensor my_amax(Tensor t) {
return amax(t, 0, false);
}
void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_amax(to<Tensor>(stack[0]));
stack[0] = from(res);
}
Tensor my_amax_vec(Tensor t) {
std::vector<int64_t> v = {0,1};
return amax(t, v, false);
}
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_amax_vec(to<Tensor>(stack[0]));
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my_zero_(Tensor(a!) t) -> Tensor(a!)");
m.def("my_amax(Tensor a) -> Tensor");
m.def("my_amax_vec(Tensor a) -> Tensor");
m.def("my_is_cpu(Tensor t) -> bool");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
@ -414,6 +435,8 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("test_default_constructor", &boxed_test_default_constructor);
m.impl("my_amax", &boxed_my_amax);
m.impl("my_amax_vec", &boxed_my_amax_vec);
}
// Test functions for torch::stable::accelerator APIs

View File

@ -167,6 +167,30 @@ def my_zero_(t) -> Tensor:
return torch.ops.libtorch_agnostic.my_zero_.default(t)
def my_amax(t) -> Tensor:
"""
Returns t.amax()
Args:
t: Tensor
Returns: amax(t)
"""
return torch.ops.libtorch_agnostic.my_amax.default(t)
def my_amax_vec(t) -> Tensor:
"""
Returns t.amax()
Args:
t: Tensor
Returns: amax(t)
"""
return torch.ops.libtorch_agnostic.my_amax_vec.default(t)
def fill_infinity(t) -> Tensor:
"""
Fills the tensor with inf.

View File

@ -209,6 +209,20 @@ if not IS_WINDOWS:
self.assertEqual(id(out), id(t))
self.assertEqual(out, torch.zeros_like(t))
def test_my_amax(self, device):
import libtorch_agnostic
t = torch.rand(2, 7, device=device)
out = libtorch_agnostic.ops.my_amax(t)
self.assertEqual(out, torch.amax(t, 0))
def test_my_amax_vec(self, device):
import libtorch_agnostic
t = torch.rand(2, 7, 5, device=device)
out = libtorch_agnostic.ops.my_amax_vec(t)
self.assertEqual(out, torch.amax(t, (0, 1)))
def test_my_is_cpu(self, device):
import libtorch_agnostic