diff --git a/src/alphafold3/model/components/mapping.py b/src/alphafold3/model/components/mapping.py index e315649..44ab109 100644 --- a/src/alphafold3/model/components/mapping.py +++ b/src/alphafold3/model/components/mapping.py @@ -12,7 +12,7 @@ from collections.abc import Callable, Sequence import functools -from typing import Any +from typing import Any, TypeVar import haiku as hk import jax @@ -25,6 +25,8 @@ PytreeJaxArray = Any partial = functools.partial PROXY = object() +T = TypeVar("T") + def _maybe_slice(array, i, slice_size, axis): if axis is PROXY: @@ -82,6 +84,16 @@ def sharded_map( return sharded_apply(vmapped_fun, shard_size, in_axes, out_axes) +def _set_docstring(docstr: str) -> Callable[[T], T]: + """Decorator for setting the docstring of a function.""" + + def wrapped(fun: T) -> T: + fun.__doc__ = docstr.format(fun=getattr(fun, "__name__", repr(fun))) + return fun + + return wrapped + + def sharded_apply( fun: Callable[..., PytreeJaxArray], shard_size: int | None = 1, @@ -120,7 +132,8 @@ def sharded_apply( if shard_size is None: return fun - @jax.util.wraps(fun, docstr=docstr) + @_set_docstring(docstr) + @functools.wraps(fun) def mapped_fn(*args, **kwargs): # Expand in axes and determine loop range. in_axes_ = _expand_axes(in_axes, args)