[MPS] Fix c0:🤘:log_gamma correctness on M4 (#145740)

To workaround a bug where `abs` method call seems to be ignored before calling log, which could be reproduced by running the following code (submitted as FB16415011 )
```swift
import Metal

func run_shader<T: BinaryFloatingPoint> (library: MTLLibrary, kernel_name: String, type: T.Type, nelem: Int = 16) {
  guard let mfunc = library.makeFunction(name: kernel_name) else { fatalError("Can't find function") }
  let device = library.device
  guard let queue = device.makeCommandQueue() else { fatalError("Can't make queue") }
  guard let cmdBuffer = queue.makeCommandBuffer() else { fatalError("Can't make command buffer") }
  guard let computeEncoder = cmdBuffer.makeComputeCommandEncoder() else { fatalError("Can't make compute encoder") }
  guard let ibuf = device.makeBuffer(length:nelem * MemoryLayout<T>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }
  let ibuf_data = ibuf.contents().assumingMemoryBound(to: T.self)
  for i in 0..<nelem {
    ibuf_data[i] = T(sin(Float(2 + i)))
  }
  guard let obuf = device.makeBuffer(length:nelem * MemoryLayout<T>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }
  let obuf_data = obuf.contents().assumingMemoryBound(to: T.self)

  computeEncoder.setComputePipelineState(try! device.makeComputePipelineState(function: mfunc))
  computeEncoder.setBuffer(obuf, offset:0, index: 0)
  computeEncoder.setBuffer(ibuf, offset:0, index: 1)
  computeEncoder.dispatchThreads(MTLSizeMake(nelem, 1, 1), threadsPerThreadgroup:MTLSizeMake(nelem, 1, 1))
  computeEncoder.endEncoding()
  cmdBuffer.commit()
  cmdBuffer.waitUntilCompleted()

  print("Results for \(String(describing: T.self)):", terminator: " ")
  for i in 0..<nelem {
    print(obuf_data[i], terminator: " ")
  }
  print()
}

let shader_source = """
#include <metal_stdlib>

template<typename T>
float foo(T x) {
  const auto abs_x = :🤘:abs(static_cast<float>(x));
  auto rc = :🤘:log(abs_x);

  return rc - :🤘:log(:🤘:abs(abs_x * :🤘:sinpi(abs_x)));
}

kernel void half_kernel(
    device half* out_ptr0,
    constant half* in_ptr0,
    uint xindex [[thread_position_in_grid]]
) {
  auto inp = in_ptr0[xindex];
  auto out = foo(inp);
  out_ptr0[xindex] = static_cast<half>(out);
}

kernel void float_kernel(
    device float* out_ptr0,
    constant float* in_ptr0,
    uint xindex [[thread_position_in_grid]]
) {
  auto inp = in_ptr0[xindex];
  auto out = foo(inp);
  out_ptr0[xindex] = static_cast<float>(out);
}
"""
let options = MTLCompileOptions()
options.mathMode = .safe
options.mathFloatingPointFunctions = .precise

guard let device = MTLCopyAllDevices().first else { fatalError("Not Metal device found") }
let library = try! device.makeLibrary(source:shader_source, options:options)
run_shader(library:library, kernel_name:"half_kernel", type: Float16.self)
run_shader(library:library, kernel_name:"float_kernel", type: Float.self)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145740
Approved by: https://github.com/dcci
This commit is contained in:
Nikita Shulga
2025-01-27 13:21:40 -08:00
committed by PyTorch MergeBot
parent 60f98262f1
commit 3a23d75b37

View File

@ -305,8 +305,10 @@ float log_gamma(const T x) {
}
// Reflection formula
return LOG_PI - rc -
::metal::log(::metal::abs(abs_x * ::metal::sinpi(abs_x)));
// Compute arg first to workaround Metal compiler bgg of sorts on M4
// See https://github.com/pytorch/pytorch/pull/145740 for more details
auto log_arg = abs_x * ::metal::abs(::metal::sinpi(abs_x));
return LOG_PI - rc - ::metal::log(log_arg);
}
} // namespace metal