Compare commits

10 Commits

Author SHA1 Message Date
3c89cc7b89 Improve error message and consistency of ordering
PiperOrigin-RevId: 820694312
Change-Id: I6b04514b1f838c49bf5dca1e138e1d6b2ff81139
2025-10-17 08:06:15 -07:00
142e4bc6e5 Fix residue mapping bug in _maybe_add_missing_scheme_tables
PiperOrigin-RevId: 817086740
Change-Id: Ib4a74efd410ec25f0cd92b6d6ffa23f684e660fb
2025-10-09 02:29:51 -07:00
03a6c295e5 Return b factors and occupancy from to_res_arrays()
Change to_res_arrays() to return a ResArrays dataclass which contains [num_res, num_atoms, *] shaped arrays for all atom-level data in a Structure.

PiperOrigin-RevId: 814245949
Change-Id: I38abeec96c8d4ed97a06964e6c7b7fc74f37f97d
2025-10-02 08:31:47 -07:00
a8ecdb2d7a Cache the underlying CCD dictionary instead of the whole CCD object.
* Enables removal of the chemical_components.cached_ccd function.
* If there are multiple custom CCDs (tiny in comparison to the full CCD) in a single run, this will consume only O(1) RAM instead of O(num_inputs).

Addresses:
* https://github.com/google-deepmind/alphafold3/pull/514
* https://github.com/google-deepmind/alphafold3/issues/509
PiperOrigin-RevId: 804321599
Change-Id: I484f851ce79cd6b1c5c31bf70b6cd945791f8b66
2025-09-08 01:44:10 -07:00
85a03ec086 Use the CPU count available to the process instead of total CPU count
Addresses https://github.com/google-deepmind/alphafold3/issues/513

PiperOrigin-RevId: 803078244
Change-Id: I6446deac260f13ad08873d1c464e65444c099283
2025-09-04 10:20:38 -07:00
b467f92160 Improve validation of CIF files.
* Make sure the data name in `data_<name>` is non-empty.
* Check for duplicate key names.
* Check that the multi-line tokens are closed at the end of the file.
* Fix also two broken mmCIF files this check has uncovered.

PiperOrigin-RevId: 802572236
Change-Id: Ie7a3a5ec816ec5b97158508cc1b12064cf0e70a8
2025-09-03 08:51:00 -07:00
4208665547 Make sure Structure doesn't have a name that is an empty string
PiperOrigin-RevId: 802558849
Change-Id: Ic469eec11e92d14df339d3a3d9205f144dd005fc
2025-09-03 08:10:28 -07:00
0ea324ce74 Validate that a non-loop CIF key has only a single value
PiperOrigin-RevId: 802455341
Change-Id: I089c4e0a9a52862b5eba6d26023184a4a34efc96
2025-09-03 01:48:24 -07:00
7b816f4035 Add support for protein/DNA/RNA/ligand descriptions
Suggested in https://github.com/google-deepmind/alphafold3/issues/496.

PiperOrigin-RevId: 802081348
Change-Id: I666466fd6a770b6f4a891ed33e6a26651d600c4a
2025-09-02 04:11:53 -07:00
a2b03dab51 Validate the number of tokens in a CIF loop
PiperOrigin-RevId: 797771843
Change-Id: Ib34796c2e9f402cd2cad79d67127ddb32f0803ae
2025-08-21 08:17:40 -07:00
11 changed files with 239 additions and 59 deletions

View File

@ -117,7 +117,7 @@ The top-level structure of the input JSON is:
"userCCD": "...", # Optional, mutually exclusive with userCCDPath.
"userCCDPath": "...", # Optional, mutually exclusive with userCCD.
"dialect": "alphafold3", # Required.
"version": 3 # Required.
"version": 4 # Required.
}
```
@ -166,6 +166,8 @@ The top-level `version` field (for the `alphafold3` dialect) can be either `1`,
added fields `unpairedMsaPath`, `pairedMsaPath`, and `mmcifPath`.
* `3`: added the option of specifying external user-provided CCD using newly
added field `userCCDPath`.
* `4`: added the option of specifying textual `description` of protein chains,
RNA chains, DNA chains, or ligands.
## Sequences
@ -186,6 +188,7 @@ Specifies a single protein chain.
{"ptmType": "HY3", "ptmPosition": 1},
{"ptmType": "P1L", "ptmPosition": 5}
],
"description": ..., # Optional.
"unpairedMsa": ..., # Mutually exclusive with unpairedMsaPath.
"unpairedMsaPath": ..., # Mutually exclusive with unpairedMsa.
"pairedMsa": ..., # Mutually exclusive with pairedMsaPath.
@ -207,6 +210,9 @@ The fields specify the following:
post-translational modifications. Each modification is specified using its
CCD code and 1-based residue position. In the example above, we see that the
first residue won't be a proline (`P`) but instead `HY3`.
* `description: str`: An optional textual description of this chain. This
field will is only used in the JSON format and serves as a comment
describing this chain.
* `unpairedMsa: str`: An optional multiple sequence alignment for this chain.
This is specified using the A3M format (equivalent to the FASTA format, but
also allows gaps denoted by the hyphen `-` character). See more details
@ -239,6 +245,7 @@ Specifies a single RNA chain.
{"modificationType": "2MG", "basePosition": 1},
{"modificationType": "5MC", "basePosition": 4}
],
"description": ..., # Optional.
"unpairedMsa": ..., # Mutually exclusive with unpairedMsaPath.
"unpairedMsaPath": ... # Mutually exclusive with unpairedMsa.
}
@ -255,6 +262,9 @@ The fields specify the following:
letters `A`, `C`, `G`, `U`.
* `modifications: list[RnaModification]`: An optional list of modifications.
Each modification is specified using its CCD code and 1-based base position.
* `description: str`: An optional textual description of this chain. This
field will is only used in the JSON format and serves as a comment
describing this chain.
* `unpairedMsa: str`: An optional multiple sequence alignment for this chain.
This is specified using the A3M format. See more details below.
* `unpairedMsaPath: str`: An optional path to a file that contains the
@ -275,7 +285,8 @@ Specifies a single DNA chain.
"modifications": [
{"modificationType": "6OG", "basePosition": 1},
{"modificationType": "6MA", "basePosition": 2}
]
],
"description": ... # Optional.
}
}
```
@ -290,6 +301,9 @@ The fields specify the following:
letters `A`, `C`, `G`, `T`.
* `modifications: list[DnaModification]`: An optional list of modifications.
Each modification is specified using its CCD code and 1-based base position.
* `description: str`: An optional textual description of this chain. This
field will is only used in the JSON format and serves as a comment
describing this chain.
### Ligands
@ -314,19 +328,22 @@ Specifies a single ligand. Ligands can be specified using 3 different formats:
{
"ligand": {
"id": ["G", "H", "I"],
"ccdCodes": ["ATP"]
"ccdCodes": ["ATP"],
"description": ... # Optional.
}
},
{
"ligand": {
"id": "J",
"ccdCodes": ["LIG-1337"]
"ccdCodes": ["LIG-1337"],
"description": ... # Optional.
}
},
{
"ligand": {
"id": "K",
"smiles": "CC(=O)OC1C[NH+]2CCC1CC2"
"smiles": "CC(=O)OC1C[NH+]2CCC1CC2",
"description": ... # Optional.
}
}
```
@ -342,6 +359,9 @@ The fields specify the following:
[user-provided CCD](#user-provided-ccd).
* `smiles: str`: An optional string defining the ligand using a SMILES string.
The SMILES string must be correctly JSON-escaped.
* `description: str`: An optional textual description of this chain. This
field will is only used in the JSON format and serves as a comment
describing this ligand.
Each ligand may be specified using CCD codes or SMILES but not both, i.e. for a
given ligand, the `ccdCodes` and `smiles` fields are mutually exclusive.
@ -919,6 +939,7 @@ certain fields and the sequences are not biologically meaningful.
{"ptmType": "HY3", "ptmPosition": 1},
{"ptmType": "P1L", "ptmPosition": 5}
],
"description": "10-residue protein with 2 modifications",
"unpairedMsa": ...,
"pairedMsa": ""
}
@ -982,7 +1003,6 @@ certain fields and the sequences are not biologically meaningful.
],
"userCCD": ...,
"dialect": "alphafold3",
"version": 3
"version": 4
}
```

View File

@ -24,7 +24,6 @@ import csv
import dataclasses
import datetime
import functools
import multiprocessing
import os
import pathlib
import shutil
@ -178,14 +177,16 @@ _SEQRES_DATABASE_PATH = flags.DEFINE_string(
# Number of CPUs to use for MSA tools.
_JACKHMMER_N_CPU = flags.DEFINE_integer(
'jackhmmer_n_cpu',
min(multiprocessing.cpu_count(), 8),
# Unfortunately, os.process_cpu_count() is only available in Python 3.13+.
min(len(os.sched_getaffinity(0)), 8),
'Number of CPUs to use for Jackhmmer. Defaults to min(cpu_count, 8). Going'
' above 8 CPUs provides very little additional speedup.',
lower_bound=0,
)
_NHMMER_N_CPU = flags.DEFINE_integer(
'nhmmer_n_cpu',
min(multiprocessing.cpu_count(), 8),
# Unfortunately, os.process_cpu_count() is only available in Python 3.13+.
min(len(os.sched_getaffinity(0)), 8),
'Number of CPUs to use for Nhmmer. Defaults to min(cpu_count, 8). Going'
' above 8 CPUs provides very little additional speedup.',
lower_bound=0,
@ -448,7 +449,7 @@ def predict_structure(
print(f'Featurising data with {len(fold_input.rng_seeds)} seed(s)...')
featurisation_start_time = time.time()
ccd = chemical_components.cached_ccd(user_ccd=fold_input.user_ccd)
ccd = chemical_components.Ccd(user_ccd=fold_input.user_ccd)
featurised_examples = featurisation.featurise_input(
fold_input=fold_input,
buckets=buckets,

View File

@ -162,7 +162,9 @@ class DataPipelineTest(parameterized.TestCase):
{
'protein': {
'id': 'P',
'sequence': 'SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREAIDKGDKDSLEQLLEELEQALQKIRELAEKKN',
'sequence': (
'SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREAIDKGDKDSLEQLLEELEQALQKIRELAEKKN'
),
'modifications': [],
'unpairedMsa': None,
'pairedMsa': None,
@ -205,7 +207,7 @@ class DataPipelineTest(parameterized.TestCase):
full_fold_input = data_pipeline.process(fold_input)
featurised_example = featurisation.featurise_input(
full_fold_input,
ccd=chemical_components.cached_ccd(),
ccd=chemical_components.Ccd(),
buckets=None,
)
del featurised_example[0]['ref_pos'] # Depends on specific RDKit version.

View File

@ -36,7 +36,7 @@ import zstandard as zstd
BondAtomId: TypeAlias = tuple[str, int, str]
JSON_DIALECT: Final[str] = 'alphafold3'
JSON_VERSIONS: Final[tuple[int, ...]] = (1, 2, 3)
JSON_VERSIONS: Final[tuple[int, ...]] = (1, 2, 3, 4)
JSON_VERSION: Final[int] = JSON_VERSIONS[-1]
ALPHAFOLDSERVER_JSON_DIALECT: Final[str] = 'alphafoldserver'
@ -127,6 +127,7 @@ class ProteinChain:
'_id',
'_sequence',
'_ptms',
'_description',
'_paired_msa',
'_unpaired_msa',
'_templates',
@ -138,6 +139,7 @@ class ProteinChain:
id: str, # pylint: disable=redefined-builtin
sequence: str,
ptms: Sequence[tuple[str, int]],
description: str | None = None,
paired_msa: str | None = None,
unpaired_msa: str | None = None,
templates: Sequence[Template] | None = None,
@ -149,6 +151,7 @@ class ProteinChain:
sequence: The amino acid sequence of the chain.
ptms: A list of tuples containing the post-translational modification type
and the (1-based) residue index where the modification is applied.
description: An optional textual description of the protein chain.
paired_msa: Paired A3M-formatted MSA for this chain. This MSA is not
deduplicated and will be used to compute paired features. If None, this
field is unset and must be filled in by the data pipeline before
@ -175,6 +178,7 @@ class ProteinChain:
self._id = id
self._sequence = sequence
self._ptms = tuple(ptms)
self._description = description
self._paired_msa = paired_msa
self._unpaired_msa = unpaired_msa
self._templates = tuple(templates) if templates is not None else None
@ -198,6 +202,10 @@ class ProteinChain:
def ptms(self) -> Sequence[tuple[str, int]]:
return self._ptms
@property
def description(self) -> str | None:
return self._description
@property
def paired_msa(self) -> str | None:
return self._paired_msa
@ -218,6 +226,7 @@ class ProteinChain:
self._id == other._id
and self._sequence == other._sequence
and self._ptms == other._ptms
and self._description == other._description
and self._paired_msa == other._paired_msa
and self._unpaired_msa == other._unpaired_msa
and self._templates == other._templates
@ -228,6 +237,7 @@ class ProteinChain:
self._id,
self._sequence,
self._ptms,
self._description,
self._paired_msa,
self._unpaired_msa,
self._templates,
@ -238,6 +248,7 @@ class ProteinChain:
return hash((
self._sequence,
self._ptms,
self._description,
self._paired_msa,
self._unpaired_msa,
self._templates,
@ -298,6 +309,7 @@ class ProteinChain:
'id',
'sequence',
'modifications',
'description',
'unpairedMsa',
'unpairedMsaPath',
'pairedMsa',
@ -368,6 +380,7 @@ class ProteinChain:
id=seq_id or json_dict['id'],
sequence=sequence,
ptms=ptms,
description=json_dict.get('description', None),
paired_msa=paired_msa,
unpaired_msa=unpaired_msa,
templates=templates,
@ -400,6 +413,8 @@ class ProteinChain:
'pairedMsa': self._paired_msa,
'templates': templates,
}
if self._description is not None:
contents['description'] = self._description
return {'protein': contents}
def to_ccd_sequence(self) -> Sequence[str]:
@ -418,6 +433,7 @@ class ProteinChain:
id=self.id,
sequence=self._sequence,
ptms=self._ptms,
description=self._description,
unpaired_msa=self._unpaired_msa or '',
paired_msa=self._paired_msa or '',
templates=self._templates or [],
@ -427,7 +443,13 @@ class ProteinChain:
class RnaChain:
"""RNA chain input."""
__slots__ = ('_id', '_sequence', '_modifications', '_unpaired_msa')
__slots__ = (
'_id',
'_sequence',
'_modifications',
'_description',
'_unpaired_msa',
)
def __init__(
self,
@ -435,6 +457,7 @@ class RnaChain:
id: str, # pylint: disable=redefined-builtin
sequence: str,
modifications: Sequence[tuple[str, int]],
description: str | None = None,
unpaired_msa: str | None = None,
):
"""Initializes a single strand RNA chain input.
@ -444,6 +467,7 @@ class RnaChain:
sequence: The RNA sequence of the chain.
modifications: A list of tuples containing the modification type and the
(1-based) residue index where the modification is applied.
description: An optional textual description of the RNA chain.
unpaired_msa: Unpaired A3M-formatted MSA for this chain. This will be
deduplicated and used to compute unpaired features. If None, this field
is unset and must be filled in by the data pipeline before
@ -463,6 +487,7 @@ class RnaChain:
self._sequence = sequence
# Use hashable container for modifications.
self._modifications = tuple(modifications)
self._description = description
self._unpaired_msa = unpaired_msa
@property
@ -484,6 +509,10 @@ class RnaChain:
def modifications(self) -> Sequence[tuple[str, int]]:
return self._modifications
@property
def description(self) -> str | None:
return self._description
@property
def unpaired_msa(self) -> str | None:
return self._unpaired_msa
@ -496,17 +525,27 @@ class RnaChain:
self._id == other._id
and self._sequence == other._sequence
and self._modifications == other._modifications
and self._description == other._description
and self._unpaired_msa == other._unpaired_msa
)
def __hash__(self) -> int:
return hash(
(self._id, self._sequence, self._modifications, self._unpaired_msa)
)
return hash((
self._id,
self._sequence,
self._modifications,
self._description,
self._unpaired_msa,
))
def hash_without_id(self) -> int:
"""Returns a hash ignoring the ID - useful for deduplication."""
return hash((self._sequence, self._modifications, self._unpaired_msa))
return hash((
self._sequence,
self._modifications,
self._description,
self._unpaired_msa,
))
@classmethod
def from_alphafoldserver_dict(
@ -532,7 +571,14 @@ class RnaChain:
json_dict = json_dict['rna']
_validate_keys(
json_dict.keys(),
{'id', 'sequence', 'unpairedMsa', 'unpairedMsaPath', 'modifications'},
{
'id',
'sequence',
'modifications',
'description',
'unpairedMsa',
'unpairedMsaPath',
},
)
sequence = json_dict['sequence']
modifications = [
@ -559,6 +605,7 @@ class RnaChain:
id=seq_id or json_dict['id'],
sequence=sequence,
modifications=modifications,
description=json_dict.get('description', None),
unpaired_msa=unpaired_msa,
)
@ -575,6 +622,8 @@ class RnaChain:
],
'unpairedMsa': self._unpaired_msa,
}
if self._description is not None:
contents['description'] = self._description
return {'rna': contents}
def to_ccd_sequence(self) -> Sequence[str]:
@ -600,7 +649,7 @@ class RnaChain:
class DnaChain:
"""Single strand DNA chain input."""
__slots__ = ('_id', '_sequence', '_modifications')
__slots__ = ('_id', '_sequence', '_modifications', '_description')
def __init__(
self,
@ -608,6 +657,7 @@ class DnaChain:
id: str, # pylint: disable=redefined-builtin
sequence: str,
modifications: Sequence[tuple[str, int]],
description: str | None = None,
):
"""Initializes a single strand DNA chain input.
@ -616,6 +666,7 @@ class DnaChain:
sequence: The DNA sequence of the chain.
modifications: A list of tuples containing the modification type and the
(1-based) residue index where the modification is applied.
description: An optional textual description of the DNA chain.
"""
if not all(res.isalpha() for res in sequence):
raise ValueError(f'DNA must contain only letters, got "{sequence}"')
@ -630,6 +681,7 @@ class DnaChain:
self._sequence = sequence
# Use hashable container for modifications.
self._modifications = tuple(modifications)
self._description = description
@property
def id(self) -> str:
@ -646,6 +698,10 @@ class DnaChain:
for r in self.to_ccd_sequence()
])
@property
def description(self) -> str | None:
return self._description
def __len__(self) -> int:
return len(self._sequence)
@ -654,17 +710,20 @@ class DnaChain:
self._id == other._id
and self._sequence == other._sequence
and self._modifications == other._modifications
and self._description == other._description
)
def __hash__(self) -> int:
return hash((self._id, self._sequence, self._modifications))
return hash(
(self._id, self._sequence, self._modifications, self._description)
)
def modifications(self) -> Sequence[tuple[str, int]]:
return self._modifications
def hash_without_id(self) -> int:
"""Returns a hash ignoring the ID - useful for deduplication."""
return hash((self._sequence, self._modifications))
return hash((self._sequence, self._modifications, self._description))
@classmethod
def from_alphafoldserver_dict(
@ -685,7 +744,9 @@ class DnaChain:
) -> Self:
"""Constructs DnaChain from the AlphaFold JSON dict."""
json_dict = json_dict['dna']
_validate_keys(json_dict.keys(), {'id', 'sequence', 'modifications'})
_validate_keys(
json_dict.keys(), {'id', 'sequence', 'modifications', 'description'}
)
sequence = json_dict['sequence']
modifications = [
(mod['modificationType'], mod['basePosition'])
@ -695,6 +756,7 @@ class DnaChain:
id=seq_id or json_dict['id'],
sequence=sequence,
modifications=modifications,
description=json_dict.get('description', None),
)
def to_dict(
@ -709,6 +771,8 @@ class DnaChain:
for mod in self._modifications
],
}
if self._description is not None:
contents['description'] = self._description
return {'dna': contents}
def to_ccd_sequence(self) -> Sequence[str]:
@ -734,11 +798,13 @@ class Ligand:
a bond linking these components should be added to the bonded_atom_pairs
Input field.
smiles: The SMILES representation of the ligand.
description: An optional textual description of the ligand.
"""
id: str
ccd_ids: Sequence[str] | None = None
smiles: str | None = None
description: str | None = None
def __post_init__(self):
if (self.ccd_ids is None) == (self.smiles is None):
@ -761,7 +827,7 @@ class Ligand:
def hash_without_id(self) -> int:
"""Returns a hash ignoring the ID - useful for deduplication."""
return hash((self.ccd_ids, self.smiles))
return hash((self.ccd_ids, self.smiles, self.description))
@classmethod
def from_alphafoldserver_dict(
@ -783,7 +849,9 @@ class Ligand:
) -> Self:
"""Constructs Ligand from the AlphaFold JSON dict."""
json_dict = json_dict['ligand']
_validate_keys(json_dict.keys(), {'id', 'ccdCodes', 'smiles'})
_validate_keys(
json_dict.keys(), {'id', 'ccdCodes', 'smiles', 'description'}
)
if json_dict.get('ccdCodes') and json_dict.get('smiles'):
raise ValueError(
'Ligand cannot have both CCD code and SMILES set at the same time, '
@ -797,9 +865,17 @@ class Ligand:
'CCD codes must be a list of strings, got '
f'{type(ccd_codes).__name__} instead: {ccd_codes}'
)
return cls(id=seq_id or json_dict['id'], ccd_ids=ccd_codes)
return cls(
id=seq_id or json_dict['id'],
ccd_ids=ccd_codes,
description=json_dict.get('description', None),
)
elif 'smiles' in json_dict:
return cls(id=seq_id or json_dict['id'], smiles=json_dict['smiles'])
return cls(
id=seq_id or json_dict['id'],
smiles=json_dict['smiles'],
description=json_dict.get('description', None),
)
else:
raise ValueError(f'Unknown ligand type: {json_dict}')
@ -812,6 +888,8 @@ class Ligand:
contents['ccdCodes'] = self.ccd_ids
if self.smiles is not None:
contents['smiles'] = self.smiles
if self.description is not None:
contents['description'] = self.description
return {'ligand': contents}

View File

@ -25,6 +25,15 @@ _CCD_PICKLE_FILE = resources.filename(
)
@functools.cache
def _load_ccd_pickle_cached(
path: os.PathLike[str],
) -> dict[str, Mapping[str, Sequence[str]]]:
"""Loads the CCD pickle file and caches it so that it is only loaded once."""
with open(path, 'rb') as f:
return pickle.loads(f.read())
class Ccd(Mapping[str, Mapping[str, Sequence[str]]]):
"""Chemical Components found in PDB (CCD) constants.
@ -52,8 +61,7 @@ class Ccd(Mapping[str, Mapping[str, Sequence[str]]]):
be used to override specific entries in the CCD if desired.
"""
self._ccd_pickle_path = ccd_pickle_path or _CCD_PICKLE_FILE
with open(self._ccd_pickle_path, 'rb') as f:
self._dict = pickle.loads(f.read())
self._dict = _load_ccd_pickle_cached(self._ccd_pickle_path)
if user_ccd is not None:
if not user_ccd:
@ -94,11 +102,6 @@ class Ccd(Mapping[str, Mapping[str, Sequence[str]]]):
return self._dict.keys()
@functools.cache
def cached_ccd(user_ccd: str | None = None) -> Ccd:
return Ccd(user_ccd=user_ccd)
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class ComponentInfo:
name: str

View File

@ -862,9 +862,11 @@ def get_polymer_features(
chain_sequence = chain.chain_single_letter_sequence()[label_chain_id]
polymer = _POLYMERS[chain_poly_type]
positions, positions_mask = chain.to_res_arrays(
res_arrays = chain.to_res_arrays(
include_missing_residues=True, atom_order=polymer.atom_order
)
positions = res_arrays.atom_positions
positions_mask = res_arrays.atom_mask
template_all_atom_positions = np.zeros(
(query_sequence_length, polymer.num_atom_types, 3), dtype=np.float64
)

View File

@ -117,7 +117,7 @@ def _mol_from_ligand_struc(
def _maybe_mol_from_ccd(res_name: str) -> rd_chem.Mol | None:
"""Creates a Mol object from CCD information if res_name is in the CCD."""
ccd = chemical_components.cached_ccd()
ccd = chemical_components.Ccd()
ccd_cif = ccd.get(res_name)
if not ccd_cif:
logging.warning('No ccd information for residue %s.', res_name)

View File

@ -134,6 +134,9 @@ absl::StatusOr<std::vector<absl::string_view>> TokenizeInternal(
line_num++;
if (!multiline.empty() && multiline[0] == ';') {
break;
} else if (line_num == lines.size()) {
return absl::InvalidArgumentError(
"Last multiline token is not terminated by a semicolon.");
}
multiline_tokens.push_back(multiline);
}
@ -340,6 +343,16 @@ struct GroupedKeys {
int value_size;
};
absl::Status CheckLoopColumnSizes(int num_loop_keys, int num_loop_values) {
if ((num_loop_keys > 0) && (num_loop_values % num_loop_keys != 0)) {
return absl::InvalidArgumentError(absl::StrFormat(
"The number of values (%d) in a loop is not a multiple of the "
"number of the loop's columns (%d)",
num_loop_values, num_loop_keys));
}
return absl::OkStatus();
}
} // namespace
absl::StatusOr<CifDict> CifDict::FromString(absl::string_view cif_string) {
@ -364,6 +377,10 @@ absl::StatusOr<CifDict> CifDict::FromString(absl::string_view cif_string) {
return absl::InvalidArgumentError(
"The CIF file does not start with the data_ field.");
}
if (first_token.empty()) {
return absl::InvalidArgumentError(
"The CIF file does not contain a data block name.");
}
cif["data_"].emplace_back(first_token);
// Counters for CIF loop_ regions.
@ -380,7 +397,12 @@ absl::StatusOr<CifDict> CifDict::FromString(absl::string_view cif_string) {
++token_itr) {
auto token = *token_itr;
if (absl::EqualsIgnoreCase(token, "loop_")) {
// A new loop started, get rid of old loop's data.
// A new loop started, check the previous loop and get rid of its data.
absl::Status loop_status =
CheckLoopColumnSizes(num_loop_keys, loop_token_index);
if (!loop_status.ok()) {
return loop_status;
}
loop_flag = true;
loop_column_values.clear();
loop_token_index = 0;
@ -398,7 +420,12 @@ absl::StatusOr<CifDict> CifDict::FromString(absl::string_view cif_string) {
loop_flag = false;
} else {
// We are in the keys (column names) section of the loop.
auto& columns = cif[token];
auto [it, inserted] = cif.try_emplace(token);
if (!inserted) {
return absl::InvalidArgumentError(
absl::StrCat("Duplicate loop key: '", token, "'"));
}
auto& columns = it->second;
columns.clear();
// Heuristic: _atom_site is typically the largest table in an mmCIF
@ -428,11 +455,25 @@ absl::StatusOr<CifDict> CifDict::FromString(absl::string_view cif_string) {
}
if (key.empty()) {
key = token;
if (!absl::StartsWith(key, "_")) {
return absl::InvalidArgumentError(
absl::StrCat("Key '", key, "' does not start with an underscore."));
}
} else {
cif[key].emplace_back(token);
auto [it, inserted] = cif.try_emplace(key);
if (!inserted) {
return absl::InvalidArgumentError(
absl::StrCat("Duplicate key: '", key, "'"));
}
(it->second).emplace_back(token);
key = "";
}
}
absl::Status loop_status =
CheckLoopColumnSizes(num_loop_keys, loop_token_index);
if (!loop_status.ok()) {
return loop_status;
}
return CifDict(std::move(cif));
}

View File

@ -165,7 +165,7 @@ def get_or_infer_type_symbol(
_atom_site.label_comp_id (residue name), _atom_site.label_atom_id (atom name)
and the CCD.
"""
ccd = ccd or chemical_components.cached_ccd()
ccd = ccd or chemical_components.Ccd()
type_symbol_fn = lambda res_name, atom_name: chemical_components.type_symbol(
ccd, res_name, atom_name
)

View File

@ -1131,12 +1131,12 @@ def _maybe_add_missing_scheme_tables(
zip(cif['_entity.id'], cif['_entity.type'], strict=True)
)
# Remap asym ID -> entity ID.
chain_type = string_array.remap(
label_entity_id = string_array.remap(
label_asym_ids, mapping=entity_id_by_chain_id, inplace=False
)
# Remap entity ID -> chain type.
string_array.remap(
chain_type, mapping=chain_type_by_entity_id, inplace=True
chain_type = string_array.remap(
label_entity_id, mapping=chain_type_by_entity_id, inplace=False
)
res_mask = np.zeros_like(label_seq_ids, dtype=bool)
res_mask[res_starts] = True
@ -1154,11 +1154,19 @@ def _maybe_add_missing_scheme_tables(
poly_seq_entity_id = cif.get_array(
'_entity_poly_seq.entity_id', dtype=object
)
# We have to add the entity ID to the residue ID because multiple residues
# can share the same ID. This also allows using string_array.remap.
label_seq_id_to_auth_seq_id = dict(
zip(label_seq_ids[res_mask], auth_seq_ids[res_mask], strict=True)
zip(
np.char.add(label_entity_id[res_mask], label_seq_ids[res_mask]),
auth_seq_ids[res_mask],
strict=True,
)
)
scheme_pdb_seq_num = string_array.remap(
poly_seq_num, mapping=label_seq_id_to_auth_seq_id, default_value='.'
np.char.add(poly_seq_entity_id, poly_seq_num),
mapping=label_seq_id_to_auth_seq_id,
default_value='.',
)
label_seq_id_to_ins_code = dict(
zip(label_seq_ids[res_mask], pdb_ins_codes[res_mask], strict=True)
@ -1201,11 +1209,11 @@ def _maybe_add_missing_scheme_tables(
inplace=False,
)
update['_pdbx_poly_seq_scheme.asym_id'] = res_asym_ids
update['_pdbx_poly_seq_scheme.pdb_strand_id'] = res_strand_ids
update['_pdbx_poly_seq_scheme.pdb_seq_num'] = auth_seq_ids[res_mask]
update['_pdbx_poly_seq_scheme.pdb_ins_code'] = pdb_ins_codes[res_mask]
update['_pdbx_poly_seq_scheme.seq_id'] = label_seq_ids[res_mask]
update['_pdbx_poly_seq_scheme.mon_id'] = label_comp_ids[res_mask]
update['_pdbx_poly_seq_scheme.pdb_strand_id'] = res_strand_ids
required_nonpoly_scheme_cols = (
'_pdbx_nonpoly_scheme.mon_id',
@ -1604,10 +1612,11 @@ def get_tables(
except KeyError as e:
raise ValueError(
'Lookup for the following atom from the _atom_site table failed: '
f'(atom_id, auth_seq_id, res_name, ins_code)={e}. This is '
'likely due to a known issue with some multi-model mmCIFs that only '
'match the first model in _atom_site table to the _pdbx_poly_scheme, '
'_pdbx_nonpoly_scheme, or _pdbx_branch_scheme tables.'
f'(label_asym_id, auth_seq_id, res_name, ins_code)={e}. This typically '
'indicates that the _pdbx_poly_seq_scheme, _pdbx_nonpoly_scheme, or '
'_pdbx_branch_scheme tables do not have data for all residues present '
'in the _atom_site table. It could also be due to a known issue with '
'a small number of multi-model mmCIFs.'
) from e
# The residue ID will be shared for all atoms within that residue.

View File

@ -275,6 +275,24 @@ class StructureTables:
bonds: structure_tables.Bonds
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class ResArrays:
"""Atom-level data arrays with a residue dimension.
Attributes:
atom_positions: float32 of shape [num_res, num_atom_type, 3] coordinates.
atom_mask: float32 of shape [num_res, num_atom_type] indicating if an atom
is present.
atom_b_factor: float32 of shape [num_res, num_atom_type] b_factors.
atom_occupancy: float32 of shape [num_res, num_atom_type] occupancies.
"""
atom_positions: np.ndarray
atom_mask: np.ndarray
atom_b_factor: np.ndarray
atom_occupancy: np.ndarray
class Structure(table.Database):
"""Structure class for representing and processing molecular structures."""
@ -308,7 +326,7 @@ class Structure(table.Database):
# b/345221494 Rename this variable when structure_v1 compatibility code
# is removed.
self._VERSION = '2.0.0' # pylint: disable=invalid-name
self._name = name
self._name = name or 'unset'
self._release_date = release_date
self._resolution = resolution
self._structure_method = structure_method
@ -2438,8 +2456,8 @@ class Structure(table.Database):
*,
include_missing_residues: bool,
atom_order: Mapping[str, int] = atom_types.ATOM37_ORDER,
) -> tuple[np.ndarray, np.ndarray]:
"""Returns an atom position and atom mask array with a num_res dimension.
) -> ResArrays:
"""Returns atom-level information in arrays containing a num_res dimension.
NB: All residues in the structure will appear in the residue dimension but
atoms will only have a True (1.0) mask value if the residue + atom
@ -2455,15 +2473,14 @@ class Structure(table.Database):
choose atom_types.ATOM29_ORDER for nucleics.
Returns:
A pair of arrays:
* atom_positions: [num_res, atom_type_num, 3] float32 array of coords.
* atom_mask: [num_res, atom_type_num] float32 atom mask denoting
which atoms are present in this Structure.
A ResArrays object.
"""
num_res = self.num_residues(count_unresolved=include_missing_residues)
atom_type_num = len(atom_order)
atom_positions = np.zeros((num_res, atom_type_num, 3), dtype=np.float32)
atom_mask = np.zeros((num_res, atom_type_num), dtype=np.float32)
atom_b_factor = np.zeros((num_res, atom_type_num), dtype=np.float32)
atom_occupancy = np.zeros((num_res, atom_type_num), dtype=np.float32)
all_residues = None if not include_missing_residues else self.all_residues
for i, atom in enumerate_residues(self.iter_atoms(), all_residues):
@ -2473,8 +2490,15 @@ class Structure(table.Database):
atom_positions[i, atom_idx, 1] = atom['atom_y']
atom_positions[i, atom_idx, 2] = atom['atom_z']
atom_mask[i, atom_idx] = 1.0
atom_b_factor[i, atom_idx] = atom['atom_b_factor']
atom_occupancy[i, atom_idx] = atom['atom_occupancy']
return atom_positions, atom_mask
return ResArrays(
atom_positions=atom_positions,
atom_mask=atom_mask,
atom_b_factor=atom_b_factor,
atom_occupancy=atom_occupancy,
)
def to_res_atom_lists(
self, *, include_missing_residues: bool