Compare commits

...

1 Commits

Author SHA1 Message Date
2950d38184 xpu still use libdevice.sqrt 2025-10-20 00:51:26 +00:00

View File

@ -791,6 +791,9 @@ class TritonPrinter(PythonPrinter):
return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
def _helper_sqrt(self, expr: sympy.Expr) -> str:
# work around for https://github.com/pytorch/pytorch/issues/165738
if torch.xpu.is_available():
return f"libdevice.sqrt(({self._print(expr)}).to(tl.float32))"
return f"tl.sqrt_rn(({self._print(expr)}).to(tl.float32))"
def _print_FloatPow(self, expr: sympy.Expr) -> str:
@ -1201,6 +1204,9 @@ class TritonOverrides(OpOverrides):
@staticmethod
@maybe_upcast_float32()
def sqrt(x):
# work around for https://github.com/pytorch/pytorch/issues/165738
if torch.xpu.is_available():
return f"libdevice.sqrt({x})"
return f"tl.sqrt_rn({x})"
@staticmethod