mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPSInductor] Add constant
, isinf
and isnan
ops (#144156)
Per Table 6.5 of [Metal Language Specification](https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf) infinity is `HUGE_VALF` Pull Request resolved: https://github.com/pytorch/pytorch/pull/144156 Approved by: https://github.com/Skylion007, https://github.com/jansel ghstack dependencies: #144055, #144051, #144122, #144105
This commit is contained in:
committed by
PyTorch MergeBot
parent
383ff4011c
commit
52e107a7ca
@ -41,6 +41,8 @@ class MPSBasicTests(TestCase):
|
||||
test_add_const_int = CommonTemplate.test_add_const_int
|
||||
test_add_inplace_permuted_mps = CommonTemplate.test_add_inplace_permuted
|
||||
test_max_min = CommonTemplate.test_max_min
|
||||
test_inf = CommonTemplate.test_inf
|
||||
test_nan_to_num = CommonTemplate.test_nan_to_num
|
||||
test_views6 = CommonTemplate.test_views6
|
||||
test_addmm = CommonTemplate.test_addmm
|
||||
test_signbit = CommonTemplate.test_signbit
|
||||
|
@ -60,6 +60,14 @@ class MetalOverrides(OpOverrides):
|
||||
) -> str:
|
||||
return f"static_cast<{DTYPE_TO_METAL[dtype]}>({x})"
|
||||
|
||||
@staticmethod
|
||||
def constant(val: CSEVariable, dtype: torch.dtype) -> str:
|
||||
if val == torch.inf:
|
||||
return "HUGE_VALF"
|
||||
elif val == -torch.inf:
|
||||
return "-HUGE_VALF"
|
||||
return str(val)
|
||||
|
||||
@staticmethod
|
||||
def where(a: CSEVariable, b: CSEVariable, c: CSEVariable) -> str:
|
||||
return f"{a} ? {b} : {c}"
|
||||
@ -81,6 +89,14 @@ class MetalOverrides(OpOverrides):
|
||||
def logical_and(a: CSEVariable, b: CSEVariable) -> str:
|
||||
return f"{a} && {b}"
|
||||
|
||||
@staticmethod
|
||||
def isnan(x: CSEVariable) -> str:
|
||||
return f"metal::isnan({x})"
|
||||
|
||||
@staticmethod
|
||||
def isinf(x: CSEVariable) -> str:
|
||||
return f"metal::isinf({x})"
|
||||
|
||||
@staticmethod
|
||||
def abs(x: CSEVariable) -> str:
|
||||
return f"metal::abs({x})"
|
||||
|
Reference in New Issue
Block a user