Add support for protein/DNA/RNA/ligand descriptions

Suggested in https://github.com/google-deepmind/alphafold3/issues/496.

PiperOrigin-RevId: 802081348
Change-Id: I666466fd6a770b6f4a891ed33e6a26651d600c4a
This commit is contained in:
Augustin Zidek
2025-09-02 04:11:09 -07:00
committed by Copybara-Service
parent a2b03dab51
commit 7b816f4035
2 changed files with 120 additions and 22 deletions

View File

@ -117,7 +117,7 @@ The top-level structure of the input JSON is:
"userCCD": "...", # Optional, mutually exclusive with userCCDPath.
"userCCDPath": "...", # Optional, mutually exclusive with userCCD.
"dialect": "alphafold3", # Required.
"version": 3 # Required.
"version": 4 # Required.
}
```
@ -166,6 +166,8 @@ The top-level `version` field (for the `alphafold3` dialect) can be either `1`,
added fields `unpairedMsaPath`, `pairedMsaPath`, and `mmcifPath`.
* `3`: added the option of specifying external user-provided CCD using newly
added field `userCCDPath`.
* `4`: added the option of specifying textual `description` of protein chains,
RNA chains, DNA chains, or ligands.
## Sequences
@ -186,6 +188,7 @@ Specifies a single protein chain.
{"ptmType": "HY3", "ptmPosition": 1},
{"ptmType": "P1L", "ptmPosition": 5}
],
"description": ..., # Optional.
"unpairedMsa": ..., # Mutually exclusive with unpairedMsaPath.
"unpairedMsaPath": ..., # Mutually exclusive with unpairedMsa.
"pairedMsa": ..., # Mutually exclusive with pairedMsaPath.
@ -207,6 +210,9 @@ The fields specify the following:
post-translational modifications. Each modification is specified using its
CCD code and 1-based residue position. In the example above, we see that the
first residue won't be a proline (`P`) but instead `HY3`.
* `description: str`: An optional textual description of this chain. This
field will is only used in the JSON format and serves as a comment
describing this chain.
* `unpairedMsa: str`: An optional multiple sequence alignment for this chain.
This is specified using the A3M format (equivalent to the FASTA format, but
also allows gaps denoted by the hyphen `-` character). See more details
@ -239,6 +245,7 @@ Specifies a single RNA chain.
{"modificationType": "2MG", "basePosition": 1},
{"modificationType": "5MC", "basePosition": 4}
],
"description": ..., # Optional.
"unpairedMsa": ..., # Mutually exclusive with unpairedMsaPath.
"unpairedMsaPath": ... # Mutually exclusive with unpairedMsa.
}
@ -255,6 +262,9 @@ The fields specify the following:
letters `A`, `C`, `G`, `U`.
* `modifications: list[RnaModification]`: An optional list of modifications.
Each modification is specified using its CCD code and 1-based base position.
* `description: str`: An optional textual description of this chain. This
field will is only used in the JSON format and serves as a comment
describing this chain.
* `unpairedMsa: str`: An optional multiple sequence alignment for this chain.
This is specified using the A3M format. See more details below.
* `unpairedMsaPath: str`: An optional path to a file that contains the
@ -275,7 +285,8 @@ Specifies a single DNA chain.
"modifications": [
{"modificationType": "6OG", "basePosition": 1},
{"modificationType": "6MA", "basePosition": 2}
]
],
"description": ... # Optional.
}
}
```
@ -290,6 +301,9 @@ The fields specify the following:
letters `A`, `C`, `G`, `T`.
* `modifications: list[DnaModification]`: An optional list of modifications.
Each modification is specified using its CCD code and 1-based base position.
* `description: str`: An optional textual description of this chain. This
field will is only used in the JSON format and serves as a comment
describing this chain.
### Ligands
@ -314,19 +328,22 @@ Specifies a single ligand. Ligands can be specified using 3 different formats:
{
"ligand": {
"id": ["G", "H", "I"],
"ccdCodes": ["ATP"]
"ccdCodes": ["ATP"],
"description": ... # Optional.
}
},
{
"ligand": {
"id": "J",
"ccdCodes": ["LIG-1337"]
"ccdCodes": ["LIG-1337"],
"description": ... # Optional.
}
},
{
"ligand": {
"id": "K",
"smiles": "CC(=O)OC1C[NH+]2CCC1CC2"
"smiles": "CC(=O)OC1C[NH+]2CCC1CC2",
"description": ... # Optional.
}
}
```
@ -342,6 +359,9 @@ The fields specify the following:
[user-provided CCD](#user-provided-ccd).
* `smiles: str`: An optional string defining the ligand using a SMILES string.
The SMILES string must be correctly JSON-escaped.
* `description: str`: An optional textual description of this chain. This
field will is only used in the JSON format and serves as a comment
describing this ligand.
Each ligand may be specified using CCD codes or SMILES but not both, i.e. for a
given ligand, the `ccdCodes` and `smiles` fields are mutually exclusive.
@ -919,6 +939,7 @@ certain fields and the sequences are not biologically meaningful.
{"ptmType": "HY3", "ptmPosition": 1},
{"ptmType": "P1L", "ptmPosition": 5}
],
"description": "10-residue protein with 2 modifications",
"unpairedMsa": ...,
"pairedMsa": ""
}
@ -982,7 +1003,6 @@ certain fields and the sequences are not biologically meaningful.
],
"userCCD": ...,
"dialect": "alphafold3",
"version": 3
"version": 4
}
```

View File

@ -36,7 +36,7 @@ import zstandard as zstd
BondAtomId: TypeAlias = tuple[str, int, str]
JSON_DIALECT: Final[str] = 'alphafold3'
JSON_VERSIONS: Final[tuple[int, ...]] = (1, 2, 3)
JSON_VERSIONS: Final[tuple[int, ...]] = (1, 2, 3, 4)
JSON_VERSION: Final[int] = JSON_VERSIONS[-1]
ALPHAFOLDSERVER_JSON_DIALECT: Final[str] = 'alphafoldserver'
@ -127,6 +127,7 @@ class ProteinChain:
'_id',
'_sequence',
'_ptms',
'_description',
'_paired_msa',
'_unpaired_msa',
'_templates',
@ -138,6 +139,7 @@ class ProteinChain:
id: str, # pylint: disable=redefined-builtin
sequence: str,
ptms: Sequence[tuple[str, int]],
description: str | None = None,
paired_msa: str | None = None,
unpaired_msa: str | None = None,
templates: Sequence[Template] | None = None,
@ -149,6 +151,7 @@ class ProteinChain:
sequence: The amino acid sequence of the chain.
ptms: A list of tuples containing the post-translational modification type
and the (1-based) residue index where the modification is applied.
description: An optional textual description of the protein chain.
paired_msa: Paired A3M-formatted MSA for this chain. This MSA is not
deduplicated and will be used to compute paired features. If None, this
field is unset and must be filled in by the data pipeline before
@ -175,6 +178,7 @@ class ProteinChain:
self._id = id
self._sequence = sequence
self._ptms = tuple(ptms)
self._description = description
self._paired_msa = paired_msa
self._unpaired_msa = unpaired_msa
self._templates = tuple(templates) if templates is not None else None
@ -198,6 +202,10 @@ class ProteinChain:
def ptms(self) -> Sequence[tuple[str, int]]:
return self._ptms
@property
def description(self) -> str | None:
return self._description
@property
def paired_msa(self) -> str | None:
return self._paired_msa
@ -218,6 +226,7 @@ class ProteinChain:
self._id == other._id
and self._sequence == other._sequence
and self._ptms == other._ptms
and self._description == other._description
and self._paired_msa == other._paired_msa
and self._unpaired_msa == other._unpaired_msa
and self._templates == other._templates
@ -228,6 +237,7 @@ class ProteinChain:
self._id,
self._sequence,
self._ptms,
self._description,
self._paired_msa,
self._unpaired_msa,
self._templates,
@ -238,6 +248,7 @@ class ProteinChain:
return hash((
self._sequence,
self._ptms,
self._description,
self._paired_msa,
self._unpaired_msa,
self._templates,
@ -298,6 +309,7 @@ class ProteinChain:
'id',
'sequence',
'modifications',
'description',
'unpairedMsa',
'unpairedMsaPath',
'pairedMsa',
@ -368,6 +380,7 @@ class ProteinChain:
id=seq_id or json_dict['id'],
sequence=sequence,
ptms=ptms,
description=json_dict.get('description', None),
paired_msa=paired_msa,
unpaired_msa=unpaired_msa,
templates=templates,
@ -400,6 +413,8 @@ class ProteinChain:
'pairedMsa': self._paired_msa,
'templates': templates,
}
if self._description is not None:
contents['description'] = self._description
return {'protein': contents}
def to_ccd_sequence(self) -> Sequence[str]:
@ -418,6 +433,7 @@ class ProteinChain:
id=self.id,
sequence=self._sequence,
ptms=self._ptms,
description=self._description,
unpaired_msa=self._unpaired_msa or '',
paired_msa=self._paired_msa or '',
templates=self._templates or [],
@ -427,7 +443,13 @@ class ProteinChain:
class RnaChain:
"""RNA chain input."""
__slots__ = ('_id', '_sequence', '_modifications', '_unpaired_msa')
__slots__ = (
'_id',
'_sequence',
'_modifications',
'_description',
'_unpaired_msa',
)
def __init__(
self,
@ -435,6 +457,7 @@ class RnaChain:
id: str, # pylint: disable=redefined-builtin
sequence: str,
modifications: Sequence[tuple[str, int]],
description: str | None = None,
unpaired_msa: str | None = None,
):
"""Initializes a single strand RNA chain input.
@ -444,6 +467,7 @@ class RnaChain:
sequence: The RNA sequence of the chain.
modifications: A list of tuples containing the modification type and the
(1-based) residue index where the modification is applied.
description: An optional textual description of the RNA chain.
unpaired_msa: Unpaired A3M-formatted MSA for this chain. This will be
deduplicated and used to compute unpaired features. If None, this field
is unset and must be filled in by the data pipeline before
@ -463,6 +487,7 @@ class RnaChain:
self._sequence = sequence
# Use hashable container for modifications.
self._modifications = tuple(modifications)
self._description = description
self._unpaired_msa = unpaired_msa
@property
@ -484,6 +509,10 @@ class RnaChain:
def modifications(self) -> Sequence[tuple[str, int]]:
return self._modifications
@property
def description(self) -> str | None:
return self._description
@property
def unpaired_msa(self) -> str | None:
return self._unpaired_msa
@ -496,17 +525,27 @@ class RnaChain:
self._id == other._id
and self._sequence == other._sequence
and self._modifications == other._modifications
and self._description == other._description
and self._unpaired_msa == other._unpaired_msa
)
def __hash__(self) -> int:
return hash(
(self._id, self._sequence, self._modifications, self._unpaired_msa)
)
return hash((
self._id,
self._sequence,
self._modifications,
self._description,
self._unpaired_msa,
))
def hash_without_id(self) -> int:
"""Returns a hash ignoring the ID - useful for deduplication."""
return hash((self._sequence, self._modifications, self._unpaired_msa))
return hash((
self._sequence,
self._modifications,
self._description,
self._unpaired_msa,
))
@classmethod
def from_alphafoldserver_dict(
@ -532,7 +571,14 @@ class RnaChain:
json_dict = json_dict['rna']
_validate_keys(
json_dict.keys(),
{'id', 'sequence', 'unpairedMsa', 'unpairedMsaPath', 'modifications'},
{
'id',
'sequence',
'modifications',
'description',
'unpairedMsa',
'unpairedMsaPath',
},
)
sequence = json_dict['sequence']
modifications = [
@ -559,6 +605,7 @@ class RnaChain:
id=seq_id or json_dict['id'],
sequence=sequence,
modifications=modifications,
description=json_dict.get('description', None),
unpaired_msa=unpaired_msa,
)
@ -575,6 +622,8 @@ class RnaChain:
],
'unpairedMsa': self._unpaired_msa,
}
if self._description is not None:
contents['description'] = self._description
return {'rna': contents}
def to_ccd_sequence(self) -> Sequence[str]:
@ -600,7 +649,7 @@ class RnaChain:
class DnaChain:
"""Single strand DNA chain input."""
__slots__ = ('_id', '_sequence', '_modifications')
__slots__ = ('_id', '_sequence', '_modifications', '_description')
def __init__(
self,
@ -608,6 +657,7 @@ class DnaChain:
id: str, # pylint: disable=redefined-builtin
sequence: str,
modifications: Sequence[tuple[str, int]],
description: str | None = None,
):
"""Initializes a single strand DNA chain input.
@ -616,6 +666,7 @@ class DnaChain:
sequence: The DNA sequence of the chain.
modifications: A list of tuples containing the modification type and the
(1-based) residue index where the modification is applied.
description: An optional textual description of the DNA chain.
"""
if not all(res.isalpha() for res in sequence):
raise ValueError(f'DNA must contain only letters, got "{sequence}"')
@ -630,6 +681,7 @@ class DnaChain:
self._sequence = sequence
# Use hashable container for modifications.
self._modifications = tuple(modifications)
self._description = description
@property
def id(self) -> str:
@ -646,6 +698,10 @@ class DnaChain:
for r in self.to_ccd_sequence()
])
@property
def description(self) -> str | None:
return self._description
def __len__(self) -> int:
return len(self._sequence)
@ -654,17 +710,20 @@ class DnaChain:
self._id == other._id
and self._sequence == other._sequence
and self._modifications == other._modifications
and self._description == other._description
)
def __hash__(self) -> int:
return hash((self._id, self._sequence, self._modifications))
return hash(
(self._id, self._sequence, self._modifications, self._description)
)
def modifications(self) -> Sequence[tuple[str, int]]:
return self._modifications
def hash_without_id(self) -> int:
"""Returns a hash ignoring the ID - useful for deduplication."""
return hash((self._sequence, self._modifications))
return hash((self._sequence, self._modifications, self._description))
@classmethod
def from_alphafoldserver_dict(
@ -685,7 +744,9 @@ class DnaChain:
) -> Self:
"""Constructs DnaChain from the AlphaFold JSON dict."""
json_dict = json_dict['dna']
_validate_keys(json_dict.keys(), {'id', 'sequence', 'modifications'})
_validate_keys(
json_dict.keys(), {'id', 'sequence', 'modifications', 'description'}
)
sequence = json_dict['sequence']
modifications = [
(mod['modificationType'], mod['basePosition'])
@ -695,6 +756,7 @@ class DnaChain:
id=seq_id or json_dict['id'],
sequence=sequence,
modifications=modifications,
description=json_dict.get('description', None),
)
def to_dict(
@ -709,6 +771,8 @@ class DnaChain:
for mod in self._modifications
],
}
if self._description is not None:
contents['description'] = self._description
return {'dna': contents}
def to_ccd_sequence(self) -> Sequence[str]:
@ -734,11 +798,13 @@ class Ligand:
a bond linking these components should be added to the bonded_atom_pairs
Input field.
smiles: The SMILES representation of the ligand.
description: An optional textual description of the ligand.
"""
id: str
ccd_ids: Sequence[str] | None = None
smiles: str | None = None
description: str | None = None
def __post_init__(self):
if (self.ccd_ids is None) == (self.smiles is None):
@ -761,7 +827,7 @@ class Ligand:
def hash_without_id(self) -> int:
"""Returns a hash ignoring the ID - useful for deduplication."""
return hash((self.ccd_ids, self.smiles))
return hash((self.ccd_ids, self.smiles, self.description))
@classmethod
def from_alphafoldserver_dict(
@ -783,7 +849,9 @@ class Ligand:
) -> Self:
"""Constructs Ligand from the AlphaFold JSON dict."""
json_dict = json_dict['ligand']
_validate_keys(json_dict.keys(), {'id', 'ccdCodes', 'smiles'})
_validate_keys(
json_dict.keys(), {'id', 'ccdCodes', 'smiles', 'description'}
)
if json_dict.get('ccdCodes') and json_dict.get('smiles'):
raise ValueError(
'Ligand cannot have both CCD code and SMILES set at the same time, '
@ -797,9 +865,17 @@ class Ligand:
'CCD codes must be a list of strings, got '
f'{type(ccd_codes).__name__} instead: {ccd_codes}'
)
return cls(id=seq_id or json_dict['id'], ccd_ids=ccd_codes)
return cls(
id=seq_id or json_dict['id'],
ccd_ids=ccd_codes,
description=json_dict.get('description', None),
)
elif 'smiles' in json_dict:
return cls(id=seq_id or json_dict['id'], smiles=json_dict['smiles'])
return cls(
id=seq_id or json_dict['id'],
smiles=json_dict['smiles'],
description=json_dict.get('description', None),
)
else:
raise ValueError(f'Unknown ligand type: {json_dict}')
@ -812,6 +888,8 @@ class Ligand:
contents['ccdCodes'] = self.ccd_ids
if self.smiles is not None:
contents['smiles'] = self.smiles
if self.description is not None:
contents['description'] = self.description
return {'ligand': contents}