mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add new_zeros dtype variant to the shim and as a stable op (#161597)
In case we want this before 2.9 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161597 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
committed by
PyTorch MergeBot
parent
05d0f11dbd
commit
63632fc7ee
@ -343,7 +343,7 @@ void boxed_my_narrow(
|
||||
|
||||
Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||
std::vector<int64_t> sizes = {2, 5};
|
||||
auto dtype = std::make_optional(at::ScalarType::BFloat16);
|
||||
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
|
||||
return new_empty(t, sizes, dtype);
|
||||
}
|
||||
|
||||
@ -352,6 +352,17 @@ void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, ui
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
Tensor my_new_zeros_dtype_variant(Tensor t) {
|
||||
std::vector<int64_t> sizes = {2, 5};
|
||||
auto dtype = std::make_optional(at::ScalarType::Float);
|
||||
return new_zeros(t, sizes, dtype);
|
||||
}
|
||||
|
||||
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_new_zeros_dtype_variant(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor");
|
||||
m.def("my_empty_like(Tensor t) -> Tensor");
|
||||
@ -359,6 +370,7 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my_pad(Tensor t) -> Tensor");
|
||||
m.def("my_narrow(Tensor t, int dim, int start, int length) -> Tensor");
|
||||
m.def("my_new_empty_dtype_variant(Tensor t) -> Tensor");
|
||||
m.def("my_new_zeros_dtype_variant(Tensor t) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
@ -367,6 +379,7 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("fill_infinity", &boxed_fill_infinity);
|
||||
m.impl("my_is_cpu", &boxed_my_is_cpu);
|
||||
m.impl("my_new_empty_dtype_variant", &boxed_my_new_empty_dtype_variant);
|
||||
m.impl("my_new_zeros_dtype_variant", &boxed_my_new_zeros_dtype_variant);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeImplicitAutograd, m) {
|
||||
|
@ -295,3 +295,15 @@ def my_new_empty_dtype_variant(t) -> Tensor:
|
||||
Returns: New empty tensor with shape [2, 5] and dtype bfloat16
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_new_empty_dtype_variant.default(t)
|
||||
|
||||
|
||||
def my_new_zeros_dtype_variant(t) -> Tensor:
|
||||
"""
|
||||
Returns a new tensor filled with 0s with shape [2, 5] and dtype Float
|
||||
|
||||
Args:
|
||||
t: Input tensor used as a reference for device and other properties
|
||||
|
||||
Returns: New zeros tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_new_zeros_dtype_variant.default(t)
|
||||
|
@ -337,6 +337,14 @@ if not IS_WINDOWS:
|
||||
finally:
|
||||
torch.use_deterministic_algorithms(deterministic)
|
||||
|
||||
def test_my_new_zeros_dtype_variant(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t = torch.randn(3, 4, device=device)
|
||||
out = libtorch_agnostic.ops.my_new_zeros_dtype_variant(t)
|
||||
ref_out = t.new_zeros((2, 5), dtype=torch.float)
|
||||
self.assertEqual(out, ref_out, exact_device=True)
|
||||
|
||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user