Make register_dataclass compatible with JAX < v0.4.36

PiperOrigin-RevId: 765135084
Change-Id: Ifc89843c9706c38cbe6a9e2ed3c4f918b2c36954
This commit is contained in:
Augustin Zidek
2025-05-30 04:35:10 -07:00
committed by Copybara-Service
parent 6a34fea9a2
commit 5ecdfe883a
3 changed files with 98 additions and 14 deletions

View File

@ -16,7 +16,6 @@ from alphafold3.model import features
import jax import jax
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Batch: class Batch:
"""Dataclass containing batch.""" """Dataclass containing batch."""
@ -76,3 +75,10 @@ class Batch:
**self.frames.as_data_dict(), **self.frames.as_data_dict(),
} }
return output return output
jax.tree_util.register_dataclass(
Batch,
data_fields=[f.name for f in dataclasses.fields(Batch)],
meta_fields=[],
)

View File

@ -100,7 +100,6 @@ def _unwrap(obj):
return obj return obj
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Chains: class Chains:
chain_id: np.ndarray chain_id: np.ndarray
@ -109,6 +108,13 @@ class Chains:
sym_id: np.ndarray sym_id: np.ndarray
jax.tree_util.register_dataclass(
Chains,
data_fields=[f.name for f in dataclasses.fields(Chains)],
meta_fields=[],
)
def _compute_asym_entity_and_sym_id( def _compute_asym_entity_and_sym_id(
all_tokens: atom_layout.AtomLayout, all_tokens: atom_layout.AtomLayout,
) -> Chains: ) -> Chains:
@ -392,7 +398,6 @@ def tokenizer(
return all_tokens, all_token_atoms_layout, standard_token_idxs return all_tokens, all_token_atoms_layout, standard_token_idxs
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class MSA: class MSA:
"""Dataclass containing MSA.""" """Dataclass containing MSA."""
@ -689,7 +694,13 @@ class MSA:
} }
@jax.tree_util.register_dataclass jax.tree_util.register_dataclass(
MSA,
data_fields=[f.name for f in dataclasses.fields(MSA)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Templates: class Templates:
"""Dataclass containing templates.""" """Dataclass containing templates."""
@ -852,6 +863,13 @@ class Templates:
} }
jax.tree_util.register_dataclass(
Templates,
data_fields=[f.name for f in dataclasses.fields(Templates)],
meta_fields=[],
)
def _reduce_template_features( def _reduce_template_features(
template_features: data3.FeatureDict, template_features: data3.FeatureDict,
max_templates: int, max_templates: int,
@ -870,7 +888,6 @@ def _reduce_template_features(
return template_features return template_features
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class TokenFeatures: class TokenFeatures:
"""Dataclass containing features for tokens.""" """Dataclass containing features for tokens."""
@ -1013,7 +1030,13 @@ class TokenFeatures:
} }
@jax.tree_util.register_dataclass jax.tree_util.register_dataclass(
TokenFeatures,
data_fields=[f.name for f in dataclasses.fields(TokenFeatures)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class PredictedStructureInfo: class PredictedStructureInfo:
"""Contains information necessary to work with predicted structure.""" """Contains information necessary to work with predicted structure."""
@ -1076,7 +1099,13 @@ class PredictedStructureInfo:
} }
@jax.tree_util.register_dataclass jax.tree_util.register_dataclass(
PredictedStructureInfo,
data_fields=[f.name for f in dataclasses.fields(PredictedStructureInfo)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class PolymerLigandBondInfo: class PolymerLigandBondInfo:
"""Contains information about polymer-ligand bonds.""" """Contains information about polymer-ligand bonds."""
@ -1193,7 +1222,13 @@ class PolymerLigandBondInfo:
} }
@jax.tree_util.register_dataclass jax.tree_util.register_dataclass(
PolymerLigandBondInfo,
data_fields=[f.name for f in dataclasses.fields(PolymerLigandBondInfo)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class LigandLigandBondInfo: class LigandLigandBondInfo:
"""Contains information about the location of ligand-ligand bonds.""" """Contains information about the location of ligand-ligand bonds."""
@ -1290,7 +1325,13 @@ class LigandLigandBondInfo:
} }
@jax.tree_util.register_dataclass jax.tree_util.register_dataclass(
LigandLigandBondInfo,
data_fields=[f.name for f in dataclasses.fields(LigandLigandBondInfo)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class PseudoBetaInfo: class PseudoBetaInfo:
"""Contains information for extracting pseudo-beta and equivalent atoms.""" """Contains information for extracting pseudo-beta and equivalent atoms."""
@ -1406,6 +1447,13 @@ class PseudoBetaInfo:
} }
jax.tree_util.register_dataclass(
PseudoBetaInfo,
data_fields=[f.name for f in dataclasses.fields(PseudoBetaInfo)],
meta_fields=[],
)
_DEFAULT_BLANK_REF = { _DEFAULT_BLANK_REF = {
'positions': np.zeros(3), 'positions': np.zeros(3),
'mask': 0, 'mask': 0,
@ -1606,7 +1654,6 @@ def get_reference(
return features, from_atom, dest_atom return features, from_atom, dest_atom
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class RefStructure: class RefStructure:
"""Contains ref structure information.""" """Contains ref structure information."""
@ -1765,7 +1812,13 @@ class RefStructure:
} }
@jax.tree_util.register_dataclass jax.tree_util.register_dataclass(
RefStructure,
data_fields=[f.name for f in dataclasses.fields(RefStructure)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class ConvertModelOutput: class ConvertModelOutput:
"""Contains atom layout info.""" """Contains atom layout info."""
@ -1827,7 +1880,13 @@ class ConvertModelOutput:
} }
@jax.tree_util.register_dataclass jax.tree_util.register_dataclass(
ConvertModelOutput,
data_fields=[f.name for f in dataclasses.fields(ConvertModelOutput)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class AtomCrossAtt: class AtomCrossAtt:
"""Operate on flat atoms.""" """Operate on flat atoms."""
@ -1972,7 +2031,13 @@ class AtomCrossAtt:
} }
@jax.tree_util.register_dataclass jax.tree_util.register_dataclass(
AtomCrossAtt,
data_fields=[f.name for f in dataclasses.fields(AtomCrossAtt)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Frames: class Frames:
"""Features for backbone frames.""" """Features for backbone frames."""
@ -2090,3 +2155,10 @@ class Frames:
def as_data_dict(self) -> BatchDict: def as_data_dict(self) -> BatchDict:
return {'frames_mask': self.mask} return {'frames_mask': self.mask}
jax.tree_util.register_dataclass(
Frames,
data_fields=[f.name for f in dataclasses.fields(Frames)],
meta_fields=[],
)

View File

@ -96,7 +96,6 @@ def _per_atom_conditioning(
return act, pair_act return act, pair_act
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class AtomCrossAttEncoderOutput: class AtomCrossAttEncoderOutput:
token_act: jnp.ndarray # (num_tokens, ch) token_act: jnp.ndarray # (num_tokens, ch)
@ -108,6 +107,13 @@ class AtomCrossAttEncoderOutput:
pair_cond: jnp.ndarray # (num_subsets, num_queries, num_keys, ch) pair_cond: jnp.ndarray # (num_subsets, num_queries, num_keys, ch)
jax.tree_util.register_dataclass(
AtomCrossAttEncoderOutput,
data_fields=[f.name for f in dataclasses.fields(AtomCrossAttEncoderOutput)],
meta_fields=[],
)
def atom_cross_att_encoder( def atom_cross_att_encoder(
token_atoms_act: jnp.ndarray | None, # (num_tokens, max_atoms_per_token, 3) token_atoms_act: jnp.ndarray | None, # (num_tokens, max_atoms_per_token, 3)
trunk_single_cond: jnp.ndarray | None, # (num_tokens, ch) trunk_single_cond: jnp.ndarray | None, # (num_tokens, ch)