mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2025-10-20 13:23:47 +08:00
Save embeddings as float16 instead of float32 to make them 2x smaller
PiperOrigin-RevId: 758140305 Change-Id: I70cb7adc53440ce19d20fa89d64bbb36b84b13f7
This commit is contained in:
committed by
Copybara-Service
parent
424201003f
commit
267ab05055
@ -24,7 +24,7 @@ The following structure is used within the output directory:
|
|||||||
Numpy zip file contains 2 keys: `single_embeddings` and `pair_embeddings`.
|
Numpy zip file contains 2 keys: `single_embeddings` and `pair_embeddings`.
|
||||||
The embeddings can be large, their shapes are `(num_tokens, 384)` for
|
The embeddings can be large, their shapes are `(num_tokens, 384)` for
|
||||||
`single_embeddings`, and `(num_tokens, num_tokens, 128)` for
|
`single_embeddings`, and `(num_tokens, num_tokens, 128)` for
|
||||||
`pair_embeddings`. Their dtype is `np.float32` (almost 12 GiB for a
|
`pair_embeddings`. Their dtype is `np.float16` (almost 6 GiB for a
|
||||||
5,000-token input). Only saved if AlphaFold 3 is run with
|
5,000-token input). Only saved if AlphaFold 3 is run with
|
||||||
`--save_embeddings=true`.
|
`--save_embeddings=true`.
|
||||||
* Top-ranking prediction mmCIF: `<job_name>_model.cif`. This file contains the
|
* Top-ranking prediction mmCIF: `<job_name>_model.cif`. This file contains the
|
||||||
|
@ -382,11 +382,13 @@ class ModelRunner:
|
|||||||
"""Extracts embeddings from model outputs."""
|
"""Extracts embeddings from model outputs."""
|
||||||
embeddings = {}
|
embeddings = {}
|
||||||
if 'single_embeddings' in result:
|
if 'single_embeddings' in result:
|
||||||
embeddings['single_embeddings'] = result['single_embeddings'][:num_tokens]
|
embeddings['single_embeddings'] = result['single_embeddings'][
|
||||||
|
:num_tokens
|
||||||
|
].astype(np.float16)
|
||||||
if 'pair_embeddings' in result:
|
if 'pair_embeddings' in result:
|
||||||
embeddings['pair_embeddings'] = result['pair_embeddings'][
|
embeddings['pair_embeddings'] = result['pair_embeddings'][
|
||||||
:num_tokens, :num_tokens
|
:num_tokens, :num_tokens
|
||||||
]
|
].astype(np.float16)
|
||||||
return embeddings or None
|
return embeddings or None
|
||||||
|
|
||||||
def extract_distogram(
|
def extract_distogram(
|
||||||
|
@ -280,11 +280,11 @@ class InferenceTest(test_utils.StructureTestCase):
|
|||||||
# Ligand 7BU has 41 tokens.
|
# Ligand 7BU has 41 tokens.
|
||||||
num_tokens = len(fold_input.protein_chains[0].sequence) + 41
|
num_tokens = len(fold_input.protein_chains[0].sequence) + 41
|
||||||
self.assertEqual(embeddings['single_embeddings'].shape, (num_tokens, 384))
|
self.assertEqual(embeddings['single_embeddings'].shape, (num_tokens, 384))
|
||||||
self.assertEqual(embeddings['single_embeddings'].dtype, np.float32)
|
self.assertEqual(embeddings['single_embeddings'].dtype, np.float16)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
embeddings['pair_embeddings'].shape, (num_tokens, num_tokens, 128)
|
embeddings['pair_embeddings'].shape, (num_tokens, num_tokens, 128)
|
||||||
)
|
)
|
||||||
self.assertEqual(embeddings['pair_embeddings'].dtype, np.float32)
|
self.assertEqual(embeddings['pair_embeddings'].dtype, np.float16)
|
||||||
|
|
||||||
distogram_dir = os.path.join(output_dir, f'{prefix}_distogram')
|
distogram_dir = os.path.join(output_dir, f'{prefix}_distogram')
|
||||||
distogram_filename = f'{fold_input.sanitised_name()}_{prefix}_distogram.npz'
|
distogram_filename = f'{fold_input.sanitised_name()}_{prefix}_distogram.npz'
|
||||||
|
Reference in New Issue
Block a user