mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2025-10-20 13:23:47 +08:00
Cache the underlying CCD dictionary instead of the whole CCD object.
* Enables removal of the chemical_components.cached_ccd function. * If there are multiple custom CCDs (tiny in comparison to the full CCD) in a single run, this will consume only O(1) RAM instead of O(num_inputs). Addresses: * https://github.com/google-deepmind/alphafold3/pull/514 * https://github.com/google-deepmind/alphafold3/issues/509 PiperOrigin-RevId: 804321599 Change-Id: I484f851ce79cd6b1c5c31bf70b6cd945791f8b66
This commit is contained in:
committed by
Copybara-Service
parent
85a03ec086
commit
a8ecdb2d7a
@ -449,7 +449,7 @@ def predict_structure(
|
|||||||
|
|
||||||
print(f'Featurising data with {len(fold_input.rng_seeds)} seed(s)...')
|
print(f'Featurising data with {len(fold_input.rng_seeds)} seed(s)...')
|
||||||
featurisation_start_time = time.time()
|
featurisation_start_time = time.time()
|
||||||
ccd = chemical_components.cached_ccd(user_ccd=fold_input.user_ccd)
|
ccd = chemical_components.Ccd(user_ccd=fold_input.user_ccd)
|
||||||
featurised_examples = featurisation.featurise_input(
|
featurised_examples = featurisation.featurise_input(
|
||||||
fold_input=fold_input,
|
fold_input=fold_input,
|
||||||
buckets=buckets,
|
buckets=buckets,
|
||||||
|
@ -162,7 +162,9 @@ class DataPipelineTest(parameterized.TestCase):
|
|||||||
{
|
{
|
||||||
'protein': {
|
'protein': {
|
||||||
'id': 'P',
|
'id': 'P',
|
||||||
'sequence': 'SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREAIDKGDKDSLEQLLEELEQALQKIRELAEKKN',
|
'sequence': (
|
||||||
|
'SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREAIDKGDKDSLEQLLEELEQALQKIRELAEKKN'
|
||||||
|
),
|
||||||
'modifications': [],
|
'modifications': [],
|
||||||
'unpairedMsa': None,
|
'unpairedMsa': None,
|
||||||
'pairedMsa': None,
|
'pairedMsa': None,
|
||||||
@ -205,7 +207,7 @@ class DataPipelineTest(parameterized.TestCase):
|
|||||||
full_fold_input = data_pipeline.process(fold_input)
|
full_fold_input = data_pipeline.process(fold_input)
|
||||||
featurised_example = featurisation.featurise_input(
|
featurised_example = featurisation.featurise_input(
|
||||||
full_fold_input,
|
full_fold_input,
|
||||||
ccd=chemical_components.cached_ccd(),
|
ccd=chemical_components.Ccd(),
|
||||||
buckets=None,
|
buckets=None,
|
||||||
)
|
)
|
||||||
del featurised_example[0]['ref_pos'] # Depends on specific RDKit version.
|
del featurised_example[0]['ref_pos'] # Depends on specific RDKit version.
|
||||||
|
@ -25,6 +25,15 @@ _CCD_PICKLE_FILE = resources.filename(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def _load_ccd_pickle_cached(
|
||||||
|
path: os.PathLike[str],
|
||||||
|
) -> dict[str, Mapping[str, Sequence[str]]]:
|
||||||
|
"""Loads the CCD pickle file and caches it so that it is only loaded once."""
|
||||||
|
with open(path, 'rb') as f:
|
||||||
|
return pickle.loads(f.read())
|
||||||
|
|
||||||
|
|
||||||
class Ccd(Mapping[str, Mapping[str, Sequence[str]]]):
|
class Ccd(Mapping[str, Mapping[str, Sequence[str]]]):
|
||||||
"""Chemical Components found in PDB (CCD) constants.
|
"""Chemical Components found in PDB (CCD) constants.
|
||||||
|
|
||||||
@ -52,8 +61,7 @@ class Ccd(Mapping[str, Mapping[str, Sequence[str]]]):
|
|||||||
be used to override specific entries in the CCD if desired.
|
be used to override specific entries in the CCD if desired.
|
||||||
"""
|
"""
|
||||||
self._ccd_pickle_path = ccd_pickle_path or _CCD_PICKLE_FILE
|
self._ccd_pickle_path = ccd_pickle_path or _CCD_PICKLE_FILE
|
||||||
with open(self._ccd_pickle_path, 'rb') as f:
|
self._dict = _load_ccd_pickle_cached(self._ccd_pickle_path)
|
||||||
self._dict = pickle.loads(f.read())
|
|
||||||
|
|
||||||
if user_ccd is not None:
|
if user_ccd is not None:
|
||||||
if not user_ccd:
|
if not user_ccd:
|
||||||
@ -94,11 +102,6 @@ class Ccd(Mapping[str, Mapping[str, Sequence[str]]]):
|
|||||||
return self._dict.keys()
|
return self._dict.keys()
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
|
||||||
def cached_ccd(user_ccd: str | None = None) -> Ccd:
|
|
||||||
return Ccd(user_ccd=user_ccd)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
|
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
|
||||||
class ComponentInfo:
|
class ComponentInfo:
|
||||||
name: str
|
name: str
|
||||||
|
@ -117,7 +117,7 @@ def _mol_from_ligand_struc(
|
|||||||
|
|
||||||
def _maybe_mol_from_ccd(res_name: str) -> rd_chem.Mol | None:
|
def _maybe_mol_from_ccd(res_name: str) -> rd_chem.Mol | None:
|
||||||
"""Creates a Mol object from CCD information if res_name is in the CCD."""
|
"""Creates a Mol object from CCD information if res_name is in the CCD."""
|
||||||
ccd = chemical_components.cached_ccd()
|
ccd = chemical_components.Ccd()
|
||||||
ccd_cif = ccd.get(res_name)
|
ccd_cif = ccd.get(res_name)
|
||||||
if not ccd_cif:
|
if not ccd_cif:
|
||||||
logging.warning('No ccd information for residue %s.', res_name)
|
logging.warning('No ccd information for residue %s.', res_name)
|
||||||
|
@ -165,7 +165,7 @@ def get_or_infer_type_symbol(
|
|||||||
_atom_site.label_comp_id (residue name), _atom_site.label_atom_id (atom name)
|
_atom_site.label_comp_id (residue name), _atom_site.label_atom_id (atom name)
|
||||||
and the CCD.
|
and the CCD.
|
||||||
"""
|
"""
|
||||||
ccd = ccd or chemical_components.cached_ccd()
|
ccd = ccd or chemical_components.Ccd()
|
||||||
type_symbol_fn = lambda res_name, atom_name: chemical_components.type_symbol(
|
type_symbol_fn = lambda res_name, atom_name: chemical_components.type_symbol(
|
||||||
ccd, res_name, atom_name
|
ccd, res_name, atom_name
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user