Add transpose to torch/csrc/stable (#158160)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158160
Approved by: https://github.com/janeyx99
This commit is contained in:
Mikayla Gawarecki
2025-07-16 08:39:30 -07:00
committed by PyTorch MergeBot
parent 3cb11877aa
commit e311886e3d
4 changed files with 58 additions and 0 deletions

View File

@ -1,6 +1,7 @@
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <optional>
@ -254,3 +255,21 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("is_contiguous", &boxed_is_contiguous);
}
Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) {
return transpose(t, dim0, dim1);
}
void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_transpose(to<Tensor>(stack[0]), to<int64_t>(stack[1]), to<int64_t>(stack[2]));
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_transpose", &boxed_my_transpose);
}

View File

@ -116,3 +116,15 @@ def is_contiguous(t) -> bool:
Returns: is_contiguous(t)
"""
return torch.ops.libtorch_agnostic.is_contiguous.default(t)
def my_transpose(t, dim0, dim1) -> Tensor:
"""
Returns t.transpose(dim0, dim1)
Args:
t: Tensor
Returns: my_transpose(t, dim0, dim1)
"""
return torch.ops.libtorch_agnostic.my_transpose.default(t, dim0, dim1)

View File

@ -173,6 +173,16 @@ if not IS_WINDOWS:
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
def test_my_transpose(self, device):
import libtorch_agnostic
t = torch.rand(2, 7, device=device)
out = libtorch_agnostic.ops.my_transpose(t, 0, 1)
self.assertEqual(out, torch.transpose(t, 0, 1))
with self.assertRaisesRegex(RuntimeError, "API call failed"):
libtorch_agnostic.ops.my_transpose(t, 1, 2)
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
if __name__ == "__main__":