mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[MPS] Fix c0:🤘:log_gamma
correctness on M4 (#145740)
To workaround a bug where `abs` method call seems to be ignored before calling log, which could be reproduced by running the following code (submitted as FB16415011 ) ```swift import Metal func run_shader<T: BinaryFloatingPoint> (library: MTLLibrary, kernel_name: String, type: T.Type, nelem: Int = 16) { guard let mfunc = library.makeFunction(name: kernel_name) else { fatalError("Can't find function") } let device = library.device guard let queue = device.makeCommandQueue() else { fatalError("Can't make queue") } guard let cmdBuffer = queue.makeCommandBuffer() else { fatalError("Can't make command buffer") } guard let computeEncoder = cmdBuffer.makeComputeCommandEncoder() else { fatalError("Can't make compute encoder") } guard let ibuf = device.makeBuffer(length:nelem * MemoryLayout<T>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") } let ibuf_data = ibuf.contents().assumingMemoryBound(to: T.self) for i in 0..<nelem { ibuf_data[i] = T(sin(Float(2 + i))) } guard let obuf = device.makeBuffer(length:nelem * MemoryLayout<T>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") } let obuf_data = obuf.contents().assumingMemoryBound(to: T.self) computeEncoder.setComputePipelineState(try! device.makeComputePipelineState(function: mfunc)) computeEncoder.setBuffer(obuf, offset:0, index: 0) computeEncoder.setBuffer(ibuf, offset:0, index: 1) computeEncoder.dispatchThreads(MTLSizeMake(nelem, 1, 1), threadsPerThreadgroup:MTLSizeMake(nelem, 1, 1)) computeEncoder.endEncoding() cmdBuffer.commit() cmdBuffer.waitUntilCompleted() print("Results for \(String(describing: T.self)):", terminator: " ") for i in 0..<nelem { print(obuf_data[i], terminator: " ") } print() } let shader_source = """ #include <metal_stdlib> template<typename T> float foo(T x) { const auto abs_x = :🤘:abs(static_cast<float>(x)); auto rc = :🤘:log(abs_x); return rc - :🤘:log(:🤘:abs(abs_x * :🤘:sinpi(abs_x))); } kernel void half_kernel( device half* out_ptr0, constant half* in_ptr0, uint xindex [[thread_position_in_grid]] ) { auto inp = in_ptr0[xindex]; auto out = foo(inp); out_ptr0[xindex] = static_cast<half>(out); } kernel void float_kernel( device float* out_ptr0, constant float* in_ptr0, uint xindex [[thread_position_in_grid]] ) { auto inp = in_ptr0[xindex]; auto out = foo(inp); out_ptr0[xindex] = static_cast<float>(out); } """ let options = MTLCompileOptions() options.mathMode = .safe options.mathFloatingPointFunctions = .precise guard let device = MTLCopyAllDevices().first else { fatalError("Not Metal device found") } let library = try! device.makeLibrary(source:shader_source, options:options) run_shader(library:library, kernel_name:"half_kernel", type: Float16.self) run_shader(library:library, kernel_name:"float_kernel", type: Float.self) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/145740 Approved by: https://github.com/dcci
This commit is contained in:
committed by
PyTorch MergeBot
parent
60f98262f1
commit
3a23d75b37
@ -305,8 +305,10 @@ float log_gamma(const T x) {
|
||||
}
|
||||
|
||||
// Reflection formula
|
||||
return LOG_PI - rc -
|
||||
::metal::log(::metal::abs(abs_x * ::metal::sinpi(abs_x)));
|
||||
// Compute arg first to workaround Metal compiler bgg of sorts on M4
|
||||
// See https://github.com/pytorch/pytorch/pull/145740 for more details
|
||||
auto log_arg = abs_x * ::metal::abs(::metal::sinpi(abs_x));
|
||||
return LOG_PI - rc - ::metal::log(log_arg);
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
Reference in New Issue
Block a user