Internal change.

PiperOrigin-RevId: 748595613
Change-Id: I7e84c89c15c4887b51a66fc708d15c835aa8a0a8
This commit is contained in:
Augustin Zidek
2025-04-17 01:49:16 -07:00
committed by Copybara-Service
parent 7a4a2f7142
commit e274d27978

View File

@ -210,15 +210,17 @@ def _gated_linear_unit(
else:
input_output_aliases = {3: 0}
compiler_params = dict(
triton=dict(num_warps=config.num_warps, num_stages=config.num_stages)
)
return pl.pallas_call(
kernel,
name=name,
grid=(pl.cdiv(m, config.block_m) * pl.cdiv(n, config.block_n),),
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype) if dst is None else dst,
input_output_aliases=input_output_aliases,
compiler_params=dict(
triton=dict(num_warps=config.num_warps, num_stages=config.num_stages)
),
compiler_params=compiler_params,
)(x, weights_projection, weights_gate, dst, epilogue_args)