mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2025-10-20 21:33:47 +08:00
Undo backwards incompatible Triton API change
PiperOrigin-RevId: 790639906 Change-Id: If98343e8a5601baeeb882229a4e4c5dcdc601881
This commit is contained in:
committed by
Copybara-Service
parent
903a931d36
commit
f2edd59f87
@ -65,20 +65,20 @@ def get_tl_dot_fn(
|
||||
*,
|
||||
trans_a: bool = False,
|
||||
trans_b: bool = False,
|
||||
_semantic,
|
||||
_builder,
|
||||
):
|
||||
if in_dtype == tl.float32:
|
||||
tl.static_assert(a.dtype == tl.float32, _semantic=_semantic)
|
||||
tl.static_assert(b.dtype == tl.float32, _semantic=_semantic)
|
||||
tl.static_assert(a.dtype == tl.float32, _builder=_builder)
|
||||
tl.static_assert(b.dtype == tl.float32, _builder=_builder)
|
||||
else:
|
||||
tl.static_assert(a.dtype.is_standard_floating(), _semantic=_semantic)
|
||||
tl.static_assert(b.dtype.is_standard_floating(), _semantic=_semantic)
|
||||
a = a.to(in_dtype, _semantic=_semantic)
|
||||
b = b.to(in_dtype, _semantic=_semantic)
|
||||
a = tl.trans(a, _semantic=_semantic) if trans_a else a
|
||||
b = tl.trans(b, _semantic=_semantic) if trans_b else b
|
||||
tl.static_assert(a.dtype.is_standard_floating(), _builder=_builder)
|
||||
tl.static_assert(b.dtype.is_standard_floating(), _builder=_builder)
|
||||
a = a.to(in_dtype, _builder=_builder)
|
||||
b = b.to(in_dtype, _builder=_builder)
|
||||
a = tl.trans(a, _builder=_builder) if trans_a else a
|
||||
b = tl.trans(b, _builder=_builder) if trans_b else b
|
||||
return tl.dot(
|
||||
a, b, allow_tf32=allow_tf32, out_dtype=out_dtype, _semantic=_semantic
|
||||
a, b, allow_tf32=allow_tf32, out_dtype=out_dtype, _builder=_builder
|
||||
)
|
||||
|
||||
return _dot_fn
|
||||
|
Reference in New Issue
Block a user