[MPSInductor] Fix larger-than-threadgroup Welford reductions (#151152)

By using `welford_combine` primitive in the loop
This fixes `GPUTests.test_multilayer_var_lowp_mps`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151152
Approved by: https://github.com/jansel
ghstack dependencies: #151042, #150824, #151151
This commit is contained in:
Nikita Shulga
2025-04-12 14:27:11 -07:00
committed by PyTorch MergeBot
parent 7762bddd87
commit 9699cc3eb9
3 changed files with 37 additions and 17 deletions

View File

@ -91,7 +91,7 @@ opmath_t<T> threadgroup_prod(
}
template <typename T>
float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
float3 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
float m = data[0];
float m2 = 0;
@ -100,7 +100,7 @@ float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
m += delta / (idx + 1);
m2 += delta * (data[idx] - m);
}
return float2(m, m2);
return float3(m, m2, size);
}
// Each vec3type is tuple of mean, m2 and weight

View File

@ -209,6 +209,7 @@ for test_name in [
"test_max_min",
"test_max_pool2d2",
"test_multilayer_prime_size",
"test_multilayer_var_lowp",
"test_min_max_reduction_nan",
"test_nan_to_num",
"test_neg_max_uint8",

View File

@ -119,6 +119,8 @@ class MetalExprPrinter(ExprPrinter_):
class MetalOverrides(OpOverrides):
"""Implements Metal-specific overrids for ops. Base class emits Python-friendly overrides"""
@staticmethod
def to_dtype(
x: CSEVariable,
@ -458,6 +460,8 @@ MetalOverrides._initialize_pointwise_overrides("mps")
class MetalKernel(SIMDKernel):
"""Implement Metal codegen based on the SIMDKernel abstraction"""
overrides = MetalOverrides # type: ignore[assignment]
suffix = ";"
newvar_prefix = "auto "
@ -640,31 +644,46 @@ class MetalKernel(SIMDKernel):
dtype=dtype,
)
if reduction_type == "welford_reduce":
assert not self.multistage_reduction, (
f"Multistage reduction not yet supported for {reduction_type}"
if not self.multistage_reduction:
acc_buf = self._new_idxvar(src_dtype, acc_buf_size)
self.compute.splice(f"{acc_buf}[{reduction_idx}] = {value};")
wf_res = self.cse.generate(
self.compute,
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})",
)
self.cse.reduction_cache[cache_key] = result_tuple = OpsWrapper._unwrap(
(f"{wf_res}.x", f"{wf_res}.y", f"{wf_res}.z")
)
return result_tuple
acc_buf = self._new_idxvar("float3", acc_buf_size)
acc_thread_var = f"{acc_buf}[{reduction_idx}]"
self.indexing_code.splice(f"{acc_thread_var} = 0.0;")
self.compute.writeline(
f"{acc_thread_var} = ::c10::metal::welford_combine({acc_thread_var}, float3({value}, 0.0, 1.0));"
)
acc_buf = self._new_idxvar(src_dtype, acc_buf_size)
self.compute.splice(f"{acc_buf}[{reduction_idx}] = {value};")
wf_res = self.cse.generate(
self.compute,
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})",
self.stores,
f"c10::metal::threadgroup_welford_combine({acc_buf}, {acc_buf_size})",
)
self.cse.reduction_cache[cache_key] = result_tuple = OpsWrapper._unwrap(
(f"{wf_res}.x", f"{wf_res}.y", self.features.reduction_numel)
(f"{wf_res}.x", f"{wf_res}.y", f"{wf_res}.z")
)
return result_tuple
if reduction_type == "welford_combine":
assert not self.multistage_reduction, (
f"Multistage reduction not yet supported for {reduction_type}"
)
assert isinstance(value, tuple), "Input to welford combine must be tuple"
acc_buf = self._new_idxvar("float3", acc_buf_size)
self.compute.splice(
f"{acc_buf}[{reduction_idx}] = float3({value[0]}, {value[1]}, {value[2]});"
)
acc_thread_var = f"{acc_buf}[{reduction_idx}]"
inp_value = f"float3({value[0]}, {value[1]}, {value[2]})"
self.indexing_code.splice(f"{acc_thread_var} = 0.0;")
if self.multistage_reduction:
self.indexing_code.splice(f"{acc_thread_var} = 0.0;")
self.compute.writeline(
f"{acc_thread_var} = ::c10::metal::welford_combine({acc_thread_var}, {inp_value});"
)
else:
self.compute.writeline(f"{acc_thread_var} = {inp_value};")
wf_res = self.cse.generate(
self.compute,
self.stores if self.multistage_reduction else self.compute,
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})",
)
self.cse.reduction_cache[cache_key] = result_tuple = OpsWrapper._unwrap(