Save embeddings as float16 instead of float32 to make them 2x smaller

PiperOrigin-RevId: 758140305
Change-Id: I70cb7adc53440ce19d20fa89d64bbb36b84b13f7
This commit is contained in:
Augustin Zidek
2025-05-13 03:22:32 -07:00
committed by Copybara-Service
parent 424201003f
commit 267ab05055
3 changed files with 7 additions and 5 deletions

View File

@ -24,7 +24,7 @@ The following structure is used within the output directory:
Numpy zip file contains 2 keys: `single_embeddings` and `pair_embeddings`. Numpy zip file contains 2 keys: `single_embeddings` and `pair_embeddings`.
The embeddings can be large, their shapes are `(num_tokens, 384)` for The embeddings can be large, their shapes are `(num_tokens, 384)` for
`single_embeddings`, and `(num_tokens, num_tokens, 128)` for `single_embeddings`, and `(num_tokens, num_tokens, 128)` for
`pair_embeddings`. Their dtype is `np.float32` (almost 12 GiB for a `pair_embeddings`. Their dtype is `np.float16` (almost 6 GiB for a
5,000-token input). Only saved if AlphaFold 3 is run with 5,000-token input). Only saved if AlphaFold 3 is run with
`--save_embeddings=true`. `--save_embeddings=true`.
* Top-ranking prediction mmCIF: `<job_name>_model.cif`. This file contains the * Top-ranking prediction mmCIF: `<job_name>_model.cif`. This file contains the

View File

@ -382,11 +382,13 @@ class ModelRunner:
"""Extracts embeddings from model outputs.""" """Extracts embeddings from model outputs."""
embeddings = {} embeddings = {}
if 'single_embeddings' in result: if 'single_embeddings' in result:
embeddings['single_embeddings'] = result['single_embeddings'][:num_tokens] embeddings['single_embeddings'] = result['single_embeddings'][
:num_tokens
].astype(np.float16)
if 'pair_embeddings' in result: if 'pair_embeddings' in result:
embeddings['pair_embeddings'] = result['pair_embeddings'][ embeddings['pair_embeddings'] = result['pair_embeddings'][
:num_tokens, :num_tokens :num_tokens, :num_tokens
] ].astype(np.float16)
return embeddings or None return embeddings or None
def extract_distogram( def extract_distogram(

View File

@ -280,11 +280,11 @@ class InferenceTest(test_utils.StructureTestCase):
# Ligand 7BU has 41 tokens. # Ligand 7BU has 41 tokens.
num_tokens = len(fold_input.protein_chains[0].sequence) + 41 num_tokens = len(fold_input.protein_chains[0].sequence) + 41
self.assertEqual(embeddings['single_embeddings'].shape, (num_tokens, 384)) self.assertEqual(embeddings['single_embeddings'].shape, (num_tokens, 384))
self.assertEqual(embeddings['single_embeddings'].dtype, np.float32) self.assertEqual(embeddings['single_embeddings'].dtype, np.float16)
self.assertEqual( self.assertEqual(
embeddings['pair_embeddings'].shape, (num_tokens, num_tokens, 128) embeddings['pair_embeddings'].shape, (num_tokens, num_tokens, 128)
) )
self.assertEqual(embeddings['pair_embeddings'].dtype, np.float32) self.assertEqual(embeddings['pair_embeddings'].dtype, np.float16)
distogram_dir = os.path.join(output_dir, f'{prefix}_distogram') distogram_dir = os.path.join(output_dir, f'{prefix}_distogram')
distogram_filename = f'{fold_input.sanitised_name()}_{prefix}_distogram.npz' distogram_filename = f'{fold_input.sanitised_name()}_{prefix}_distogram.npz'