mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2025-10-20 13:23:47 +08:00
Remove unused code
PiperOrigin-RevId: 744678177 Change-Id: Ib2444fa8ba6c9850f2331d133e038883fc32e681
This commit is contained in:
committed by
Augustin Zidek
parent
dfd8b9cb63
commit
5fbae81a9e
@ -129,7 +129,9 @@ class InferenceTest(test_utils.StructureTestCase):
|
||||
{
|
||||
'protein': {
|
||||
'id': 'P',
|
||||
'sequence': 'SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREAIDKGDKDSLEQLLEELEQALQKIRELAEKKN',
|
||||
'sequence': (
|
||||
'SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREAIDKGDKDSLEQLLEELEQALQKIRELAEKKN'
|
||||
),
|
||||
'modifications': [],
|
||||
'unpairedMsa': None,
|
||||
'pairedMsa': None,
|
||||
|
@ -98,7 +98,7 @@ def sharded_apply(
|
||||
|
||||
Args:
|
||||
fun: Function to apply smap transform to.
|
||||
shard_size: Integer denoting shard size.
|
||||
shard_size: Integer denoting shard size. None will return `fun` unchanged.
|
||||
in_axes: Either integer or pytree describing which axis to map over for each
|
||||
input to `fun`, None denotes broadcasting.
|
||||
out_axes: Integer or pytree denoting to what axis in the output the mapped
|
||||
@ -117,7 +117,6 @@ def sharded_apply(
|
||||
if new_out_axes:
|
||||
raise NotImplementedError("New output axes not yet implemented.")
|
||||
|
||||
# shard size None denotes no sharding
|
||||
if shard_size is None:
|
||||
return fun
|
||||
|
||||
@ -202,37 +201,6 @@ def sharded_apply(
|
||||
return mapped_fn
|
||||
|
||||
|
||||
def reshape_partitioned_inputs(
|
||||
batched_args: Sequence[PytreeJaxArray],
|
||||
partitioned_dim: int,
|
||||
subbatch_size: int,
|
||||
) -> Sequence[PytreeJaxArray]:
|
||||
"""Reshapes so subbatching doesn't happen on the partitioned dim."""
|
||||
subbatched_args = []
|
||||
for arg in batched_args:
|
||||
shape = arg.shape
|
||||
|
||||
new_shape = (
|
||||
shape[:partitioned_dim]
|
||||
+ (subbatch_size, shape[partitioned_dim] // subbatch_size)
|
||||
+ shape[partitioned_dim + 1 :]
|
||||
)
|
||||
subbatched_args.append(arg.reshape(new_shape))
|
||||
return subbatched_args
|
||||
|
||||
|
||||
def reshape_partitioned_output(
|
||||
output: jax.Array, output_subbatch_dim: int
|
||||
) -> jax.Array:
|
||||
"""Reshapes outputs as if reshape_partitioned_inputs were never applied."""
|
||||
out_shape = (
|
||||
output.shape[: output_subbatch_dim - 1]
|
||||
+ (-1,)
|
||||
+ output.shape[output_subbatch_dim + 1 :]
|
||||
)
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
def inference_subbatch(
|
||||
module: Callable[..., PytreeJaxArray],
|
||||
subbatch_size: int,
|
||||
@ -240,7 +208,6 @@ def inference_subbatch(
|
||||
nonbatched_args: Sequence[PytreeJaxArray],
|
||||
input_subbatch_dim: int = 0,
|
||||
output_subbatch_dim: int | None = None,
|
||||
input_subbatch_dim_is_partitioned: bool = False,
|
||||
) -> PytreeJaxArray:
|
||||
"""Run through subbatches (like batch apply but with split and concat)."""
|
||||
assert len(batched_args) > 0 # pylint: disable=g-explicit-length-test
|
||||
@ -252,38 +219,9 @@ def inference_subbatch(
|
||||
if output_subbatch_dim is None:
|
||||
output_subbatch_dim = input_subbatch_dim
|
||||
|
||||
if input_subbatch_dim_is_partitioned:
|
||||
# Subbatching along the partitioned axis would induce an all-gather that
|
||||
# undoes the partitioning. So instead we reshape such that
|
||||
# [..., partitioned_input_size, ...] becomes [..., subbatch_size,
|
||||
# partitioned_input_size // subbatch_size, ...] and then actually subbatch
|
||||
# along the partitioned_input_size // subbatch_size axis in slices of
|
||||
# size 1. Partitioning is then preserved on the partitioned axis, except
|
||||
# that dimension is now of size subbatch_size instead of
|
||||
# partitioned_input_size. Note that the module itself still sees inputs of
|
||||
# size [..., subbatch_size, ...], just as it would if this reshaping were
|
||||
# not applied.
|
||||
batched_args = reshape_partitioned_inputs(
|
||||
batched_args, input_subbatch_dim, subbatch_size
|
||||
)
|
||||
input_subbatch_dim += 1
|
||||
output_subbatch_dim += 1
|
||||
subbatch_size = 1
|
||||
|
||||
def run_module(*batched_args):
|
||||
if input_subbatch_dim_is_partitioned:
|
||||
# Squeeze off the singleton dimension (otherwise the module would see
|
||||
# [..., subbatch_size, 1, ...]).
|
||||
batched_args = [b.squeeze(axis=input_subbatch_dim) for b in batched_args]
|
||||
args = list(batched_args) + list(nonbatched_args)
|
||||
res = module(*args)
|
||||
if input_subbatch_dim_is_partitioned:
|
||||
# Add back in the singleton dimension so the outputs are stacked on the
|
||||
# axis we are actually subbatching over (i.e stacked back to
|
||||
# [..., subbatch_size, partitioned_input_size // subbatch_size, ...]),
|
||||
# rather than on the partitioned axis, which would again induce an
|
||||
# all-gather that breaks partitioning.
|
||||
res = jnp.expand_dims(res, axis=output_subbatch_dim)
|
||||
return res
|
||||
|
||||
sharded_module = sharded_apply(
|
||||
@ -293,11 +231,5 @@ def inference_subbatch(
|
||||
out_axes=output_subbatch_dim,
|
||||
)
|
||||
output = sharded_module(*batched_args)
|
||||
if input_subbatch_dim_is_partitioned:
|
||||
# The is of the same shape as the inputs [..., subbatch_size,
|
||||
# partitioned_input_size // subbatch_size, ...]. Reshape to
|
||||
# [..., partitioned_input_size, ...] as if the reshaping due to partitioning
|
||||
# had never been applied.
|
||||
output = reshape_partitioned_output(output, output_subbatch_dim)
|
||||
|
||||
return output
|
||||
|
@ -143,7 +143,9 @@ class DiffusionHead(hk.Module):
|
||||
pair_embedding = use_conditioning * embeddings['pair']
|
||||
|
||||
rel_features = featurization.create_relative_encoding(
|
||||
batch.token_features, max_relative_idx=32, max_relative_chain=2
|
||||
seq_features=batch.token_features,
|
||||
max_relative_idx=32,
|
||||
max_relative_chain=2,
|
||||
).astype(pair_embedding.dtype)
|
||||
features_2d = jnp.concatenate([pair_embedding, rel_features], axis=-1)
|
||||
pair_cond = hm.Linear(
|
||||
|
@ -79,9 +79,9 @@ class Evoformer(hk.Module):
|
||||
) -> jnp.ndarray:
|
||||
"""Add relative position encodings."""
|
||||
rel_feat = featurization.create_relative_encoding(
|
||||
batch.token_features,
|
||||
self.config.max_relative_idx,
|
||||
self.config.max_relative_chain,
|
||||
seq_features=batch.token_features,
|
||||
max_relative_idx=self.config.max_relative_idx,
|
||||
max_relative_chain=self.config.max_relative_chain,
|
||||
)
|
||||
rel_feat = rel_feat.astype(pair_activations.dtype)
|
||||
|
||||
|
@ -149,7 +149,12 @@ class GridSelfAttention(hk.Module):
|
||||
self.transpose = transpose
|
||||
|
||||
@hk.transparent
|
||||
def _attention(self, act, mask, bias):
|
||||
def _attention(
|
||||
self,
|
||||
act,
|
||||
mask,
|
||||
bias,
|
||||
):
|
||||
num_channels = act.shape[-1]
|
||||
assert num_channels % self.config.num_head == 0
|
||||
# Triton requires a minimum dimension of 16 for doing matmul.
|
||||
@ -230,7 +235,6 @@ class GridSelfAttention(hk.Module):
|
||||
chunk_size,
|
||||
batched_args=[act, pair_mask],
|
||||
nonbatched_args=[nonbatched_bias],
|
||||
input_subbatch_dim_is_partitioned=False,
|
||||
)
|
||||
|
||||
if self.transpose:
|
||||
@ -405,7 +409,6 @@ class OuterProductMean(hk.Module):
|
||||
nonbatched_args=[],
|
||||
input_subbatch_dim=1,
|
||||
output_subbatch_dim=0,
|
||||
input_subbatch_dim_is_partitioned=False,
|
||||
)
|
||||
|
||||
epsilon = 1e-3
|
||||
|
Reference in New Issue
Block a user