mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2025-10-20 13:23:47 +08:00
Add an option to save the distogram
Based on the implementation in https://github.com/google-deepmind/alphafold3/pull/345 by @PasqM. PiperOrigin-RevId: 758128292 Change-Id: I868d43018833f889d858e5c9a7d687a44cfd3e82
This commit is contained in:
committed by
Copybara-Service
parent
6a0e8b2afe
commit
424201003f
@ -15,8 +15,18 @@ The following structure is used within the output directory:
|
||||
`seed-<seed value>_sample-<sample number>`. Each of these directories
|
||||
contains a confidence JSON, summary confidence JSON, and the mmCIF with the
|
||||
predicted structure.
|
||||
* Embeddings for each seed: `seed-<seed value>_embeddings/embeddings.npz`.
|
||||
Only saved if AlphaFold 3 is run with `--save_embeddings=true`.
|
||||
* Distogram for each seed: `seed-<seed value>_distogram/distogram.npz`. The
|
||||
Numpy zip file contains a single key: `distogram`. The distogram can be
|
||||
large, its shape is `(num_tokens, num_tokens, 64)` and dtype `np.float16`
|
||||
(almost 3 GiB for a 5,000-token input). Only saved if AlphaFold 3 is run
|
||||
with `--save_distogram=true`.
|
||||
* Embeddings for each seed: `seed-<seed value>_embeddings/embeddings.npz`. The
|
||||
Numpy zip file contains 2 keys: `single_embeddings` and `pair_embeddings`.
|
||||
The embeddings can be large, their shapes are `(num_tokens, 384)` for
|
||||
`single_embeddings`, and `(num_tokens, num_tokens, 128)` for
|
||||
`pair_embeddings`. Their dtype is `np.float32` (almost 12 GiB for a
|
||||
5,000-token input). Only saved if AlphaFold 3 is run with
|
||||
`--save_embeddings=true`.
|
||||
* Top-ranking prediction mmCIF: `<job_name>_model.cif`. This file contains the
|
||||
predicted coordinates and should be compatible with most structural biology
|
||||
tools. We do not provide the output in the PDB format, the CIF file can be
|
||||
@ -35,6 +45,8 @@ Fold", that has been ran with 1 seed and 5 samples:
|
||||
|
||||
```txt
|
||||
hello_fold/
|
||||
├── seed-1234_distogram # Only if --save_distogram=true.
|
||||
│ └── hello_fold_seed-1234_distogram.npz # Only if --save_distogram=true.
|
||||
├── seed-1234_embeddings # Only if --save_embeddings=true.
|
||||
│ └── hello_fold_seed-1234_embeddings.npz # Only if --save_embeddings=true.
|
||||
├── seed-1234_sample-0/
|
||||
|
@ -274,6 +274,12 @@ _SAVE_EMBEDDINGS = flags.DEFINE_bool(
|
||||
False,
|
||||
'Whether to save the final trunk single and pair embeddings in the output.',
|
||||
)
|
||||
_SAVE_DISTOGRAM = flags.DEFINE_bool(
|
||||
'save_distogram',
|
||||
False,
|
||||
'Whether to save the final distogram in the output. Note that the distoram '
|
||||
'large: num_tokens * num_tokens * 39.',
|
||||
)
|
||||
_FORCE_OUTPUT_DIR = flags.DEFINE_bool(
|
||||
'force_output_dir',
|
||||
False,
|
||||
@ -289,6 +295,7 @@ def make_model_config(
|
||||
num_diffusion_samples: int = 5,
|
||||
num_recycles: int = 10,
|
||||
return_embeddings: bool = False,
|
||||
return_distogram: bool = False,
|
||||
) -> model.Model.Config:
|
||||
"""Returns a model config with some defaults overridden."""
|
||||
config = model.Model.Config()
|
||||
@ -298,6 +305,7 @@ def make_model_config(
|
||||
config.heads.diffusion.eval.num_samples = num_diffusion_samples
|
||||
config.num_recycles = num_recycles
|
||||
config.return_embeddings = return_embeddings
|
||||
config.return_distogram = return_distogram
|
||||
return config
|
||||
|
||||
|
||||
@ -355,19 +363,23 @@ class ModelRunner:
|
||||
result['__identifier__'] = identifier
|
||||
return result
|
||||
|
||||
def extract_inference_results_and_maybe_embeddings(
|
||||
def extract_inference_results(
|
||||
self,
|
||||
batch: features.BatchDict,
|
||||
result: model.ModelResult,
|
||||
target_name: str,
|
||||
) -> tuple[list[model.InferenceResult], dict[str, np.ndarray] | None]:
|
||||
"""Extracts inference results and embeddings (if set) from model outputs."""
|
||||
inference_results = list(
|
||||
) -> list[model.InferenceResult]:
|
||||
"""Extracts inference results from model outputs."""
|
||||
return list(
|
||||
model.Model.get_inference_result(
|
||||
batch=batch, result=result, target_name=target_name
|
||||
)
|
||||
)
|
||||
num_tokens = len(inference_results[0].metadata['token_chain_ids'])
|
||||
|
||||
def extract_embeddings(
|
||||
self, result: model.ModelResult, num_tokens: int
|
||||
) -> dict[str, np.ndarray] | None:
|
||||
"""Extracts embeddings from model outputs."""
|
||||
embeddings = {}
|
||||
if 'single_embeddings' in result:
|
||||
embeddings['single_embeddings'] = result['single_embeddings'][:num_tokens]
|
||||
@ -375,7 +387,16 @@ class ModelRunner:
|
||||
embeddings['pair_embeddings'] = result['pair_embeddings'][
|
||||
:num_tokens, :num_tokens
|
||||
]
|
||||
return inference_results, embeddings or None
|
||||
return embeddings or None
|
||||
|
||||
def extract_distogram(
|
||||
self, result: model.ModelResult, num_tokens: int
|
||||
) -> np.ndarray | None:
|
||||
"""Extracts distogram from model outputs."""
|
||||
if 'distogram' not in result['distogram']:
|
||||
return None
|
||||
distogram = result['distogram']['distogram'][:num_tokens, :num_tokens, :]
|
||||
return distogram
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
|
||||
@ -388,12 +409,14 @@ class ResultsForSeed:
|
||||
full_fold_input: The fold input that must also include the results of
|
||||
running the data pipeline - MSA and templates.
|
||||
embeddings: The final trunk single and pair embeddings, if requested.
|
||||
distogram: The token distance histogram, if requested.
|
||||
"""
|
||||
|
||||
seed: int
|
||||
inference_results: Sequence[model.InferenceResult]
|
||||
full_fold_input: folding_input.Input
|
||||
embeddings: dict[str, np.ndarray] | None = None
|
||||
distogram: np.ndarray | None = None
|
||||
|
||||
|
||||
def predict_structure(
|
||||
@ -437,10 +460,15 @@ def predict_structure(
|
||||
)
|
||||
print(f'Extracting inference results with seed {seed}...')
|
||||
extract_structures = time.time()
|
||||
inference_results, embeddings = (
|
||||
model_runner.extract_inference_results_and_maybe_embeddings(
|
||||
batch=example, result=result, target_name=fold_input.name
|
||||
)
|
||||
inference_results = model_runner.extract_inference_results(
|
||||
batch=example, result=result, target_name=fold_input.name
|
||||
)
|
||||
num_tokens = len(inference_results[0].metadata['token_chain_ids'])
|
||||
embeddings = model_runner.extract_embeddings(
|
||||
result=result, num_tokens=num_tokens
|
||||
)
|
||||
distogram = model_runner.extract_distogram(
|
||||
result=result, num_tokens=num_tokens
|
||||
)
|
||||
print(
|
||||
f'Extracting {len(inference_results)} inference samples with'
|
||||
@ -453,6 +481,7 @@ def predict_structure(
|
||||
inference_results=inference_results,
|
||||
full_fold_input=fold_input,
|
||||
embeddings=embeddings,
|
||||
distogram=distogram,
|
||||
)
|
||||
)
|
||||
print(
|
||||
@ -515,6 +544,15 @@ def write_outputs(
|
||||
name=f'{job_name}_seed-{seed}',
|
||||
)
|
||||
|
||||
if (distogram := results_for_seed.distogram) is not None:
|
||||
distogram_dir = os.path.join(output_dir, f'seed-{seed}_distogram')
|
||||
os.makedirs(distogram_dir, exist_ok=True)
|
||||
distogram_path = os.path.join(
|
||||
distogram_dir, f'{job_name}_seed-{seed}_distogram.npz'
|
||||
)
|
||||
with open(distogram_path, 'wb') as f:
|
||||
np.savez_compressed(f, distogram=distogram.astype(np.float16))
|
||||
|
||||
if max_ranking_result is not None: # True iff ranking_scores non-empty.
|
||||
post_processing.write_output(
|
||||
inference_result=max_ranking_result,
|
||||
@ -794,6 +832,7 @@ def main(_):
|
||||
num_diffusion_samples=_NUM_DIFFUSION_SAMPLES.value,
|
||||
num_recycles=_NUM_RECYCLES.value,
|
||||
return_embeddings=_SAVE_EMBEDDINGS.value,
|
||||
return_distogram=_SAVE_DISTOGRAM.value,
|
||||
),
|
||||
device=devices[_GPU_DEVICE.value],
|
||||
model_dir=pathlib.Path(MODEL_DIR.value),
|
||||
|
@ -144,7 +144,9 @@ class InferenceTest(test_utils.StructureTestCase):
|
||||
}
|
||||
self._test_input_json = json.dumps(test_input)
|
||||
self._model_config = run_alphafold.make_model_config(
|
||||
return_embeddings=True, flash_attention_implementation='triton'
|
||||
flash_attention_implementation='triton',
|
||||
return_embeddings=True,
|
||||
return_distogram=True,
|
||||
)
|
||||
self._runner = run_alphafold.ModelRunner(
|
||||
config=self._model_config,
|
||||
@ -164,9 +166,13 @@ class InferenceTest(test_utils.StructureTestCase):
|
||||
featurised_example, jax.random.PRNGKey(0)
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
_, embeddings = self._runner.extract_inference_results_and_maybe_embeddings(
|
||||
inference_results = self._runner.extract_inference_results(
|
||||
batch=featurised_example, result=result, target_name='target'
|
||||
)
|
||||
embeddings = self._runner.extract_embeddings(
|
||||
result=result,
|
||||
num_tokens=len(inference_results[0].metadata['token_chain_ids']),
|
||||
)
|
||||
self.assertLen(embeddings, 2)
|
||||
|
||||
def test_process_fold_input_runs_only_inference(self):
|
||||
@ -232,6 +238,7 @@ class InferenceTest(test_utils.StructureTestCase):
|
||||
f'{prefix}_sample-3',
|
||||
f'{prefix}_sample-4',
|
||||
f'{prefix}_embeddings',
|
||||
f'{prefix}_distogram',
|
||||
# Top ranking result.
|
||||
expected_confidences_filename,
|
||||
expected_model_cif_filename,
|
||||
@ -273,9 +280,20 @@ class InferenceTest(test_utils.StructureTestCase):
|
||||
# Ligand 7BU has 41 tokens.
|
||||
num_tokens = len(fold_input.protein_chains[0].sequence) + 41
|
||||
self.assertEqual(embeddings['single_embeddings'].shape, (num_tokens, 384))
|
||||
self.assertEqual(embeddings['single_embeddings'].dtype, np.float32)
|
||||
self.assertEqual(
|
||||
embeddings['pair_embeddings'].shape, (num_tokens, num_tokens, 128)
|
||||
)
|
||||
self.assertEqual(embeddings['pair_embeddings'].dtype, np.float32)
|
||||
|
||||
distogram_dir = os.path.join(output_dir, f'{prefix}_distogram')
|
||||
distogram_filename = f'{fold_input.sanitised_name()}_{prefix}_distogram.npz'
|
||||
self.assertSameElements(os.listdir(distogram_dir), [distogram_filename])
|
||||
|
||||
with open(os.path.join(distogram_dir, distogram_filename), 'rb') as f:
|
||||
distogram = np.load(f)['distogram']
|
||||
self.assertEqual(distogram.shape, (num_tokens, num_tokens, 64))
|
||||
self.assertEqual(distogram.dtype, np.float16)
|
||||
|
||||
with open(os.path.join(output_dir, expected_data_json_filename), 'rt') as f:
|
||||
actual_input_json = json.load(f)
|
||||
|
@ -224,6 +224,7 @@ class Model(hk.Module):
|
||||
heads: 'Model.HeadsConfig' = base_config.autocreate()
|
||||
num_recycles: int = 10
|
||||
return_embeddings: bool = False
|
||||
return_distogram: bool = False
|
||||
|
||||
def __init__(self, config: Config, name: str = 'diffuser'):
|
||||
super().__init__(name=name)
|
||||
@ -327,7 +328,7 @@ class Model(hk.Module):
|
||||
|
||||
distogram = distogram_head.DistogramHead(
|
||||
self.config.heads.distogram, self.global_config
|
||||
)(batch, embeddings)
|
||||
)(batch, embeddings, return_distogram=self.config.return_distogram)
|
||||
|
||||
output = {
|
||||
'diffusion_samples': samples,
|
||||
|
@ -47,6 +47,7 @@ class DistogramHead(hk.Module):
|
||||
self,
|
||||
batch: feat_batch.Batch,
|
||||
embeddings: dict[str, jnp.ndarray],
|
||||
return_distogram: bool = False,
|
||||
) -> dict[str, jnp.ndarray]:
|
||||
pair_act = embeddings['pair']
|
||||
seq_mask = batch.token_features.mask.astype(bool)
|
||||
@ -75,7 +76,8 @@ class DistogramHead(hk.Module):
|
||||
)
|
||||
contact_probs = pair_mask * contact_probs
|
||||
|
||||
return {
|
||||
'bin_edges': breaks,
|
||||
'contact_probs': contact_probs,
|
||||
}
|
||||
return_dict = {'bin_edges': breaks, 'contact_probs': contact_probs}
|
||||
if return_distogram:
|
||||
return_dict['distogram'] = logits
|
||||
|
||||
return return_dict
|
||||
|
@ -226,5 +226,6 @@
|
||||
}
|
||||
},
|
||||
"num_recycles": 10,
|
||||
"return_distogram": false,
|
||||
"return_embeddings": false
|
||||
}
|
Reference in New Issue
Block a user