mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] coalesce for sparse tensors (#159729)
MPS coalesce function for sparse tensors Pull Request resolved: https://github.com/pytorch/pytorch/pull/159729 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
556e2a73f4
commit
7f4cb4a3e0
@ -2849,14 +2849,13 @@ def main() -> None:
|
||||
# TODO: stop generating CUDA kernels for non-CUDA builds
|
||||
ignore_keys = set()
|
||||
|
||||
MPS_KEYS = {DispatchKey.MPS, DispatchKey.SparseMPS, DispatchKey.SparseCsrMPS}
|
||||
if options.mps or options.update_aoti_c_shim:
|
||||
functions_keys.add(DispatchKey.MPS)
|
||||
functions_keys.update(MPS_KEYS)
|
||||
aoti_backends.add(DispatchKey.MPS)
|
||||
else:
|
||||
ignore_keys.add(DispatchKey.MPS)
|
||||
|
||||
if DispatchKey.MPS in dispatch_keys:
|
||||
del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
|
||||
ignore_keys.update(MPS_KEYS)
|
||||
dispatch_keys[:] = [k for k in dispatch_keys if k not in MPS_KEYS]
|
||||
|
||||
if options.xpu or options.update_aoti_c_shim:
|
||||
functions_keys.add(DispatchKey.XPU)
|
||||
|
Reference in New Issue
Block a user