Remove unnecessary chex dependency

PiperOrigin-RevId: 765086591
Change-Id: I34d7e7b83073d84fe083ee091a3b54e90dcdeba3
This commit is contained in:
Augustin Zidek
2025-05-30 01:43:37 -07:00
committed by Copybara-Service
parent 17afe151ea
commit 565f286892
9 changed files with 43 additions and 63 deletions

View File

@ -140,7 +140,6 @@ AlphaFold 3 uses the following separate libraries and packages:
* [abseil-cpp](https://github.com/abseil/abseil-cpp) and
[abseil-py](https://github.com/abseil/abseil-py)
* [Chex](https://github.com/deepmind/chex)
* [Docker](https://www.docker.com)
* [DSSP](https://github.com/PDB-REDO/dssp)
* [HMMER Suite](https://github.com/EddyRivasLab/hmmer)

View File

@ -9,13 +9,8 @@ absl-py==2.1.0 \
--hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff
# via
# alphafold3 (pyproject.toml)
# chex
# dm-haiku
# jax-triton
chex==0.1.87 \
--hash=sha256:0096d89cc8d898bb521ef4bfbf5c24549022b0e5b301f529ab57238896fe6c5d \
--hash=sha256:ce536475661fd96d21be0c1728ecdbedd03f8ff950c662dfc338c92ea782cb16
# via alphafold3 (pyproject.toml)
dm-haiku==0.0.13 \
--hash=sha256:029bb91b5b1edb0d3fe23304d3bf12a545ea6e485041f7f5d8c8d85ebcf6e17d \
--hash=sha256:ee9562c68a059f146ad07f555ca591cb8c11ef751afecc38353863562bd23f43
@ -77,7 +72,6 @@ jax[cuda12]==0.4.34 \
--hash=sha256:b957ca1fc91f7343f91a186af9f19c7f342c946f95a8c11c7f1e5cdfe2e58d9e
# via
# alphafold3 (pyproject.toml)
# chex
# jax-triton
jax-cuda12-pjrt==0.4.34 \
--hash=sha256:0c7cc98f962cc7fc8e0a5ea6331b42a0cee516f202f1c3019f6aa5cd9530cca0 \
@ -118,9 +112,7 @@ jaxlib==0.4.34 \
--hash=sha256:c7b3e724a30426a856070aba0192b5d199e95b4411070e7ad96ad8b196877b10 \
--hash=sha256:c9d3adcae43a33aad4332be9c2aedc5ef751d1e755f917a5afb30c7872eacaa8 \
--hash=sha256:d840e64b85f8865404d6d225b9bb340e158df1457152a361b05680e24792b232
# via
# chex
# jax
# via jax
jaxtyping==0.2.34 \
--hash=sha256:2f81fb6d1586e497a6ea2d28c06dcab37b108a096cbb36ea3fe4fa2e1c1f32e5 \
--hash=sha256:eed9a3458ec8726c84ea5457ebde53c964f65d2c22c0ec40d0555ae3fed5bbaf
@ -212,7 +204,6 @@ numpy==2.1.3 \
--hash=sha256:fa2d1337dc61c8dc417fbccf20f6d1e139896a30721b7f1e832b2bb6ef4eb6c4
# via
# alphafold3 (pyproject.toml)
# chex
# dm-haiku
# jax
# jaxlib
@ -427,10 +418,6 @@ tabulate==0.9.0 \
--hash=sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c \
--hash=sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f
# via dm-haiku
toolz==1.0.0 \
--hash=sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236 \
--hash=sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02
# via chex
tqdm==4.67.0 \
--hash=sha256:0cd8af9d56911acab92182e88d763100d4788bdf421d251616040cc4d44863be \
--hash=sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a
@ -447,11 +434,9 @@ triton==3.1.0 \
typeguard==2.13.3 \
--hash=sha256:00edaa8da3a133674796cf5ea87d9f4b4c367d77476e185e80251cc13dfbb8c4 \
--hash=sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1
# via jaxtyping
typing-extensions==4.12.2 \
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8
# via chex
# via
# alphafold3 (pyproject.toml)
# jaxtyping
zstandard==0.23.0 \
--hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \
--hash=sha256:0a7f0804bb3799414af278e9ad51be25edf67f78f916e08afdb983e74161b916 \

View File

@ -16,7 +16,6 @@ readme = "README.md"
license = {file = "LICENSE"}
dependencies = [
"absl-py",
"chex",
"dm-haiku==0.0.13",
"dm-tree",
"jax==0.4.34",

View File

@ -9,13 +9,8 @@ absl-py==2.1.0 \
--hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff
# via
# alphafold3 (pyproject.toml)
# chex
# dm-haiku
# jax-triton
chex==0.1.87 \
--hash=sha256:0096d89cc8d898bb521ef4bfbf5c24549022b0e5b301f529ab57238896fe6c5d \
--hash=sha256:ce536475661fd96d21be0c1728ecdbedd03f8ff950c662dfc338c92ea782cb16
# via alphafold3 (pyproject.toml)
dm-haiku==0.0.13 \
--hash=sha256:029bb91b5b1edb0d3fe23304d3bf12a545ea6e485041f7f5d8c8d85ebcf6e17d \
--hash=sha256:ee9562c68a059f146ad07f555ca591cb8c11ef751afecc38353863562bd23f43
@ -77,7 +72,6 @@ jax[cuda12]==0.4.34 \
--hash=sha256:b957ca1fc91f7343f91a186af9f19c7f342c946f95a8c11c7f1e5cdfe2e58d9e
# via
# alphafold3 (pyproject.toml)
# chex
# jax-triton
jax-cuda12-pjrt==0.4.34 \
--hash=sha256:0c7cc98f962cc7fc8e0a5ea6331b42a0cee516f202f1c3019f6aa5cd9530cca0 \
@ -118,9 +112,7 @@ jaxlib==0.4.34 \
--hash=sha256:c7b3e724a30426a856070aba0192b5d199e95b4411070e7ad96ad8b196877b10 \
--hash=sha256:c9d3adcae43a33aad4332be9c2aedc5ef751d1e755f917a5afb30c7872eacaa8 \
--hash=sha256:d840e64b85f8865404d6d225b9bb340e158df1457152a361b05680e24792b232
# via
# chex
# jax
# via jax
jaxtyping==0.2.34 \
--hash=sha256:2f81fb6d1586e497a6ea2d28c06dcab37b108a096cbb36ea3fe4fa2e1c1f32e5 \
--hash=sha256:eed9a3458ec8726c84ea5457ebde53c964f65d2c22c0ec40d0555ae3fed5bbaf
@ -212,7 +204,6 @@ numpy==2.1.3 \
--hash=sha256:fa2d1337dc61c8dc417fbccf20f6d1e139896a30721b7f1e832b2bb6ef4eb6c4
# via
# alphafold3 (pyproject.toml)
# chex
# dm-haiku
# jax
# jaxlib
@ -427,10 +418,6 @@ tabulate==0.9.0 \
--hash=sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c \
--hash=sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f
# via dm-haiku
toolz==1.0.0 \
--hash=sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236 \
--hash=sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02
# via chex
tqdm==4.67.0 \
--hash=sha256:0cd8af9d56911acab92182e88d763100d4788bdf421d251616040cc4d44863be \
--hash=sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a
@ -447,11 +434,9 @@ triton==3.1.0 \
typeguard==2.13.3 \
--hash=sha256:00edaa8da3a133674796cf5ea87d9f4b4c367d77476e185e80251cc13dfbb8c4 \
--hash=sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1
# via jaxtyping
typing-extensions==4.12.2 \
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8
# via chex
# via
# alphafold3 (pyproject.toml)
# jaxtyping
zstandard==0.23.0 \
--hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \
--hash=sha256:0a7f0804bb3799414af278e9ad51be25edf67f78f916e08afdb983e74161b916 \

View File

@ -9,13 +9,15 @@
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Batch dataclass."""
import dataclasses
from typing import Self
from alphafold3.model import features
import chex
import jax
@chex.dataclass(mappable_dataclass=False, frozen=True)
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class Batch:
"""Dataclass containing batch."""

View File

@ -32,7 +32,7 @@ from alphafold3.model import merging_features
from alphafold3.model import msa_pairing
from alphafold3.model.atom_layout import atom_layout
from alphafold3.structure import chemical_components as struc_chem_comps
import chex
import jax
import jax.numpy as jnp
import numpy as np
from rdkit import Chem
@ -100,7 +100,8 @@ def _unwrap(obj):
return obj
@chex.dataclass(mappable_dataclass=False, frozen=True)
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class Chains:
chain_id: np.ndarray
asym_id: np.ndarray
@ -391,7 +392,8 @@ def tokenizer(
return all_tokens, all_token_atoms_layout, standard_token_idxs
@chex.dataclass(mappable_dataclass=False, frozen=True)
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class MSA:
"""Dataclass containing MSA."""
@ -687,7 +689,8 @@ class MSA:
}
@chex.dataclass(mappable_dataclass=False, frozen=True)
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class Templates:
"""Dataclass containing templates."""
@ -867,7 +870,8 @@ def _reduce_template_features(
return template_features
@chex.dataclass(mappable_dataclass=False, frozen=True)
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class TokenFeatures:
"""Dataclass containing features for tokens."""
@ -1009,7 +1013,8 @@ class TokenFeatures:
}
@chex.dataclass(mappable_dataclass=False, frozen=True)
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class PredictedStructureInfo:
"""Contains information necessary to work with predicted structure."""
@ -1071,7 +1076,8 @@ class PredictedStructureInfo:
}
@chex.dataclass(mappable_dataclass=False, frozen=True)
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class PolymerLigandBondInfo:
"""Contains information about polymer-ligand bonds."""
@ -1187,7 +1193,8 @@ class PolymerLigandBondInfo:
}
@chex.dataclass(mappable_dataclass=False, frozen=True)
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class LigandLigandBondInfo:
"""Contains information about the location of ligand-ligand bonds."""
@ -1283,7 +1290,8 @@ class LigandLigandBondInfo:
}
@chex.dataclass(mappable_dataclass=False, frozen=True)
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class PseudoBetaInfo:
"""Contains information for extracting pseudo-beta and equivalent atoms."""
@ -1598,7 +1606,8 @@ def get_reference(
return features, from_atom, dest_atom
@chex.dataclass(mappable_dataclass=False, frozen=True)
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class RefStructure:
"""Contains ref structure information."""
@ -1756,7 +1765,8 @@ class RefStructure:
}
@chex.dataclass(mappable_dataclass=False, frozen=True)
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class ConvertModelOutput:
"""Contains atom layout info."""
@ -1817,7 +1827,8 @@ class ConvertModelOutput:
}
@chex.dataclass(mappable_dataclass=False, frozen=True)
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class AtomCrossAtt:
"""Operate on flat atoms."""
@ -1961,7 +1972,8 @@ class AtomCrossAtt:
}
@chex.dataclass(mappable_dataclass=False, frozen=True)
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class Frames:
"""Features for backbone frames."""

View File

@ -9,6 +9,7 @@
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Per-atom cross attention."""
import dataclasses
from alphafold3.common import base_config
from alphafold3.model import feat_batch
@ -17,7 +18,6 @@ from alphafold3.model.atom_layout import atom_layout
from alphafold3.model.components import haiku_modules as hm
from alphafold3.model.components import utils
from alphafold3.model.network import diffusion_transformer
import chex
import jax
import jax.numpy as jnp
@ -96,7 +96,8 @@ def _per_atom_conditioning(
return act, pair_act
@chex.dataclass(mappable_dataclass=False, frozen=True)
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class AtomCrossAttEncoderOutput:
token_act: jnp.ndarray # (num_tokens, ch)
skip_connection: jnp.ndarray # (num_subsets, num_queries, ch)

View File

@ -21,7 +21,6 @@ from alphafold3.model.network import atom_cross_attention
from alphafold3.model.network import diffusion_transformer
from alphafold3.model.network import featurization
from alphafold3.model.network import noise_level_embeddings
import chex
import haiku as hk
import jax
import jax.numpy as jnp
@ -239,7 +238,6 @@ class DiffusionHead(hk.Module):
act = enc.token_act
# Token-token attention
chex.assert_shape(act, (None, self.config.per_token_channels))
act = jnp.asarray(act, dtype=jnp.float32)
act += hm.Linear(

View File

@ -16,7 +16,6 @@ from alphafold3.constants import residue_names
from alphafold3.model import feat_batch
from alphafold3.model import features
from alphafold3.model.components import utils
import chex
import jax
import jax.numpy as jnp
@ -109,7 +108,7 @@ def gumbel_argsort_sample_idx(
return perm[::-1]
def create_msa_feat(msa: features.MSA) -> chex.ArrayDevice:
def create_msa_feat(msa: features.MSA) -> jax.Array:
"""Create and concatenate MSA features."""
msa_1hot = jax.nn.one_hot(
msa.rows, residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP + 1
@ -137,7 +136,7 @@ def truncate_msa_batch(msa: features.MSA, num_msa: int) -> features.MSA:
def create_target_feat(
batch: feat_batch.Batch,
append_per_atom_features: bool,
) -> chex.ArrayDevice:
) -> jax.Array:
"""Make target feat."""
token_features = batch.token_features
target_features = []
@ -170,7 +169,7 @@ def create_relative_encoding(
seq_features: features.TokenFeatures,
max_relative_idx: int,
max_relative_chain: int,
) -> chex.ArrayDevice:
) -> jax.Array:
"""Add relative position encodings."""
rel_feats = []
token_index = seq_features.token_index