Integrate Triton up to [6af1c4b5](6af1c4b507)

https://github.com/openxla/triton/tree/triton_integrate_branch-1.10

PiperOrigin-RevId: 789217115
Change-Id: Ib89b60601802c247b4dc96cfe6cbff4752a41402
This commit is contained in:
Tori Baker
2025-07-31 00:13:23 -07:00
committed by Copybara-Service
parent a14376d249
commit ec4254ac5d
3 changed files with 10 additions and 10 deletions

View File

@ -65,20 +65,20 @@ def get_tl_dot_fn(
*,
trans_a: bool = False,
trans_b: bool = False,
_builder,
_semantic,
):
if in_dtype == tl.float32:
tl.static_assert(a.dtype == tl.float32, _builder=_builder)
tl.static_assert(b.dtype == tl.float32, _builder=_builder)
tl.static_assert(a.dtype == tl.float32, _semantic=_semantic)
tl.static_assert(b.dtype == tl.float32, _semantic=_semantic)
else:
tl.static_assert(a.dtype.is_standard_floating(), _builder=_builder)
tl.static_assert(b.dtype.is_standard_floating(), _builder=_builder)
a = a.to(in_dtype, _builder=_builder)
b = b.to(in_dtype, _builder=_builder)
a = tl.trans(a, _builder=_builder) if trans_a else a
b = tl.trans(b, _builder=_builder) if trans_b else b
tl.static_assert(a.dtype.is_standard_floating(), _semantic=_semantic)
tl.static_assert(b.dtype.is_standard_floating(), _semantic=_semantic)
a = a.to(in_dtype, _semantic=_semantic)
b = b.to(in_dtype, _semantic=_semantic)
a = tl.trans(a, _semantic=_semantic) if trans_a else a
b = tl.trans(b, _semantic=_semantic) if trans_b else b
return tl.dot(
a, b, allow_tf32=allow_tf32, out_dtype=out_dtype, _builder=_builder
a, b, allow_tf32=allow_tf32, out_dtype=out_dtype, _semantic=_semantic
)
return _dot_fn