[METAL] inline bfloat min/max (#146588)

After a recent commit 36c6e09528a7e071edecde083254da70cba26c95 , building from source with `python setup.py develop` leads to an error due to multiple symbols for min/max:
```
FAILED: caffe2/aten/src/ATen/kernels_bfloat.metallib /Users/Irakli_Salia/Desktop/pytorch/build/caffe2/aten/src/ATen/kernels_bfloat.metallib
cd /Users/Irakli_Salia/Desktop/pytorch/build/caffe2/aten/src/ATen && xcrun metallib -o kernels_bfloat.metallib BinaryKernel_31.air Bucketization_31.air CrossKernel_31.air FusedOptimizerOps_31.air Gamma_31.air HistogramKernel_31.air Im2Col_31.air Indexing_31.air LinearAlgebra_31.air Quantized_31.air RMSNorm_31.air RenormKernel_31.air Repeat_31.air SpecialOps_31.air TriangularOps_31.air UnaryKernel_31.air UnfoldBackward_31.air UpSample_31.air
LLVM ERROR: multiple symbols ('_ZN3c105metal3minIDF16bEEN5metal9enable_ifIXgssr5metalE19is_floating_point_vIT_EES4_E4typeES4_S4_')!
```

This PR fixes that.

@malfet
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146588
Approved by: https://github.com/FFFrog, https://github.com/Skylion007, https://github.com/malfet
This commit is contained in:
Isalia20
2025-02-06 17:57:31 +00:00
committed by PyTorch MergeBot
parent e2e265e27b
commit 7725d0ba12

View File

@ -108,13 +108,13 @@ template <typename T>
#if __METAL_VERSION__ >= 310
template <>
bfloat min(bfloat a, bfloat b) {
inline bfloat min(bfloat a, bfloat b) {
return bfloat(
::metal::isunordered(a, b) ? NAN : ::metal::min(float(a), float(b)));
}
template <>
bfloat max(bfloat a, bfloat b) {
inline bfloat max(bfloat a, bfloat b) {
return bfloat(
::metal::isunordered(a, b) ? NAN : ::metal::max(float(a), float(b)));
}