Revert "[MPSInductor] Naive welford_reduce implementation (#150824)"

This reverts commit 5edfb4c4fad1bb9504482d930a2540d22427d383.

Reverted https://github.com/pytorch/pytorch/pull/150824 on behalf of https://github.com/malfet due to I should have waited for lint ([comment](https://github.com/pytorch/pytorch/pull/150824#issuecomment-2798249264))
This commit is contained in:
PyTorch MergeBot
2025-04-12 00:21:14 +00:00
parent ca2e8cd352
commit 83f14c0b06
3 changed files with 2 additions and 42 deletions

View File

@ -92,7 +92,6 @@ opmath_t<T> threadgroup_prod(
template <typename T>
float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
float m = data[0];
float m2 = 0;
for (unsigned idx = 1; idx < size; ++idx) {
@ -103,28 +102,6 @@ float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
return float2(m, m2);
}
// Each vec3type is tuple of mean, m2 and weight
template <typename T>
float3 welford_combine(T a, T b) {
float delta = b.x - a.x;
float new_weight = a.z + b.z;
auto w2_over_w = new_weight != 0 ? b.z / new_weight : 0.0;
return float3(
a.x + delta * w2_over_w,
a.y + b.y + delta * delta * a.z * w2_over_w,
new_weight);
}
template <typename T>
float3 threadgroup_welford_combine(threadgroup T* data, unsigned size) {
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
float3 rc = data[0];
for (unsigned idx = 1; idx < size; ++idx) {
rc = welford_combine(rc, data[idx]);
}
return rc;
}
template <typename T>
T threadgroup_max(threadgroup T* data, unsigned size) {
// TODO: This should be moved to the callee

View File

@ -175,7 +175,6 @@ for test_name in [
"test_argmax_argmin2",
"test_avg_pool2d5",
"test_avg_pool2d8",
"test_batch_norm_2d_2",
"test_bernoulli1",
"test_builtins_round",
"test_builtins_round_float_ndigits_neg",

View File

@ -510,22 +510,19 @@ class MetalKernel(SIMDKernel):
def _new_idxvar(
self,
dtype: Union[str | torch.dtype],
dtype: torch.dtype,
elem_count: Optional[int] = None,
default_value: Optional[Any] = None,
is_threadgroup: bool = True,
bounds: ValueRanges[Any] = ValueRanges.unknown(),
) -> CSEVariable:
if isinstance(dtype, torch.dtype):
dtype = self.dtype_to_str(dtype)
var_name = f"tmp_acc_{next(self.acc_var_ids)}"
var = V.kernel.create_cse_var(var_name, bounds, dtype)
var_def = "threadgroup " if is_threadgroup else ""
var_def += f"{dtype} {var_name}"
var_def += f"{self.dtype_to_str(dtype)} {var_name}"
if elem_count:
var_def += f"[{elem_count}]"
if default_value is not None:
assert not is_threadgroup, "Thread group var can not have default value"
var_def += f" = {default_value}"
self.indexing_code.writeline(var_def + self.suffix)
return var
@ -644,19 +641,6 @@ class MetalKernel(SIMDKernel):
return OpsWrapper._unwrap(
(f"{wf_res}.x", f"{wf_res}.y", self.features.reduction_numel)
)
if reduction_type == "welford_combine":
assert not self.multistage_reduction, (
f"Multistage reduction not yet supported for {reduction_type}"
)
acc_buf = self._new_idxvar("float3", acc_buf_size)
self.compute.splice(
f"{acc_buf}[{reduction_idx}] = float3({value[0]}, {value[1]}, {value[2]});"
)
wf_res = self.cse.generate(
self.compute,
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})",
)
return OpsWrapper._unwrap((f"{wf_res}.x", f"{wf_res}.y", f"{wf_res}.z"))
raise NotImplementedError(reduction_type)
def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None: