mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2025-10-20 13:23:47 +08:00
[alphafold3] avoid use of deprecated jax.util
APIs
jax.util was deprecated in JAX v0.6.0, and will be removed in JAX v0.7.0. PiperOrigin-RevId: 756703396 Change-Id: I9e0b0267c4a403939aeb413f59d8ff7ef3ee390b
This commit is contained in:
committed by
Copybara-Service
parent
e274d27978
commit
6a0e8b2afe
@ -12,7 +12,7 @@
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
import functools
|
||||
from typing import Any
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
@ -25,6 +25,8 @@ PytreeJaxArray = Any
|
||||
partial = functools.partial
|
||||
PROXY = object()
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _maybe_slice(array, i, slice_size, axis):
|
||||
if axis is PROXY:
|
||||
@ -82,6 +84,16 @@ def sharded_map(
|
||||
return sharded_apply(vmapped_fun, shard_size, in_axes, out_axes)
|
||||
|
||||
|
||||
def _set_docstring(docstr: str) -> Callable[[T], T]:
|
||||
"""Decorator for setting the docstring of a function."""
|
||||
|
||||
def wrapped(fun: T) -> T:
|
||||
fun.__doc__ = docstr.format(fun=getattr(fun, "__name__", repr(fun)))
|
||||
return fun
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def sharded_apply(
|
||||
fun: Callable[..., PytreeJaxArray],
|
||||
shard_size: int | None = 1,
|
||||
@ -120,7 +132,8 @@ def sharded_apply(
|
||||
if shard_size is None:
|
||||
return fun
|
||||
|
||||
@jax.util.wraps(fun, docstr=docstr)
|
||||
@_set_docstring(docstr)
|
||||
@functools.wraps(fun)
|
||||
def mapped_fn(*args, **kwargs):
|
||||
# Expand in axes and determine loop range.
|
||||
in_axes_ = _expand_axes(in_axes, args)
|
||||
|
Reference in New Issue
Block a user