mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add transpose to torch/csrc/stable (#158160)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158160 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
3cb11877aa
commit
e311886e3d
@ -1,6 +1,7 @@
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
|
||||
#include <optional>
|
||||
|
||||
@ -254,3 +255,21 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("is_contiguous", &boxed_is_contiguous);
|
||||
}
|
||||
|
||||
Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) {
|
||||
return transpose(t, dim0, dim1);
|
||||
}
|
||||
|
||||
void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_transpose(to<Tensor>(stack[0]), to<int64_t>(stack[1]), to<int64_t>(stack[2]));
|
||||
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_transpose", &boxed_my_transpose);
|
||||
}
|
||||
|
@ -116,3 +116,15 @@ def is_contiguous(t) -> bool:
|
||||
Returns: is_contiguous(t)
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.is_contiguous.default(t)
|
||||
|
||||
|
||||
def my_transpose(t, dim0, dim1) -> Tensor:
|
||||
"""
|
||||
Returns t.transpose(dim0, dim1)
|
||||
|
||||
Args:
|
||||
t: Tensor
|
||||
|
||||
Returns: my_transpose(t, dim0, dim1)
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_transpose.default(t, dim0, dim1)
|
||||
|
@ -173,6 +173,16 @@ if not IS_WINDOWS:
|
||||
curr_mem = torch.cuda.memory_allocated(device)
|
||||
self.assertEqual(curr_mem, init_mem)
|
||||
|
||||
def test_my_transpose(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t = torch.rand(2, 7, device=device)
|
||||
out = libtorch_agnostic.ops.my_transpose(t, 0, 1)
|
||||
self.assertEqual(out, torch.transpose(t, 0, 1))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "API call failed"):
|
||||
libtorch_agnostic.ops.my_transpose(t, 1, 2)
|
||||
|
||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user