Remove unnecessary chex dependency

PiperOrigin-RevId: 765086591
Change-Id: I34d7e7b83073d84fe083ee091a3b54e90dcdeba3
This commit is contained in:
Augustin Zidek
2025-05-30 01:43:37 -07:00
committed by Copybara-Service
parent 17afe151ea
commit 565f286892
9 changed files with 43 additions and 63 deletions

View File

@ -140,7 +140,6 @@ AlphaFold 3 uses the following separate libraries and packages:
* [abseil-cpp](https://github.com/abseil/abseil-cpp) and * [abseil-cpp](https://github.com/abseil/abseil-cpp) and
[abseil-py](https://github.com/abseil/abseil-py) [abseil-py](https://github.com/abseil/abseil-py)
* [Chex](https://github.com/deepmind/chex)
* [Docker](https://www.docker.com) * [Docker](https://www.docker.com)
* [DSSP](https://github.com/PDB-REDO/dssp) * [DSSP](https://github.com/PDB-REDO/dssp)
* [HMMER Suite](https://github.com/EddyRivasLab/hmmer) * [HMMER Suite](https://github.com/EddyRivasLab/hmmer)

View File

@ -9,13 +9,8 @@ absl-py==2.1.0 \
--hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff
# via # via
# alphafold3 (pyproject.toml) # alphafold3 (pyproject.toml)
# chex
# dm-haiku # dm-haiku
# jax-triton # jax-triton
chex==0.1.87 \
--hash=sha256:0096d89cc8d898bb521ef4bfbf5c24549022b0e5b301f529ab57238896fe6c5d \
--hash=sha256:ce536475661fd96d21be0c1728ecdbedd03f8ff950c662dfc338c92ea782cb16
# via alphafold3 (pyproject.toml)
dm-haiku==0.0.13 \ dm-haiku==0.0.13 \
--hash=sha256:029bb91b5b1edb0d3fe23304d3bf12a545ea6e485041f7f5d8c8d85ebcf6e17d \ --hash=sha256:029bb91b5b1edb0d3fe23304d3bf12a545ea6e485041f7f5d8c8d85ebcf6e17d \
--hash=sha256:ee9562c68a059f146ad07f555ca591cb8c11ef751afecc38353863562bd23f43 --hash=sha256:ee9562c68a059f146ad07f555ca591cb8c11ef751afecc38353863562bd23f43
@ -77,7 +72,6 @@ jax[cuda12]==0.4.34 \
--hash=sha256:b957ca1fc91f7343f91a186af9f19c7f342c946f95a8c11c7f1e5cdfe2e58d9e --hash=sha256:b957ca1fc91f7343f91a186af9f19c7f342c946f95a8c11c7f1e5cdfe2e58d9e
# via # via
# alphafold3 (pyproject.toml) # alphafold3 (pyproject.toml)
# chex
# jax-triton # jax-triton
jax-cuda12-pjrt==0.4.34 \ jax-cuda12-pjrt==0.4.34 \
--hash=sha256:0c7cc98f962cc7fc8e0a5ea6331b42a0cee516f202f1c3019f6aa5cd9530cca0 \ --hash=sha256:0c7cc98f962cc7fc8e0a5ea6331b42a0cee516f202f1c3019f6aa5cd9530cca0 \
@ -118,9 +112,7 @@ jaxlib==0.4.34 \
--hash=sha256:c7b3e724a30426a856070aba0192b5d199e95b4411070e7ad96ad8b196877b10 \ --hash=sha256:c7b3e724a30426a856070aba0192b5d199e95b4411070e7ad96ad8b196877b10 \
--hash=sha256:c9d3adcae43a33aad4332be9c2aedc5ef751d1e755f917a5afb30c7872eacaa8 \ --hash=sha256:c9d3adcae43a33aad4332be9c2aedc5ef751d1e755f917a5afb30c7872eacaa8 \
--hash=sha256:d840e64b85f8865404d6d225b9bb340e158df1457152a361b05680e24792b232 --hash=sha256:d840e64b85f8865404d6d225b9bb340e158df1457152a361b05680e24792b232
# via # via jax
# chex
# jax
jaxtyping==0.2.34 \ jaxtyping==0.2.34 \
--hash=sha256:2f81fb6d1586e497a6ea2d28c06dcab37b108a096cbb36ea3fe4fa2e1c1f32e5 \ --hash=sha256:2f81fb6d1586e497a6ea2d28c06dcab37b108a096cbb36ea3fe4fa2e1c1f32e5 \
--hash=sha256:eed9a3458ec8726c84ea5457ebde53c964f65d2c22c0ec40d0555ae3fed5bbaf --hash=sha256:eed9a3458ec8726c84ea5457ebde53c964f65d2c22c0ec40d0555ae3fed5bbaf
@ -212,7 +204,6 @@ numpy==2.1.3 \
--hash=sha256:fa2d1337dc61c8dc417fbccf20f6d1e139896a30721b7f1e832b2bb6ef4eb6c4 --hash=sha256:fa2d1337dc61c8dc417fbccf20f6d1e139896a30721b7f1e832b2bb6ef4eb6c4
# via # via
# alphafold3 (pyproject.toml) # alphafold3 (pyproject.toml)
# chex
# dm-haiku # dm-haiku
# jax # jax
# jaxlib # jaxlib
@ -427,10 +418,6 @@ tabulate==0.9.0 \
--hash=sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c \ --hash=sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c \
--hash=sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f --hash=sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f
# via dm-haiku # via dm-haiku
toolz==1.0.0 \
--hash=sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236 \
--hash=sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02
# via chex
tqdm==4.67.0 \ tqdm==4.67.0 \
--hash=sha256:0cd8af9d56911acab92182e88d763100d4788bdf421d251616040cc4d44863be \ --hash=sha256:0cd8af9d56911acab92182e88d763100d4788bdf421d251616040cc4d44863be \
--hash=sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a --hash=sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a
@ -447,11 +434,9 @@ triton==3.1.0 \
typeguard==2.13.3 \ typeguard==2.13.3 \
--hash=sha256:00edaa8da3a133674796cf5ea87d9f4b4c367d77476e185e80251cc13dfbb8c4 \ --hash=sha256:00edaa8da3a133674796cf5ea87d9f4b4c367d77476e185e80251cc13dfbb8c4 \
--hash=sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1 --hash=sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1
# via jaxtyping # via
typing-extensions==4.12.2 \ # alphafold3 (pyproject.toml)
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \ # jaxtyping
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8
# via chex
zstandard==0.23.0 \ zstandard==0.23.0 \
--hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \ --hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \
--hash=sha256:0a7f0804bb3799414af278e9ad51be25edf67f78f916e08afdb983e74161b916 \ --hash=sha256:0a7f0804bb3799414af278e9ad51be25edf67f78f916e08afdb983e74161b916 \

View File

@ -16,7 +16,6 @@ readme = "README.md"
license = {file = "LICENSE"} license = {file = "LICENSE"}
dependencies = [ dependencies = [
"absl-py", "absl-py",
"chex",
"dm-haiku==0.0.13", "dm-haiku==0.0.13",
"dm-tree", "dm-tree",
"jax==0.4.34", "jax==0.4.34",

View File

@ -9,13 +9,8 @@ absl-py==2.1.0 \
--hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff
# via # via
# alphafold3 (pyproject.toml) # alphafold3 (pyproject.toml)
# chex
# dm-haiku # dm-haiku
# jax-triton # jax-triton
chex==0.1.87 \
--hash=sha256:0096d89cc8d898bb521ef4bfbf5c24549022b0e5b301f529ab57238896fe6c5d \
--hash=sha256:ce536475661fd96d21be0c1728ecdbedd03f8ff950c662dfc338c92ea782cb16
# via alphafold3 (pyproject.toml)
dm-haiku==0.0.13 \ dm-haiku==0.0.13 \
--hash=sha256:029bb91b5b1edb0d3fe23304d3bf12a545ea6e485041f7f5d8c8d85ebcf6e17d \ --hash=sha256:029bb91b5b1edb0d3fe23304d3bf12a545ea6e485041f7f5d8c8d85ebcf6e17d \
--hash=sha256:ee9562c68a059f146ad07f555ca591cb8c11ef751afecc38353863562bd23f43 --hash=sha256:ee9562c68a059f146ad07f555ca591cb8c11ef751afecc38353863562bd23f43
@ -77,7 +72,6 @@ jax[cuda12]==0.4.34 \
--hash=sha256:b957ca1fc91f7343f91a186af9f19c7f342c946f95a8c11c7f1e5cdfe2e58d9e --hash=sha256:b957ca1fc91f7343f91a186af9f19c7f342c946f95a8c11c7f1e5cdfe2e58d9e
# via # via
# alphafold3 (pyproject.toml) # alphafold3 (pyproject.toml)
# chex
# jax-triton # jax-triton
jax-cuda12-pjrt==0.4.34 \ jax-cuda12-pjrt==0.4.34 \
--hash=sha256:0c7cc98f962cc7fc8e0a5ea6331b42a0cee516f202f1c3019f6aa5cd9530cca0 \ --hash=sha256:0c7cc98f962cc7fc8e0a5ea6331b42a0cee516f202f1c3019f6aa5cd9530cca0 \
@ -118,9 +112,7 @@ jaxlib==0.4.34 \
--hash=sha256:c7b3e724a30426a856070aba0192b5d199e95b4411070e7ad96ad8b196877b10 \ --hash=sha256:c7b3e724a30426a856070aba0192b5d199e95b4411070e7ad96ad8b196877b10 \
--hash=sha256:c9d3adcae43a33aad4332be9c2aedc5ef751d1e755f917a5afb30c7872eacaa8 \ --hash=sha256:c9d3adcae43a33aad4332be9c2aedc5ef751d1e755f917a5afb30c7872eacaa8 \
--hash=sha256:d840e64b85f8865404d6d225b9bb340e158df1457152a361b05680e24792b232 --hash=sha256:d840e64b85f8865404d6d225b9bb340e158df1457152a361b05680e24792b232
# via # via jax
# chex
# jax
jaxtyping==0.2.34 \ jaxtyping==0.2.34 \
--hash=sha256:2f81fb6d1586e497a6ea2d28c06dcab37b108a096cbb36ea3fe4fa2e1c1f32e5 \ --hash=sha256:2f81fb6d1586e497a6ea2d28c06dcab37b108a096cbb36ea3fe4fa2e1c1f32e5 \
--hash=sha256:eed9a3458ec8726c84ea5457ebde53c964f65d2c22c0ec40d0555ae3fed5bbaf --hash=sha256:eed9a3458ec8726c84ea5457ebde53c964f65d2c22c0ec40d0555ae3fed5bbaf
@ -212,7 +204,6 @@ numpy==2.1.3 \
--hash=sha256:fa2d1337dc61c8dc417fbccf20f6d1e139896a30721b7f1e832b2bb6ef4eb6c4 --hash=sha256:fa2d1337dc61c8dc417fbccf20f6d1e139896a30721b7f1e832b2bb6ef4eb6c4
# via # via
# alphafold3 (pyproject.toml) # alphafold3 (pyproject.toml)
# chex
# dm-haiku # dm-haiku
# jax # jax
# jaxlib # jaxlib
@ -427,10 +418,6 @@ tabulate==0.9.0 \
--hash=sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c \ --hash=sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c \
--hash=sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f --hash=sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f
# via dm-haiku # via dm-haiku
toolz==1.0.0 \
--hash=sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236 \
--hash=sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02
# via chex
tqdm==4.67.0 \ tqdm==4.67.0 \
--hash=sha256:0cd8af9d56911acab92182e88d763100d4788bdf421d251616040cc4d44863be \ --hash=sha256:0cd8af9d56911acab92182e88d763100d4788bdf421d251616040cc4d44863be \
--hash=sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a --hash=sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a
@ -447,11 +434,9 @@ triton==3.1.0 \
typeguard==2.13.3 \ typeguard==2.13.3 \
--hash=sha256:00edaa8da3a133674796cf5ea87d9f4b4c367d77476e185e80251cc13dfbb8c4 \ --hash=sha256:00edaa8da3a133674796cf5ea87d9f4b4c367d77476e185e80251cc13dfbb8c4 \
--hash=sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1 --hash=sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1
# via jaxtyping # via
typing-extensions==4.12.2 \ # alphafold3 (pyproject.toml)
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \ # jaxtyping
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8
# via chex
zstandard==0.23.0 \ zstandard==0.23.0 \
--hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \ --hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \
--hash=sha256:0a7f0804bb3799414af278e9ad51be25edf67f78f916e08afdb983e74161b916 \ --hash=sha256:0a7f0804bb3799414af278e9ad51be25edf67f78f916e08afdb983e74161b916 \

View File

@ -9,13 +9,15 @@
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md # https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Batch dataclass.""" """Batch dataclass."""
import dataclasses
from typing import Self from typing import Self
from alphafold3.model import features from alphafold3.model import features
import chex import jax
@chex.dataclass(mappable_dataclass=False, frozen=True) @jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class Batch: class Batch:
"""Dataclass containing batch.""" """Dataclass containing batch."""

View File

@ -32,7 +32,7 @@ from alphafold3.model import merging_features
from alphafold3.model import msa_pairing from alphafold3.model import msa_pairing
from alphafold3.model.atom_layout import atom_layout from alphafold3.model.atom_layout import atom_layout
from alphafold3.structure import chemical_components as struc_chem_comps from alphafold3.structure import chemical_components as struc_chem_comps
import chex import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from rdkit import Chem from rdkit import Chem
@ -100,7 +100,8 @@ def _unwrap(obj):
return obj return obj
@chex.dataclass(mappable_dataclass=False, frozen=True) @jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class Chains: class Chains:
chain_id: np.ndarray chain_id: np.ndarray
asym_id: np.ndarray asym_id: np.ndarray
@ -391,7 +392,8 @@ def tokenizer(
return all_tokens, all_token_atoms_layout, standard_token_idxs return all_tokens, all_token_atoms_layout, standard_token_idxs
@chex.dataclass(mappable_dataclass=False, frozen=True) @jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class MSA: class MSA:
"""Dataclass containing MSA.""" """Dataclass containing MSA."""
@ -687,7 +689,8 @@ class MSA:
} }
@chex.dataclass(mappable_dataclass=False, frozen=True) @jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class Templates: class Templates:
"""Dataclass containing templates.""" """Dataclass containing templates."""
@ -867,7 +870,8 @@ def _reduce_template_features(
return template_features return template_features
@chex.dataclass(mappable_dataclass=False, frozen=True) @jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class TokenFeatures: class TokenFeatures:
"""Dataclass containing features for tokens.""" """Dataclass containing features for tokens."""
@ -1009,7 +1013,8 @@ class TokenFeatures:
} }
@chex.dataclass(mappable_dataclass=False, frozen=True) @jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class PredictedStructureInfo: class PredictedStructureInfo:
"""Contains information necessary to work with predicted structure.""" """Contains information necessary to work with predicted structure."""
@ -1071,7 +1076,8 @@ class PredictedStructureInfo:
} }
@chex.dataclass(mappable_dataclass=False, frozen=True) @jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class PolymerLigandBondInfo: class PolymerLigandBondInfo:
"""Contains information about polymer-ligand bonds.""" """Contains information about polymer-ligand bonds."""
@ -1187,7 +1193,8 @@ class PolymerLigandBondInfo:
} }
@chex.dataclass(mappable_dataclass=False, frozen=True) @jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class LigandLigandBondInfo: class LigandLigandBondInfo:
"""Contains information about the location of ligand-ligand bonds.""" """Contains information about the location of ligand-ligand bonds."""
@ -1283,7 +1290,8 @@ class LigandLigandBondInfo:
} }
@chex.dataclass(mappable_dataclass=False, frozen=True) @jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class PseudoBetaInfo: class PseudoBetaInfo:
"""Contains information for extracting pseudo-beta and equivalent atoms.""" """Contains information for extracting pseudo-beta and equivalent atoms."""
@ -1598,7 +1606,8 @@ def get_reference(
return features, from_atom, dest_atom return features, from_atom, dest_atom
@chex.dataclass(mappable_dataclass=False, frozen=True) @jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class RefStructure: class RefStructure:
"""Contains ref structure information.""" """Contains ref structure information."""
@ -1756,7 +1765,8 @@ class RefStructure:
} }
@chex.dataclass(mappable_dataclass=False, frozen=True) @jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class ConvertModelOutput: class ConvertModelOutput:
"""Contains atom layout info.""" """Contains atom layout info."""
@ -1817,7 +1827,8 @@ class ConvertModelOutput:
} }
@chex.dataclass(mappable_dataclass=False, frozen=True) @jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class AtomCrossAtt: class AtomCrossAtt:
"""Operate on flat atoms.""" """Operate on flat atoms."""
@ -1961,7 +1972,8 @@ class AtomCrossAtt:
} }
@chex.dataclass(mappable_dataclass=False, frozen=True) @jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class Frames: class Frames:
"""Features for backbone frames.""" """Features for backbone frames."""

View File

@ -9,6 +9,7 @@
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md # https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Per-atom cross attention.""" """Per-atom cross attention."""
import dataclasses
from alphafold3.common import base_config from alphafold3.common import base_config
from alphafold3.model import feat_batch from alphafold3.model import feat_batch
@ -17,7 +18,6 @@ from alphafold3.model.atom_layout import atom_layout
from alphafold3.model.components import haiku_modules as hm from alphafold3.model.components import haiku_modules as hm
from alphafold3.model.components import utils from alphafold3.model.components import utils
from alphafold3.model.network import diffusion_transformer from alphafold3.model.network import diffusion_transformer
import chex
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
@ -96,7 +96,8 @@ def _per_atom_conditioning(
return act, pair_act return act, pair_act
@chex.dataclass(mappable_dataclass=False, frozen=True) @jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class AtomCrossAttEncoderOutput: class AtomCrossAttEncoderOutput:
token_act: jnp.ndarray # (num_tokens, ch) token_act: jnp.ndarray # (num_tokens, ch)
skip_connection: jnp.ndarray # (num_subsets, num_queries, ch) skip_connection: jnp.ndarray # (num_subsets, num_queries, ch)

View File

@ -21,7 +21,6 @@ from alphafold3.model.network import atom_cross_attention
from alphafold3.model.network import diffusion_transformer from alphafold3.model.network import diffusion_transformer
from alphafold3.model.network import featurization from alphafold3.model.network import featurization
from alphafold3.model.network import noise_level_embeddings from alphafold3.model.network import noise_level_embeddings
import chex
import haiku as hk import haiku as hk
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
@ -239,7 +238,6 @@ class DiffusionHead(hk.Module):
act = enc.token_act act = enc.token_act
# Token-token attention # Token-token attention
chex.assert_shape(act, (None, self.config.per_token_channels))
act = jnp.asarray(act, dtype=jnp.float32) act = jnp.asarray(act, dtype=jnp.float32)
act += hm.Linear( act += hm.Linear(

View File

@ -16,7 +16,6 @@ from alphafold3.constants import residue_names
from alphafold3.model import feat_batch from alphafold3.model import feat_batch
from alphafold3.model import features from alphafold3.model import features
from alphafold3.model.components import utils from alphafold3.model.components import utils
import chex
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
@ -109,7 +108,7 @@ def gumbel_argsort_sample_idx(
return perm[::-1] return perm[::-1]
def create_msa_feat(msa: features.MSA) -> chex.ArrayDevice: def create_msa_feat(msa: features.MSA) -> jax.Array:
"""Create and concatenate MSA features.""" """Create and concatenate MSA features."""
msa_1hot = jax.nn.one_hot( msa_1hot = jax.nn.one_hot(
msa.rows, residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP + 1 msa.rows, residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP + 1
@ -137,7 +136,7 @@ def truncate_msa_batch(msa: features.MSA, num_msa: int) -> features.MSA:
def create_target_feat( def create_target_feat(
batch: feat_batch.Batch, batch: feat_batch.Batch,
append_per_atom_features: bool, append_per_atom_features: bool,
) -> chex.ArrayDevice: ) -> jax.Array:
"""Make target feat.""" """Make target feat."""
token_features = batch.token_features token_features = batch.token_features
target_features = [] target_features = []
@ -170,7 +169,7 @@ def create_relative_encoding(
seq_features: features.TokenFeatures, seq_features: features.TokenFeatures,
max_relative_idx: int, max_relative_idx: int,
max_relative_chain: int, max_relative_chain: int,
) -> chex.ArrayDevice: ) -> jax.Array:
"""Add relative position encodings.""" """Add relative position encodings."""
rel_feats = [] rel_feats = []
token_index = seq_features.token_index token_index = seq_features.token_index