mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[MPSInductor] Naive welford_reduce implementation (#150824)"
This reverts commit 5edfb4c4fad1bb9504482d930a2540d22427d383. Reverted https://github.com/pytorch/pytorch/pull/150824 on behalf of https://github.com/malfet due to I should have waited for lint ([comment](https://github.com/pytorch/pytorch/pull/150824#issuecomment-2798249264))
This commit is contained in:
@ -92,7 +92,6 @@ opmath_t<T> threadgroup_prod(
|
||||
|
||||
template <typename T>
|
||||
float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
float m = data[0];
|
||||
float m2 = 0;
|
||||
for (unsigned idx = 1; idx < size; ++idx) {
|
||||
@ -103,28 +102,6 @@ float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
|
||||
return float2(m, m2);
|
||||
}
|
||||
|
||||
// Each vec3type is tuple of mean, m2 and weight
|
||||
template <typename T>
|
||||
float3 welford_combine(T a, T b) {
|
||||
float delta = b.x - a.x;
|
||||
float new_weight = a.z + b.z;
|
||||
auto w2_over_w = new_weight != 0 ? b.z / new_weight : 0.0;
|
||||
return float3(
|
||||
a.x + delta * w2_over_w,
|
||||
a.y + b.y + delta * delta * a.z * w2_over_w,
|
||||
new_weight);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float3 threadgroup_welford_combine(threadgroup T* data, unsigned size) {
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
float3 rc = data[0];
|
||||
for (unsigned idx = 1; idx < size; ++idx) {
|
||||
rc = welford_combine(rc, data[idx]);
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T threadgroup_max(threadgroup T* data, unsigned size) {
|
||||
// TODO: This should be moved to the callee
|
||||
|
||||
@ -175,7 +175,6 @@ for test_name in [
|
||||
"test_argmax_argmin2",
|
||||
"test_avg_pool2d5",
|
||||
"test_avg_pool2d8",
|
||||
"test_batch_norm_2d_2",
|
||||
"test_bernoulli1",
|
||||
"test_builtins_round",
|
||||
"test_builtins_round_float_ndigits_neg",
|
||||
|
||||
@ -510,22 +510,19 @@ class MetalKernel(SIMDKernel):
|
||||
|
||||
def _new_idxvar(
|
||||
self,
|
||||
dtype: Union[str | torch.dtype],
|
||||
dtype: torch.dtype,
|
||||
elem_count: Optional[int] = None,
|
||||
default_value: Optional[Any] = None,
|
||||
is_threadgroup: bool = True,
|
||||
bounds: ValueRanges[Any] = ValueRanges.unknown(),
|
||||
) -> CSEVariable:
|
||||
if isinstance(dtype, torch.dtype):
|
||||
dtype = self.dtype_to_str(dtype)
|
||||
var_name = f"tmp_acc_{next(self.acc_var_ids)}"
|
||||
var = V.kernel.create_cse_var(var_name, bounds, dtype)
|
||||
var_def = "threadgroup " if is_threadgroup else ""
|
||||
var_def += f"{dtype} {var_name}"
|
||||
var_def += f"{self.dtype_to_str(dtype)} {var_name}"
|
||||
if elem_count:
|
||||
var_def += f"[{elem_count}]"
|
||||
if default_value is not None:
|
||||
assert not is_threadgroup, "Thread group var can not have default value"
|
||||
var_def += f" = {default_value}"
|
||||
self.indexing_code.writeline(var_def + self.suffix)
|
||||
return var
|
||||
@ -644,19 +641,6 @@ class MetalKernel(SIMDKernel):
|
||||
return OpsWrapper._unwrap(
|
||||
(f"{wf_res}.x", f"{wf_res}.y", self.features.reduction_numel)
|
||||
)
|
||||
if reduction_type == "welford_combine":
|
||||
assert not self.multistage_reduction, (
|
||||
f"Multistage reduction not yet supported for {reduction_type}"
|
||||
)
|
||||
acc_buf = self._new_idxvar("float3", acc_buf_size)
|
||||
self.compute.splice(
|
||||
f"{acc_buf}[{reduction_idx}] = float3({value[0]}, {value[1]}, {value[2]});"
|
||||
)
|
||||
wf_res = self.cse.generate(
|
||||
self.compute,
|
||||
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})",
|
||||
)
|
||||
return OpsWrapper._unwrap((f"{wf_res}.x", f"{wf_res}.y", f"{wf_res}.z"))
|
||||
raise NotImplementedError(reduction_type)
|
||||
|
||||
def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None:
|
||||
|
||||
Reference in New Issue
Block a user