[MPSInductor] Add constant, isinf and isnan ops (#144156)

Per Table 6.5 of [Metal Language Specification](https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf) infinity is `HUGE_VALF`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144156
Approved by: https://github.com/Skylion007, https://github.com/jansel
ghstack dependencies: #144055, #144051, #144122, #144105
This commit is contained in:
Nikita Shulga
2025-01-03 10:27:28 -08:00
committed by PyTorch MergeBot
parent 383ff4011c
commit 52e107a7ca
2 changed files with 18 additions and 0 deletions

View File

@ -41,6 +41,8 @@ class MPSBasicTests(TestCase):
test_add_const_int = CommonTemplate.test_add_const_int
test_add_inplace_permuted_mps = CommonTemplate.test_add_inplace_permuted
test_max_min = CommonTemplate.test_max_min
test_inf = CommonTemplate.test_inf
test_nan_to_num = CommonTemplate.test_nan_to_num
test_views6 = CommonTemplate.test_views6
test_addmm = CommonTemplate.test_addmm
test_signbit = CommonTemplate.test_signbit

View File

@ -60,6 +60,14 @@ class MetalOverrides(OpOverrides):
) -> str:
return f"static_cast<{DTYPE_TO_METAL[dtype]}>({x})"
@staticmethod
def constant(val: CSEVariable, dtype: torch.dtype) -> str:
if val == torch.inf:
return "HUGE_VALF"
elif val == -torch.inf:
return "-HUGE_VALF"
return str(val)
@staticmethod
def where(a: CSEVariable, b: CSEVariable, c: CSEVariable) -> str:
return f"{a} ? {b} : {c}"
@ -81,6 +89,14 @@ class MetalOverrides(OpOverrides):
def logical_and(a: CSEVariable, b: CSEVariable) -> str:
return f"{a} && {b}"
@staticmethod
def isnan(x: CSEVariable) -> str:
return f"metal::isnan({x})"
@staticmethod
def isinf(x: CSEVariable) -> str:
return f"metal::isinf({x})"
@staticmethod
def abs(x: CSEVariable) -> str:
return f"metal::abs({x})"