From a8ecdb2d7a433c5e9f8510a0f52434a0c55018c4 Mon Sep 17 00:00:00 2001 From: Augustin Zidek Date: Mon, 8 Sep 2025 01:43:38 -0700 Subject: [PATCH] Cache the underlying CCD dictionary instead of the whole CCD object. * Enables removal of the chemical_components.cached_ccd function. * If there are multiple custom CCDs (tiny in comparison to the full CCD) in a single run, this will consume only O(1) RAM instead of O(num_inputs). Addresses: * https://github.com/google-deepmind/alphafold3/pull/514 * https://github.com/google-deepmind/alphafold3/issues/509 PiperOrigin-RevId: 804321599 Change-Id: I484f851ce79cd6b1c5c31bf70b6cd945791f8b66 --- run_alphafold.py | 2 +- run_alphafold_data_test.py | 6 ++++-- src/alphafold3/constants/chemical_components.py | 17 ++++++++++------- src/alphafold3/model/scoring/chirality.py | 2 +- src/alphafold3/structure/mmcif.py | 2 +- 5 files changed, 17 insertions(+), 12 deletions(-) diff --git a/run_alphafold.py b/run_alphafold.py index 6f2c6f7..b7c6590 100644 --- a/run_alphafold.py +++ b/run_alphafold.py @@ -449,7 +449,7 @@ def predict_structure( print(f'Featurising data with {len(fold_input.rng_seeds)} seed(s)...') featurisation_start_time = time.time() - ccd = chemical_components.cached_ccd(user_ccd=fold_input.user_ccd) + ccd = chemical_components.Ccd(user_ccd=fold_input.user_ccd) featurised_examples = featurisation.featurise_input( fold_input=fold_input, buckets=buckets, diff --git a/run_alphafold_data_test.py b/run_alphafold_data_test.py index 0e15afd..5465e88 100644 --- a/run_alphafold_data_test.py +++ b/run_alphafold_data_test.py @@ -162,7 +162,9 @@ class DataPipelineTest(parameterized.TestCase): { 'protein': { 'id': 'P', - 'sequence': 'SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREAIDKGDKDSLEQLLEELEQALQKIRELAEKKN', + 'sequence': ( + 'SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREAIDKGDKDSLEQLLEELEQALQKIRELAEKKN' + ), 'modifications': [], 'unpairedMsa': None, 'pairedMsa': None, @@ -205,7 +207,7 @@ class DataPipelineTest(parameterized.TestCase): full_fold_input = data_pipeline.process(fold_input) featurised_example = featurisation.featurise_input( full_fold_input, - ccd=chemical_components.cached_ccd(), + ccd=chemical_components.Ccd(), buckets=None, ) del featurised_example[0]['ref_pos'] # Depends on specific RDKit version. diff --git a/src/alphafold3/constants/chemical_components.py b/src/alphafold3/constants/chemical_components.py index 6b106a4..f1bbc9d 100644 --- a/src/alphafold3/constants/chemical_components.py +++ b/src/alphafold3/constants/chemical_components.py @@ -25,6 +25,15 @@ _CCD_PICKLE_FILE = resources.filename( ) +@functools.cache +def _load_ccd_pickle_cached( + path: os.PathLike[str], +) -> dict[str, Mapping[str, Sequence[str]]]: + """Loads the CCD pickle file and caches it so that it is only loaded once.""" + with open(path, 'rb') as f: + return pickle.loads(f.read()) + + class Ccd(Mapping[str, Mapping[str, Sequence[str]]]): """Chemical Components found in PDB (CCD) constants. @@ -52,8 +61,7 @@ class Ccd(Mapping[str, Mapping[str, Sequence[str]]]): be used to override specific entries in the CCD if desired. """ self._ccd_pickle_path = ccd_pickle_path or _CCD_PICKLE_FILE - with open(self._ccd_pickle_path, 'rb') as f: - self._dict = pickle.loads(f.read()) + self._dict = _load_ccd_pickle_cached(self._ccd_pickle_path) if user_ccd is not None: if not user_ccd: @@ -94,11 +102,6 @@ class Ccd(Mapping[str, Mapping[str, Sequence[str]]]): return self._dict.keys() -@functools.cache -def cached_ccd(user_ccd: str | None = None) -> Ccd: - return Ccd(user_ccd=user_ccd) - - @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) class ComponentInfo: name: str diff --git a/src/alphafold3/model/scoring/chirality.py b/src/alphafold3/model/scoring/chirality.py index 249fcca..0b5c85e 100644 --- a/src/alphafold3/model/scoring/chirality.py +++ b/src/alphafold3/model/scoring/chirality.py @@ -117,7 +117,7 @@ def _mol_from_ligand_struc( def _maybe_mol_from_ccd(res_name: str) -> rd_chem.Mol | None: """Creates a Mol object from CCD information if res_name is in the CCD.""" - ccd = chemical_components.cached_ccd() + ccd = chemical_components.Ccd() ccd_cif = ccd.get(res_name) if not ccd_cif: logging.warning('No ccd information for residue %s.', res_name) diff --git a/src/alphafold3/structure/mmcif.py b/src/alphafold3/structure/mmcif.py index 7fdf723..78fd267 100644 --- a/src/alphafold3/structure/mmcif.py +++ b/src/alphafold3/structure/mmcif.py @@ -165,7 +165,7 @@ def get_or_infer_type_symbol( _atom_site.label_comp_id (residue name), _atom_site.label_atom_id (atom name) and the CCD. """ - ccd = ccd or chemical_components.cached_ccd() + ccd = ccd or chemical_components.Ccd() type_symbol_fn = lambda res_name, atom_name: chemical_components.type_symbol( ccd, res_name, atom_name )