mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2025-10-20 13:23:47 +08:00
Save embeddings as float16 instead of float32 to make them 2x smaller
PiperOrigin-RevId: 758140305 Change-Id: I70cb7adc53440ce19d20fa89d64bbb36b84b13f7
This commit is contained in:
committed by
Copybara-Service
parent
424201003f
commit
267ab05055
@ -24,7 +24,7 @@ The following structure is used within the output directory:
|
||||
Numpy zip file contains 2 keys: `single_embeddings` and `pair_embeddings`.
|
||||
The embeddings can be large, their shapes are `(num_tokens, 384)` for
|
||||
`single_embeddings`, and `(num_tokens, num_tokens, 128)` for
|
||||
`pair_embeddings`. Their dtype is `np.float32` (almost 12 GiB for a
|
||||
`pair_embeddings`. Their dtype is `np.float16` (almost 6 GiB for a
|
||||
5,000-token input). Only saved if AlphaFold 3 is run with
|
||||
`--save_embeddings=true`.
|
||||
* Top-ranking prediction mmCIF: `<job_name>_model.cif`. This file contains the
|
||||
|
@ -382,11 +382,13 @@ class ModelRunner:
|
||||
"""Extracts embeddings from model outputs."""
|
||||
embeddings = {}
|
||||
if 'single_embeddings' in result:
|
||||
embeddings['single_embeddings'] = result['single_embeddings'][:num_tokens]
|
||||
embeddings['single_embeddings'] = result['single_embeddings'][
|
||||
:num_tokens
|
||||
].astype(np.float16)
|
||||
if 'pair_embeddings' in result:
|
||||
embeddings['pair_embeddings'] = result['pair_embeddings'][
|
||||
:num_tokens, :num_tokens
|
||||
]
|
||||
].astype(np.float16)
|
||||
return embeddings or None
|
||||
|
||||
def extract_distogram(
|
||||
|
@ -280,11 +280,11 @@ class InferenceTest(test_utils.StructureTestCase):
|
||||
# Ligand 7BU has 41 tokens.
|
||||
num_tokens = len(fold_input.protein_chains[0].sequence) + 41
|
||||
self.assertEqual(embeddings['single_embeddings'].shape, (num_tokens, 384))
|
||||
self.assertEqual(embeddings['single_embeddings'].dtype, np.float32)
|
||||
self.assertEqual(embeddings['single_embeddings'].dtype, np.float16)
|
||||
self.assertEqual(
|
||||
embeddings['pair_embeddings'].shape, (num_tokens, num_tokens, 128)
|
||||
)
|
||||
self.assertEqual(embeddings['pair_embeddings'].dtype, np.float32)
|
||||
self.assertEqual(embeddings['pair_embeddings'].dtype, np.float16)
|
||||
|
||||
distogram_dir = os.path.join(output_dir, f'{prefix}_distogram')
|
||||
distogram_filename = f'{fold_input.sanitised_name()}_{prefix}_distogram.npz'
|
||||
|
Reference in New Issue
Block a user