mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[STABLE ABI] Add copy_ operation. (#161895)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161895 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
eb11d172e3
commit
f9074c7332
@ -363,6 +363,15 @@ void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, ui
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
Tensor my_copy_(Tensor dst, Tensor src, bool non_blocking) {
|
||||
return copy_(dst, src, non_blocking);
|
||||
}
|
||||
|
||||
void boxed_my_copy_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor tensor_res = my_copy_(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<bool>(stack[2]));
|
||||
stack[0] = from(tensor_res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor");
|
||||
m.def("my_empty_like(Tensor t) -> Tensor");
|
||||
@ -371,6 +380,7 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my_narrow(Tensor t, int dim, int start, int length) -> Tensor");
|
||||
m.def("my_new_empty_dtype_variant(Tensor t) -> Tensor");
|
||||
m.def("my_new_zeros_dtype_variant(Tensor t) -> Tensor");
|
||||
m.def("my_copy_(Tensor dst, Tensor src, bool non_blocking) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
@ -380,6 +390,7 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_is_cpu", &boxed_my_is_cpu);
|
||||
m.impl("my_new_empty_dtype_variant", &boxed_my_new_empty_dtype_variant);
|
||||
m.impl("my_new_zeros_dtype_variant", &boxed_my_new_zeros_dtype_variant);
|
||||
m.impl("my_copy_", &boxed_my_copy_);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeImplicitAutograd, m) {
|
||||
@ -415,7 +426,6 @@ void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outp
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my_zero_(Tensor(a!) t) -> Tensor(a!)");
|
||||
m.def("my_amax(Tensor a) -> Tensor");
|
||||
|
@ -242,6 +242,20 @@ def my_narrow(t, dim, start, length) -> Tensor:
|
||||
return torch.ops.libtorch_agnostic.my_narrow.default(t, dim, start, length)
|
||||
|
||||
|
||||
def my_copy_(dst, src, non_blocking) -> Tensor:
|
||||
"""
|
||||
Returns tensor dst that is updated with src elements.
|
||||
|
||||
Args:
|
||||
dst: Destination tensor
|
||||
src: Source tensor
|
||||
non_blocking: bool
|
||||
|
||||
Returns: Updated tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_copy_.default(dst, src, non_blocking)
|
||||
|
||||
|
||||
def test_device_guard(device_index) -> int:
|
||||
"""
|
||||
Tests the DeviceGuard functionality by creating a device guard and returning an empty tensor.
|
||||
|
@ -345,6 +345,17 @@ if not IS_WINDOWS:
|
||||
ref_out = t.new_zeros((2, 5), dtype=torch.float)
|
||||
self.assertEqual(out, ref_out, exact_device=True)
|
||||
|
||||
def test_my_copy_(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
dst = torch.empty(2, 5, device=device)
|
||||
src = torch.randn(2, 5, device=device)
|
||||
|
||||
result = libtorch_agnostic.ops.my_copy_(dst, src, False)
|
||||
expected = src
|
||||
self.assertEqual(result, expected)
|
||||
self.assertEqual(result.data_ptr(), dst.data_ptr())
|
||||
|
||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -211,4 +211,18 @@ inline torch::stable::Tensor zero_(torch::stable::Tensor& self) {
|
||||
return to<torch::stable::Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
// We expect this to be the stable version of the copy_ op with
|
||||
// identical semantics to the existing copy_ op.
|
||||
inline Tensor copy_(
|
||||
Tensor& self,
|
||||
const Tensor& src,
|
||||
std::optional<bool> non_blocking = std::nullopt) {
|
||||
const auto num_args = 3;
|
||||
std::array<StableIValue, num_args> stack{
|
||||
from(self), from(src), from(non_blocking.value_or(false))};
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_call_dispatcher("aten::copy_", "", stack.data()));
|
||||
return to<Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
} // namespace torch::stable
|
||||
|
Reference in New Issue
Block a user