Add a flag to enable skipping unpaired MSA deduplication against paired MSA

Addresses an issue reported in https://github.com/google-deepmind/alphafold3/issues/404.

PiperOrigin-RevId: 761917258
Change-Id: I26592af0f40481a40ba611130739a40a62e4a6dc
This commit is contained in:
Augustin Zidek
2025-05-22 05:10:40 -07:00
committed by Copybara-Service
parent fbabfe25b0
commit e64804c49f
5 changed files with 55 additions and 9 deletions

View File

@ -486,6 +486,12 @@ opposed to relying on name-matching post-processing heuristics used for
When setting `unpairedMsa` manually, the `pairedMsa` must be explicitly set to
an empty string (`""`).
Make sure to run with `--resolve_msa_overlaps=false`. This prevents
deduplication of the unpaired MSA within each chain against the paired MSA
sequences. Even if you set `pairedMsa` to an empty string, the query sequence(s)
will still be added in there and the deduplication procedure could destroy the
carefully crafted sequence positioning in the unpaired MSA.
For instance, if there are two chains `DEEP` and `MIND` which we want to be
paired on organism A and C, we can achieve it as follows:

View File

@ -179,17 +179,28 @@ _SEQRES_DATABASE_PATH = flags.DEFINE_string(
_JACKHMMER_N_CPU = flags.DEFINE_integer(
'jackhmmer_n_cpu',
min(multiprocessing.cpu_count(), 8),
'Number of CPUs to use for Jackhmmer. Default to min(cpu_count, 8). Going'
' beyond 8 CPUs provides very little additional speedup.',
'Number of CPUs to use for Jackhmmer. Defaults to min(cpu_count, 8). Going'
' above 8 CPUs provides very little additional speedup.',
lower_bound=0,
)
_NHMMER_N_CPU = flags.DEFINE_integer(
'nhmmer_n_cpu',
min(multiprocessing.cpu_count(), 8),
'Number of CPUs to use for Nhmmer. Default to min(cpu_count, 8). Going'
' beyond 8 CPUs provides very little additional speedup.',
'Number of CPUs to use for Nhmmer. Defaults to min(cpu_count, 8). Going'
' above 8 CPUs provides very little additional speedup.',
lower_bound=0,
)
# Template search configuration.
# Data pipeline configuration.
_RESOLVE_MSA_OVERLAPS = flags.DEFINE_bool(
'resolve_msa_overlaps',
True,
'Whether to deduplicate unpaired MSA against paired MSA. The default'
' behaviour matches the method described in the AlphaFold 3 paper. Set this'
' to false if providing custom paired MSA using the unpaired MSA field to'
' keep it exactly as is as deduplication against the paired MSA could break'
' the manually crafted pairing between MSA sequences.',
)
_MAX_TEMPLATE_DATE = flags.DEFINE_string(
'max_template_date',
'2021-09-30', # By default, use the date from the AlphaFold 3 paper.
@ -200,12 +211,12 @@ _MAX_TEMPLATE_DATE = flags.DEFINE_string(
' coordinates set. Only for components that have been released before this'
' date the model coordinates can be used as a fallback.',
)
_CONFORMER_MAX_ITERATIONS = flags.DEFINE_integer(
'conformer_max_iterations',
None, # Default to RDKit default parameters value.
'Optional override for maximum number of iterations to run for RDKit '
'conformer search.',
lower_bound=0,
)
# JAX inference performance tuning.
@ -429,6 +440,7 @@ def predict_structure(
buckets: Sequence[int] | None = None,
ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
) -> Sequence[ResultsForSeed]:
"""Runs the full inference pipeline to predict structures for each seed."""
@ -442,6 +454,7 @@ def predict_structure(
verbose=True,
ref_max_modified_date=ref_max_modified_date,
conformer_max_iterations=conformer_max_iterations,
resolve_msa_overlaps=resolve_msa_overlaps,
)
print(
f'Featurising data with {len(fold_input.rng_seeds)} seed(s) took'
@ -600,6 +613,7 @@ def process_fold_input(
buckets: Sequence[int] | None = None,
ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
force_output_dir: bool = False,
) -> folding_input.Input:
...
@ -614,6 +628,7 @@ def process_fold_input(
buckets: Sequence[int] | None = None,
ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
force_output_dir: bool = False,
) -> Sequence[ResultsForSeed]:
...
@ -627,6 +642,7 @@ def process_fold_input(
buckets: Sequence[int] | None = None,
ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
force_output_dir: bool = False,
) -> folding_input.Input | Sequence[ResultsForSeed]:
"""Runs data pipeline and/or inference on a single fold input.
@ -649,6 +665,11 @@ def process_fold_input(
date the model coordinates can be used as a fallback.
conformer_max_iterations: Optional override for maximum number of iterations
to run for RDKit conformer search.
resolve_msa_overlaps: Whether to deduplicate unpaired MSA against paired
MSA. The default behaviour matches the method described in the AlphaFold 3
paper. Set this to false if providing custom paired MSA using the unpaired
MSA field to keep it exactly as is as deduplication against the paired MSA
could break the manually crafted pairing between MSA sequences.
force_output_dir: If True, do not create a new output directory even if the
existing one is non-empty. Instead use the existing output directory and
potentially overwrite existing files. If False, create a new timestamped
@ -702,6 +723,7 @@ def process_fold_input(
buckets=buckets,
ref_max_modified_date=ref_max_modified_date,
conformer_max_iterations=conformer_max_iterations,
resolve_msa_overlaps=resolve_msa_overlaps,
)
print(f'Writing outputs with {len(fold_input.rng_seeds)} seed(s)...')
write_outputs(
@ -860,6 +882,7 @@ def main(_):
buckets=tuple(int(bucket) for bucket in _BUCKETS.value),
ref_max_modified_date=max_template_date,
conformer_max_iterations=_CONFORMER_MAX_ITERATIONS.value,
resolve_msa_overlaps=_RESOLVE_MSA_OVERLAPS.value,
force_output_dir=_FORCE_OUTPUT_DIR.value,
)
num_fold_inputs += 1

View File

@ -41,6 +41,7 @@ def featurise_input(
buckets: Sequence[int] | None,
ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
verbose: bool = False,
) -> Sequence[features.BatchDict]:
"""Featurise the folding input.
@ -60,6 +61,11 @@ def featurise_input(
date the model coordinates can be used as a fallback.
conformer_max_iterations: Optional override for maximum number of iterations
to run for RDKit conformer search.
resolve_msa_overlaps: Whether to deduplicate unpaired MSA against paired
MSA. The default behaviour matches the method described in the AlphaFold 3
paper. Set this to false if providing custom paired MSA using the unpaired
MSA field to keep it exactly as is as deduplication against the paired MSA
could break the manually crafted pairing between MSA sequences.
verbose: Whether to print progress messages.
Returns:
@ -73,6 +79,7 @@ def featurise_input(
buckets=buckets,
ref_max_modified_date=ref_max_modified_date,
conformer_max_iterations=conformer_max_iterations,
resolve_msa_overlaps=resolve_msa_overlaps,
),
)

View File

@ -415,6 +415,7 @@ class MSA:
fold_input: folding_input.Input,
logging_name: str,
max_paired_sequence_per_species: int,
resolve_msa_overlaps: bool = True,
) -> Self:
"""Compute the msa features."""
seen_entities = {}
@ -533,9 +534,10 @@ class MSA:
nonempty_chain_ids=nonempty_chain_ids,
max_hits_per_species=max_paired_sequence_per_species,
)
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(
np_chains_list
)
if resolve_msa_overlaps:
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(
np_chains_list
)
# Remove all gapped rows from all seqs.
nonempty_asym_ids = []

View File

@ -118,6 +118,12 @@ class WholePdbPipeline:
symmetric polymer chains.
deterministic_frames: Whether to use fixed-seed reference positions to
construct deterministic frames.
resolve_msa_overlaps: Whether to deduplicate unpaired MSA against paired
MSA. The default behaviour matches the method described in the AlphaFold
3 paper. Set this to false if providing custom paired MSA using the
unpaired MSA field to keep it exactly as is as deduplication against
the paired MSA could break the manually crafted pairing between MSA
sequences.
"""
max_atoms_per_token: int = 24
@ -140,6 +146,7 @@ class WholePdbPipeline:
remove_nonsymmetric_bonds: bool = False
deterministic_frames: bool = True
conformer_max_iterations: int | None = None
resolve_msa_overlaps: bool = True
def __init__(self, *, config: Config):
"""Initializes WholePdb data pipeline.
@ -338,6 +345,7 @@ class WholePdbPipeline:
fold_input=fold_input,
logging_name=logging_name,
max_paired_sequence_per_species=self._config.max_paired_sequence_per_species,
resolve_msa_overlaps=self._config.resolve_msa_overlaps,
)
# Create template features