Make register_dataclass compatible with JAX < v0.4.36

PiperOrigin-RevId: 765135084
Change-Id: Ifc89843c9706c38cbe6a9e2ed3c4f918b2c36954
This commit is contained in:
Augustin Zidek
2025-05-30 04:35:10 -07:00
committed by Copybara-Service
parent 6a34fea9a2
commit 5ecdfe883a
3 changed files with 98 additions and 14 deletions

View File

@ -16,7 +16,6 @@ from alphafold3.model import features
import jax
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class Batch:
"""Dataclass containing batch."""
@ -76,3 +75,10 @@ class Batch:
**self.frames.as_data_dict(),
}
return output
jax.tree_util.register_dataclass(
Batch,
data_fields=[f.name for f in dataclasses.fields(Batch)],
meta_fields=[],
)

View File

@ -100,7 +100,6 @@ def _unwrap(obj):
return obj
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class Chains:
chain_id: np.ndarray
@ -109,6 +108,13 @@ class Chains:
sym_id: np.ndarray
jax.tree_util.register_dataclass(
Chains,
data_fields=[f.name for f in dataclasses.fields(Chains)],
meta_fields=[],
)
def _compute_asym_entity_and_sym_id(
all_tokens: atom_layout.AtomLayout,
) -> Chains:
@ -392,7 +398,6 @@ def tokenizer(
return all_tokens, all_token_atoms_layout, standard_token_idxs
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class MSA:
"""Dataclass containing MSA."""
@ -689,7 +694,13 @@ class MSA:
}
@jax.tree_util.register_dataclass
jax.tree_util.register_dataclass(
MSA,
data_fields=[f.name for f in dataclasses.fields(MSA)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True)
class Templates:
"""Dataclass containing templates."""
@ -852,6 +863,13 @@ class Templates:
}
jax.tree_util.register_dataclass(
Templates,
data_fields=[f.name for f in dataclasses.fields(Templates)],
meta_fields=[],
)
def _reduce_template_features(
template_features: data3.FeatureDict,
max_templates: int,
@ -870,7 +888,6 @@ def _reduce_template_features(
return template_features
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class TokenFeatures:
"""Dataclass containing features for tokens."""
@ -1013,7 +1030,13 @@ class TokenFeatures:
}
@jax.tree_util.register_dataclass
jax.tree_util.register_dataclass(
TokenFeatures,
data_fields=[f.name for f in dataclasses.fields(TokenFeatures)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True)
class PredictedStructureInfo:
"""Contains information necessary to work with predicted structure."""
@ -1076,7 +1099,13 @@ class PredictedStructureInfo:
}
@jax.tree_util.register_dataclass
jax.tree_util.register_dataclass(
PredictedStructureInfo,
data_fields=[f.name for f in dataclasses.fields(PredictedStructureInfo)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True)
class PolymerLigandBondInfo:
"""Contains information about polymer-ligand bonds."""
@ -1193,7 +1222,13 @@ class PolymerLigandBondInfo:
}
@jax.tree_util.register_dataclass
jax.tree_util.register_dataclass(
PolymerLigandBondInfo,
data_fields=[f.name for f in dataclasses.fields(PolymerLigandBondInfo)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True)
class LigandLigandBondInfo:
"""Contains information about the location of ligand-ligand bonds."""
@ -1290,7 +1325,13 @@ class LigandLigandBondInfo:
}
@jax.tree_util.register_dataclass
jax.tree_util.register_dataclass(
LigandLigandBondInfo,
data_fields=[f.name for f in dataclasses.fields(LigandLigandBondInfo)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True)
class PseudoBetaInfo:
"""Contains information for extracting pseudo-beta and equivalent atoms."""
@ -1406,6 +1447,13 @@ class PseudoBetaInfo:
}
jax.tree_util.register_dataclass(
PseudoBetaInfo,
data_fields=[f.name for f in dataclasses.fields(PseudoBetaInfo)],
meta_fields=[],
)
_DEFAULT_BLANK_REF = {
'positions': np.zeros(3),
'mask': 0,
@ -1606,7 +1654,6 @@ def get_reference(
return features, from_atom, dest_atom
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class RefStructure:
"""Contains ref structure information."""
@ -1765,7 +1812,13 @@ class RefStructure:
}
@jax.tree_util.register_dataclass
jax.tree_util.register_dataclass(
RefStructure,
data_fields=[f.name for f in dataclasses.fields(RefStructure)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True)
class ConvertModelOutput:
"""Contains atom layout info."""
@ -1827,7 +1880,13 @@ class ConvertModelOutput:
}
@jax.tree_util.register_dataclass
jax.tree_util.register_dataclass(
ConvertModelOutput,
data_fields=[f.name for f in dataclasses.fields(ConvertModelOutput)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True)
class AtomCrossAtt:
"""Operate on flat atoms."""
@ -1972,7 +2031,13 @@ class AtomCrossAtt:
}
@jax.tree_util.register_dataclass
jax.tree_util.register_dataclass(
AtomCrossAtt,
data_fields=[f.name for f in dataclasses.fields(AtomCrossAtt)],
meta_fields=[],
)
@dataclasses.dataclass(frozen=True)
class Frames:
"""Features for backbone frames."""
@ -2090,3 +2155,10 @@ class Frames:
def as_data_dict(self) -> BatchDict:
return {'frames_mask': self.mask}
jax.tree_util.register_dataclass(
Frames,
data_fields=[f.name for f in dataclasses.fields(Frames)],
meta_fields=[],
)

View File

@ -96,7 +96,6 @@ def _per_atom_conditioning(
return act, pair_act
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class AtomCrossAttEncoderOutput:
token_act: jnp.ndarray # (num_tokens, ch)
@ -108,6 +107,13 @@ class AtomCrossAttEncoderOutput:
pair_cond: jnp.ndarray # (num_subsets, num_queries, num_keys, ch)
jax.tree_util.register_dataclass(
AtomCrossAttEncoderOutput,
data_fields=[f.name for f in dataclasses.fields(AtomCrossAttEncoderOutput)],
meta_fields=[],
)
def atom_cross_att_encoder(
token_atoms_act: jnp.ndarray | None, # (num_tokens, max_atoms_per_token, 3)
trunk_single_cond: jnp.ndarray | None, # (num_tokens, ch)