diff --git a/docs/output.md b/docs/output.md index 1386a3b..45bf4a2 100644 --- a/docs/output.md +++ b/docs/output.md @@ -24,7 +24,7 @@ The following structure is used within the output directory: Numpy zip file contains 2 keys: `single_embeddings` and `pair_embeddings`. The embeddings can be large, their shapes are `(num_tokens, 384)` for `single_embeddings`, and `(num_tokens, num_tokens, 128)` for - `pair_embeddings`. Their dtype is `np.float32` (almost 12 GiB for a + `pair_embeddings`. Their dtype is `np.float16` (almost 6 GiB for a 5,000-token input). Only saved if AlphaFold 3 is run with `--save_embeddings=true`. * Top-ranking prediction mmCIF: `_model.cif`. This file contains the diff --git a/run_alphafold.py b/run_alphafold.py index f9ccff3..9015033 100644 --- a/run_alphafold.py +++ b/run_alphafold.py @@ -382,11 +382,13 @@ class ModelRunner: """Extracts embeddings from model outputs.""" embeddings = {} if 'single_embeddings' in result: - embeddings['single_embeddings'] = result['single_embeddings'][:num_tokens] + embeddings['single_embeddings'] = result['single_embeddings'][ + :num_tokens + ].astype(np.float16) if 'pair_embeddings' in result: embeddings['pair_embeddings'] = result['pair_embeddings'][ :num_tokens, :num_tokens - ] + ].astype(np.float16) return embeddings or None def extract_distogram( diff --git a/run_alphafold_test.py b/run_alphafold_test.py index 0d7ff22..22d0787 100644 --- a/run_alphafold_test.py +++ b/run_alphafold_test.py @@ -280,11 +280,11 @@ class InferenceTest(test_utils.StructureTestCase): # Ligand 7BU has 41 tokens. num_tokens = len(fold_input.protein_chains[0].sequence) + 41 self.assertEqual(embeddings['single_embeddings'].shape, (num_tokens, 384)) - self.assertEqual(embeddings['single_embeddings'].dtype, np.float32) + self.assertEqual(embeddings['single_embeddings'].dtype, np.float16) self.assertEqual( embeddings['pair_embeddings'].shape, (num_tokens, num_tokens, 128) ) - self.assertEqual(embeddings['pair_embeddings'].dtype, np.float32) + self.assertEqual(embeddings['pair_embeddings'].dtype, np.float16) distogram_dir = os.path.join(output_dir, f'{prefix}_distogram') distogram_filename = f'{fold_input.sanitised_name()}_{prefix}_distogram.npz'