Remove excess code from run_alphafold_test, clean up the testing a little.

PiperOrigin-RevId: 725260840
Change-Id: I5199a62b3b3cac9e898993dc65620c8d135451a9
This commit is contained in:
Josh Abramson
2025-02-10 10:16:25 -08:00
committed by Copybara-Service
parent 90c3d05aa4
commit bb0e5415ef

View File

@ -14,23 +14,18 @@ import contextlib
import csv
import datetime
import difflib
import functools
import hashlib
import json
import os
import pathlib
import pickle
from typing import Any
from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
from alphafold3 import structure
from alphafold3.common import folding_input
from alphafold3.common import resources
from alphafold3.common.testing import data as testing_data
from alphafold3.data import pipeline
from alphafold3.model.atom_layout import atom_layout
from alphafold3.model.scoring import alignment
from alphafold3.structure import test_utils
import jax
@ -68,40 +63,6 @@ def _generate_diff(actual: str, expected: str) -> str:
)
@functools.singledispatch
def _hash_data(x: Any, /) -> str:
if x is None:
return '<<None>>'
return _hash_data(json.dumps(x).encode('utf-8'))
@_hash_data.register
def _(x: bytes, /) -> str:
return hashlib.sha256(x).hexdigest()
@_hash_data.register
def _(x: jax.Array) -> str:
return _hash_data(jax.device_get(x))
@_hash_data.register
def _(x: np.ndarray) -> str:
if x.dtype == object:
return ';'.join(map(_hash_data, x.ravel().tolist()))
return _hash_data(x.tobytes())
@_hash_data.register
def _(_: structure.Structure) -> str:
return '<<structure>>'
@_hash_data.register
def _(_: atom_layout.AtomLayout) -> str:
return '<<atom-layout>>'
class InferenceTest(test_utils.StructureTestCase):
"""Test AlphaFold 3 inference."""
@ -185,22 +146,8 @@ class InferenceTest(test_utils.StructureTestCase):
model_dir=pathlib.Path(run_alphafold.MODEL_DIR.value),
)
def compare_golden(self, result_path: str) -> None:
filename = os.path.split(result_path)[1]
golden_path = testing_data.Data(
resources.ROOT / f'test_data/{filename}'
).path()
with open(golden_path, 'r') as golden_file:
golden_text = golden_file.read()
with open(result_path, 'r') as result_file:
result_text = result_file.read()
diff = _generate_diff(result_text, golden_text)
self.assertEqual(diff, "", f"Result differs from golden:\n{diff}")
def test_model_inference(self):
"""Run model inference and assert that the output is as expected."""
"""Run model inference and assert that output exists."""
featurised_examples = pickle.loads(
(resources.ROOT / 'test_data' / 'featurised_example.pkl').read_bytes()
)
@ -210,7 +157,6 @@ class InferenceTest(test_utils.StructureTestCase):
result = self._runner.run_inference(
featurised_example, jax.random.PRNGKey(0)
)
result_hashes = jax.tree_util.tree_map(_hash_data, result)
self.assertIsNotNone(result)
_, embeddings = self._runner.extract_inference_results_and_maybe_embeddings(
batch=featurised_example, result=result, target_name='target'
@ -232,15 +178,13 @@ class InferenceTest(test_utils.StructureTestCase):
{
'testcase_name': 'default_bucket',
'bucket': None,
'exp_ranking_scores': [0.69, 0.69, 0.72, 0.75, 0.70],
},
{
'testcase_name': 'bucket_1024',
'bucket': 1024,
'exp_ranking_scores': [0.69, 0.71, 0.71, 0.69, 0.70],
},
)
def test_inference(self, bucket, exp_ranking_scores):
def test_inference(self, bucket):
"""Run AlphaFold 3 inference."""
### Prepare inputs.
@ -327,21 +271,20 @@ class InferenceTest(test_utils.StructureTestCase):
)
with open(os.path.join(output_dir, 'ranking_scores.csv'), 'rt') as f:
actual_ranking_scores = list(csv.DictReader(f))
ranking_scores = list(csv.DictReader(f))
self.assertLen(actual_ranking_scores, 5)
self.assertLen(ranking_scores, 5)
self.assertEqual([int(s['seed']) for s in ranking_scores], [1234] * 5)
self.assertEqual(
[int(s['seed']) for s in actual_ranking_scores], [1234] * 5
)
self.assertEqual(
[int(s['sample']) for s in actual_ranking_scores], [0, 1, 2, 3, 4]
)
np.testing.assert_array_almost_equal(
[float(s['ranking_score']) for s in actual_ranking_scores],
exp_ranking_scores,
decimal=2,
[int(s['sample']) for s in ranking_scores], [0, 1, 2, 3, 4]
)
# Ranking score should be between 0.66 and 0.76 for all samples.
ranking_scores = [float(s['ranking_score']) for s in ranking_scores]
scores_ok = [0.66 <= score <= 0.76 for score in ranking_scores]
if not all(scores_ok):
self.fail(f'{ranking_scores=} are not in expected range [0.66, 0.76]')
with open(os.path.join(output_dir, 'TERMS_OF_USE.md'), 'rt') as f:
actual_terms_of_use = f.read()
self.assertStartsWith(
@ -384,22 +327,25 @@ class InferenceTest(test_utils.StructureTestCase):
expected_inf.inference_results,
strict=True,
):
# Check RMSD is within tolerance.
# 5tgy is very stable, NMR samples were all within 3.0 RMSD.
actual_rmsd = alignment.rmsd_from_coords(
actual_inf.predicted_structure.coords,
expected_inf.predicted_structure.coords,
)
self.assertLess(actual_rmsd, 3.0)
np.testing.assert_array_equal(
actual_inf.predicted_structure.atom_occupancy,
[1.0] * actual_inf.predicted_structure.num_atoms,
)
# Make sure the token chain IDs are the same as the input chain IDs.
self.assertEqual(
actual_inf.metadata['token_chain_ids'],
['P'] * len(fold_input.protein_chains[0].sequence) + ['LL'] * 41,
)
# All atom occupancies should be 1.0.
np.testing.assert_array_equal(
actual_inf.predicted_structure.atom_occupancy,
[1.0] * actual_inf.predicted_structure.num_atoms,
)
# Check RMSD is within tolerance.
# 5tgy is stably predicted, samples should be all within 3.0 RMSD
# regardless of bucket, device type, etc.
actual_rmsd = alignment.rmsd_from_coords(
actual_inf.predicted_structure.coords,
expected_inf.predicted_structure.coords,
)
self.assertLess(actual_rmsd, 3.0)
if __name__ == '__main__':
absltest.main()