Undo backwards incompatible Triton API change

PiperOrigin-RevId: 790639906
Change-Id: If98343e8a5601baeeb882229a4e4c5dcdc601881
This commit is contained in:
Augustin Zidek
2025-08-04 01:36:43 -07:00
committed by Copybara-Service
parent 903a931d36
commit f2edd59f87

View File

@ -65,20 +65,20 @@ def get_tl_dot_fn(
*,
trans_a: bool = False,
trans_b: bool = False,
_semantic,
_builder,
):
if in_dtype == tl.float32:
tl.static_assert(a.dtype == tl.float32, _semantic=_semantic)
tl.static_assert(b.dtype == tl.float32, _semantic=_semantic)
tl.static_assert(a.dtype == tl.float32, _builder=_builder)
tl.static_assert(b.dtype == tl.float32, _builder=_builder)
else:
tl.static_assert(a.dtype.is_standard_floating(), _semantic=_semantic)
tl.static_assert(b.dtype.is_standard_floating(), _semantic=_semantic)
a = a.to(in_dtype, _semantic=_semantic)
b = b.to(in_dtype, _semantic=_semantic)
a = tl.trans(a, _semantic=_semantic) if trans_a else a
b = tl.trans(b, _semantic=_semantic) if trans_b else b
tl.static_assert(a.dtype.is_standard_floating(), _builder=_builder)
tl.static_assert(b.dtype.is_standard_floating(), _builder=_builder)
a = a.to(in_dtype, _builder=_builder)
b = b.to(in_dtype, _builder=_builder)
a = tl.trans(a, _builder=_builder) if trans_a else a
b = tl.trans(b, _builder=_builder) if trans_b else b
return tl.dot(
a, b, allow_tf32=allow_tf32, out_dtype=out_dtype, _semantic=_semantic
a, b, allow_tf32=allow_tf32, out_dtype=out_dtype, _builder=_builder
)
return _dot_fn