Remove unused code

PiperOrigin-RevId: 744678177
Change-Id: Ib2444fa8ba6c9850f2331d133e038883fc32e681
This commit is contained in:
Josh Abramson
2025-04-02 08:46:45 -07:00
committed by Augustin Zidek
parent dfd8b9cb63
commit 5fbae81a9e
5 changed files with 16 additions and 77 deletions

View File

@ -129,7 +129,9 @@ class InferenceTest(test_utils.StructureTestCase):
{
'protein': {
'id': 'P',
'sequence': 'SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREAIDKGDKDSLEQLLEELEQALQKIRELAEKKN',
'sequence': (
'SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREAIDKGDKDSLEQLLEELEQALQKIRELAEKKN'
),
'modifications': [],
'unpairedMsa': None,
'pairedMsa': None,

View File

@ -98,7 +98,7 @@ def sharded_apply(
Args:
fun: Function to apply smap transform to.
shard_size: Integer denoting shard size.
shard_size: Integer denoting shard size. None will return `fun` unchanged.
in_axes: Either integer or pytree describing which axis to map over for each
input to `fun`, None denotes broadcasting.
out_axes: Integer or pytree denoting to what axis in the output the mapped
@ -117,7 +117,6 @@ def sharded_apply(
if new_out_axes:
raise NotImplementedError("New output axes not yet implemented.")
# shard size None denotes no sharding
if shard_size is None:
return fun
@ -202,37 +201,6 @@ def sharded_apply(
return mapped_fn
def reshape_partitioned_inputs(
batched_args: Sequence[PytreeJaxArray],
partitioned_dim: int,
subbatch_size: int,
) -> Sequence[PytreeJaxArray]:
"""Reshapes so subbatching doesn't happen on the partitioned dim."""
subbatched_args = []
for arg in batched_args:
shape = arg.shape
new_shape = (
shape[:partitioned_dim]
+ (subbatch_size, shape[partitioned_dim] // subbatch_size)
+ shape[partitioned_dim + 1 :]
)
subbatched_args.append(arg.reshape(new_shape))
return subbatched_args
def reshape_partitioned_output(
output: jax.Array, output_subbatch_dim: int
) -> jax.Array:
"""Reshapes outputs as if reshape_partitioned_inputs were never applied."""
out_shape = (
output.shape[: output_subbatch_dim - 1]
+ (-1,)
+ output.shape[output_subbatch_dim + 1 :]
)
return output.reshape(out_shape)
def inference_subbatch(
module: Callable[..., PytreeJaxArray],
subbatch_size: int,
@ -240,7 +208,6 @@ def inference_subbatch(
nonbatched_args: Sequence[PytreeJaxArray],
input_subbatch_dim: int = 0,
output_subbatch_dim: int | None = None,
input_subbatch_dim_is_partitioned: bool = False,
) -> PytreeJaxArray:
"""Run through subbatches (like batch apply but with split and concat)."""
assert len(batched_args) > 0 # pylint: disable=g-explicit-length-test
@ -252,38 +219,9 @@ def inference_subbatch(
if output_subbatch_dim is None:
output_subbatch_dim = input_subbatch_dim
if input_subbatch_dim_is_partitioned:
# Subbatching along the partitioned axis would induce an all-gather that
# undoes the partitioning. So instead we reshape such that
# [..., partitioned_input_size, ...] becomes [..., subbatch_size,
# partitioned_input_size // subbatch_size, ...] and then actually subbatch
# along the partitioned_input_size // subbatch_size axis in slices of
# size 1. Partitioning is then preserved on the partitioned axis, except
# that dimension is now of size subbatch_size instead of
# partitioned_input_size. Note that the module itself still sees inputs of
# size [..., subbatch_size, ...], just as it would if this reshaping were
# not applied.
batched_args = reshape_partitioned_inputs(
batched_args, input_subbatch_dim, subbatch_size
)
input_subbatch_dim += 1
output_subbatch_dim += 1
subbatch_size = 1
def run_module(*batched_args):
if input_subbatch_dim_is_partitioned:
# Squeeze off the singleton dimension (otherwise the module would see
# [..., subbatch_size, 1, ...]).
batched_args = [b.squeeze(axis=input_subbatch_dim) for b in batched_args]
args = list(batched_args) + list(nonbatched_args)
res = module(*args)
if input_subbatch_dim_is_partitioned:
# Add back in the singleton dimension so the outputs are stacked on the
# axis we are actually subbatching over (i.e stacked back to
# [..., subbatch_size, partitioned_input_size // subbatch_size, ...]),
# rather than on the partitioned axis, which would again induce an
# all-gather that breaks partitioning.
res = jnp.expand_dims(res, axis=output_subbatch_dim)
return res
sharded_module = sharded_apply(
@ -293,11 +231,5 @@ def inference_subbatch(
out_axes=output_subbatch_dim,
)
output = sharded_module(*batched_args)
if input_subbatch_dim_is_partitioned:
# The is of the same shape as the inputs [..., subbatch_size,
# partitioned_input_size // subbatch_size, ...]. Reshape to
# [..., partitioned_input_size, ...] as if the reshaping due to partitioning
# had never been applied.
output = reshape_partitioned_output(output, output_subbatch_dim)
return output

View File

@ -143,7 +143,9 @@ class DiffusionHead(hk.Module):
pair_embedding = use_conditioning * embeddings['pair']
rel_features = featurization.create_relative_encoding(
batch.token_features, max_relative_idx=32, max_relative_chain=2
seq_features=batch.token_features,
max_relative_idx=32,
max_relative_chain=2,
).astype(pair_embedding.dtype)
features_2d = jnp.concatenate([pair_embedding, rel_features], axis=-1)
pair_cond = hm.Linear(

View File

@ -79,9 +79,9 @@ class Evoformer(hk.Module):
) -> jnp.ndarray:
"""Add relative position encodings."""
rel_feat = featurization.create_relative_encoding(
batch.token_features,
self.config.max_relative_idx,
self.config.max_relative_chain,
seq_features=batch.token_features,
max_relative_idx=self.config.max_relative_idx,
max_relative_chain=self.config.max_relative_chain,
)
rel_feat = rel_feat.astype(pair_activations.dtype)

View File

@ -149,7 +149,12 @@ class GridSelfAttention(hk.Module):
self.transpose = transpose
@hk.transparent
def _attention(self, act, mask, bias):
def _attention(
self,
act,
mask,
bias,
):
num_channels = act.shape[-1]
assert num_channels % self.config.num_head == 0
# Triton requires a minimum dimension of 16 for doing matmul.
@ -230,7 +235,6 @@ class GridSelfAttention(hk.Module):
chunk_size,
batched_args=[act, pair_mask],
nonbatched_args=[nonbatched_bias],
input_subbatch_dim_is_partitioned=False,
)
if self.transpose:
@ -405,7 +409,6 @@ class OuterProductMean(hk.Module):
nonbatched_args=[],
input_subbatch_dim=1,
output_subbatch_dim=0,
input_subbatch_dim_is_partitioned=False,
)
epsilon = 1e-3