mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 21:59:56 +08:00
[mps/inductor] Introduce is_mps_backend/skip_if_mps decorators. (#145035)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145035 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
cfd9cc19a3
commit
fd8e0e3e10
@ -779,6 +779,16 @@ def skip_if_halide(fn):
|
||||
return wrapper
|
||||
|
||||
|
||||
def skip_if_mps(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapper(self):
|
||||
if is_mps_backend(self.device):
|
||||
raise unittest.SkipTest("mps not supported")
|
||||
return fn(self)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def skip_if_triton(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapper(self):
|
||||
@ -805,6 +815,10 @@ def is_halide_backend(device):
|
||||
return config.cuda_backend == "halide"
|
||||
|
||||
|
||||
def is_mps_backend(device):
|
||||
return getattr(device, "type", device) == "mps"
|
||||
|
||||
|
||||
def is_triton_backend(device):
|
||||
if getattr(device, "type", device) == "cpu":
|
||||
return config.cpu_backend == "triton"
|
||||
|
||||
Reference in New Issue
Block a user