[MPS] coalesce for sparse tensors (#159729)

MPS coalesce function for sparse tensors

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159729
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
Isalia20
2025-08-08 13:49:55 +00:00
committed by PyTorch MergeBot
parent 556e2a73f4
commit 7f4cb4a3e0
9 changed files with 416 additions and 11 deletions

View File

@ -2849,14 +2849,13 @@ def main() -> None:
# TODO: stop generating CUDA kernels for non-CUDA builds
ignore_keys = set()
MPS_KEYS = {DispatchKey.MPS, DispatchKey.SparseMPS, DispatchKey.SparseCsrMPS}
if options.mps or options.update_aoti_c_shim:
functions_keys.add(DispatchKey.MPS)
functions_keys.update(MPS_KEYS)
aoti_backends.add(DispatchKey.MPS)
else:
ignore_keys.add(DispatchKey.MPS)
if DispatchKey.MPS in dispatch_keys:
del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
ignore_keys.update(MPS_KEYS)
dispatch_keys[:] = [k for k in dispatch_keys if k not in MPS_KEYS]
if options.xpu or options.update_aoti_c_shim:
functions_keys.add(DispatchKey.XPU)