What started as simple fix for `mps_convolution_backward_input` resulted in a pretty significant refactor/fixes:
- Updated `mps_conv_use_channels_last` to return channels last output if either input or weights are channels last
- Use the same primitive throughout `Convolution.mm` to determine wether output should be allocated in channels last format or not
But doing only those two, resulted in crash in `test_memory_format_nn_Conv2d_mps_float32`, when weights were backward, and bias is present:
```
% python -c "import torch;print(torch.nn.functional.conv2d(torch.rand(2, 4, 3, 4,device='mps'), torch.rand(5, 4, 3, 3,device='mps').to(memory_format=torch.channels_last), torch.rand(5,device='mps')))"
/AppleInternal/Library/BuildRoots/4~B5E4ugDCh2RsPWAjMEoPu8LC5w1yXEwd7XweDhg/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:3619: failed assertion `Error: MLIR pass manager failed'
zsh: abort python -c
```
Which requires a more thorough redesign/cleanup, namely:
- Do not alter the layout based on MacOS version, but rather do additional copies on MacOS-14 if inputs/output or weight are in channels last format ( done by defining `std::optional<Tensor> output_c;` that contains a contiguous copy of the output tensor
- Introduced `input_suggested_layout` which is set to ChannelsLast if and only if input is channels last and is running on MacOS-15+
- Delete unused `memory_layout` and `group` arguments from `fill_depthwise_conv_desc`
- Fix bias broadcasting logic for channels last
As result, in addition to adding one more regression test this change removes `expectedFailures` from:
- `TestModule.test_memory_format` for `Conv2d`, `ConvTranspose2d`, `LazyConv1d`, `LazyConvTranspose1d`
- `test_require_stride_expanded_dynamic_shapes`
- `test_mutable_custom_op_fixed_layout2` for MacOS-14
Fixes https://github.com/pytorch/pytorch/issues/161905
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162776
Approved by: https://github.com/Skylion007
Everything should go thru a generalized kernels, and Metal kernels should work with the same sizes and strides as CPU or CUDA backends to avoid problems with `torch.compile` that relies on the meta kernels to tell what its ouput going to look like.
To avoid returning tensors with different layout depending on whether upper parameter is true or false, templatize `factorDiagonalBlock`, `applyTRSM` and `applySYRK` to take upper/lower (actually row-wise vs column-wise) as template argument and call appropriate templates from host
TODOs:
- Rename upper parameter to something more sensible and add comments
- Use simd_groupsize instead of hardcoded 32 everywhere
Fixes https://github.com/pytorch/pytorch/issues/156658
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157014
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #157179
Which is a regression, introduced by https://github.com/pytorch/pytorch/issues/150629#issue-2970312779 which I should have reviewed more thoroughly.
- Defined `_fused_rms_norm`, added MPS-only implementation for it and dispatch from `rms_norm_symint`, which is registered as `CompositeImplicitAutograd`, i.e. it is not supposed to do any computations over Tensor, only dispatch to other ops
-
- Register `_fused_rms_norm` as a fallback in `torch/_inductor/lowering.py`
- Added unit test to avoid those regressions in the future
TODO:
- Get rid of this op, change `rms_norm_symint` definition to `CompositeExplicitAutograd` and implement backward function in `tools/autograd/derivatives.yaml`
- Benchmark compiler and re-enable decomp as follows when compiled code is faster
```python
@register_decomposition(aten._rms_norm_fused)
def rms_norm_fused(
self: torch.Tensor, ndim: int, weight: torch.Tensor, eps: float
) -> torch.Tensor:
dtr = [self.dim() - i - 1 for i in range(ndim)]
return self * weight * (self.pow(2).mean(dtr, keepdim=True).add(eps).rsqrt())
```
Fixes https://github.com/pytorch/pytorch/issues/150629
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150661
Approved by: https://github.com/manuelcandales, https://github.com/jansel
Which is a regression, introduced by https://github.com/pytorch/pytorch/issues/150629#issue-2970312779 which I should have reviewed more thoroughly.
- Defined `_fused_rms_norm`, added MPS-only implementation for it and dispatch from `rms_norm_symint`, which is registered as `CompositeImplicitAutograd`, i.e. it is not supposed to do any computations over Tensor, only dispatch to other ops
-
- Register `_fused_rms_norm` as a fallback in `torch/_inductor/lowering.py`
- Added unit test to avoid those regressions in the future
TODO:
- Get rid of this op, change `rms_norm_symint` definition to `CompositeExplicitAutograd` and implement backward function in `tools/autograd/derivatives.yaml`
- Benchmark compiler and re-enable decomp as follows when compiled code is faster
```python
@register_decomposition(aten._rms_norm_fused)
def rms_norm_fused(
self: torch.Tensor, ndim: int, weight: torch.Tensor, eps: float
) -> torch.Tensor:
dtr = [self.dim() - i - 1 for i in range(ndim)]
return self * weight * (self.pow(2).mean(dtr, keepdim=True).add(eps).rsqrt())
```
Fixes https://github.com/pytorch/pytorch/issues/150629
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150661
Approved by: https://github.com/manuelcandales, https://github.com/jansel
There are only 118 failures atm, mark them all with xfail to avoid new regressions
Add `xfail_if_mps_unimplemented` decorator to distinguish between tests that call unimplemented eager op vs ones that fail for some other reason.
Added `aten._scaled_dot_product_attention_math_for_mps` fallback to make test behavior consistent between MacOS-15 (where falback is in place) and MacOS-14
Weird MacOS-14 specific skips:
- test_torchinductor.py::GPUTests::test_cat_extern_kernel_mps
- test_torchinductor.py::GPUTests::test_sort_transpose_mps (likely an eager bug)
- test_torchinductor.py::GPUTests::test_unaligned_input_mps
Numerous MacOS-13 skips, including few eager hard crashes, for example running `test_torchinductor.py::GPUTests::test_scatter5_mps` causes
```
/AppleInternal/Library/BuildRoots/c651a45f-806e-11ed-a221-7ef33c48bc85/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSNDArray/Kernels/MPSNDArrayScatter.mm:309: failed assertion `Rank of destination array (1) must be greater than or equal to inner-most dimension of indices array (3)'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150821
Approved by: https://github.com/ZainRizvi, https://github.com/dcci
ghstack dependencies: #151224, #151246, #151272, #151282, #151288
By using Metal `as_type` which according to documentation does exactly
that:
> Metal adds an as_type<type-id> operator to allow any scalar or vector data type (that is not
a pointer) to be reinterpreted as another scalar or vector data type of the same size. The bits in
the operand are returned directly without modification as the new type. The usual type
promotion for function arguments is not performed.
Using `reinterpret_cast` created a potential silent correctness error when dtypes of different sizes were bitcast to each other
Add expicit cast to src_type to avoid errors due to type promotion (i.e.
soemthing like (x+1).view(dtype=torch.float16) would work correctly in
eager mode for int16 dtype, but would fail in compile, as arithmetic
operations will promote int16 to int32
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151272
Approved by: https://github.com/dcci
ghstack dependencies: #151224, #151246
To avoid accuracy issues when small reductions are unrolled, cast half to float during the `load` op
As `op_math_t<half>` is indeed float
This fixes `test_unroll_small_reduction` for reduced precision types
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151246
Approved by: https://github.com/dcci
ghstack dependencies: #151224
By adding `pass` in front of the comment for fake set_device call
Which fixes `TestGPU.test_zero_element_mutation_mps`, which previously
failed with
```
torch._inductor.exc.InductorError: RuntimeError: Failed to import /var/folders/sc/2thx6_x95h7_h9qs8s48yh140000gn/T/tmp2emka_sx/7k/c7kmnwhb363ysalhewglr3cwtej6tiz3t4ppqa4bvhubaokmlprw.py
IndentationError: expected an indented block after 'with' statement on line 38 (c7kmnwhb363ysalhewglr3cwtej6tiz3t4ppqa4bvhubaokmlprw.py, line 40)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151224
Approved by: https://github.com/Skylion007, https://github.com/jansel, https://github.com/dcci
If input is channels last than MPS will return a channels last output
This fixed `GPUTests.test_convolution_4_mps` from test_torchinductor.py
That previous failed with
```
AssertionError: expected size 3==3, stride 1==192 at dim=1; expected size 12==12, stride 48==16 at dim=2; expected size 16==16, stride 3==1 at dim=3
```
As FakeTensor implementation of conv returned `Contiguous`, rather than `ChannelLast` layout on MacOS-15 or later.
This doesn't seem to be very well documented, so will try to document the call path for `ExternKernel` invocation for `aten::convolution`:
- First inductor decomp defined here is called
c93e4b8290/torch/_inductor/kernel/conv.py (L424-L425)
- Then it goes thru FakeTensor decomposition implemented here
320914f1b6/torch/_subclasses/fake_impls.py (L739-L740)
- Finally it goes down to convolution meta registrations implemented here
320914f1b6/torch/_meta_registrations.py (L2416-L2417)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151042
Approved by: https://github.com/dcci
When generating reduction kernel, otherwise compiler can unroll loops too much that kernel could not be launched for the intended threadgroup size
Extend `c10:🤘:max` to accept different dtypes
Together this fixes `test_large_broadcast_reduction`
TODO:
- Explore different threadgroup_sizes for best perf
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150247
Approved by: https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #150246