mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Add STABLE_LIBRARY test for multiple returns (#149230)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149230 Approved by: https://github.com/albanD, https://github.com/zou3519 ghstack dependencies: #149052
This commit is contained in:
committed by
PyTorch MergeBot
parent
988827cdfb
commit
cccdf860e2
@ -185,3 +185,40 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_ones_like", &boxed_my_ones_like);
|
||||
}
|
||||
|
||||
std::tuple<RAIIATH, RAIIATH, bool> exp_neg_is_leaf(RAIIATH t1, RAIIATH t2, RAIIATH t3) {
|
||||
StableIValue stack1[1];
|
||||
stack1[0] = from(t1.release());
|
||||
aoti_torch_call_dispatcher("aten::exp", "", stack1);
|
||||
|
||||
StableIValue stack2[1];
|
||||
stack2[0] = from(t2.release());
|
||||
aoti_torch_call_dispatcher("aten::neg", "", stack2);
|
||||
|
||||
StableIValue stack3[1];
|
||||
stack3[0] = from(t3.release());
|
||||
aoti_torch_call_dispatcher("aten::is_leaf", "", stack3);
|
||||
|
||||
return std::make_tuple(
|
||||
RAIIATH(to<AtenTensorHandle>(stack1[0])),
|
||||
RAIIATH(to<AtenTensorHandle>(stack2[0])),
|
||||
to<bool>(stack3[0]));
|
||||
}
|
||||
|
||||
void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
RAIIATH t1(to<AtenTensorHandle>(stack[0]));
|
||||
RAIIATH t2(to<AtenTensorHandle>(stack[1]));
|
||||
RAIIATH t3(to<AtenTensorHandle>(stack[2]));
|
||||
auto tuple = exp_neg_is_leaf(std::move(t1), std::move(t2), std::move(t3));
|
||||
stack[0] = from(std::get<0>(tuple).release());
|
||||
stack[1] = from(std::get<1>(tuple).release());
|
||||
stack[2] = from(std::get<2>(tuple));
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) -> (Tensor, Tensor, bool)");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("exp_neg_is_leaf", &boxed_exp_neg_is_leaf);
|
||||
}
|
||||
|
@ -64,3 +64,19 @@ def my_ones_like(tensor, device) -> Tensor:
|
||||
like the input tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_ones_like.default(tensor, device)
|
||||
|
||||
|
||||
def exp_neg_is_leaf(t1, t2, t3) -> tuple[Tensor, Tensor, bool]:
|
||||
"""
|
||||
Returns a Tensor, Tensor, bool tuple corresponding to the respective inputs
|
||||
t1, t2, and t3.
|
||||
|
||||
Args:
|
||||
t1: Tensor
|
||||
t2: Tensor
|
||||
t3: Tensor
|
||||
|
||||
Returns:
|
||||
(exp(t1), neg(t2), is_leaf(t3))
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.exp_neg_is_leaf.default(t1, t2, t3)
|
||||
|
@ -52,6 +52,16 @@ class TestLibtorchAgnostic(TestCase):
|
||||
curr_mem = torch.cuda.memory_allocated(device)
|
||||
self.assertEqual(curr_mem, init_mem)
|
||||
|
||||
def test_exp_neg_is_leaf(self, device):
|
||||
t1 = torch.rand(2, 3, device=device)
|
||||
t2 = torch.rand(3, 2, device=device)
|
||||
t3 = torch.rand(2, device=device)
|
||||
|
||||
exp, neg, is_leaf = libtorch_agnostic.ops.exp_neg_is_leaf(t1, t2, t3)
|
||||
self.assertEqual(exp, torch.exp(t1))
|
||||
self.assertEqual(neg, torch.neg(t2))
|
||||
self.assertEqual(is_leaf, t3.is_leaf)
|
||||
|
||||
def test_my_abs(self, device):
|
||||
t = torch.rand(32, 16, device=device) - 0.5
|
||||
cpu_t = libtorch_agnostic.ops.my_abs(t)
|
||||
|
@ -299,6 +299,16 @@ class TestCppExtensionAOT(common.TestCase):
|
||||
curr_mem = torch.cuda.memory_allocated(device)
|
||||
self.assertEqual(curr_mem, init_mem)
|
||||
|
||||
# (4) test multiple returns
|
||||
t1 = torch.rand(2, 3, device="cuda")
|
||||
t2 = torch.rand(3, 2, device="cpu")
|
||||
t3 = torch.rand(2, device="cpu")
|
||||
|
||||
exp, neg, is_leaf = libtorch_agnostic.ops.exp_neg_is_leaf(t1, t2, t3)
|
||||
self.assertEqual(exp, torch.exp(t1))
|
||||
self.assertEqual(neg, torch.neg(t2))
|
||||
self.assertEqual(is_leaf, t3.is_leaf)
|
||||
|
||||
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
class TestPybindTypeCasters(common.TestCase):
|
||||
|
Reference in New Issue
Block a user