mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2025-10-20 13:23:47 +08:00
Remove excess code from run_alphafold_test, clean up the testing a little.
PiperOrigin-RevId: 725260840 Change-Id: I5199a62b3b3cac9e898993dc65620c8d135451a9
This commit is contained in:
committed by
Copybara-Service
parent
90c3d05aa4
commit
bb0e5415ef
@ -14,23 +14,18 @@ import contextlib
|
||||
import csv
|
||||
import datetime
|
||||
import difflib
|
||||
import functools
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import pickle
|
||||
from typing import Any
|
||||
|
||||
from absl import logging
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
from alphafold3 import structure
|
||||
from alphafold3.common import folding_input
|
||||
from alphafold3.common import resources
|
||||
from alphafold3.common.testing import data as testing_data
|
||||
from alphafold3.data import pipeline
|
||||
from alphafold3.model.atom_layout import atom_layout
|
||||
from alphafold3.model.scoring import alignment
|
||||
from alphafold3.structure import test_utils
|
||||
import jax
|
||||
@ -68,40 +63,6 @@ def _generate_diff(actual: str, expected: str) -> str:
|
||||
)
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def _hash_data(x: Any, /) -> str:
|
||||
if x is None:
|
||||
return '<<None>>'
|
||||
return _hash_data(json.dumps(x).encode('utf-8'))
|
||||
|
||||
|
||||
@_hash_data.register
|
||||
def _(x: bytes, /) -> str:
|
||||
return hashlib.sha256(x).hexdigest()
|
||||
|
||||
|
||||
@_hash_data.register
|
||||
def _(x: jax.Array) -> str:
|
||||
return _hash_data(jax.device_get(x))
|
||||
|
||||
|
||||
@_hash_data.register
|
||||
def _(x: np.ndarray) -> str:
|
||||
if x.dtype == object:
|
||||
return ';'.join(map(_hash_data, x.ravel().tolist()))
|
||||
return _hash_data(x.tobytes())
|
||||
|
||||
|
||||
@_hash_data.register
|
||||
def _(_: structure.Structure) -> str:
|
||||
return '<<structure>>'
|
||||
|
||||
|
||||
@_hash_data.register
|
||||
def _(_: atom_layout.AtomLayout) -> str:
|
||||
return '<<atom-layout>>'
|
||||
|
||||
|
||||
class InferenceTest(test_utils.StructureTestCase):
|
||||
"""Test AlphaFold 3 inference."""
|
||||
|
||||
@ -185,22 +146,8 @@ class InferenceTest(test_utils.StructureTestCase):
|
||||
model_dir=pathlib.Path(run_alphafold.MODEL_DIR.value),
|
||||
)
|
||||
|
||||
def compare_golden(self, result_path: str) -> None:
|
||||
filename = os.path.split(result_path)[1]
|
||||
golden_path = testing_data.Data(
|
||||
resources.ROOT / f'test_data/{filename}'
|
||||
).path()
|
||||
with open(golden_path, 'r') as golden_file:
|
||||
golden_text = golden_file.read()
|
||||
with open(result_path, 'r') as result_file:
|
||||
result_text = result_file.read()
|
||||
|
||||
diff = _generate_diff(result_text, golden_text)
|
||||
|
||||
self.assertEqual(diff, "", f"Result differs from golden:\n{diff}")
|
||||
|
||||
def test_model_inference(self):
|
||||
"""Run model inference and assert that the output is as expected."""
|
||||
"""Run model inference and assert that output exists."""
|
||||
featurised_examples = pickle.loads(
|
||||
(resources.ROOT / 'test_data' / 'featurised_example.pkl').read_bytes()
|
||||
)
|
||||
@ -210,7 +157,6 @@ class InferenceTest(test_utils.StructureTestCase):
|
||||
result = self._runner.run_inference(
|
||||
featurised_example, jax.random.PRNGKey(0)
|
||||
)
|
||||
result_hashes = jax.tree_util.tree_map(_hash_data, result)
|
||||
self.assertIsNotNone(result)
|
||||
_, embeddings = self._runner.extract_inference_results_and_maybe_embeddings(
|
||||
batch=featurised_example, result=result, target_name='target'
|
||||
@ -232,15 +178,13 @@ class InferenceTest(test_utils.StructureTestCase):
|
||||
{
|
||||
'testcase_name': 'default_bucket',
|
||||
'bucket': None,
|
||||
'exp_ranking_scores': [0.69, 0.69, 0.72, 0.75, 0.70],
|
||||
},
|
||||
{
|
||||
'testcase_name': 'bucket_1024',
|
||||
'bucket': 1024,
|
||||
'exp_ranking_scores': [0.69, 0.71, 0.71, 0.69, 0.70],
|
||||
},
|
||||
)
|
||||
def test_inference(self, bucket, exp_ranking_scores):
|
||||
def test_inference(self, bucket):
|
||||
"""Run AlphaFold 3 inference."""
|
||||
|
||||
### Prepare inputs.
|
||||
@ -327,21 +271,20 @@ class InferenceTest(test_utils.StructureTestCase):
|
||||
)
|
||||
|
||||
with open(os.path.join(output_dir, 'ranking_scores.csv'), 'rt') as f:
|
||||
actual_ranking_scores = list(csv.DictReader(f))
|
||||
ranking_scores = list(csv.DictReader(f))
|
||||
|
||||
self.assertLen(actual_ranking_scores, 5)
|
||||
self.assertLen(ranking_scores, 5)
|
||||
self.assertEqual([int(s['seed']) for s in ranking_scores], [1234] * 5)
|
||||
self.assertEqual(
|
||||
[int(s['seed']) for s in actual_ranking_scores], [1234] * 5
|
||||
)
|
||||
self.assertEqual(
|
||||
[int(s['sample']) for s in actual_ranking_scores], [0, 1, 2, 3, 4]
|
||||
)
|
||||
np.testing.assert_array_almost_equal(
|
||||
[float(s['ranking_score']) for s in actual_ranking_scores],
|
||||
exp_ranking_scores,
|
||||
decimal=2,
|
||||
[int(s['sample']) for s in ranking_scores], [0, 1, 2, 3, 4]
|
||||
)
|
||||
|
||||
# Ranking score should be between 0.66 and 0.76 for all samples.
|
||||
ranking_scores = [float(s['ranking_score']) for s in ranking_scores]
|
||||
scores_ok = [0.66 <= score <= 0.76 for score in ranking_scores]
|
||||
if not all(scores_ok):
|
||||
self.fail(f'{ranking_scores=} are not in expected range [0.66, 0.76]')
|
||||
|
||||
with open(os.path.join(output_dir, 'TERMS_OF_USE.md'), 'rt') as f:
|
||||
actual_terms_of_use = f.read()
|
||||
self.assertStartsWith(
|
||||
@ -384,22 +327,25 @@ class InferenceTest(test_utils.StructureTestCase):
|
||||
expected_inf.inference_results,
|
||||
strict=True,
|
||||
):
|
||||
# Check RMSD is within tolerance.
|
||||
# 5tgy is very stable, NMR samples were all within 3.0 RMSD.
|
||||
actual_rmsd = alignment.rmsd_from_coords(
|
||||
actual_inf.predicted_structure.coords,
|
||||
expected_inf.predicted_structure.coords,
|
||||
)
|
||||
self.assertLess(actual_rmsd, 3.0)
|
||||
np.testing.assert_array_equal(
|
||||
actual_inf.predicted_structure.atom_occupancy,
|
||||
[1.0] * actual_inf.predicted_structure.num_atoms,
|
||||
)
|
||||
# Make sure the token chain IDs are the same as the input chain IDs.
|
||||
self.assertEqual(
|
||||
actual_inf.metadata['token_chain_ids'],
|
||||
['P'] * len(fold_input.protein_chains[0].sequence) + ['LL'] * 41,
|
||||
)
|
||||
# All atom occupancies should be 1.0.
|
||||
np.testing.assert_array_equal(
|
||||
actual_inf.predicted_structure.atom_occupancy,
|
||||
[1.0] * actual_inf.predicted_structure.num_atoms,
|
||||
)
|
||||
# Check RMSD is within tolerance.
|
||||
# 5tgy is stably predicted, samples should be all within 3.0 RMSD
|
||||
# regardless of bucket, device type, etc.
|
||||
actual_rmsd = alignment.rmsd_from_coords(
|
||||
actual_inf.predicted_structure.coords,
|
||||
expected_inf.predicted_structure.coords,
|
||||
)
|
||||
self.assertLess(actual_rmsd, 3.0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Reference in New Issue
Block a user