[MPS] sparse norm (#164961)

Norms for sparse mps tensors

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164961
Approved by: https://github.com/malfet
This commit is contained in:
Isalia20
2025-10-08 21:41:38 +00:00
committed by PyTorch MergeBot
parent 0b15f7ae05
commit 1d182dd81c
3 changed files with 20 additions and 8 deletions

View File

@ -1381,8 +1381,9 @@ def largeTensorTest(size, device=None, inductor=TEST_WITH_TORCHINDUCTOR):
class expectedFailure:
def __init__(self, device_type):
def __init__(self, device_type, dtype=None):
self.device_type = device_type
self.dtype = dtype
def __call__(self, fn):
@wraps(fn)
@ -1396,7 +1397,13 @@ class expectedFailure:
else:
target_device_type = slf.device_type
if self.device_type is None or self.device_type == target_device_type:
target_dtype = kwargs.get("dtype", getattr(slf, "dtype", None))
device_matches = (
self.device_type is None or self.device_type == target_device_type
)
dtype_matches = self.dtype is None or self.dtype == target_dtype
if device_matches and dtype_matches:
try:
fn(slf, *args, **kwargs)
except Exception:
@ -1716,6 +1723,10 @@ def expectedFailureMPS(fn):
return expectedFailure("mps")(fn)
def expectedFailureMPSComplex(fn):
return expectedFailure("mps", torch.complex64)(fn)
def expectedFailureMPSPre15(fn):
import platform