mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPSInductor] Improve _default
dtype inference (#156121)
By just adding 'mps' as one of the backend options and fixing reduction op to actually return tuple of CSEVariable's rather than tuple of strings Test plan: CI Pull Request resolved: https://github.com/pytorch/pytorch/pull/156121 Approved by: https://github.com/dcci
This commit is contained in:
committed by
PyTorch MergeBot
parent
508cdc4fc9
commit
bb1f3d1a55
@ -2387,7 +2387,7 @@ class CSEProxy(DefaultHandler):
|
||||
output_dtype = V.interpreter.current_node.meta.get(
|
||||
OptimizationContext.key, None
|
||||
).dtype
|
||||
elif backend in ("triton", "cpp"):
|
||||
elif backend in ("triton", "cpp", "mps"):
|
||||
dtype_op = getattr(dtype_handler, name)
|
||||
output_dtype = dtype_op(*args, **kwargs)
|
||||
|
||||
|
@ -566,6 +566,12 @@ class MetalKernel(SIMDKernel):
|
||||
assert self.inside_reduction
|
||||
assert not self._load_mask
|
||||
|
||||
def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]:
|
||||
# Uwraps vec3 dtype into individual components
|
||||
return OpsWrapper._unwrap(
|
||||
[CSEVariable(f"{res3}.{t}", res3.bounds, res3.dtype) for t in "xyz"]
|
||||
)
|
||||
|
||||
# Establish reduction buffer size and index expression
|
||||
reduction_idx = ""
|
||||
acc_buf_size = 1
|
||||
@ -669,8 +675,9 @@ class MetalKernel(SIMDKernel):
|
||||
wf_res = self.cse.generate(
|
||||
self.compute,
|
||||
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})",
|
||||
dtype=torch.float32,
|
||||
)
|
||||
return OpsWrapper._unwrap((f"{wf_res}.x", f"{wf_res}.y", f"{wf_res}.z"))
|
||||
return _unwrap_helper(wf_res)
|
||||
acc_buf = self._new_idxvar("float3", acc_buf_size)
|
||||
acc_thread_var = f"{acc_buf}[{reduction_idx}]"
|
||||
self.indexing_code.splice(f"{acc_thread_var} = 0.0;")
|
||||
@ -680,8 +687,9 @@ class MetalKernel(SIMDKernel):
|
||||
wf_res = self.cse.generate(
|
||||
self.stores,
|
||||
f"c10::metal::threadgroup_welford_combine({acc_buf}, {acc_buf_size})",
|
||||
dtype=torch.float32,
|
||||
)
|
||||
return OpsWrapper._unwrap((f"{wf_res}.x", f"{wf_res}.y", f"{wf_res}.z"))
|
||||
return _unwrap_helper(wf_res)
|
||||
if reduction_type == "welford_combine":
|
||||
assert isinstance(value, tuple), "Input to welford combine must be tuple"
|
||||
acc_buf = self._new_idxvar("float3", acc_buf_size)
|
||||
@ -698,8 +706,9 @@ class MetalKernel(SIMDKernel):
|
||||
wf_res = self.cse.generate(
|
||||
self.stores if self.multistage_reduction_entry else self.compute,
|
||||
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})",
|
||||
dtype=torch.float32,
|
||||
)
|
||||
return OpsWrapper._unwrap((f"{wf_res}.x", f"{wf_res}.y", f"{wf_res}.z"))
|
||||
return _unwrap_helper(wf_res)
|
||||
raise NotImplementedError(reduction_type)
|
||||
|
||||
def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None:
|
||||
|
Reference in New Issue
Block a user