mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][Ez]: Add NT unary op macro (#140213)
* Adds a macro to simplify adding more unary ops to NT. * Adds sqrt support to NT Pull Request resolved: https://github.com/pytorch/pytorch/pull/140213 Approved by: https://github.com/jbschlosser
This commit is contained in:
committed by
PyTorch MergeBot
parent
069a71023b
commit
034b105d53
@ -1791,7 +1791,7 @@
|
||||
variants: function, method
|
||||
structured_delegate: cos.out
|
||||
dispatch:
|
||||
NestedTensorCPU, NestedTensorCUDA: cos_nested
|
||||
NestedTensorCPU, NestedTensorCUDA: NestedTensor_cos
|
||||
tags: [core, pointwise]
|
||||
|
||||
- func: cos_(Tensor(a!) self) -> Tensor(a!)
|
||||
@ -5321,7 +5321,7 @@
|
||||
dispatch:
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sin_sparse_csr
|
||||
SparseCPU, SparseCUDA: sin_sparse
|
||||
NestedTensorCPU, NestedTensorCUDA: sin_nested
|
||||
NestedTensorCPU, NestedTensorCUDA: NestedTensor_sin
|
||||
tags: [core, pointwise]
|
||||
|
||||
- func: sin_(Tensor(a!) self) -> Tensor(a!)
|
||||
@ -5819,6 +5819,7 @@
|
||||
structured_delegate: sqrt.out
|
||||
variants: function, method
|
||||
dispatch:
|
||||
NestedTensorCPU, NestedTensorCUDA: NestedTensor_sqrt
|
||||
SparseCPU, SparseCUDA: sqrt_sparse
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sqrt_sparse_csr
|
||||
tags: [core, pointwise]
|
||||
|
@ -15,10 +15,29 @@
|
||||
|
||||
namespace at::native {
|
||||
|
||||
Tensor NestedTensor_abs(const Tensor& self) {
|
||||
return map_nt(self, at::abs);
|
||||
#define DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(op_name) \
|
||||
Tensor NestedTensor_##op_name(const Tensor& self) { \
|
||||
return map_nt(self, at::op_name); \
|
||||
}
|
||||
|
||||
// Use the macro to define operations concisely
|
||||
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(abs)
|
||||
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(sgn)
|
||||
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(logical_not)
|
||||
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(isinf)
|
||||
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(isposinf)
|
||||
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(isneginf)
|
||||
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(isnan)
|
||||
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(relu)
|
||||
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(silu)
|
||||
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(sin)
|
||||
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(sqrt)
|
||||
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(cos)
|
||||
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(neg)
|
||||
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(tanh)
|
||||
|
||||
#undef DEFINE_TORCH_NESTED_TENSOR_UNARY_OP
|
||||
|
||||
Tensor& NestedTensor_abs_(Tensor& self) {
|
||||
auto self_ptr = get_nested_tensor_impl(self);
|
||||
check_numel_equals_buffer_size(self_ptr);
|
||||
@ -96,10 +115,6 @@ Tensor& NestedTensor_where_out(const Tensor& condition, const Tensor& self, cons
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor NestedTensor_sgn(const Tensor& self) {
|
||||
return map_nt(self, at::sgn);
|
||||
}
|
||||
|
||||
Tensor& NestedTensor_sgn_(Tensor& self) {
|
||||
auto self_ptr = get_nested_tensor_impl(self);
|
||||
check_numel_equals_buffer_size(self_ptr);
|
||||
@ -116,25 +131,6 @@ Tensor& NestedTensor_logical_not_(Tensor& self){
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor NestedTensor_logical_not(const Tensor& self) {
|
||||
return map_nt(self, at::logical_not);
|
||||
}
|
||||
|
||||
Tensor NestedTensor_isinf(const Tensor& self) {
|
||||
return map_nt(self, at::isinf);
|
||||
}
|
||||
|
||||
Tensor NestedTensor_isposinf(const Tensor& self) {
|
||||
return map_nt(self, at::isposinf);
|
||||
}
|
||||
|
||||
Tensor NestedTensor_isneginf(const Tensor& self) {
|
||||
return map_nt(self, at::isneginf);
|
||||
}
|
||||
|
||||
Tensor NestedTensor_isnan(const Tensor& self) {
|
||||
return map_nt(self, at::isnan);
|
||||
}
|
||||
|
||||
Tensor& NestedTensor_relu_(Tensor& self) {
|
||||
auto self_ptr = get_nested_tensor_impl(self);
|
||||
@ -144,10 +140,6 @@ Tensor& NestedTensor_relu_(Tensor& self) {
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor NestedTensor_relu(const Tensor& self) {
|
||||
return map_nt(self, at::relu);
|
||||
}
|
||||
|
||||
Tensor& NestedTensor_gelu_(Tensor& self, c10::string_view approximate) {
|
||||
auto self_ptr = get_nested_tensor_impl(self);
|
||||
check_numel_equals_buffer_size(self_ptr);
|
||||
@ -172,10 +164,6 @@ Tensor& NestedTensor_tanh_(Tensor& self) {
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor NestedTensor_tanh(const Tensor& self) {
|
||||
return map_nt(self, at::tanh);
|
||||
}
|
||||
|
||||
Tensor& NestedTensor_neg_(Tensor& self) {
|
||||
auto self_ptr = get_nested_tensor_impl(self);
|
||||
check_numel_equals_buffer_size(self_ptr);
|
||||
@ -184,20 +172,12 @@ Tensor& NestedTensor_neg_(Tensor& self) {
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor NestedTensor_neg(const Tensor& self) {
|
||||
return map_nt(self, at::neg);
|
||||
}
|
||||
|
||||
Tensor& zero_nested_(Tensor& self) {
|
||||
const auto& self_buf = get_nested_tensor_impl(self)->get_buffer();
|
||||
self_buf.fill_(0);
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor NestedTensor_silu(const Tensor& self){
|
||||
return map_nt(self, at::silu);
|
||||
}
|
||||
|
||||
Tensor& NestedTensor_silu_(Tensor& self){
|
||||
auto self_ptr = get_nested_tensor_impl(self);
|
||||
check_numel_equals_buffer_size(self_ptr);
|
||||
@ -206,14 +186,6 @@ Tensor& NestedTensor_silu_(Tensor& self){
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor sin_nested(const Tensor& self) {
|
||||
return map_nt(self, at::sin);
|
||||
}
|
||||
|
||||
Tensor cos_nested(const Tensor& self) {
|
||||
return map_nt(self, at::cos);
|
||||
}
|
||||
|
||||
Tensor _pin_memory_nested(const Tensor& self, std::optional<Device> device) {
|
||||
auto* nt_input = get_nested_tensor_impl(self);
|
||||
const auto& input_buffer = nt_input->get_unsafe_storage_as_tensor();
|
||||
|
@ -1289,6 +1289,7 @@ class TestNestedTensorDeviceType(NestedTensorTestCase):
|
||||
subtest(torch.isposinf, name="isposinf"),
|
||||
subtest(torch.isneginf, name="isneginf"),
|
||||
subtest(torch.isnan, name="isnan"),
|
||||
subtest(torch.sqrt, name="sqrt"),
|
||||
],
|
||||
)
|
||||
def test_unary_funcs(self, device, func):
|
||||
|
Reference in New Issue
Block a user