[MPSInductor] Improve _default dtype inference (#156121)

By just adding 'mps' as one of the backend options and fixing reduction op to actually return tuple of CSEVariable's rather than tuple of strings

Test plan: CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156121
Approved by: https://github.com/dcci
This commit is contained in:
Nikita Shulga
2025-06-16 15:14:40 -07:00
committed by PyTorch MergeBot
parent 508cdc4fc9
commit bb1f3d1a55
2 changed files with 13 additions and 4 deletions

View File

@ -2387,7 +2387,7 @@ class CSEProxy(DefaultHandler):
output_dtype = V.interpreter.current_node.meta.get(
OptimizationContext.key, None
).dtype
elif backend in ("triton", "cpp"):
elif backend in ("triton", "cpp", "mps"):
dtype_op = getattr(dtype_handler, name)
output_dtype = dtype_op(*args, **kwargs)

View File

@ -566,6 +566,12 @@ class MetalKernel(SIMDKernel):
assert self.inside_reduction
assert not self._load_mask
def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]:
# Uwraps vec3 dtype into individual components
return OpsWrapper._unwrap(
[CSEVariable(f"{res3}.{t}", res3.bounds, res3.dtype) for t in "xyz"]
)
# Establish reduction buffer size and index expression
reduction_idx = ""
acc_buf_size = 1
@ -669,8 +675,9 @@ class MetalKernel(SIMDKernel):
wf_res = self.cse.generate(
self.compute,
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})",
dtype=torch.float32,
)
return OpsWrapper._unwrap((f"{wf_res}.x", f"{wf_res}.y", f"{wf_res}.z"))
return _unwrap_helper(wf_res)
acc_buf = self._new_idxvar("float3", acc_buf_size)
acc_thread_var = f"{acc_buf}[{reduction_idx}]"
self.indexing_code.splice(f"{acc_thread_var} = 0.0;")
@ -680,8 +687,9 @@ class MetalKernel(SIMDKernel):
wf_res = self.cse.generate(
self.stores,
f"c10::metal::threadgroup_welford_combine({acc_buf}, {acc_buf_size})",
dtype=torch.float32,
)
return OpsWrapper._unwrap((f"{wf_res}.x", f"{wf_res}.y", f"{wf_res}.z"))
return _unwrap_helper(wf_res)
if reduction_type == "welford_combine":
assert isinstance(value, tuple), "Input to welford combine must be tuple"
acc_buf = self._new_idxvar("float3", acc_buf_size)
@ -698,8 +706,9 @@ class MetalKernel(SIMDKernel):
wf_res = self.cse.generate(
self.stores if self.multistage_reduction_entry else self.compute,
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})",
dtype=torch.float32,
)
return OpsWrapper._unwrap((f"{wf_res}.x", f"{wf_res}.y", f"{wf_res}.z"))
return _unwrap_helper(wf_res)
raise NotImplementedError(reduction_type)
def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None: