Cache the underlying CCD dictionary instead of the whole CCD object.

* Enables removal of the chemical_components.cached_ccd function.
* If there are multiple custom CCDs (tiny in comparison to the full CCD) in a single run, this will consume only O(1) RAM instead of O(num_inputs).

Addresses:
* https://github.com/google-deepmind/alphafold3/pull/514
* https://github.com/google-deepmind/alphafold3/issues/509
PiperOrigin-RevId: 804321599
Change-Id: I484f851ce79cd6b1c5c31bf70b6cd945791f8b66
This commit is contained in:
Augustin Zidek
2025-09-08 01:43:38 -07:00
committed by Copybara-Service
parent 85a03ec086
commit a8ecdb2d7a
5 changed files with 17 additions and 12 deletions

View File

@ -449,7 +449,7 @@ def predict_structure(
print(f'Featurising data with {len(fold_input.rng_seeds)} seed(s)...') print(f'Featurising data with {len(fold_input.rng_seeds)} seed(s)...')
featurisation_start_time = time.time() featurisation_start_time = time.time()
ccd = chemical_components.cached_ccd(user_ccd=fold_input.user_ccd) ccd = chemical_components.Ccd(user_ccd=fold_input.user_ccd)
featurised_examples = featurisation.featurise_input( featurised_examples = featurisation.featurise_input(
fold_input=fold_input, fold_input=fold_input,
buckets=buckets, buckets=buckets,

View File

@ -162,7 +162,9 @@ class DataPipelineTest(parameterized.TestCase):
{ {
'protein': { 'protein': {
'id': 'P', 'id': 'P',
'sequence': 'SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREAIDKGDKDSLEQLLEELEQALQKIRELAEKKN', 'sequence': (
'SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREAIDKGDKDSLEQLLEELEQALQKIRELAEKKN'
),
'modifications': [], 'modifications': [],
'unpairedMsa': None, 'unpairedMsa': None,
'pairedMsa': None, 'pairedMsa': None,
@ -205,7 +207,7 @@ class DataPipelineTest(parameterized.TestCase):
full_fold_input = data_pipeline.process(fold_input) full_fold_input = data_pipeline.process(fold_input)
featurised_example = featurisation.featurise_input( featurised_example = featurisation.featurise_input(
full_fold_input, full_fold_input,
ccd=chemical_components.cached_ccd(), ccd=chemical_components.Ccd(),
buckets=None, buckets=None,
) )
del featurised_example[0]['ref_pos'] # Depends on specific RDKit version. del featurised_example[0]['ref_pos'] # Depends on specific RDKit version.

View File

@ -25,6 +25,15 @@ _CCD_PICKLE_FILE = resources.filename(
) )
@functools.cache
def _load_ccd_pickle_cached(
path: os.PathLike[str],
) -> dict[str, Mapping[str, Sequence[str]]]:
"""Loads the CCD pickle file and caches it so that it is only loaded once."""
with open(path, 'rb') as f:
return pickle.loads(f.read())
class Ccd(Mapping[str, Mapping[str, Sequence[str]]]): class Ccd(Mapping[str, Mapping[str, Sequence[str]]]):
"""Chemical Components found in PDB (CCD) constants. """Chemical Components found in PDB (CCD) constants.
@ -52,8 +61,7 @@ class Ccd(Mapping[str, Mapping[str, Sequence[str]]]):
be used to override specific entries in the CCD if desired. be used to override specific entries in the CCD if desired.
""" """
self._ccd_pickle_path = ccd_pickle_path or _CCD_PICKLE_FILE self._ccd_pickle_path = ccd_pickle_path or _CCD_PICKLE_FILE
with open(self._ccd_pickle_path, 'rb') as f: self._dict = _load_ccd_pickle_cached(self._ccd_pickle_path)
self._dict = pickle.loads(f.read())
if user_ccd is not None: if user_ccd is not None:
if not user_ccd: if not user_ccd:
@ -94,11 +102,6 @@ class Ccd(Mapping[str, Mapping[str, Sequence[str]]]):
return self._dict.keys() return self._dict.keys()
@functools.cache
def cached_ccd(user_ccd: str | None = None) -> Ccd:
return Ccd(user_ccd=user_ccd)
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) @dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class ComponentInfo: class ComponentInfo:
name: str name: str

View File

@ -117,7 +117,7 @@ def _mol_from_ligand_struc(
def _maybe_mol_from_ccd(res_name: str) -> rd_chem.Mol | None: def _maybe_mol_from_ccd(res_name: str) -> rd_chem.Mol | None:
"""Creates a Mol object from CCD information if res_name is in the CCD.""" """Creates a Mol object from CCD information if res_name is in the CCD."""
ccd = chemical_components.cached_ccd() ccd = chemical_components.Ccd()
ccd_cif = ccd.get(res_name) ccd_cif = ccd.get(res_name)
if not ccd_cif: if not ccd_cif:
logging.warning('No ccd information for residue %s.', res_name) logging.warning('No ccd information for residue %s.', res_name)

View File

@ -165,7 +165,7 @@ def get_or_infer_type_symbol(
_atom_site.label_comp_id (residue name), _atom_site.label_atom_id (atom name) _atom_site.label_comp_id (residue name), _atom_site.label_atom_id (atom name)
and the CCD. and the CCD.
""" """
ccd = ccd or chemical_components.cached_ccd() ccd = ccd or chemical_components.Ccd()
type_symbol_fn = lambda res_name, atom_name: chemical_components.type_symbol( type_symbol_fn = lambda res_name, atom_name: chemical_components.type_symbol(
ccd, res_name, atom_name ccd, res_name, atom_name
) )