mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2025-10-20 13:23:47 +08:00
Make register_dataclass compatible with JAX < v0.4.36
PiperOrigin-RevId: 765135084 Change-Id: Ifc89843c9706c38cbe6a9e2ed3c4f918b2c36954
This commit is contained in:
committed by
Copybara-Service
parent
6a34fea9a2
commit
5ecdfe883a
@ -16,7 +16,6 @@ from alphafold3.model import features
|
|||||||
import jax
|
import jax
|
||||||
|
|
||||||
|
|
||||||
@jax.tree_util.register_dataclass
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class Batch:
|
class Batch:
|
||||||
"""Dataclass containing batch."""
|
"""Dataclass containing batch."""
|
||||||
@ -76,3 +75,10 @@ class Batch:
|
|||||||
**self.frames.as_data_dict(),
|
**self.frames.as_data_dict(),
|
||||||
}
|
}
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
jax.tree_util.register_dataclass(
|
||||||
|
Batch,
|
||||||
|
data_fields=[f.name for f in dataclasses.fields(Batch)],
|
||||||
|
meta_fields=[],
|
||||||
|
)
|
||||||
|
@ -100,7 +100,6 @@ def _unwrap(obj):
|
|||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
@jax.tree_util.register_dataclass
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class Chains:
|
class Chains:
|
||||||
chain_id: np.ndarray
|
chain_id: np.ndarray
|
||||||
@ -109,6 +108,13 @@ class Chains:
|
|||||||
sym_id: np.ndarray
|
sym_id: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
|
jax.tree_util.register_dataclass(
|
||||||
|
Chains,
|
||||||
|
data_fields=[f.name for f in dataclasses.fields(Chains)],
|
||||||
|
meta_fields=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _compute_asym_entity_and_sym_id(
|
def _compute_asym_entity_and_sym_id(
|
||||||
all_tokens: atom_layout.AtomLayout,
|
all_tokens: atom_layout.AtomLayout,
|
||||||
) -> Chains:
|
) -> Chains:
|
||||||
@ -392,7 +398,6 @@ def tokenizer(
|
|||||||
return all_tokens, all_token_atoms_layout, standard_token_idxs
|
return all_tokens, all_token_atoms_layout, standard_token_idxs
|
||||||
|
|
||||||
|
|
||||||
@jax.tree_util.register_dataclass
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class MSA:
|
class MSA:
|
||||||
"""Dataclass containing MSA."""
|
"""Dataclass containing MSA."""
|
||||||
@ -689,7 +694,13 @@ class MSA:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@jax.tree_util.register_dataclass
|
jax.tree_util.register_dataclass(
|
||||||
|
MSA,
|
||||||
|
data_fields=[f.name for f in dataclasses.fields(MSA)],
|
||||||
|
meta_fields=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class Templates:
|
class Templates:
|
||||||
"""Dataclass containing templates."""
|
"""Dataclass containing templates."""
|
||||||
@ -852,6 +863,13 @@ class Templates:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
jax.tree_util.register_dataclass(
|
||||||
|
Templates,
|
||||||
|
data_fields=[f.name for f in dataclasses.fields(Templates)],
|
||||||
|
meta_fields=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _reduce_template_features(
|
def _reduce_template_features(
|
||||||
template_features: data3.FeatureDict,
|
template_features: data3.FeatureDict,
|
||||||
max_templates: int,
|
max_templates: int,
|
||||||
@ -870,7 +888,6 @@ def _reduce_template_features(
|
|||||||
return template_features
|
return template_features
|
||||||
|
|
||||||
|
|
||||||
@jax.tree_util.register_dataclass
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class TokenFeatures:
|
class TokenFeatures:
|
||||||
"""Dataclass containing features for tokens."""
|
"""Dataclass containing features for tokens."""
|
||||||
@ -1013,7 +1030,13 @@ class TokenFeatures:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@jax.tree_util.register_dataclass
|
jax.tree_util.register_dataclass(
|
||||||
|
TokenFeatures,
|
||||||
|
data_fields=[f.name for f in dataclasses.fields(TokenFeatures)],
|
||||||
|
meta_fields=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class PredictedStructureInfo:
|
class PredictedStructureInfo:
|
||||||
"""Contains information necessary to work with predicted structure."""
|
"""Contains information necessary to work with predicted structure."""
|
||||||
@ -1076,7 +1099,13 @@ class PredictedStructureInfo:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@jax.tree_util.register_dataclass
|
jax.tree_util.register_dataclass(
|
||||||
|
PredictedStructureInfo,
|
||||||
|
data_fields=[f.name for f in dataclasses.fields(PredictedStructureInfo)],
|
||||||
|
meta_fields=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class PolymerLigandBondInfo:
|
class PolymerLigandBondInfo:
|
||||||
"""Contains information about polymer-ligand bonds."""
|
"""Contains information about polymer-ligand bonds."""
|
||||||
@ -1193,7 +1222,13 @@ class PolymerLigandBondInfo:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@jax.tree_util.register_dataclass
|
jax.tree_util.register_dataclass(
|
||||||
|
PolymerLigandBondInfo,
|
||||||
|
data_fields=[f.name for f in dataclasses.fields(PolymerLigandBondInfo)],
|
||||||
|
meta_fields=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class LigandLigandBondInfo:
|
class LigandLigandBondInfo:
|
||||||
"""Contains information about the location of ligand-ligand bonds."""
|
"""Contains information about the location of ligand-ligand bonds."""
|
||||||
@ -1290,7 +1325,13 @@ class LigandLigandBondInfo:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@jax.tree_util.register_dataclass
|
jax.tree_util.register_dataclass(
|
||||||
|
LigandLigandBondInfo,
|
||||||
|
data_fields=[f.name for f in dataclasses.fields(LigandLigandBondInfo)],
|
||||||
|
meta_fields=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class PseudoBetaInfo:
|
class PseudoBetaInfo:
|
||||||
"""Contains information for extracting pseudo-beta and equivalent atoms."""
|
"""Contains information for extracting pseudo-beta and equivalent atoms."""
|
||||||
@ -1406,6 +1447,13 @@ class PseudoBetaInfo:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
jax.tree_util.register_dataclass(
|
||||||
|
PseudoBetaInfo,
|
||||||
|
data_fields=[f.name for f in dataclasses.fields(PseudoBetaInfo)],
|
||||||
|
meta_fields=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_DEFAULT_BLANK_REF = {
|
_DEFAULT_BLANK_REF = {
|
||||||
'positions': np.zeros(3),
|
'positions': np.zeros(3),
|
||||||
'mask': 0,
|
'mask': 0,
|
||||||
@ -1606,7 +1654,6 @@ def get_reference(
|
|||||||
return features, from_atom, dest_atom
|
return features, from_atom, dest_atom
|
||||||
|
|
||||||
|
|
||||||
@jax.tree_util.register_dataclass
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class RefStructure:
|
class RefStructure:
|
||||||
"""Contains ref structure information."""
|
"""Contains ref structure information."""
|
||||||
@ -1765,7 +1812,13 @@ class RefStructure:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@jax.tree_util.register_dataclass
|
jax.tree_util.register_dataclass(
|
||||||
|
RefStructure,
|
||||||
|
data_fields=[f.name for f in dataclasses.fields(RefStructure)],
|
||||||
|
meta_fields=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class ConvertModelOutput:
|
class ConvertModelOutput:
|
||||||
"""Contains atom layout info."""
|
"""Contains atom layout info."""
|
||||||
@ -1827,7 +1880,13 @@ class ConvertModelOutput:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@jax.tree_util.register_dataclass
|
jax.tree_util.register_dataclass(
|
||||||
|
ConvertModelOutput,
|
||||||
|
data_fields=[f.name for f in dataclasses.fields(ConvertModelOutput)],
|
||||||
|
meta_fields=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class AtomCrossAtt:
|
class AtomCrossAtt:
|
||||||
"""Operate on flat atoms."""
|
"""Operate on flat atoms."""
|
||||||
@ -1972,7 +2031,13 @@ class AtomCrossAtt:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@jax.tree_util.register_dataclass
|
jax.tree_util.register_dataclass(
|
||||||
|
AtomCrossAtt,
|
||||||
|
data_fields=[f.name for f in dataclasses.fields(AtomCrossAtt)],
|
||||||
|
meta_fields=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class Frames:
|
class Frames:
|
||||||
"""Features for backbone frames."""
|
"""Features for backbone frames."""
|
||||||
@ -2090,3 +2155,10 @@ class Frames:
|
|||||||
|
|
||||||
def as_data_dict(self) -> BatchDict:
|
def as_data_dict(self) -> BatchDict:
|
||||||
return {'frames_mask': self.mask}
|
return {'frames_mask': self.mask}
|
||||||
|
|
||||||
|
|
||||||
|
jax.tree_util.register_dataclass(
|
||||||
|
Frames,
|
||||||
|
data_fields=[f.name for f in dataclasses.fields(Frames)],
|
||||||
|
meta_fields=[],
|
||||||
|
)
|
||||||
|
@ -96,7 +96,6 @@ def _per_atom_conditioning(
|
|||||||
return act, pair_act
|
return act, pair_act
|
||||||
|
|
||||||
|
|
||||||
@jax.tree_util.register_dataclass
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class AtomCrossAttEncoderOutput:
|
class AtomCrossAttEncoderOutput:
|
||||||
token_act: jnp.ndarray # (num_tokens, ch)
|
token_act: jnp.ndarray # (num_tokens, ch)
|
||||||
@ -108,6 +107,13 @@ class AtomCrossAttEncoderOutput:
|
|||||||
pair_cond: jnp.ndarray # (num_subsets, num_queries, num_keys, ch)
|
pair_cond: jnp.ndarray # (num_subsets, num_queries, num_keys, ch)
|
||||||
|
|
||||||
|
|
||||||
|
jax.tree_util.register_dataclass(
|
||||||
|
AtomCrossAttEncoderOutput,
|
||||||
|
data_fields=[f.name for f in dataclasses.fields(AtomCrossAttEncoderOutput)],
|
||||||
|
meta_fields=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def atom_cross_att_encoder(
|
def atom_cross_att_encoder(
|
||||||
token_atoms_act: jnp.ndarray | None, # (num_tokens, max_atoms_per_token, 3)
|
token_atoms_act: jnp.ndarray | None, # (num_tokens, max_atoms_per_token, 3)
|
||||||
trunk_single_cond: jnp.ndarray | None, # (num_tokens, ch)
|
trunk_single_cond: jnp.ndarray | None, # (num_tokens, ch)
|
||||||
|
Reference in New Issue
Block a user