mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPSInductor] Specify max_total_threads_per_threadgroup (#150247)
When generating reduction kernel, otherwise compiler can unroll loops too much that kernel could not be launched for the intended threadgroup size Extend `c10:🤘:max` to accept different dtypes Together this fixes `test_large_broadcast_reduction` TODO: - Explore different threadgroup_sizes for best perf Pull Request resolved: https://github.com/pytorch/pytorch/pull/150247 Approved by: https://github.com/jansel, https://github.com/dcci ghstack dependencies: #150246
This commit is contained in:
committed by
PyTorch MergeBot
parent
52135db69a
commit
965784eb9b
@ -92,9 +92,10 @@ template <typename T>
|
||||
return ::metal::isunordered(a, b) ? NAN : ::metal::max(a, b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
::metal::enable_if_t<::metal::is_integral_v<T>, T> max(T a, T b) {
|
||||
return ::metal::max(a, b);
|
||||
template <typename T, typename U>
|
||||
::metal::enable_if_t<::metal::is_integral_v<T>&& ::metal::is_integral_v<U>, T>
|
||||
max(T a, U b) {
|
||||
return ::metal::max(a, static_cast<T>(b));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -102,9 +103,10 @@ template <typename T>
|
||||
return ::metal::isunordered(a, b) ? NAN : ::metal::min(a, b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
::metal::enable_if_t<::metal::is_integral_v<T>, T> min(T a, T b) {
|
||||
return ::metal::min(a, b);
|
||||
template <typename T, typename U>
|
||||
::metal::enable_if_t<::metal::is_integral_v<T>&& ::metal::is_integral_v<U>, T>
|
||||
min(T a, U b) {
|
||||
return ::metal::min(a, static_cast<T>(b));
|
||||
}
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
|
||||
@ -195,6 +195,7 @@ for test_name in [
|
||||
"test_inf",
|
||||
"test_isinf",
|
||||
"test_isinf2",
|
||||
"test_large_broadcast_reduction",
|
||||
"test_layer_norm",
|
||||
"test_lgamma",
|
||||
"test_linear_float64",
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import math
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
import sympy
|
||||
@ -677,6 +678,14 @@ class MetalKernel(SIMDKernel):
|
||||
)
|
||||
if self.inside_reduction:
|
||||
code.writeline("#include <c10/metal/reduction_utils.h>")
|
||||
if self.inside_reduction:
|
||||
total_reduction_size = math.prod(
|
||||
t.numel for t in self.range_trees if t.is_reduction
|
||||
)
|
||||
threadgroup_size = min(total_reduction_size, self.max_threadgroup_size)
|
||||
code.writeline(
|
||||
f"[[max_total_threads_per_threadgroup({threadgroup_size})]]"
|
||||
)
|
||||
code.writeline("kernel void generated_kernel(")
|
||||
with code.indent():
|
||||
for outer, inner in self.args.output_buffers.items():
|
||||
|
||||
Reference in New Issue
Block a user