mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] sparse norm (#164961)
Norms for sparse mps tensors Pull Request resolved: https://github.com/pytorch/pytorch/pull/164961 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
0b15f7ae05
commit
1d182dd81c
@ -1381,8 +1381,9 @@ def largeTensorTest(size, device=None, inductor=TEST_WITH_TORCHINDUCTOR):
|
||||
|
||||
|
||||
class expectedFailure:
|
||||
def __init__(self, device_type):
|
||||
def __init__(self, device_type, dtype=None):
|
||||
self.device_type = device_type
|
||||
self.dtype = dtype
|
||||
|
||||
def __call__(self, fn):
|
||||
@wraps(fn)
|
||||
@ -1396,7 +1397,13 @@ class expectedFailure:
|
||||
else:
|
||||
target_device_type = slf.device_type
|
||||
|
||||
if self.device_type is None or self.device_type == target_device_type:
|
||||
target_dtype = kwargs.get("dtype", getattr(slf, "dtype", None))
|
||||
device_matches = (
|
||||
self.device_type is None or self.device_type == target_device_type
|
||||
)
|
||||
dtype_matches = self.dtype is None or self.dtype == target_dtype
|
||||
|
||||
if device_matches and dtype_matches:
|
||||
try:
|
||||
fn(slf, *args, **kwargs)
|
||||
except Exception:
|
||||
@ -1716,6 +1723,10 @@ def expectedFailureMPS(fn):
|
||||
return expectedFailure("mps")(fn)
|
||||
|
||||
|
||||
def expectedFailureMPSComplex(fn):
|
||||
return expectedFailure("mps", torch.complex64)(fn)
|
||||
|
||||
|
||||
def expectedFailureMPSPre15(fn):
|
||||
import platform
|
||||
|
||||
|
Reference in New Issue
Block a user