mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2025-10-20 13:23:47 +08:00
Remove unnecessary chex dependency
PiperOrigin-RevId: 765086591 Change-Id: I34d7e7b83073d84fe083ee091a3b54e90dcdeba3
This commit is contained in:
committed by
Copybara-Service
parent
17afe151ea
commit
565f286892
@ -140,7 +140,6 @@ AlphaFold 3 uses the following separate libraries and packages:
|
|||||||
|
|
||||||
* [abseil-cpp](https://github.com/abseil/abseil-cpp) and
|
* [abseil-cpp](https://github.com/abseil/abseil-cpp) and
|
||||||
[abseil-py](https://github.com/abseil/abseil-py)
|
[abseil-py](https://github.com/abseil/abseil-py)
|
||||||
* [Chex](https://github.com/deepmind/chex)
|
|
||||||
* [Docker](https://www.docker.com)
|
* [Docker](https://www.docker.com)
|
||||||
* [DSSP](https://github.com/PDB-REDO/dssp)
|
* [DSSP](https://github.com/PDB-REDO/dssp)
|
||||||
* [HMMER Suite](https://github.com/EddyRivasLab/hmmer)
|
* [HMMER Suite](https://github.com/EddyRivasLab/hmmer)
|
||||||
|
@ -9,13 +9,8 @@ absl-py==2.1.0 \
|
|||||||
--hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff
|
--hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff
|
||||||
# via
|
# via
|
||||||
# alphafold3 (pyproject.toml)
|
# alphafold3 (pyproject.toml)
|
||||||
# chex
|
|
||||||
# dm-haiku
|
# dm-haiku
|
||||||
# jax-triton
|
# jax-triton
|
||||||
chex==0.1.87 \
|
|
||||||
--hash=sha256:0096d89cc8d898bb521ef4bfbf5c24549022b0e5b301f529ab57238896fe6c5d \
|
|
||||||
--hash=sha256:ce536475661fd96d21be0c1728ecdbedd03f8ff950c662dfc338c92ea782cb16
|
|
||||||
# via alphafold3 (pyproject.toml)
|
|
||||||
dm-haiku==0.0.13 \
|
dm-haiku==0.0.13 \
|
||||||
--hash=sha256:029bb91b5b1edb0d3fe23304d3bf12a545ea6e485041f7f5d8c8d85ebcf6e17d \
|
--hash=sha256:029bb91b5b1edb0d3fe23304d3bf12a545ea6e485041f7f5d8c8d85ebcf6e17d \
|
||||||
--hash=sha256:ee9562c68a059f146ad07f555ca591cb8c11ef751afecc38353863562bd23f43
|
--hash=sha256:ee9562c68a059f146ad07f555ca591cb8c11ef751afecc38353863562bd23f43
|
||||||
@ -77,7 +72,6 @@ jax[cuda12]==0.4.34 \
|
|||||||
--hash=sha256:b957ca1fc91f7343f91a186af9f19c7f342c946f95a8c11c7f1e5cdfe2e58d9e
|
--hash=sha256:b957ca1fc91f7343f91a186af9f19c7f342c946f95a8c11c7f1e5cdfe2e58d9e
|
||||||
# via
|
# via
|
||||||
# alphafold3 (pyproject.toml)
|
# alphafold3 (pyproject.toml)
|
||||||
# chex
|
|
||||||
# jax-triton
|
# jax-triton
|
||||||
jax-cuda12-pjrt==0.4.34 \
|
jax-cuda12-pjrt==0.4.34 \
|
||||||
--hash=sha256:0c7cc98f962cc7fc8e0a5ea6331b42a0cee516f202f1c3019f6aa5cd9530cca0 \
|
--hash=sha256:0c7cc98f962cc7fc8e0a5ea6331b42a0cee516f202f1c3019f6aa5cd9530cca0 \
|
||||||
@ -118,9 +112,7 @@ jaxlib==0.4.34 \
|
|||||||
--hash=sha256:c7b3e724a30426a856070aba0192b5d199e95b4411070e7ad96ad8b196877b10 \
|
--hash=sha256:c7b3e724a30426a856070aba0192b5d199e95b4411070e7ad96ad8b196877b10 \
|
||||||
--hash=sha256:c9d3adcae43a33aad4332be9c2aedc5ef751d1e755f917a5afb30c7872eacaa8 \
|
--hash=sha256:c9d3adcae43a33aad4332be9c2aedc5ef751d1e755f917a5afb30c7872eacaa8 \
|
||||||
--hash=sha256:d840e64b85f8865404d6d225b9bb340e158df1457152a361b05680e24792b232
|
--hash=sha256:d840e64b85f8865404d6d225b9bb340e158df1457152a361b05680e24792b232
|
||||||
# via
|
# via jax
|
||||||
# chex
|
|
||||||
# jax
|
|
||||||
jaxtyping==0.2.34 \
|
jaxtyping==0.2.34 \
|
||||||
--hash=sha256:2f81fb6d1586e497a6ea2d28c06dcab37b108a096cbb36ea3fe4fa2e1c1f32e5 \
|
--hash=sha256:2f81fb6d1586e497a6ea2d28c06dcab37b108a096cbb36ea3fe4fa2e1c1f32e5 \
|
||||||
--hash=sha256:eed9a3458ec8726c84ea5457ebde53c964f65d2c22c0ec40d0555ae3fed5bbaf
|
--hash=sha256:eed9a3458ec8726c84ea5457ebde53c964f65d2c22c0ec40d0555ae3fed5bbaf
|
||||||
@ -212,7 +204,6 @@ numpy==2.1.3 \
|
|||||||
--hash=sha256:fa2d1337dc61c8dc417fbccf20f6d1e139896a30721b7f1e832b2bb6ef4eb6c4
|
--hash=sha256:fa2d1337dc61c8dc417fbccf20f6d1e139896a30721b7f1e832b2bb6ef4eb6c4
|
||||||
# via
|
# via
|
||||||
# alphafold3 (pyproject.toml)
|
# alphafold3 (pyproject.toml)
|
||||||
# chex
|
|
||||||
# dm-haiku
|
# dm-haiku
|
||||||
# jax
|
# jax
|
||||||
# jaxlib
|
# jaxlib
|
||||||
@ -427,10 +418,6 @@ tabulate==0.9.0 \
|
|||||||
--hash=sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c \
|
--hash=sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c \
|
||||||
--hash=sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f
|
--hash=sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f
|
||||||
# via dm-haiku
|
# via dm-haiku
|
||||||
toolz==1.0.0 \
|
|
||||||
--hash=sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236 \
|
|
||||||
--hash=sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02
|
|
||||||
# via chex
|
|
||||||
tqdm==4.67.0 \
|
tqdm==4.67.0 \
|
||||||
--hash=sha256:0cd8af9d56911acab92182e88d763100d4788bdf421d251616040cc4d44863be \
|
--hash=sha256:0cd8af9d56911acab92182e88d763100d4788bdf421d251616040cc4d44863be \
|
||||||
--hash=sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a
|
--hash=sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a
|
||||||
@ -447,11 +434,9 @@ triton==3.1.0 \
|
|||||||
typeguard==2.13.3 \
|
typeguard==2.13.3 \
|
||||||
--hash=sha256:00edaa8da3a133674796cf5ea87d9f4b4c367d77476e185e80251cc13dfbb8c4 \
|
--hash=sha256:00edaa8da3a133674796cf5ea87d9f4b4c367d77476e185e80251cc13dfbb8c4 \
|
||||||
--hash=sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1
|
--hash=sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1
|
||||||
# via jaxtyping
|
# via
|
||||||
typing-extensions==4.12.2 \
|
# alphafold3 (pyproject.toml)
|
||||||
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
|
# jaxtyping
|
||||||
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8
|
|
||||||
# via chex
|
|
||||||
zstandard==0.23.0 \
|
zstandard==0.23.0 \
|
||||||
--hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \
|
--hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \
|
||||||
--hash=sha256:0a7f0804bb3799414af278e9ad51be25edf67f78f916e08afdb983e74161b916 \
|
--hash=sha256:0a7f0804bb3799414af278e9ad51be25edf67f78f916e08afdb983e74161b916 \
|
||||||
|
@ -16,7 +16,6 @@ readme = "README.md"
|
|||||||
license = {file = "LICENSE"}
|
license = {file = "LICENSE"}
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"absl-py",
|
"absl-py",
|
||||||
"chex",
|
|
||||||
"dm-haiku==0.0.13",
|
"dm-haiku==0.0.13",
|
||||||
"dm-tree",
|
"dm-tree",
|
||||||
"jax==0.4.34",
|
"jax==0.4.34",
|
||||||
|
@ -9,13 +9,8 @@ absl-py==2.1.0 \
|
|||||||
--hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff
|
--hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff
|
||||||
# via
|
# via
|
||||||
# alphafold3 (pyproject.toml)
|
# alphafold3 (pyproject.toml)
|
||||||
# chex
|
|
||||||
# dm-haiku
|
# dm-haiku
|
||||||
# jax-triton
|
# jax-triton
|
||||||
chex==0.1.87 \
|
|
||||||
--hash=sha256:0096d89cc8d898bb521ef4bfbf5c24549022b0e5b301f529ab57238896fe6c5d \
|
|
||||||
--hash=sha256:ce536475661fd96d21be0c1728ecdbedd03f8ff950c662dfc338c92ea782cb16
|
|
||||||
# via alphafold3 (pyproject.toml)
|
|
||||||
dm-haiku==0.0.13 \
|
dm-haiku==0.0.13 \
|
||||||
--hash=sha256:029bb91b5b1edb0d3fe23304d3bf12a545ea6e485041f7f5d8c8d85ebcf6e17d \
|
--hash=sha256:029bb91b5b1edb0d3fe23304d3bf12a545ea6e485041f7f5d8c8d85ebcf6e17d \
|
||||||
--hash=sha256:ee9562c68a059f146ad07f555ca591cb8c11ef751afecc38353863562bd23f43
|
--hash=sha256:ee9562c68a059f146ad07f555ca591cb8c11ef751afecc38353863562bd23f43
|
||||||
@ -77,7 +72,6 @@ jax[cuda12]==0.4.34 \
|
|||||||
--hash=sha256:b957ca1fc91f7343f91a186af9f19c7f342c946f95a8c11c7f1e5cdfe2e58d9e
|
--hash=sha256:b957ca1fc91f7343f91a186af9f19c7f342c946f95a8c11c7f1e5cdfe2e58d9e
|
||||||
# via
|
# via
|
||||||
# alphafold3 (pyproject.toml)
|
# alphafold3 (pyproject.toml)
|
||||||
# chex
|
|
||||||
# jax-triton
|
# jax-triton
|
||||||
jax-cuda12-pjrt==0.4.34 \
|
jax-cuda12-pjrt==0.4.34 \
|
||||||
--hash=sha256:0c7cc98f962cc7fc8e0a5ea6331b42a0cee516f202f1c3019f6aa5cd9530cca0 \
|
--hash=sha256:0c7cc98f962cc7fc8e0a5ea6331b42a0cee516f202f1c3019f6aa5cd9530cca0 \
|
||||||
@ -118,9 +112,7 @@ jaxlib==0.4.34 \
|
|||||||
--hash=sha256:c7b3e724a30426a856070aba0192b5d199e95b4411070e7ad96ad8b196877b10 \
|
--hash=sha256:c7b3e724a30426a856070aba0192b5d199e95b4411070e7ad96ad8b196877b10 \
|
||||||
--hash=sha256:c9d3adcae43a33aad4332be9c2aedc5ef751d1e755f917a5afb30c7872eacaa8 \
|
--hash=sha256:c9d3adcae43a33aad4332be9c2aedc5ef751d1e755f917a5afb30c7872eacaa8 \
|
||||||
--hash=sha256:d840e64b85f8865404d6d225b9bb340e158df1457152a361b05680e24792b232
|
--hash=sha256:d840e64b85f8865404d6d225b9bb340e158df1457152a361b05680e24792b232
|
||||||
# via
|
# via jax
|
||||||
# chex
|
|
||||||
# jax
|
|
||||||
jaxtyping==0.2.34 \
|
jaxtyping==0.2.34 \
|
||||||
--hash=sha256:2f81fb6d1586e497a6ea2d28c06dcab37b108a096cbb36ea3fe4fa2e1c1f32e5 \
|
--hash=sha256:2f81fb6d1586e497a6ea2d28c06dcab37b108a096cbb36ea3fe4fa2e1c1f32e5 \
|
||||||
--hash=sha256:eed9a3458ec8726c84ea5457ebde53c964f65d2c22c0ec40d0555ae3fed5bbaf
|
--hash=sha256:eed9a3458ec8726c84ea5457ebde53c964f65d2c22c0ec40d0555ae3fed5bbaf
|
||||||
@ -212,7 +204,6 @@ numpy==2.1.3 \
|
|||||||
--hash=sha256:fa2d1337dc61c8dc417fbccf20f6d1e139896a30721b7f1e832b2bb6ef4eb6c4
|
--hash=sha256:fa2d1337dc61c8dc417fbccf20f6d1e139896a30721b7f1e832b2bb6ef4eb6c4
|
||||||
# via
|
# via
|
||||||
# alphafold3 (pyproject.toml)
|
# alphafold3 (pyproject.toml)
|
||||||
# chex
|
|
||||||
# dm-haiku
|
# dm-haiku
|
||||||
# jax
|
# jax
|
||||||
# jaxlib
|
# jaxlib
|
||||||
@ -427,10 +418,6 @@ tabulate==0.9.0 \
|
|||||||
--hash=sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c \
|
--hash=sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c \
|
||||||
--hash=sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f
|
--hash=sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f
|
||||||
# via dm-haiku
|
# via dm-haiku
|
||||||
toolz==1.0.0 \
|
|
||||||
--hash=sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236 \
|
|
||||||
--hash=sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02
|
|
||||||
# via chex
|
|
||||||
tqdm==4.67.0 \
|
tqdm==4.67.0 \
|
||||||
--hash=sha256:0cd8af9d56911acab92182e88d763100d4788bdf421d251616040cc4d44863be \
|
--hash=sha256:0cd8af9d56911acab92182e88d763100d4788bdf421d251616040cc4d44863be \
|
||||||
--hash=sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a
|
--hash=sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a
|
||||||
@ -447,11 +434,9 @@ triton==3.1.0 \
|
|||||||
typeguard==2.13.3 \
|
typeguard==2.13.3 \
|
||||||
--hash=sha256:00edaa8da3a133674796cf5ea87d9f4b4c367d77476e185e80251cc13dfbb8c4 \
|
--hash=sha256:00edaa8da3a133674796cf5ea87d9f4b4c367d77476e185e80251cc13dfbb8c4 \
|
||||||
--hash=sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1
|
--hash=sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1
|
||||||
# via jaxtyping
|
# via
|
||||||
typing-extensions==4.12.2 \
|
# alphafold3 (pyproject.toml)
|
||||||
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
|
# jaxtyping
|
||||||
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8
|
|
||||||
# via chex
|
|
||||||
zstandard==0.23.0 \
|
zstandard==0.23.0 \
|
||||||
--hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \
|
--hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \
|
||||||
--hash=sha256:0a7f0804bb3799414af278e9ad51be25edf67f78f916e08afdb983e74161b916 \
|
--hash=sha256:0a7f0804bb3799414af278e9ad51be25edf67f78f916e08afdb983e74161b916 \
|
||||||
|
@ -9,13 +9,15 @@
|
|||||||
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
|
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
|
||||||
|
|
||||||
"""Batch dataclass."""
|
"""Batch dataclass."""
|
||||||
|
import dataclasses
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
from alphafold3.model import features
|
from alphafold3.model import features
|
||||||
import chex
|
import jax
|
||||||
|
|
||||||
|
|
||||||
@chex.dataclass(mappable_dataclass=False, frozen=True)
|
@jax.tree_util.register_dataclass
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
class Batch:
|
class Batch:
|
||||||
"""Dataclass containing batch."""
|
"""Dataclass containing batch."""
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ from alphafold3.model import merging_features
|
|||||||
from alphafold3.model import msa_pairing
|
from alphafold3.model import msa_pairing
|
||||||
from alphafold3.model.atom_layout import atom_layout
|
from alphafold3.model.atom_layout import atom_layout
|
||||||
from alphafold3.structure import chemical_components as struc_chem_comps
|
from alphafold3.structure import chemical_components as struc_chem_comps
|
||||||
import chex
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from rdkit import Chem
|
from rdkit import Chem
|
||||||
@ -100,7 +100,8 @@ def _unwrap(obj):
|
|||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
@chex.dataclass(mappable_dataclass=False, frozen=True)
|
@jax.tree_util.register_dataclass
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
class Chains:
|
class Chains:
|
||||||
chain_id: np.ndarray
|
chain_id: np.ndarray
|
||||||
asym_id: np.ndarray
|
asym_id: np.ndarray
|
||||||
@ -391,7 +392,8 @@ def tokenizer(
|
|||||||
return all_tokens, all_token_atoms_layout, standard_token_idxs
|
return all_tokens, all_token_atoms_layout, standard_token_idxs
|
||||||
|
|
||||||
|
|
||||||
@chex.dataclass(mappable_dataclass=False, frozen=True)
|
@jax.tree_util.register_dataclass
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
class MSA:
|
class MSA:
|
||||||
"""Dataclass containing MSA."""
|
"""Dataclass containing MSA."""
|
||||||
|
|
||||||
@ -687,7 +689,8 @@ class MSA:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@chex.dataclass(mappable_dataclass=False, frozen=True)
|
@jax.tree_util.register_dataclass
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
class Templates:
|
class Templates:
|
||||||
"""Dataclass containing templates."""
|
"""Dataclass containing templates."""
|
||||||
|
|
||||||
@ -867,7 +870,8 @@ def _reduce_template_features(
|
|||||||
return template_features
|
return template_features
|
||||||
|
|
||||||
|
|
||||||
@chex.dataclass(mappable_dataclass=False, frozen=True)
|
@jax.tree_util.register_dataclass
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
class TokenFeatures:
|
class TokenFeatures:
|
||||||
"""Dataclass containing features for tokens."""
|
"""Dataclass containing features for tokens."""
|
||||||
|
|
||||||
@ -1009,7 +1013,8 @@ class TokenFeatures:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@chex.dataclass(mappable_dataclass=False, frozen=True)
|
@jax.tree_util.register_dataclass
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
class PredictedStructureInfo:
|
class PredictedStructureInfo:
|
||||||
"""Contains information necessary to work with predicted structure."""
|
"""Contains information necessary to work with predicted structure."""
|
||||||
|
|
||||||
@ -1071,7 +1076,8 @@ class PredictedStructureInfo:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@chex.dataclass(mappable_dataclass=False, frozen=True)
|
@jax.tree_util.register_dataclass
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
class PolymerLigandBondInfo:
|
class PolymerLigandBondInfo:
|
||||||
"""Contains information about polymer-ligand bonds."""
|
"""Contains information about polymer-ligand bonds."""
|
||||||
|
|
||||||
@ -1187,7 +1193,8 @@ class PolymerLigandBondInfo:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@chex.dataclass(mappable_dataclass=False, frozen=True)
|
@jax.tree_util.register_dataclass
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
class LigandLigandBondInfo:
|
class LigandLigandBondInfo:
|
||||||
"""Contains information about the location of ligand-ligand bonds."""
|
"""Contains information about the location of ligand-ligand bonds."""
|
||||||
|
|
||||||
@ -1283,7 +1290,8 @@ class LigandLigandBondInfo:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@chex.dataclass(mappable_dataclass=False, frozen=True)
|
@jax.tree_util.register_dataclass
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
class PseudoBetaInfo:
|
class PseudoBetaInfo:
|
||||||
"""Contains information for extracting pseudo-beta and equivalent atoms."""
|
"""Contains information for extracting pseudo-beta and equivalent atoms."""
|
||||||
|
|
||||||
@ -1598,7 +1606,8 @@ def get_reference(
|
|||||||
return features, from_atom, dest_atom
|
return features, from_atom, dest_atom
|
||||||
|
|
||||||
|
|
||||||
@chex.dataclass(mappable_dataclass=False, frozen=True)
|
@jax.tree_util.register_dataclass
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
class RefStructure:
|
class RefStructure:
|
||||||
"""Contains ref structure information."""
|
"""Contains ref structure information."""
|
||||||
|
|
||||||
@ -1756,7 +1765,8 @@ class RefStructure:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@chex.dataclass(mappable_dataclass=False, frozen=True)
|
@jax.tree_util.register_dataclass
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
class ConvertModelOutput:
|
class ConvertModelOutput:
|
||||||
"""Contains atom layout info."""
|
"""Contains atom layout info."""
|
||||||
|
|
||||||
@ -1817,7 +1827,8 @@ class ConvertModelOutput:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@chex.dataclass(mappable_dataclass=False, frozen=True)
|
@jax.tree_util.register_dataclass
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
class AtomCrossAtt:
|
class AtomCrossAtt:
|
||||||
"""Operate on flat atoms."""
|
"""Operate on flat atoms."""
|
||||||
|
|
||||||
@ -1961,7 +1972,8 @@ class AtomCrossAtt:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@chex.dataclass(mappable_dataclass=False, frozen=True)
|
@jax.tree_util.register_dataclass
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
class Frames:
|
class Frames:
|
||||||
"""Features for backbone frames."""
|
"""Features for backbone frames."""
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
|
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
|
||||||
|
|
||||||
"""Per-atom cross attention."""
|
"""Per-atom cross attention."""
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
from alphafold3.common import base_config
|
from alphafold3.common import base_config
|
||||||
from alphafold3.model import feat_batch
|
from alphafold3.model import feat_batch
|
||||||
@ -17,7 +18,6 @@ from alphafold3.model.atom_layout import atom_layout
|
|||||||
from alphafold3.model.components import haiku_modules as hm
|
from alphafold3.model.components import haiku_modules as hm
|
||||||
from alphafold3.model.components import utils
|
from alphafold3.model.components import utils
|
||||||
from alphafold3.model.network import diffusion_transformer
|
from alphafold3.model.network import diffusion_transformer
|
||||||
import chex
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
@ -96,7 +96,8 @@ def _per_atom_conditioning(
|
|||||||
return act, pair_act
|
return act, pair_act
|
||||||
|
|
||||||
|
|
||||||
@chex.dataclass(mappable_dataclass=False, frozen=True)
|
@jax.tree_util.register_dataclass
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
class AtomCrossAttEncoderOutput:
|
class AtomCrossAttEncoderOutput:
|
||||||
token_act: jnp.ndarray # (num_tokens, ch)
|
token_act: jnp.ndarray # (num_tokens, ch)
|
||||||
skip_connection: jnp.ndarray # (num_subsets, num_queries, ch)
|
skip_connection: jnp.ndarray # (num_subsets, num_queries, ch)
|
||||||
|
@ -21,7 +21,6 @@ from alphafold3.model.network import atom_cross_attention
|
|||||||
from alphafold3.model.network import diffusion_transformer
|
from alphafold3.model.network import diffusion_transformer
|
||||||
from alphafold3.model.network import featurization
|
from alphafold3.model.network import featurization
|
||||||
from alphafold3.model.network import noise_level_embeddings
|
from alphafold3.model.network import noise_level_embeddings
|
||||||
import chex
|
|
||||||
import haiku as hk
|
import haiku as hk
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@ -239,7 +238,6 @@ class DiffusionHead(hk.Module):
|
|||||||
act = enc.token_act
|
act = enc.token_act
|
||||||
|
|
||||||
# Token-token attention
|
# Token-token attention
|
||||||
chex.assert_shape(act, (None, self.config.per_token_channels))
|
|
||||||
act = jnp.asarray(act, dtype=jnp.float32)
|
act = jnp.asarray(act, dtype=jnp.float32)
|
||||||
|
|
||||||
act += hm.Linear(
|
act += hm.Linear(
|
||||||
|
@ -16,7 +16,6 @@ from alphafold3.constants import residue_names
|
|||||||
from alphafold3.model import feat_batch
|
from alphafold3.model import feat_batch
|
||||||
from alphafold3.model import features
|
from alphafold3.model import features
|
||||||
from alphafold3.model.components import utils
|
from alphafold3.model.components import utils
|
||||||
import chex
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
@ -109,7 +108,7 @@ def gumbel_argsort_sample_idx(
|
|||||||
return perm[::-1]
|
return perm[::-1]
|
||||||
|
|
||||||
|
|
||||||
def create_msa_feat(msa: features.MSA) -> chex.ArrayDevice:
|
def create_msa_feat(msa: features.MSA) -> jax.Array:
|
||||||
"""Create and concatenate MSA features."""
|
"""Create and concatenate MSA features."""
|
||||||
msa_1hot = jax.nn.one_hot(
|
msa_1hot = jax.nn.one_hot(
|
||||||
msa.rows, residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP + 1
|
msa.rows, residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP + 1
|
||||||
@ -137,7 +136,7 @@ def truncate_msa_batch(msa: features.MSA, num_msa: int) -> features.MSA:
|
|||||||
def create_target_feat(
|
def create_target_feat(
|
||||||
batch: feat_batch.Batch,
|
batch: feat_batch.Batch,
|
||||||
append_per_atom_features: bool,
|
append_per_atom_features: bool,
|
||||||
) -> chex.ArrayDevice:
|
) -> jax.Array:
|
||||||
"""Make target feat."""
|
"""Make target feat."""
|
||||||
token_features = batch.token_features
|
token_features = batch.token_features
|
||||||
target_features = []
|
target_features = []
|
||||||
@ -170,7 +169,7 @@ def create_relative_encoding(
|
|||||||
seq_features: features.TokenFeatures,
|
seq_features: features.TokenFeatures,
|
||||||
max_relative_idx: int,
|
max_relative_idx: int,
|
||||||
max_relative_chain: int,
|
max_relative_chain: int,
|
||||||
) -> chex.ArrayDevice:
|
) -> jax.Array:
|
||||||
"""Add relative position encodings."""
|
"""Add relative position encodings."""
|
||||||
rel_feats = []
|
rel_feats = []
|
||||||
token_index = seq_features.token_index
|
token_index = seq_features.token_index
|
||||||
|
Reference in New Issue
Block a user