mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[MPS] Move expm1 op to Metal (#155611)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155611 Approved by: https://github.com/malfet
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							44df7cf28d
						
					
				
				
					commit
					013cf1e330
				
			| @ -29,6 +29,31 @@ struct exp_functor { | ||||
|   } | ||||
| }; | ||||
|  | ||||
| struct expm1_functor { | ||||
|   template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true> | ||||
|   inline T operator()(const T x) { | ||||
|     if (::metal::fabs(x) < 1e-5f) { | ||||
|       return static_cast<T>(c10::metal::expm1f(static_cast<float>(x))); | ||||
|     } else { | ||||
|       return static_cast<T>(exp_(static_cast<float>(x)) - 1.0f); | ||||
|     } | ||||
|   } | ||||
|   template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true> | ||||
|   inline float operator()(const T x) { | ||||
|     return exp_(static_cast<float>(x)) - 1; | ||||
|   } | ||||
|   template <typename T, enable_if_t<is_complex_v<T>, bool> = true> | ||||
|   inline T operator()(const T x) { | ||||
|     if (::precise::sqrt(dot(x, x)) < 1e-2) { | ||||
|       return T( | ||||
|           c10::metal::expm1f(x.x + ::precise::log(precise::cos(x.y))), | ||||
|           exp_(x.x) * precise::sin(x.y)); | ||||
|     } else { | ||||
|       return exp_(x) - T(1.0f, 0.0f); | ||||
|     } | ||||
|   } | ||||
| }; | ||||
|  | ||||
| struct sigmoid_functor { | ||||
|   template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true> | ||||
|   inline T operator()(const T x) { | ||||
| @ -511,6 +536,7 @@ REGISTER_UNARY_OP(abs, half, half); | ||||
|   REGISTER_UNARY_OP(erfc, DTYPE1, DTYPE0);         \ | ||||
|   REGISTER_UNARY_OP(erfinv, DTYPE1, DTYPE0);       \ | ||||
|   REGISTER_UNARY_OP(exp, DTYPE1, DTYPE0);          \ | ||||
|   REGISTER_UNARY_OP(expm1, DTYPE1, DTYPE0);        \ | ||||
|   REGISTER_UNARY_OP(sigmoid, DTYPE1, DTYPE0);      \ | ||||
|   REGISTER_UNARY_OP(exp2, DTYPE1, DTYPE0);         \ | ||||
|   REGISTER_UNARY_OP(log, DTYPE1, DTYPE0);          \ | ||||
| @ -547,6 +573,7 @@ INSTANTIATE_UNARY_KERNELS2(float, long); | ||||
| #define INSTANTIATE_UNARY_KERNELS_VEC2(DTYPE)     \ | ||||
|   REGISTER_UNARY_OP(neg, DTYPE##2, DTYPE##2);     \ | ||||
|   REGISTER_UNARY_OP(exp, DTYPE##2, DTYPE##2);     \ | ||||
|   REGISTER_UNARY_OP(expm1, DTYPE##2, DTYPE##2);   \ | ||||
|   REGISTER_UNARY_OP(sigmoid, DTYPE##2, DTYPE##2); \ | ||||
|   REGISTER_UNARY_OP(abs, DTYPE##2, DTYPE##2);     \ | ||||
|   REGISTER_UNARY_OP(exp2, DTYPE##2, DTYPE##2);    \ | ||||
|  | ||||
| @ -25,6 +25,7 @@ static void round_decimals_kernel(TensorIteratorBase& iter, int64_t decimals) { | ||||
| } | ||||
|  | ||||
| REGISTER_UNARY_TI_DISPATCH(exp); | ||||
| REGISTER_UNARY_TI_DISPATCH(expm1); | ||||
| REGISTER_UNARY_TI_DISPATCH(erf); | ||||
| REGISTER_UNARY_TI_DISPATCH(erfc); | ||||
| REGISTER_UNARY_TI_DISPATCH(erfinv); | ||||
|  | ||||
| @ -26,7 +26,6 @@ | ||||
| #include <ATen/ops/cumsum_native.h> | ||||
| #include <ATen/ops/erf_native.h> | ||||
| #include <ATen/ops/exp2_native.h> | ||||
| #include <ATen/ops/expm1_native.h> | ||||
| #include <ATen/ops/frac_native.h> | ||||
| #include <ATen/ops/imag.h> | ||||
| #include <ATen/ops/logical_not_native.h> | ||||
| @ -254,14 +253,6 @@ TORCH_IMPL_FUNC(frac_out_mps)(const Tensor& self, const Tensor& output) { | ||||
|   }); | ||||
| } | ||||
|  | ||||
| TORCH_IMPL_FUNC(expm1_out_mps)(const Tensor& self, const Tensor& output) { | ||||
|   mps::unary_op(self, output, "expm1_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { | ||||
|     MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:inputTensor.dataType]; | ||||
|     MPSGraphTensor* ePowTensor = [mpsGraph exponentWithTensor:inputTensor name:nil]; | ||||
|     return [mpsGraph subtractionWithPrimaryTensor:ePowTensor secondaryTensor:oneTensor name:nil]; | ||||
|   }); | ||||
| } | ||||
|  | ||||
| static void logit_mps_impl(const Tensor& self, std::optional<double> eps, Tensor& output, const std::string& op_name) { | ||||
|   std::string key = op_name + ":[" + (eps.has_value() ? std::to_string(eps.value()) : "NULL") + "]"; | ||||
|  | ||||
|  | ||||
| @ -2628,8 +2628,7 @@ | ||||
|   structured: True | ||||
|   structured_inherits: TensorIteratorBase | ||||
|   dispatch: | ||||
|     CPU, CUDA: expm1_out | ||||
|     MPS: expm1_out_mps | ||||
|     CPU, CUDA, MPS: expm1_out | ||||
|     SparseCPU, SparseCUDA: expm1_sparse_out | ||||
|     SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: expm1_sparse_csr_out | ||||
|   tags: pointwise | ||||
|  | ||||
							
								
								
									
										97
									
								
								c10/metal/expm1f.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								c10/metal/expm1f.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,97 @@ | ||||
| // Copy-and-pasted from: | ||||
| // https://github.com/ml-explore/mlx/blob/99c33d011d63174f50cea37c3eede002958be6d3/mlx/backend/metal/kernels/expm1f.h | ||||
|  | ||||
| #pragma once | ||||
|  | ||||
| #include <metal_math> | ||||
|  | ||||
| // Original license copied below: | ||||
| //  Copyright (c) 2015-2023 Norbert Juffa | ||||
| //  All rights reserved. | ||||
| // | ||||
| //  Redistribution and use in source and binary forms, with or without | ||||
| //  modification, are permitted provided that the following conditions | ||||
| //  are met: | ||||
| // | ||||
| //  1. Redistributions of source code must retain the above copyright | ||||
| //     notice, this list of conditions and the following disclaimer. | ||||
| // | ||||
| //  2. Redistributions in binary form must reproduce the above copyright | ||||
| //     notice, this list of conditions and the following disclaimer in the | ||||
| //     documentation and/or other materials provided with the distribution. | ||||
| // | ||||
| //  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | ||||
| //  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | ||||
| //  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR | ||||
| //  A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT | ||||
| //  HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | ||||
| //  SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT | ||||
| //  LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | ||||
| //  DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | ||||
| //  THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||||
| //  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||||
| //  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||||
|  | ||||
| namespace c10 { | ||||
| namespace metal { | ||||
|  | ||||
| /* Compute exponential base e minus 1. Maximum ulp error = 0.997458 | ||||
|  | ||||
|    i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1. | ||||
|    Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5). | ||||
|    With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy, | ||||
|    when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r. | ||||
|  | ||||
|    NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2) | ||||
| */ | ||||
| float expm1f_scaled_unchecked(float a, float b) { | ||||
|   float f, j, r, s, t, u, v, x, y; | ||||
|   int i; | ||||
|  | ||||
|   // exp(a) = 2**i * exp(f); i = rintf (a / log(2)) | ||||
|   j = ::metal::fma(1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23 | ||||
|   j = j - 12582912.0f; // 0x1.8p23 | ||||
|   i = (int)j; | ||||
|   f = ::metal::fma(j, -6.93145752e-1f, a); | ||||
|  | ||||
|   // approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2] | ||||
|   s = f * f; | ||||
|   if (a == 0.0f) | ||||
|     s = a; // ensure -0 is passed through | ||||
|   // err = 0.997458  ulp1 = 11081805 | ||||
|   r = 1.97350979e-4f; // 0x1.9de000p-13 | ||||
|   r = ::metal::fma(r, f, 1.39309070e-3f); // 0x1.6d30bcp-10 | ||||
|   r = ::metal::fma(r, f, 8.33343994e-3f); // 0x1.1111f6p-7 | ||||
|   r = ::metal::fma(r, f, 4.16668020e-2f); // 0x1.55559ep-5 | ||||
|   r = ::metal::fma(r, f, 1.66666716e-1f); // 0x1.55555cp-3 | ||||
|   r = ::metal::fma(r, f, 4.99999970e-1f); // 0x1.fffffep-2 | ||||
|   u = (j == 1) ? (f + 0.5f) : f; | ||||
|   v = ::metal::fma(r, s, u); | ||||
|   s = 0.5f * b; | ||||
|   t = ::metal::ldexp(s, i); | ||||
|   y = t - s; | ||||
|   x = (t - y) - s; // double-float canonicalization of difference | ||||
|   r = ::metal::fma(v, t, x) + y; | ||||
|   r = r + r; | ||||
|   if (j == 0) | ||||
|     r = v; | ||||
|   if (j == 1) | ||||
|     r = v + v; | ||||
|   return r; | ||||
| } | ||||
|  | ||||
| /* Compute exponential base e minus 1. max ulp err = 0.99746 */ | ||||
| float expm1f(float a) { | ||||
|   float r; | ||||
|  | ||||
|   r = expm1f_scaled_unchecked(a, 1.0f); | ||||
|   /* handle severe overflow and underflow */ | ||||
|   if (::metal::abs(a - 1.0f) > 88.0f) { | ||||
|     r = ::metal::pow(2, a); | ||||
|     r = ::metal::fma(r, r, -1.0f); | ||||
|   } | ||||
|   return r; | ||||
| } | ||||
|  | ||||
| } // namespace metal | ||||
| } // namespace c10 | ||||
| @ -1,5 +1,6 @@ | ||||
| // Implementation of specal math functions for Metal | ||||
| #pragma once | ||||
| #include <c10/metal/expm1f.h> | ||||
| #include <c10/metal/utils.h> | ||||
| #include <metal_stdlib> | ||||
|  | ||||
|  | ||||
| @ -61,6 +61,7 @@ if torch.backends.mps.is_available(): | ||||
|             "empty_permuted", | ||||
|             "empty_strided", | ||||
|             "exp", | ||||
|             "expm1", | ||||
|             "exp2", | ||||
|             "expand", | ||||
|             "expand_as", | ||||
| @ -209,7 +210,6 @@ if torch.backends.mps.is_available(): | ||||
|             "einsum", | ||||
|             "eq", | ||||
|             "equal", | ||||
|             "expm1", | ||||
|             "eye", | ||||
|             "fft.fft", | ||||
|             "fft.fft2", | ||||
|  | ||||
		Reference in New Issue
	
	Block a user