From 5ecdfe883a1fa9a95b1bf044c8a496f9da41fa77 Mon Sep 17 00:00:00 2001 From: Augustin Zidek Date: Fri, 30 May 2025 04:35:10 -0700 Subject: [PATCH] Make register_dataclass compatible with JAX < v0.4.36 PiperOrigin-RevId: 765135084 Change-Id: Ifc89843c9706c38cbe6a9e2ed3c4f918b2c36954 --- src/alphafold3/model/feat_batch.py | 8 +- src/alphafold3/model/features.py | 96 ++++++++++++++++--- .../model/network/atom_cross_attention.py | 8 +- 3 files changed, 98 insertions(+), 14 deletions(-) diff --git a/src/alphafold3/model/feat_batch.py b/src/alphafold3/model/feat_batch.py index 3025ae5..e2f8de4 100644 --- a/src/alphafold3/model/feat_batch.py +++ b/src/alphafold3/model/feat_batch.py @@ -16,7 +16,6 @@ from alphafold3.model import features import jax -@jax.tree_util.register_dataclass @dataclasses.dataclass(frozen=True) class Batch: """Dataclass containing batch.""" @@ -76,3 +75,10 @@ class Batch: **self.frames.as_data_dict(), } return output + + +jax.tree_util.register_dataclass( + Batch, + data_fields=[f.name for f in dataclasses.fields(Batch)], + meta_fields=[], +) diff --git a/src/alphafold3/model/features.py b/src/alphafold3/model/features.py index 86c4214..dfc13e3 100644 --- a/src/alphafold3/model/features.py +++ b/src/alphafold3/model/features.py @@ -100,7 +100,6 @@ def _unwrap(obj): return obj -@jax.tree_util.register_dataclass @dataclasses.dataclass(frozen=True) class Chains: chain_id: np.ndarray @@ -109,6 +108,13 @@ class Chains: sym_id: np.ndarray +jax.tree_util.register_dataclass( + Chains, + data_fields=[f.name for f in dataclasses.fields(Chains)], + meta_fields=[], +) + + def _compute_asym_entity_and_sym_id( all_tokens: atom_layout.AtomLayout, ) -> Chains: @@ -392,7 +398,6 @@ def tokenizer( return all_tokens, all_token_atoms_layout, standard_token_idxs -@jax.tree_util.register_dataclass @dataclasses.dataclass(frozen=True) class MSA: """Dataclass containing MSA.""" @@ -689,7 +694,13 @@ class MSA: } -@jax.tree_util.register_dataclass +jax.tree_util.register_dataclass( + MSA, + data_fields=[f.name for f in dataclasses.fields(MSA)], + meta_fields=[], +) + + @dataclasses.dataclass(frozen=True) class Templates: """Dataclass containing templates.""" @@ -852,6 +863,13 @@ class Templates: } +jax.tree_util.register_dataclass( + Templates, + data_fields=[f.name for f in dataclasses.fields(Templates)], + meta_fields=[], +) + + def _reduce_template_features( template_features: data3.FeatureDict, max_templates: int, @@ -870,7 +888,6 @@ def _reduce_template_features( return template_features -@jax.tree_util.register_dataclass @dataclasses.dataclass(frozen=True) class TokenFeatures: """Dataclass containing features for tokens.""" @@ -1013,7 +1030,13 @@ class TokenFeatures: } -@jax.tree_util.register_dataclass +jax.tree_util.register_dataclass( + TokenFeatures, + data_fields=[f.name for f in dataclasses.fields(TokenFeatures)], + meta_fields=[], +) + + @dataclasses.dataclass(frozen=True) class PredictedStructureInfo: """Contains information necessary to work with predicted structure.""" @@ -1076,7 +1099,13 @@ class PredictedStructureInfo: } -@jax.tree_util.register_dataclass +jax.tree_util.register_dataclass( + PredictedStructureInfo, + data_fields=[f.name for f in dataclasses.fields(PredictedStructureInfo)], + meta_fields=[], +) + + @dataclasses.dataclass(frozen=True) class PolymerLigandBondInfo: """Contains information about polymer-ligand bonds.""" @@ -1193,7 +1222,13 @@ class PolymerLigandBondInfo: } -@jax.tree_util.register_dataclass +jax.tree_util.register_dataclass( + PolymerLigandBondInfo, + data_fields=[f.name for f in dataclasses.fields(PolymerLigandBondInfo)], + meta_fields=[], +) + + @dataclasses.dataclass(frozen=True) class LigandLigandBondInfo: """Contains information about the location of ligand-ligand bonds.""" @@ -1290,7 +1325,13 @@ class LigandLigandBondInfo: } -@jax.tree_util.register_dataclass +jax.tree_util.register_dataclass( + LigandLigandBondInfo, + data_fields=[f.name for f in dataclasses.fields(LigandLigandBondInfo)], + meta_fields=[], +) + + @dataclasses.dataclass(frozen=True) class PseudoBetaInfo: """Contains information for extracting pseudo-beta and equivalent atoms.""" @@ -1406,6 +1447,13 @@ class PseudoBetaInfo: } +jax.tree_util.register_dataclass( + PseudoBetaInfo, + data_fields=[f.name for f in dataclasses.fields(PseudoBetaInfo)], + meta_fields=[], +) + + _DEFAULT_BLANK_REF = { 'positions': np.zeros(3), 'mask': 0, @@ -1606,7 +1654,6 @@ def get_reference( return features, from_atom, dest_atom -@jax.tree_util.register_dataclass @dataclasses.dataclass(frozen=True) class RefStructure: """Contains ref structure information.""" @@ -1765,7 +1812,13 @@ class RefStructure: } -@jax.tree_util.register_dataclass +jax.tree_util.register_dataclass( + RefStructure, + data_fields=[f.name for f in dataclasses.fields(RefStructure)], + meta_fields=[], +) + + @dataclasses.dataclass(frozen=True) class ConvertModelOutput: """Contains atom layout info.""" @@ -1827,7 +1880,13 @@ class ConvertModelOutput: } -@jax.tree_util.register_dataclass +jax.tree_util.register_dataclass( + ConvertModelOutput, + data_fields=[f.name for f in dataclasses.fields(ConvertModelOutput)], + meta_fields=[], +) + + @dataclasses.dataclass(frozen=True) class AtomCrossAtt: """Operate on flat atoms.""" @@ -1972,7 +2031,13 @@ class AtomCrossAtt: } -@jax.tree_util.register_dataclass +jax.tree_util.register_dataclass( + AtomCrossAtt, + data_fields=[f.name for f in dataclasses.fields(AtomCrossAtt)], + meta_fields=[], +) + + @dataclasses.dataclass(frozen=True) class Frames: """Features for backbone frames.""" @@ -2090,3 +2155,10 @@ class Frames: def as_data_dict(self) -> BatchDict: return {'frames_mask': self.mask} + + +jax.tree_util.register_dataclass( + Frames, + data_fields=[f.name for f in dataclasses.fields(Frames)], + meta_fields=[], +) diff --git a/src/alphafold3/model/network/atom_cross_attention.py b/src/alphafold3/model/network/atom_cross_attention.py index 6b1af41..5833313 100644 --- a/src/alphafold3/model/network/atom_cross_attention.py +++ b/src/alphafold3/model/network/atom_cross_attention.py @@ -96,7 +96,6 @@ def _per_atom_conditioning( return act, pair_act -@jax.tree_util.register_dataclass @dataclasses.dataclass(frozen=True) class AtomCrossAttEncoderOutput: token_act: jnp.ndarray # (num_tokens, ch) @@ -108,6 +107,13 @@ class AtomCrossAttEncoderOutput: pair_cond: jnp.ndarray # (num_subsets, num_queries, num_keys, ch) +jax.tree_util.register_dataclass( + AtomCrossAttEncoderOutput, + data_fields=[f.name for f in dataclasses.fields(AtomCrossAttEncoderOutput)], + meta_fields=[], +) + + def atom_cross_att_encoder( token_atoms_act: jnp.ndarray | None, # (num_tokens, max_atoms_per_token, 3) trunk_single_cond: jnp.ndarray | None, # (num_tokens, ch)