[alphafold3] avoid use of deprecated jax.util APIs

jax.util was deprecated in JAX v0.6.0, and will be removed in JAX v0.7.0.

PiperOrigin-RevId: 756703396
Change-Id: I9e0b0267c4a403939aeb413f59d8ff7ef3ee390b
This commit is contained in:
Jake VanderPlas
2025-05-09 04:09:02 -07:00
committed by Copybara-Service
parent e274d27978
commit 6a0e8b2afe

View File

@ -12,7 +12,7 @@
from collections.abc import Callable, Sequence
import functools
from typing import Any
from typing import Any, TypeVar
import haiku as hk
import jax
@ -25,6 +25,8 @@ PytreeJaxArray = Any
partial = functools.partial
PROXY = object()
T = TypeVar("T")
def _maybe_slice(array, i, slice_size, axis):
if axis is PROXY:
@ -82,6 +84,16 @@ def sharded_map(
return sharded_apply(vmapped_fun, shard_size, in_axes, out_axes)
def _set_docstring(docstr: str) -> Callable[[T], T]:
"""Decorator for setting the docstring of a function."""
def wrapped(fun: T) -> T:
fun.__doc__ = docstr.format(fun=getattr(fun, "__name__", repr(fun)))
return fun
return wrapped
def sharded_apply(
fun: Callable[..., PytreeJaxArray],
shard_size: int | None = 1,
@ -120,7 +132,8 @@ def sharded_apply(
if shard_size is None:
return fun
@jax.util.wraps(fun, docstr=docstr)
@_set_docstring(docstr)
@functools.wraps(fun)
def mapped_fn(*args, **kwargs):
# Expand in axes and determine loop range.
in_axes_ = _expand_axes(in_axes, args)