mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add new_empty (with dtype argument only) to torch::stable (#159508)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159508 Approved by: https://github.com/janeyx99 ghstack dependencies: #160557
This commit is contained in:
committed by
PyTorch MergeBot
parent
543896fcf3
commit
78a8e6a671
@ -4,6 +4,7 @@
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
#ifdef LAE_USE_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
@ -340,12 +341,24 @@ void boxed_my_narrow(
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||
std::vector<int64_t> sizes = {2, 5};
|
||||
auto dtype = std::make_optional(at::ScalarType::BFloat16);
|
||||
return new_empty(t, sizes, dtype);
|
||||
}
|
||||
|
||||
void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_new_empty_dtype_variant(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor");
|
||||
m.def("my_empty_like(Tensor t) -> Tensor");
|
||||
m.def("fill_infinity(Tensor(a!) t) -> Tensor(a!)");
|
||||
m.def("my_pad(Tensor t) -> Tensor");
|
||||
m.def("my_narrow(Tensor t, int dim, int start, int length) -> Tensor");
|
||||
m.def("my_new_empty_dtype_variant(Tensor t) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
@ -353,6 +366,7 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_empty_like", &boxed_empty_like);
|
||||
m.impl("fill_infinity", &boxed_fill_infinity);
|
||||
m.impl("my_is_cpu", &boxed_my_is_cpu);
|
||||
m.impl("my_new_empty_dtype_variant", &boxed_my_new_empty_dtype_variant);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeImplicitAutograd, m) {
|
||||
|
@ -283,3 +283,15 @@ def test_get_current_device_index() -> int:
|
||||
Returns: Current device index as an integer
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.test_get_current_device_index.default()
|
||||
|
||||
|
||||
def my_new_empty_dtype_variant(t) -> Tensor:
|
||||
"""
|
||||
Returns a new empty tensor with shape [2, 5] and dtype bfloat16
|
||||
|
||||
Args:
|
||||
t: Input tensor used as a reference for device and other properties
|
||||
|
||||
Returns: New empty tensor with shape [2, 5] and dtype bfloat16
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_new_empty_dtype_variant.default(t)
|
||||
|
@ -190,7 +190,7 @@ if not IS_WINDOWS:
|
||||
|
||||
deterministic = torch.are_deterministic_algorithms_enabled()
|
||||
try:
|
||||
# set use_deterministic_algorithms to fill unintialized memory
|
||||
# set use_deterministic_algorithms to fill uninitialized memory
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
t = torch.rand(2, 7, device=device)
|
||||
@ -322,6 +322,21 @@ if not IS_WINDOWS:
|
||||
finally:
|
||||
torch.cuda.set_device(prev_device)
|
||||
|
||||
def test_my_new_empty_dtype_variant(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
deterministic = torch.are_deterministic_algorithms_enabled()
|
||||
try:
|
||||
# set use_deterministic_algorithms to fill uninitialized memory
|
||||
torch.use_deterministic_algorithms(True)
|
||||
t = torch.randn(3, 4, device=device)
|
||||
out = libtorch_agnostic.ops.my_new_empty_dtype_variant(t)
|
||||
ref_out = t.new_empty((2, 5), dtype=torch.bfloat16)
|
||||
|
||||
self.assertEqual(out, ref_out, exact_device=True)
|
||||
finally:
|
||||
torch.use_deterministic_algorithms(deterministic)
|
||||
|
||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user