[mps/inductor] Introduce is_mps_backend/skip_if_mps decorators. (#145035)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145035
Approved by: https://github.com/jansel
This commit is contained in:
Davide Italiano
2025-01-17 05:36:38 +00:00
committed by PyTorch MergeBot
parent cfd9cc19a3
commit fd8e0e3e10

View File

@ -779,6 +779,16 @@ def skip_if_halide(fn):
return wrapper
def skip_if_mps(fn):
@functools.wraps(fn)
def wrapper(self):
if is_mps_backend(self.device):
raise unittest.SkipTest("mps not supported")
return fn(self)
return wrapper
def skip_if_triton(fn):
@functools.wraps(fn)
def wrapper(self):
@ -805,6 +815,10 @@ def is_halide_backend(device):
return config.cuda_backend == "halide"
def is_mps_backend(device):
return getattr(device, "type", device) == "mps"
def is_triton_backend(device):
if getattr(device, "type", device) == "cpu":
return config.cpu_backend == "triton"