mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2025-10-20 13:23:47 +08:00
Internal change.
PiperOrigin-RevId: 748595613 Change-Id: I7e84c89c15c4887b51a66fc708d15c835aa8a0a8
This commit is contained in:
committed by
Copybara-Service
parent
7a4a2f7142
commit
e274d27978
@ -210,15 +210,17 @@ def _gated_linear_unit(
|
||||
else:
|
||||
input_output_aliases = {3: 0}
|
||||
|
||||
compiler_params = dict(
|
||||
triton=dict(num_warps=config.num_warps, num_stages=config.num_stages)
|
||||
)
|
||||
|
||||
return pl.pallas_call(
|
||||
kernel,
|
||||
name=name,
|
||||
grid=(pl.cdiv(m, config.block_m) * pl.cdiv(n, config.block_n),),
|
||||
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype) if dst is None else dst,
|
||||
input_output_aliases=input_output_aliases,
|
||||
compiler_params=dict(
|
||||
triton=dict(num_warps=config.num_warps, num_stages=config.num_stages)
|
||||
),
|
||||
compiler_params=compiler_params,
|
||||
)(x, weights_projection, weights_gate, dst, epilogue_args)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user