[MPSInductor] Specify max_total_threads_per_threadgroup (#150247)

When generating reduction kernel, otherwise compiler can unroll loops too much that kernel could not be launched for the intended threadgroup size

Extend `c10:🤘:max` to accept different dtypes

Together this fixes `test_large_broadcast_reduction`

TODO:
  - Explore different threadgroup_sizes for best perf

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150247
Approved by: https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #150246
This commit is contained in:
Nikita Shulga
2025-03-29 12:34:11 -07:00
committed by PyTorch MergeBot
parent 52135db69a
commit 965784eb9b
3 changed files with 18 additions and 6 deletions

View File

@ -92,9 +92,10 @@ template <typename T>
return ::metal::isunordered(a, b) ? NAN : ::metal::max(a, b);
}
template <typename T>
::metal::enable_if_t<::metal::is_integral_v<T>, T> max(T a, T b) {
return ::metal::max(a, b);
template <typename T, typename U>
::metal::enable_if_t<::metal::is_integral_v<T>&& ::metal::is_integral_v<U>, T>
max(T a, U b) {
return ::metal::max(a, static_cast<T>(b));
}
template <typename T>
@ -102,9 +103,10 @@ template <typename T>
return ::metal::isunordered(a, b) ? NAN : ::metal::min(a, b);
}
template <typename T>
::metal::enable_if_t<::metal::is_integral_v<T>, T> min(T a, T b) {
return ::metal::min(a, b);
template <typename T, typename U>
::metal::enable_if_t<::metal::is_integral_v<T>&& ::metal::is_integral_v<U>, T>
min(T a, U b) {
return ::metal::min(a, static_cast<T>(b));
}
#if __METAL_VERSION__ >= 310

View File

@ -195,6 +195,7 @@ for test_name in [
"test_inf",
"test_isinf",
"test_isinf2",
"test_large_broadcast_reduction",
"test_layer_norm",
"test_lgamma",
"test_linear_float64",

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import itertools
import math
from typing import Any, Optional, TYPE_CHECKING
import sympy
@ -677,6 +678,14 @@ class MetalKernel(SIMDKernel):
)
if self.inside_reduction:
code.writeline("#include <c10/metal/reduction_utils.h>")
if self.inside_reduction:
total_reduction_size = math.prod(
t.numel for t in self.range_trees if t.is_reduction
)
threadgroup_size = min(total_reduction_size, self.max_threadgroup_size)
code.writeline(
f"[[max_total_threads_per_threadgroup({threadgroup_size})]]"
)
code.writeline("kernel void generated_kernel(")
with code.indent():
for outer, inner in self.args.output_buffers.items():