mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2025-10-20 13:23:47 +08:00
Integrate Triton up to [6af1c4b5](6af1c4b507
)
https://github.com/openxla/triton/tree/triton_integrate_branch-1.10 PiperOrigin-RevId: 789217115 Change-Id: Ib89b60601802c247b4dc96cfe6cbff4752a41402
This commit is contained in:
committed by
Copybara-Service
parent
a14376d249
commit
ec4254ac5d
@ -65,20 +65,20 @@ def get_tl_dot_fn(
|
||||
*,
|
||||
trans_a: bool = False,
|
||||
trans_b: bool = False,
|
||||
_builder,
|
||||
_semantic,
|
||||
):
|
||||
if in_dtype == tl.float32:
|
||||
tl.static_assert(a.dtype == tl.float32, _builder=_builder)
|
||||
tl.static_assert(b.dtype == tl.float32, _builder=_builder)
|
||||
tl.static_assert(a.dtype == tl.float32, _semantic=_semantic)
|
||||
tl.static_assert(b.dtype == tl.float32, _semantic=_semantic)
|
||||
else:
|
||||
tl.static_assert(a.dtype.is_standard_floating(), _builder=_builder)
|
||||
tl.static_assert(b.dtype.is_standard_floating(), _builder=_builder)
|
||||
a = a.to(in_dtype, _builder=_builder)
|
||||
b = b.to(in_dtype, _builder=_builder)
|
||||
a = tl.trans(a, _builder=_builder) if trans_a else a
|
||||
b = tl.trans(b, _builder=_builder) if trans_b else b
|
||||
tl.static_assert(a.dtype.is_standard_floating(), _semantic=_semantic)
|
||||
tl.static_assert(b.dtype.is_standard_floating(), _semantic=_semantic)
|
||||
a = a.to(in_dtype, _semantic=_semantic)
|
||||
b = b.to(in_dtype, _semantic=_semantic)
|
||||
a = tl.trans(a, _semantic=_semantic) if trans_a else a
|
||||
b = tl.trans(b, _semantic=_semantic) if trans_b else b
|
||||
return tl.dot(
|
||||
a, b, allow_tf32=allow_tf32, out_dtype=out_dtype, _builder=_builder
|
||||
a, b, allow_tf32=allow_tf32, out_dtype=out_dtype, _semantic=_semantic
|
||||
)
|
||||
|
||||
return _dot_fn
|
||||
|
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user