mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2025-10-20 13:23:47 +08:00
Make register_dataclass compatible with JAX < v0.4.36
PiperOrigin-RevId: 765135084 Change-Id: Ifc89843c9706c38cbe6a9e2ed3c4f918b2c36954
This commit is contained in:
committed by
Copybara-Service
parent
6a34fea9a2
commit
5ecdfe883a
@ -16,7 +16,6 @@ from alphafold3.model import features
|
||||
import jax
|
||||
|
||||
|
||||
@jax.tree_util.register_dataclass
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Batch:
|
||||
"""Dataclass containing batch."""
|
||||
@ -76,3 +75,10 @@ class Batch:
|
||||
**self.frames.as_data_dict(),
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
jax.tree_util.register_dataclass(
|
||||
Batch,
|
||||
data_fields=[f.name for f in dataclasses.fields(Batch)],
|
||||
meta_fields=[],
|
||||
)
|
||||
|
@ -100,7 +100,6 @@ def _unwrap(obj):
|
||||
return obj
|
||||
|
||||
|
||||
@jax.tree_util.register_dataclass
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Chains:
|
||||
chain_id: np.ndarray
|
||||
@ -109,6 +108,13 @@ class Chains:
|
||||
sym_id: np.ndarray
|
||||
|
||||
|
||||
jax.tree_util.register_dataclass(
|
||||
Chains,
|
||||
data_fields=[f.name for f in dataclasses.fields(Chains)],
|
||||
meta_fields=[],
|
||||
)
|
||||
|
||||
|
||||
def _compute_asym_entity_and_sym_id(
|
||||
all_tokens: atom_layout.AtomLayout,
|
||||
) -> Chains:
|
||||
@ -392,7 +398,6 @@ def tokenizer(
|
||||
return all_tokens, all_token_atoms_layout, standard_token_idxs
|
||||
|
||||
|
||||
@jax.tree_util.register_dataclass
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class MSA:
|
||||
"""Dataclass containing MSA."""
|
||||
@ -689,7 +694,13 @@ class MSA:
|
||||
}
|
||||
|
||||
|
||||
@jax.tree_util.register_dataclass
|
||||
jax.tree_util.register_dataclass(
|
||||
MSA,
|
||||
data_fields=[f.name for f in dataclasses.fields(MSA)],
|
||||
meta_fields=[],
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Templates:
|
||||
"""Dataclass containing templates."""
|
||||
@ -852,6 +863,13 @@ class Templates:
|
||||
}
|
||||
|
||||
|
||||
jax.tree_util.register_dataclass(
|
||||
Templates,
|
||||
data_fields=[f.name for f in dataclasses.fields(Templates)],
|
||||
meta_fields=[],
|
||||
)
|
||||
|
||||
|
||||
def _reduce_template_features(
|
||||
template_features: data3.FeatureDict,
|
||||
max_templates: int,
|
||||
@ -870,7 +888,6 @@ def _reduce_template_features(
|
||||
return template_features
|
||||
|
||||
|
||||
@jax.tree_util.register_dataclass
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TokenFeatures:
|
||||
"""Dataclass containing features for tokens."""
|
||||
@ -1013,7 +1030,13 @@ class TokenFeatures:
|
||||
}
|
||||
|
||||
|
||||
@jax.tree_util.register_dataclass
|
||||
jax.tree_util.register_dataclass(
|
||||
TokenFeatures,
|
||||
data_fields=[f.name for f in dataclasses.fields(TokenFeatures)],
|
||||
meta_fields=[],
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class PredictedStructureInfo:
|
||||
"""Contains information necessary to work with predicted structure."""
|
||||
@ -1076,7 +1099,13 @@ class PredictedStructureInfo:
|
||||
}
|
||||
|
||||
|
||||
@jax.tree_util.register_dataclass
|
||||
jax.tree_util.register_dataclass(
|
||||
PredictedStructureInfo,
|
||||
data_fields=[f.name for f in dataclasses.fields(PredictedStructureInfo)],
|
||||
meta_fields=[],
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class PolymerLigandBondInfo:
|
||||
"""Contains information about polymer-ligand bonds."""
|
||||
@ -1193,7 +1222,13 @@ class PolymerLigandBondInfo:
|
||||
}
|
||||
|
||||
|
||||
@jax.tree_util.register_dataclass
|
||||
jax.tree_util.register_dataclass(
|
||||
PolymerLigandBondInfo,
|
||||
data_fields=[f.name for f in dataclasses.fields(PolymerLigandBondInfo)],
|
||||
meta_fields=[],
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class LigandLigandBondInfo:
|
||||
"""Contains information about the location of ligand-ligand bonds."""
|
||||
@ -1290,7 +1325,13 @@ class LigandLigandBondInfo:
|
||||
}
|
||||
|
||||
|
||||
@jax.tree_util.register_dataclass
|
||||
jax.tree_util.register_dataclass(
|
||||
LigandLigandBondInfo,
|
||||
data_fields=[f.name for f in dataclasses.fields(LigandLigandBondInfo)],
|
||||
meta_fields=[],
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class PseudoBetaInfo:
|
||||
"""Contains information for extracting pseudo-beta and equivalent atoms."""
|
||||
@ -1406,6 +1447,13 @@ class PseudoBetaInfo:
|
||||
}
|
||||
|
||||
|
||||
jax.tree_util.register_dataclass(
|
||||
PseudoBetaInfo,
|
||||
data_fields=[f.name for f in dataclasses.fields(PseudoBetaInfo)],
|
||||
meta_fields=[],
|
||||
)
|
||||
|
||||
|
||||
_DEFAULT_BLANK_REF = {
|
||||
'positions': np.zeros(3),
|
||||
'mask': 0,
|
||||
@ -1606,7 +1654,6 @@ def get_reference(
|
||||
return features, from_atom, dest_atom
|
||||
|
||||
|
||||
@jax.tree_util.register_dataclass
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class RefStructure:
|
||||
"""Contains ref structure information."""
|
||||
@ -1765,7 +1812,13 @@ class RefStructure:
|
||||
}
|
||||
|
||||
|
||||
@jax.tree_util.register_dataclass
|
||||
jax.tree_util.register_dataclass(
|
||||
RefStructure,
|
||||
data_fields=[f.name for f in dataclasses.fields(RefStructure)],
|
||||
meta_fields=[],
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ConvertModelOutput:
|
||||
"""Contains atom layout info."""
|
||||
@ -1827,7 +1880,13 @@ class ConvertModelOutput:
|
||||
}
|
||||
|
||||
|
||||
@jax.tree_util.register_dataclass
|
||||
jax.tree_util.register_dataclass(
|
||||
ConvertModelOutput,
|
||||
data_fields=[f.name for f in dataclasses.fields(ConvertModelOutput)],
|
||||
meta_fields=[],
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class AtomCrossAtt:
|
||||
"""Operate on flat atoms."""
|
||||
@ -1972,7 +2031,13 @@ class AtomCrossAtt:
|
||||
}
|
||||
|
||||
|
||||
@jax.tree_util.register_dataclass
|
||||
jax.tree_util.register_dataclass(
|
||||
AtomCrossAtt,
|
||||
data_fields=[f.name for f in dataclasses.fields(AtomCrossAtt)],
|
||||
meta_fields=[],
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Frames:
|
||||
"""Features for backbone frames."""
|
||||
@ -2090,3 +2155,10 @@ class Frames:
|
||||
|
||||
def as_data_dict(self) -> BatchDict:
|
||||
return {'frames_mask': self.mask}
|
||||
|
||||
|
||||
jax.tree_util.register_dataclass(
|
||||
Frames,
|
||||
data_fields=[f.name for f in dataclasses.fields(Frames)],
|
||||
meta_fields=[],
|
||||
)
|
||||
|
@ -96,7 +96,6 @@ def _per_atom_conditioning(
|
||||
return act, pair_act
|
||||
|
||||
|
||||
@jax.tree_util.register_dataclass
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class AtomCrossAttEncoderOutput:
|
||||
token_act: jnp.ndarray # (num_tokens, ch)
|
||||
@ -108,6 +107,13 @@ class AtomCrossAttEncoderOutput:
|
||||
pair_cond: jnp.ndarray # (num_subsets, num_queries, num_keys, ch)
|
||||
|
||||
|
||||
jax.tree_util.register_dataclass(
|
||||
AtomCrossAttEncoderOutput,
|
||||
data_fields=[f.name for f in dataclasses.fields(AtomCrossAttEncoderOutput)],
|
||||
meta_fields=[],
|
||||
)
|
||||
|
||||
|
||||
def atom_cross_att_encoder(
|
||||
token_atoms_act: jnp.ndarray | None, # (num_tokens, max_atoms_per_token, 3)
|
||||
trunk_single_cond: jnp.ndarray | None, # (num_tokens, ch)
|
||||
|
Reference in New Issue
Block a user