736 Commits

Author SHA1 Message Date
46eeef9130 [MPS][BE] Surface syntax errors shader compilation (#144648)
Before this change
```python
>>> import torch
>>> torch.mps._compile_shader('What')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/malfet/miniconda3/envs/py311/lib/python3.11/site-packages/torch/mps/__init__.py", line 157, in _compile_shader
    return torch._C._mps_compileShader(source)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Failed to create metal library, error: Error Domain=MTLLibraryErrorDomain Code=3 "program_source:1:1: error: unknown type name 'What'
What
^
program_source:1:5: error: expected unqualified-id
What
    ^
" UserInfo={NSLocalizedDescription=program_source:1:1: error: unknown type name 'What'
What
^
program_source:1:5: error: expected unqualified-id
What
    ^
}
```
After this change
```python
>>> import torch
>>> torch.mps._compile_shader('What')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/malfet/git/pytorch/pytorch/torch/mps/__init__.py", line 157, in _compile_shader
    return torch._C._mps_compileShader(source)
SyntaxError: program_source:1:1: error: unknown type name 'What'
What
^
program_source:1:5: error: expected unqualified-id
What
    ^
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144648
Approved by: https://github.com/Skylion007
ghstack dependencies: #144647
2025-01-13 02:03:19 +00:00
92ddb3d3d3 [MPS] Expose MPSProfiler::start/stopCapture to Python (#144561)
I.e. when `MTL_CAPTURE_ENABLED` environment variable is set to 1, one should be able to invoke wrap the code with `torch.mps.profiler.capture_metal` to generate gputrace for shaders invoked inside the context manager.

For example, code below:
```python
import torch
import os

def foo(x):
   return x[:,::2].sin() + x[:, 1::2].cos()

if __name__ == "__main__":
    os.environ["MTL_CAPTURE_ENABLED"] = "1"
    x = torch.rand(32, 1024, device="mps")

    with torch.mps.profiler.metal_capture("compiled_shader"):
        torch.compile(foo)(x)
```
should capture the execution of a `torch.compile` generated shader
<img width="734" alt="image" src="https://github.com/user-attachments/assets/718ff64e-103b-4b11-b66c-c89cfc770b5d" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144561
Approved by: https://github.com/manuelcandales
ghstack dependencies: #144559, #144560
2025-01-11 02:05:36 +00:00
e56768f030 [MPS] Fix bitwise shifts for uint8 (#144251)
Previosly all bitwise operations were aliased to the same type, but this is wrong for shift ops

Rather than building an overly complex logic, let's just instantiate using shared `scalarToMetalTypeString` helper function

Fixes https://github.com/pytorch/pytorch/issues/144190
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144251
Approved by: https://github.com/Skylion007
ghstack dependencies: #144249, #144250
2025-01-06 18:27:16 +00:00
ebeb433e73 [BE] Fix + parametrize test_min_max_nan_propagation (#144250)
- `dtype` was not passed as argument to `torch.rand` before
- Condition bfloat16 testing on MacOS14+
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144250
Approved by: https://github.com/Skylion007
ghstack dependencies: #144249
2025-01-06 17:49:41 +00:00
11a0663eeb [BE] Parametrize test_min_max (#144249)
It's better to have one unit test per dtype rather a combined one
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144249
Approved by: https://github.com/Skylion007
2025-01-06 17:49:41 +00:00
0dc1e6be19 [mps/BE] Fix linter warning/advice. (#144199)
Two spaces before an inline comment according to E261.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144199
Approved by: https://github.com/Skylion007, https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-01-04 20:15:41 +00:00
811c714911 Fix nan propagation for minimum() and maximum() in MPS (#144086)
Fixes #143976

- Moves minimum and maximum operations to use the NaN propagating call into MPSGraph instead of the default one.
 - Adds test for the NaN propagating case to `test_mps.py`.
- Adjusts the inductor metal backend implementation for minimum and maximum to also respect the nan propagation.

Additions by @malfet:
 - Introduce MPSGraph+PyTorchFixups interface following [Customizing existing classes](https://developer.apple.com/library/archive/documentation/Cocoa/Conceptual/ProgrammingWithObjectiveC/CustomizingExistingClasses/CustomizingExistingClasses.html) tutorial and implement `minimumWithNaNPropagationAndIntFallbackWithPrimaryTensor:` as `minimumWithNaNPropagationWithPrimaryTensor:` segfaults when called for integral types

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144086
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <nshulga@meta.com>
2025-01-04 18:48:24 +00:00
6f2451c2e9 [MPS] Add aten::angle (#143449)
This adds an MPS backend implementation for `aten::angle` and `aten::angle_out` (mentioned in issue #77764), following the example #78408.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143449
Approved by: https://github.com/malfet
2025-01-04 15:38:40 +00:00
301c457032 [MPS] Fix nllnd_loss_backward crash with different dtypes (#144170)
Otherwise, invoking with torch.half inputs, but float weights will result in
```
(mpsFileLoc): /AppleInternal/Library/BuildRoots/b11baf73-9ee0-11ef-b7b4-7aebe1f78c73/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.divide' op requires the same element type for all operands and results
(mpsFileLoc): /AppleInternal/Library/BuildRoots/b11baf73-9ee0-11ef-b7b4-7aebe1f78c73/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %16 = "mps.divide"(%15, %arg2) : (tensor<5x5xf16>, tensor<1xf32>) -> tensor<*xf32>
(mpsFileLoc): /AppleInternal/Library/BuildRoots/b11baf73-9ee0-11ef-b7b4-7aebe1f78c73/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.divide' op requires the same element type for all operands and results
(mpsFileLoc): /AppleInternal/Library/BuildRoots/b11baf73-9ee0-11ef-b7b4-7aebe1f78c73/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %16 = "mps.divide"(%15, %arg2) : (tensor<5x5xf16>, tensor<1xf32>) -> tensor<*xf32>
2025-01-03 14:13:18.747151-0800 python[87772:4027380] /AppleInternal/Library/BuildRoots/b11baf73-9ee0-11ef-b7b4-7aebe1f78c73/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm, line 975: error 'original module failed verification'
/AppleInternal/Library/BuildRoots/b11baf73-9ee0-11ef-b7b4-7aebe1f78c73/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:975: failed assertion `original module failed verification'
```

Test plan: `python -mpytest test/inductor/test_torchinductor.py -k test_nll_loss_backward_mps` should not crash
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144170
Approved by: https://github.com/kit1980, https://github.com/Skylion007
ghstack dependencies: #144167, #144162, #144083, #144084
2025-01-04 15:24:55 +00:00
22580f160e Multinomial sampling fix on mps for non contiguous tensors (#141515)
Fixes #141457

As for the tests. I looked in `test/test_mps.py` but I saw that `test_multinomial` function is disabled. Glad to add test where needed if there is some place where multinomial function is tested on metal.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141515
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-01-04 01:21:37 +00:00
a93e75d1e2 [MPS] Handle implicit cpu-scalar-to-gpu transfer (#144055)
Followup after https://github.com/pytorch/pytorch/pull/143934, this check is no longer necessary and fixes a subset of inductor tests

Before `pytest test/inductor/test_torchinductor.py -k _mps` reports 463
failed, 291 passed, 32 skipped after 456 failed, 298 passed, 32 skipped
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144055
Approved by: https://github.com/Skylion007
2025-01-02 17:12:39 +00:00
c27c788e35 [MPS] Fix torch.add(x,y, alpha=2) crash (#143949)
TODO: as followup PR replace this weird logic with shaders

Fixes https://github.com/pytorch/pytorch/issues/143932

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143949
Approved by: https://github.com/Skylion007
ghstack dependencies: #143948
2024-12-30 17:16:29 +00:00
3054aae493 [MPS] Fix fmin/fmax for scalar argument (#143934)
CPU scalar promotion to GPU is allowed for CUDA and shoudl be allowed for MPS as well (at the very least it should not crash)

Fixes https://github.com/pytorch/pytorch/issues/143933 https://github.com/pytorch/pytorch/issues/142203
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143934
Approved by: https://github.com/Skylion007
2024-12-28 17:07:19 +00:00
33c27be017 Workaround for gather_out in MPS backend (#135543)
Avoids an underlying issue in reshape op in MPS that gets triggered when the input has multiple dimensions but the shape can be squeezed into 1D. The underlying issue is going to get fixed eventually.

Fixes https://github.com/pytorch/pytorch/issues/135240

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135543
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2024-12-19 18:01:01 +00:00
d8c8ba2440 Fix unused Python variables in test/[e-z]* (#136964)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136964
Approved by: https://github.com/justinchuby, https://github.com/albanD
2024-12-18 23:02:30 +00:00
24a18d76c8 [MPS] Use metal shaders for all view ops (#143375)
Before this PR Metal  shaders were used to scatter/gather 1-5 dimensional tensors.
This PR introduces generalized ones that could be used for any dimensionality and as results  gets rid of 700+ lines complex and untested code that might not even work as expected.
Generalized gather shader looks as follows
```metal
kernel void gather_kernel_n(uint linear_index           [[thread_position_in_grid]],
                            constant void * src_        [[buffer(0)]],
                            device void * dst_          [[buffer(1)]],
                            constant uint32_t * size    [[buffer(2)]],
                            constant uint32_t * stride  [[buffer(3)]],
                            constant uint32_t & numel   [[buffer(4)]],
                            constant int32_t & ndim     [[buffer(5)]]) {{
    if (linear_index >= numel) return;

    constant {0} * src = (constant {0} *)src_;
    device {1} * dst = (device {1} *)dst_;

    uint64_t src_offs = 0;
    auto src_idx = linear_index;
    for(int dim = ndim - 1; dim >= 0; --dim) {{
      src_offs += stride[dim] * (src_idx % size[dim]);
      src_idx /= size[dim];
    }}

    dst[linear_index] = cast<{1}>(src[src_offs]);
}}
```

Which, according to the following benchmark
```python
from timeit import default_timer

import torch
import torch.utils.cpp_extension
from torch.utils.benchmark import Measurement, Timer

t = Timer(
    stmt=f"y.copy_(x);torch.mps.synchronize()",
    setup=f"x=torch.rand(4, 5, 16, 64, 33, 24, dtype=torch.float32, device='mps')[:,:,:,:24,:24,];y=torch.empty(x.shape, device=x.device, dtype=x.dtype)",
    language="python", timer=default_timer
)
print(t.blocked_autorange())
```
Is almost twice as fast as previous implementation (i.e. on Mac Book M2 Pro it returns 2.9ms for MPS version vs 1.5ms for shader one

On MacOS Sequoia [`gatherWithUpdatesTensor: indicesTensor:...`](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph/gather(withupdatestensor:indicestensor:axis:batchdimensions:name:)?language=objc) crashes if invoked with complex data type, as one can see by running the code below
```swift
import Metal
import MetalPerformanceShadersGraph

func gatherComplexMPS(device: MTLDevice,
                inp_buf: MTLBuffer, idx_buf: MTLBuffer,
                out_buf: MTLBuffer,
                inp_elem: Int, upd_elem: Int) {
  let graph = MPSGraph()
  let inputPlaceholder = graph.placeholder(shape: [inp_elem as NSNumber], dataType: .complexFloat32, name: nil)
  let indicesPlaceholder = graph.placeholder(shape: [upd_elem as NSNumber], dataType: .int64, name: nil)
  let outNode = graph.gather(withUpdatesTensor: inputPlaceholder, indicesTensor: indicesPlaceholder, axis: 0, batchDimensions: 0, name: nil)
  let mpsInputBuffer = MPSGraphTensorData(inp_buf, shape: [inp_elem as NSNumber], dataType: .complexFloat32)
  let mpsIndicesBuffer = MPSGraphTensorData(idx_buf, shape: [upd_elem as NSNumber], dataType: .int64)
  let mpsOutputBuffer = MPSGraphTensorData(out_buf, shape: [inp_elem as NSNumber], dataType: .complexFloat32)
  guard let queue = device.makeCommandQueue() else { fatalError("Can't make queue") }
  graph.run(with: queue, feeds: [inputPlaceholder: mpsInputBuffer,
                               indicesPlaceholder: mpsIndicesBuffer ],
            targetOperations: nil, resultsDictionary: [outNode: mpsOutputBuffer])
}

func makeBufferWithValues<T>(device: MTLDevice, values: [T]) -> MTLBuffer {
  guard let buf = device.makeBuffer(length: values.count * MemoryLayout<T>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }
  let buf_data = buf.contents().assumingMemoryBound(to: T.self)
  for i in 0..<values.count {
    buf_data[i] = values[i]
  }
  return buf
}

guard let device = MTLCopyAllDevices().first else { fatalError("Not Metal device found") }
print("Using device \(device.name)")

let inp_buf = makeBufferWithValues(device: device, values: [1.0, 2.0 , 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
let idx_buf = makeBufferWithValues(device: device, values: [0, 1, 2, 3])
guard let out_buf = device.makeBuffer(length:8 * MemoryLayout<Float>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }

gatherComplexMPS(device: device, inp_buf: inp_buf, idx_buf: idx_buf, out_buf: out_buf, inp_elem: 4, upd_elem: 4)
```

Fixes https://github.com/pytorch/pytorch/issues/143140
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143375
Approved by: https://github.com/albanD
2024-12-18 16:15:46 +00:00
afa313e669 Extend bmm tiling to work up to 2^32 elem in any single output dim (#143095)
The previous tiling implementation worked for up to 2^32 total elements per single batch entry. This extends the functionality to support the dimensions encountered in ComfyUI (output shape: 1,72250,72250).

Fixes #141909
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143095
Approved by: https://github.com/kulinseth
2024-12-17 16:03:46 +00:00
c1d4d9d3cf [MPS] Support torch.accelerator.synchronize() on mps (#143171)
# Motivation
Support `torch.accelerator.synchronize()` on mps. The root cause is that MPS doesn't support lazy initialization. So we must check if the current accelerator supports device lazy initialization rather than early return.

# Additional Context
Add a mps UT to test code change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143171
Approved by: https://github.com/albanD
2024-12-16 02:18:32 +00:00
8a04018329 [MPS] Fix conv backward for channels last (cont) (#143196)
This is a continuation of https://github.com/pytorch/pytorch/issues/140902 but extends the same logic to input.

Looks like existing channels-last logic just produced incorrect results on pre MacOS-15 versions and fails on MacOS-15, so removing it feels like a right idea

Fixes https://github.com/pytorch/pytorch/issues/142344
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143196
Approved by: https://github.com/manuelcandales
2024-12-13 21:32:42 +00:00
e0c8abda76 Fix potentially undefined behaviour in index_put sample input (#143116)
From the [docs](https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html) for index_put_:

> If accumulate is True, the elements in values are added to self. If accumulate is False, the behavior is undefined if indices contain duplicate elements.

Currently the sample inputs for `index_put` generates 2 indices. Because they are generated randomly, they could be the same leading to undefined behaviour if `accumulate=False`.

This PR changes the input generation to only generate a single index if `accumulate=False` preventing duplicate indices and undefined behaviour.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143116
Approved by: https://github.com/albanD
2024-12-13 17:59:01 +00:00
95b17f6346 [MPS] Add CompileShader method (#141478)
This allows one to do something like that
```python
import torch
x = torch.ones(10, device="mps")
m = torch.mps._compile_shader("""
   kernel void foo(device float* x, uint idx [[thread_position_in_grid]]) {
     x[idx] += idx;
   }
")
m.foo(x)
```

And in general enables writing custom operators using Metal shaders purely in Python
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141478
Approved by: https://github.com/manuelcandales
2024-12-11 02:00:51 +00:00
393cf46f42 Revert "[MPS] Add CompileShader method (#141478)"
This reverts commit 0478fee42db16a0477add1d0a644ce713f31a875.

Reverted https://github.com/pytorch/pytorch/pull/141478 on behalf of https://github.com/malfet due to Broke doctests, by trying to run MPS example on Linux ([comment](https://github.com/pytorch/pytorch/pull/141478#issuecomment-2533351909))
2024-12-11 00:37:10 +00:00
0478fee42d [MPS] Add CompileShader method (#141478)
This allows one to do something like that
```python
import torch
x = torch.ones(10, device="mps")
m = torch.mps._compile_shader("""
   kernel void foo(device float* x, uint idx [[thread_position_in_grid]]) {
     x[idx] += idx;
   }
")
m.foo(x)
```

And in general enables writing custom operators using Metal shaders purely in Python
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141478
Approved by: https://github.com/manuelcandales
2024-12-10 22:43:17 +00:00
bee445c3a3 [MPS] Support torch.Event for MPS (#142468)
# Motivation
Support `torch.Event` on mps backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142468
Approved by: https://github.com/malfet
2024-12-10 21:17:25 +00:00
d6481333ad [MPS] Add scatter_reduce.two (#141948)
Which has been request 20+ times on https://github.com/pytorch/pytorch/issues/77764 is just a flavor of out-of-box scatter-reduce, so all this op does is redispatches existing implementation.
Unsupported dtype/reduction type combinations:
 - min/max for int64
 - min/max for int32 on MacOS-14 or older

Following swift code demonstrates problem with scatterAlongAxis MPS call
```swift
import Metal
import MetalPerformanceShadersGraph

func scatterMPS(device: MTLDevice,
                inp_buf: MTLBuffer, upd_buf: MTLBuffer,
                idx_buf: MTLBuffer, out_buf: MTLBuffer,
                inp_elem: Int, upd_elem: Int) {
  let graph = MPSGraph()
  let inputPlaceholder = graph.placeholder(shape: [inp_elem as NSNumber], dataType: .int64, name: nil)
  let updatesPlaceholder = graph.placeholder(shape: [upd_elem as NSNumber], dataType: .int64, name: nil)
  let indicesPlaceholder = graph.placeholder(shape: [upd_elem as NSNumber], dataType: .int64, name: nil)
  let outNode = graph.scatterAlongAxis(0, data: inputPlaceholder, updates: updatesPlaceholder, indices: indicesPlaceholder, mode: .min, name: nil)
  let mpsInputBuffer = MPSGraphTensorData(inp_buf, shape: [inp_elem as NSNumber], dataType: .int64)
  let mpsUpdatesBuffer = MPSGraphTensorData(upd_buf, shape: [upd_elem as NSNumber], dataType: .int64)
  let mpsIndicesBuffer = MPSGraphTensorData(idx_buf, shape: [upd_elem as NSNumber], dataType: .int64)
  let mpsOutputBuffer = MPSGraphTensorData(out_buf, shape: [inp_elem as NSNumber], dataType: .int64)
  guard let queue = device.makeCommandQueue() else { fatalError("Can't make queue") }
  graph.run(with: queue, feeds: [inputPlaceholder: mpsInputBuffer,
                               updatesPlaceholder: mpsUpdatesBuffer,
                               indicesPlaceholder: mpsIndicesBuffer ],
            targetOperations: nil, resultsDictionary: [outNode: mpsOutputBuffer])
}

func makeBufferWithValues(device: MTLDevice, values: [Int64]) -> MTLBuffer {
  guard let buf = device.makeBuffer(length: values.count * MemoryLayout<Int64>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }
  let buf_data = buf.contents().assumingMemoryBound(to: Int64.self)
  for i in 0..<values.count {
    buf_data[i] = values[i]
  }
  return buf
}

guard let device = MTLCopyAllDevices().first else { fatalError("Not Metal device found") }
print("Using device \(device.name)")

let inp_elem = 4
let upd_elem = 4
let inp_buf = makeBufferWithValues(device: device, values: [1, 2, 3, 4])
let upd_buf = makeBufferWithValues(device: device, values: [Int64.max - 1, Int64.max - 2 , Int64.max >> 16 , 11])
let idx_buf = makeBufferWithValues(device: device, values: [0, 1, 2, 3])
guard let out_buf = device.makeBuffer(length:inp_elem * MemoryLayout<Int64>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") }

scatterMPS(device: device,
           inp_buf: inp_buf, upd_buf: upd_buf,
           idx_buf: idx_buf, out_buf: out_buf,
           inp_elem: inp_elem, upd_elem: upd_elem)

let obuf_data = out_buf.contents().assumingMemoryBound(to: Int64.self)
for i in 0..<inp_elem {
    print("out_buf[\(i)] = \(obuf_data[i])")
}
```
that prints `4294967294, 4294967293, 4294967295, 4` instead of expected `1, 2, 3, 4`
Where `torch.tensor([[1, 9223372036854775806], [2, 9223372036854775805], [3, 140737488355327], [4, 11]], dtype=torch.int64, device='mps').max(1)` yields an expected results
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141948
Approved by: https://github.com/manuelcandales
2024-12-04 04:56:43 +00:00
90f19fee8a [MPS] Convert channels_last_3d to contiguous for input tensor in nn.Conv3d (#141780)
When the input tensor to Conv3d is in the channels_last_3d memory format the Conv3d op will generate incorrect output (see example image in #141471). This PR checks if the op is 3d, and then attempts to convert the input tensor to contiguous.

Added a regression test that verifies the output by running the same op on the CPU.

I'm unsure if Conv3d supports the channels last memory format after #128393. If it does, we should consider updating the logic to utilize this as it would be more efficient. Perhaps @DenisVieriu97 knows or has more context?

Fixes #141471
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141780
Approved by: https://github.com/malfet
2024-12-01 18:36:53 +00:00
4d5c096a55 [MPS] Add autocast rule for SDPA (#141776)
Fixes #141774

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141776
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2024-11-29 03:34:03 +00:00
65166d86a3 [MPS] Add regression test for sync deadlock (#141296)
See https://github.com/pytorch/pytorch/pull/140725#issuecomment-2492434870
Running `torch.mps.synchronize()` after metal kernel resulted in infinite wait inside `[_MTLCommandBuffer waitUntilCompleted]`
```
(lldb) bt
* thread #1, queue = 'com.apple.main-thread', stop reason = signal SIGSTOP
  * frame #0: 0x00000001aa919084 Metal`pthread_cond_wait + 12
    frame #1: 0x00000001aa78b1b4 Metal`-[_MTLCommandBuffer waitUntilCompleted] + 84
    frame #2: 0x00000001032bf358 libtorch_python.dylib`torch::mps::MPSModule_deviceSynchronize(_object*, _object*) + 40
    frame #3: 0x0000000100e94c20 Python`cfunction_vectorcall_NOARGS + 100
    frame #4: 0x0000000100e389b8 Python`PyObject_Vectorcall + 92
    frame #5: 0x0000000100f61e38 Python`_PyEval_EvalFrameDefault + 19040
    frame #6: 0x0000000100f5d180 Python`PyEval_EvalCode + 200
    frame #7: 0x0000000100fcd1a4 Python`run_eval_code_obj + 104
    frame #8: 0x0000000100fccbe4 Python`run_mod + 168
    frame #9: 0x0000000100fcb518 Python`pyrun_file + 164
    frame #10: 0x0000000100fca854 Python`_PyRun_SimpleFileObject + 256
    frame #11: 0x0000000100fca4e8 Python`_PyRun_AnyFileObject + 80
    frame #12: 0x0000000100ff2028 Python`pymain_run_file_obj + 164
    frame #13: 0x0000000100ff1ce4 Python`pymain_run_file + 72
    frame #14: 0x0000000100ff0f74 Python`Py_RunMain + 988
    frame #15: 0x0000000100ff1564 Python`pymain_main + 304
    frame #16: 0x0000000100ff1604 Python`Py_BytesMain + 40
    frame #17: 0x000000019f630274 dyld`start + 2840
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141296
Approved by: https://github.com/huydhn
2024-11-22 00:56:33 +00:00
5e54cf3687 Revert "Fix MPS synchronize by waiting for root buffer to complete (#140725)"
This reverts commit 9bc9d4cdb4355a385a7d7959f07d04d1648d6904.

Reverted https://github.com/pytorch/pytorch/pull/140725 on behalf of https://github.com/malfet due to It causes deadlocks when I try to run something benchmark from  https://github.com/pytorch/pytorch/pull/127242 ([comment](https://github.com/pytorch/pytorch/pull/140725#issuecomment-2492416501))
2024-11-21 21:56:22 +00:00
a8794fd7df [MPS] Fix conv backward pass for channels last (#141009)
Looks like a regression caused by use of strided API, but adding the test revealed (at least in CI), that on Ventura it worked but returned garbage results, so fixed by removing all the logic about channels last (as it's irrelevant for strided API case and placeholder already turns tensor into a correct one)

This also allows one to remove `mem_format_key` and `ns_shape_key` (it was redundant even back then, as `mem_format_key` + `getTensorsStringKey(grad_output_t)` already uniquely identified the operation)

Fixes https://github.com/pytorch/pytorch/issues/140902

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141009
Approved by: https://github.com/manuelcandales
2024-11-20 19:50:31 +00:00
9bc9d4cdb4 Fix MPS synchronize by waiting for root buffer to complete (#140725)
Makes https://github.com/pytorch/pytorch/issues/139550#issuecomment-2468860559 work

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140725
Approved by: https://github.com/malfet, https://github.com/kulinseth
2024-11-19 23:10:24 +00:00
9c88b08ac9 [BE] Replace skipIfMPS with expectedFailureMPS (#139940)
Functionally two decorators are very similar, but one should rely on expectedFailure as much as possible to get signal when something is fixed.
- Move `product_version` variable from `test_mps` to common_utils, but call it `MACOS_VERSION`
- Introduce `skipIfMPSOnMacOS13`  to decorate the hard crashes that happens only on MacOS13 (which at this point will not get any fixes and will be deprecated soon)
- Add `device_type='mps'` to all `skipIfMPS` per https://github.com/pytorch/pytorch/issues/140560
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139940
Approved by: https://github.com/janeyx99, https://github.com/huydhn
2024-11-15 03:48:37 +00:00
b0d681417c [MPS] Reintroduce support for convolutions with output_channels > 65536 (#140726)
This reintroduces support for high channel sizes for convs. The guard for macOS versions < 15.1 is still present to prevent reintroducing #129207.

I'm unsure about the specific macOS version support, but I'm assuming this was fixed in 15.1, and I'm relying on signals from ci for verification. I'm expecting the new test will fail for macOS versions < 15.1, and the old test will start failing for > 15.0. I've added xfails for this and extended the version helpers to support 15.1+.

Fixes #140722
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140726
Approved by: https://github.com/malfet
2024-11-14 20:09:01 +00:00
cd6ace1d15 [EZ] Delete unused xfailIfMacOS14_4Plus (#140735)
Issue was fixed by https://github.com/pytorch/pytorch/pull/130038 but decorator remained in place

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140735
Approved by: https://github.com/kit1980, https://github.com/atalman
2024-11-14 20:08:48 +00:00
9d93c27025 Implement unfold_backward on MPS (#135411)
This PR adds native implementation of unfold_backward as metal shader, mostly copy-n-paste of algorithms used in CUDA and CPU implementations, i.e. considering `out = in.unfold(dim, size, step)`, then following holds true:
* `out.shape[dim] == (in.shape[dim] - size) / step + 1`
* `out.shape[-1] == size`
* `out.ndim == in.ndim + 1`
`unfold_backward` Metal kernel  receives `grad_in` and returns `grad_out` such that:
* `grad_in.shape == out.shape`
* `grad_out.shape == in.shape`

For each index in `grad_out` find the elements contributing to it and sum them up. Such algorithm requires no synchronization between threads.
That is `grad_out[...,out_dim_idx,...]` accumulates all values `grad_in[...,in_dim_idx,...,in_last_idx]`, where `in_dim_idx` is range [`(out_dim_idx - size) / step`, `out_dim_idx / step`] clamped to (0, `in_dim_size`) and `in_last_idx` are equal `out_dim_idx - in_dim_idx * step` . Accumulation step is skipped if `in_last_idx` is outside of [0, size] range.

This operator has been requested 16 times on https://github.com/pytorch/pytorch/issues/77764

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135411
Approved by: https://github.com/manuelcandales

Co-authored-by: Manuel Candales <42380156+manuelcandales@users.noreply.github.com>
2024-11-13 23:04:15 +00:00
cb71bcc542 Replace clone.detach with detach.clone (#140264)
Fixes #64532

As state in issue, replace `clone.detach` by `detach.clone`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140264
Approved by: https://github.com/soulitzer
2024-11-13 07:01:02 +00:00
f77eb07662 Split int4wo weight packing (#139611)
Fixes https://github.com/pytorch/ao/issues/1117.

This PR is to seperate int4wo weight packing between CPU and other devices, to help implement `INT4CPULayout` in torchao based on https://github.com/pytorch/ao/issues/1117#issuecomment-2451252756.

Now, for CPU, the input `weight` of `_convert_weight_to_int4pack_for_cpu` is [n, k] int32, output is [n, k / 2] uint8. The input packed weight of `_weight_int4pack_mm_for_cpu` is [n, k / 2] uint8.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139611
Approved by: https://github.com/jerryzh168
2024-11-12 10:12:50 +00:00
f5ffd55a32 [MPS] Add torch.special.i1 op (#140196)
By more-or-less copy-n-pasting 58b661cda2/aten/src/ATen/native/cuda/Math.cuh (L576)

Enable respective tests in test_mps.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140196
Approved by: https://github.com/Skylion007
2024-11-11 16:57:53 +00:00
103cbd7231 [MPS] Restrict MSELoss to floating types (#139960)
Becuase if invoked with long type it crahses deep in MPSGraph framework and to keep parity with CPU

Add test that validates that if dtype is not floating, both CPU and MPS implementations will error out
Fix function name for `mse_loss_out_mps` as `__func__` for any structured op implementation is `impl`

Fixes https://github.com/pytorch/pytorch/issues/139723
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139960
Approved by: https://github.com/kimishpatel
ghstack dependencies: #139961, #139959
2024-11-08 00:28:54 +00:00
44df6522ee add Half/BFloat16 support for grid_sample on CPU (#134812)
Fix https://github.com/pytorch/pytorch/issues/127224.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134812
Approved by: https://github.com/Skylion007, https://github.com/mingfeima
2024-11-06 14:02:08 +00:00
ca43ecd599 Flip default on weights_only (#137602)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137602
Approved by: https://github.com/malfet, https://github.com/albanD
ghstack dependencies: #138936, #139221, #139433, #139541
2024-11-04 18:30:29 +00:00
51adab0829 [MPS] Fix reduction ops outputs for empty tensors (#139446)
By adding a switch for all reduction types, that either sets it to given value or raises runtime error.
Before this change, reduction ops returned uninitialized values in many case

Fixes https://github.com/pytorch/pytorch/issues/139400

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139446
Approved by: https://github.com/Skylion007
2024-11-01 17:32:12 +00:00
6e85266a47 [MPS] Fixes SiLU on non-contiguous tensors (#139006)
Similar to #123049, however, `SiLU` also produces random values, `0.0`, or `NaN` as results if input tensor is not contiguous on prior to macOS 15.0.
Orignally the problem was found at jy0205/Pyramid-Flow#113.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139006
Approved by: https://github.com/malfet
2024-10-30 15:44:59 +00:00
38645e8a3e Revert "Fix unbind_copy and add its decomposition (#134319)"
This reverts commit 8aedc649bdd0789b0ea9b9348d552fb1b0e437ff.

Reverted https://github.com/pytorch/pytorch/pull/134319 on behalf of https://github.com/huydhn due to Sorry for reverting your PR, but this is still failing the same test on ExecuTorch ([comment](https://github.com/pytorch/pytorch/pull/134319#issuecomment-2443209139))
2024-10-29 04:54:37 +00:00
652a2ab93e [BE] Skip print(foo) tests (#139009)
Skipped `test_exponential` and `test_multinomial` because simply printing the result of an operator does not constitute a test. The testing framework does not attempt to interpret the output.
Modify `test_print_non_contiguous` to get tensors string representation, which is an equivalent operation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139009
Approved by: https://github.com/Skylion007
2024-10-27 18:04:03 +00:00
a3de067975 [PyTorch] Use 128-bit vectors for ARM64 (#137426)
The correct vector length for ARM64 is 128 bits (16
bytes). We were previously using double this, apparently just because
that would be the same length as AVX2.

Differential Revision: [D63984039](https://our.internmc.facebook.com/intern/diff/D63984039/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137426
Approved by: https://github.com/jgong5, https://github.com/malfet
ghstack dependencies: #138486, #138542, #138655, #138716, #138744
2024-10-26 00:20:35 +00:00
1b31248933 [EZ] Fix typo in test_mps.py (#138738)
s/emedding_weight/embedding_weight/

Stolen from 074766d9b4

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138738
Approved by: https://github.com/atalman
2024-10-23 22:15:35 +00:00
8aedc649bd Fix unbind_copy and add its decomposition (#134319)
* Fixes https://github.com/pytorch/pytorch/issues/130829

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134319
Approved by: https://github.com/amjames, https://github.com/eellison
2024-10-23 19:13:44 +00:00
1bc73f3157 Add decomposition for permute_copy (#130944)
* Extracted from #129476

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130944
Approved by: https://github.com/amjames, https://github.com/eellison
2024-10-23 17:42:11 +00:00
de16159e56 [MPS] Fix sliced cast (#138314)
This fixes internal crash due to the invalid bufer size computation if sliced API is used

Not sure what was the purpose of
```c++
IntArrayRef baseShape;
if (src.is_view()) {
  baseShape = src._base().sizes();
} else {
  baseShape = getIMPSAllocator()->getBufferShape(src.storage().data());
}
int flattenedShaped = 1;
for (const auto i : c10::irange(baseShape.size())) {
  flattenedShaped *= baseShape[i];
}
```
As flattenShaped could be much easier computed as `[srcBuf
lengh]/src.element_size()`, and even if `srcBuf` is padded it's a safe thing to do.

When someone allocated buffer to hold say uint8 and that view-casted it
to float16, attempt to compute `baseShape` returned sizes of original
tensor in its data type, rather than size in new dtypes

Fixes https://github.com/pytorch/pytorch/issues/137800
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138314
Approved by: https://github.com/albanD, https://github.com/DenisVieriu97
2024-10-19 05:17:09 +00:00