mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPSInductor] Fix larger-than-threadgroup Welford reductions (#151152)
By using `welford_combine` primitive in the loop This fixes `GPUTests.test_multilayer_var_lowp_mps` Pull Request resolved: https://github.com/pytorch/pytorch/pull/151152 Approved by: https://github.com/jansel ghstack dependencies: #151042, #150824, #151151
This commit is contained in:
committed by
PyTorch MergeBot
parent
7762bddd87
commit
9699cc3eb9
@ -91,7 +91,7 @@ opmath_t<T> threadgroup_prod(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
|
||||
float3 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
float m = data[0];
|
||||
float m2 = 0;
|
||||
@ -100,7 +100,7 @@ float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
|
||||
m += delta / (idx + 1);
|
||||
m2 += delta * (data[idx] - m);
|
||||
}
|
||||
return float2(m, m2);
|
||||
return float3(m, m2, size);
|
||||
}
|
||||
|
||||
// Each vec3type is tuple of mean, m2 and weight
|
||||
|
@ -209,6 +209,7 @@ for test_name in [
|
||||
"test_max_min",
|
||||
"test_max_pool2d2",
|
||||
"test_multilayer_prime_size",
|
||||
"test_multilayer_var_lowp",
|
||||
"test_min_max_reduction_nan",
|
||||
"test_nan_to_num",
|
||||
"test_neg_max_uint8",
|
||||
|
@ -119,6 +119,8 @@ class MetalExprPrinter(ExprPrinter_):
|
||||
|
||||
|
||||
class MetalOverrides(OpOverrides):
|
||||
"""Implements Metal-specific overrids for ops. Base class emits Python-friendly overrides"""
|
||||
|
||||
@staticmethod
|
||||
def to_dtype(
|
||||
x: CSEVariable,
|
||||
@ -458,6 +460,8 @@ MetalOverrides._initialize_pointwise_overrides("mps")
|
||||
|
||||
|
||||
class MetalKernel(SIMDKernel):
|
||||
"""Implement Metal codegen based on the SIMDKernel abstraction"""
|
||||
|
||||
overrides = MetalOverrides # type: ignore[assignment]
|
||||
suffix = ";"
|
||||
newvar_prefix = "auto "
|
||||
@ -640,31 +644,46 @@ class MetalKernel(SIMDKernel):
|
||||
dtype=dtype,
|
||||
)
|
||||
if reduction_type == "welford_reduce":
|
||||
assert not self.multistage_reduction, (
|
||||
f"Multistage reduction not yet supported for {reduction_type}"
|
||||
if not self.multistage_reduction:
|
||||
acc_buf = self._new_idxvar(src_dtype, acc_buf_size)
|
||||
self.compute.splice(f"{acc_buf}[{reduction_idx}] = {value};")
|
||||
wf_res = self.cse.generate(
|
||||
self.compute,
|
||||
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})",
|
||||
)
|
||||
self.cse.reduction_cache[cache_key] = result_tuple = OpsWrapper._unwrap(
|
||||
(f"{wf_res}.x", f"{wf_res}.y", f"{wf_res}.z")
|
||||
)
|
||||
return result_tuple
|
||||
acc_buf = self._new_idxvar("float3", acc_buf_size)
|
||||
acc_thread_var = f"{acc_buf}[{reduction_idx}]"
|
||||
self.indexing_code.splice(f"{acc_thread_var} = 0.0;")
|
||||
self.compute.writeline(
|
||||
f"{acc_thread_var} = ::c10::metal::welford_combine({acc_thread_var}, float3({value}, 0.0, 1.0));"
|
||||
)
|
||||
acc_buf = self._new_idxvar(src_dtype, acc_buf_size)
|
||||
self.compute.splice(f"{acc_buf}[{reduction_idx}] = {value};")
|
||||
wf_res = self.cse.generate(
|
||||
self.compute,
|
||||
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})",
|
||||
self.stores,
|
||||
f"c10::metal::threadgroup_welford_combine({acc_buf}, {acc_buf_size})",
|
||||
)
|
||||
self.cse.reduction_cache[cache_key] = result_tuple = OpsWrapper._unwrap(
|
||||
(f"{wf_res}.x", f"{wf_res}.y", self.features.reduction_numel)
|
||||
(f"{wf_res}.x", f"{wf_res}.y", f"{wf_res}.z")
|
||||
)
|
||||
return result_tuple
|
||||
|
||||
if reduction_type == "welford_combine":
|
||||
assert not self.multistage_reduction, (
|
||||
f"Multistage reduction not yet supported for {reduction_type}"
|
||||
)
|
||||
assert isinstance(value, tuple), "Input to welford combine must be tuple"
|
||||
acc_buf = self._new_idxvar("float3", acc_buf_size)
|
||||
self.compute.splice(
|
||||
f"{acc_buf}[{reduction_idx}] = float3({value[0]}, {value[1]}, {value[2]});"
|
||||
)
|
||||
acc_thread_var = f"{acc_buf}[{reduction_idx}]"
|
||||
inp_value = f"float3({value[0]}, {value[1]}, {value[2]})"
|
||||
self.indexing_code.splice(f"{acc_thread_var} = 0.0;")
|
||||
if self.multistage_reduction:
|
||||
self.indexing_code.splice(f"{acc_thread_var} = 0.0;")
|
||||
self.compute.writeline(
|
||||
f"{acc_thread_var} = ::c10::metal::welford_combine({acc_thread_var}, {inp_value});"
|
||||
)
|
||||
else:
|
||||
self.compute.writeline(f"{acc_thread_var} = {inp_value};")
|
||||
wf_res = self.cse.generate(
|
||||
self.compute,
|
||||
self.stores if self.multistage_reduction else self.compute,
|
||||
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})",
|
||||
)
|
||||
self.cse.reduction_cache[cache_key] = result_tuple = OpsWrapper._unwrap(
|
||||
|
Reference in New Issue
Block a user