Add a flag to enable skipping unpaired MSA deduplication against paired MSA

Addresses an issue reported in https://github.com/google-deepmind/alphafold3/issues/404.

PiperOrigin-RevId: 761917258
Change-Id: I26592af0f40481a40ba611130739a40a62e4a6dc
This commit is contained in:
Augustin Zidek
2025-05-22 05:10:40 -07:00
committed by Copybara-Service
parent fbabfe25b0
commit e64804c49f
5 changed files with 55 additions and 9 deletions

View File

@ -486,6 +486,12 @@ opposed to relying on name-matching post-processing heuristics used for
When setting `unpairedMsa` manually, the `pairedMsa` must be explicitly set to When setting `unpairedMsa` manually, the `pairedMsa` must be explicitly set to
an empty string (`""`). an empty string (`""`).
Make sure to run with `--resolve_msa_overlaps=false`. This prevents
deduplication of the unpaired MSA within each chain against the paired MSA
sequences. Even if you set `pairedMsa` to an empty string, the query sequence(s)
will still be added in there and the deduplication procedure could destroy the
carefully crafted sequence positioning in the unpaired MSA.
For instance, if there are two chains `DEEP` and `MIND` which we want to be For instance, if there are two chains `DEEP` and `MIND` which we want to be
paired on organism A and C, we can achieve it as follows: paired on organism A and C, we can achieve it as follows:

View File

@ -179,17 +179,28 @@ _SEQRES_DATABASE_PATH = flags.DEFINE_string(
_JACKHMMER_N_CPU = flags.DEFINE_integer( _JACKHMMER_N_CPU = flags.DEFINE_integer(
'jackhmmer_n_cpu', 'jackhmmer_n_cpu',
min(multiprocessing.cpu_count(), 8), min(multiprocessing.cpu_count(), 8),
'Number of CPUs to use for Jackhmmer. Default to min(cpu_count, 8). Going' 'Number of CPUs to use for Jackhmmer. Defaults to min(cpu_count, 8). Going'
' beyond 8 CPUs provides very little additional speedup.', ' above 8 CPUs provides very little additional speedup.',
lower_bound=0,
) )
_NHMMER_N_CPU = flags.DEFINE_integer( _NHMMER_N_CPU = flags.DEFINE_integer(
'nhmmer_n_cpu', 'nhmmer_n_cpu',
min(multiprocessing.cpu_count(), 8), min(multiprocessing.cpu_count(), 8),
'Number of CPUs to use for Nhmmer. Default to min(cpu_count, 8). Going' 'Number of CPUs to use for Nhmmer. Defaults to min(cpu_count, 8). Going'
' beyond 8 CPUs provides very little additional speedup.', ' above 8 CPUs provides very little additional speedup.',
lower_bound=0,
) )
# Template search configuration. # Data pipeline configuration.
_RESOLVE_MSA_OVERLAPS = flags.DEFINE_bool(
'resolve_msa_overlaps',
True,
'Whether to deduplicate unpaired MSA against paired MSA. The default'
' behaviour matches the method described in the AlphaFold 3 paper. Set this'
' to false if providing custom paired MSA using the unpaired MSA field to'
' keep it exactly as is as deduplication against the paired MSA could break'
' the manually crafted pairing between MSA sequences.',
)
_MAX_TEMPLATE_DATE = flags.DEFINE_string( _MAX_TEMPLATE_DATE = flags.DEFINE_string(
'max_template_date', 'max_template_date',
'2021-09-30', # By default, use the date from the AlphaFold 3 paper. '2021-09-30', # By default, use the date from the AlphaFold 3 paper.
@ -200,12 +211,12 @@ _MAX_TEMPLATE_DATE = flags.DEFINE_string(
' coordinates set. Only for components that have been released before this' ' coordinates set. Only for components that have been released before this'
' date the model coordinates can be used as a fallback.', ' date the model coordinates can be used as a fallback.',
) )
_CONFORMER_MAX_ITERATIONS = flags.DEFINE_integer( _CONFORMER_MAX_ITERATIONS = flags.DEFINE_integer(
'conformer_max_iterations', 'conformer_max_iterations',
None, # Default to RDKit default parameters value. None, # Default to RDKit default parameters value.
'Optional override for maximum number of iterations to run for RDKit ' 'Optional override for maximum number of iterations to run for RDKit '
'conformer search.', 'conformer search.',
lower_bound=0,
) )
# JAX inference performance tuning. # JAX inference performance tuning.
@ -429,6 +440,7 @@ def predict_structure(
buckets: Sequence[int] | None = None, buckets: Sequence[int] | None = None,
ref_max_modified_date: datetime.date | None = None, ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None, conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
) -> Sequence[ResultsForSeed]: ) -> Sequence[ResultsForSeed]:
"""Runs the full inference pipeline to predict structures for each seed.""" """Runs the full inference pipeline to predict structures for each seed."""
@ -442,6 +454,7 @@ def predict_structure(
verbose=True, verbose=True,
ref_max_modified_date=ref_max_modified_date, ref_max_modified_date=ref_max_modified_date,
conformer_max_iterations=conformer_max_iterations, conformer_max_iterations=conformer_max_iterations,
resolve_msa_overlaps=resolve_msa_overlaps,
) )
print( print(
f'Featurising data with {len(fold_input.rng_seeds)} seed(s) took' f'Featurising data with {len(fold_input.rng_seeds)} seed(s) took'
@ -600,6 +613,7 @@ def process_fold_input(
buckets: Sequence[int] | None = None, buckets: Sequence[int] | None = None,
ref_max_modified_date: datetime.date | None = None, ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None, conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
force_output_dir: bool = False, force_output_dir: bool = False,
) -> folding_input.Input: ) -> folding_input.Input:
... ...
@ -614,6 +628,7 @@ def process_fold_input(
buckets: Sequence[int] | None = None, buckets: Sequence[int] | None = None,
ref_max_modified_date: datetime.date | None = None, ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None, conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
force_output_dir: bool = False, force_output_dir: bool = False,
) -> Sequence[ResultsForSeed]: ) -> Sequence[ResultsForSeed]:
... ...
@ -627,6 +642,7 @@ def process_fold_input(
buckets: Sequence[int] | None = None, buckets: Sequence[int] | None = None,
ref_max_modified_date: datetime.date | None = None, ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None, conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
force_output_dir: bool = False, force_output_dir: bool = False,
) -> folding_input.Input | Sequence[ResultsForSeed]: ) -> folding_input.Input | Sequence[ResultsForSeed]:
"""Runs data pipeline and/or inference on a single fold input. """Runs data pipeline and/or inference on a single fold input.
@ -649,6 +665,11 @@ def process_fold_input(
date the model coordinates can be used as a fallback. date the model coordinates can be used as a fallback.
conformer_max_iterations: Optional override for maximum number of iterations conformer_max_iterations: Optional override for maximum number of iterations
to run for RDKit conformer search. to run for RDKit conformer search.
resolve_msa_overlaps: Whether to deduplicate unpaired MSA against paired
MSA. The default behaviour matches the method described in the AlphaFold 3
paper. Set this to false if providing custom paired MSA using the unpaired
MSA field to keep it exactly as is as deduplication against the paired MSA
could break the manually crafted pairing between MSA sequences.
force_output_dir: If True, do not create a new output directory even if the force_output_dir: If True, do not create a new output directory even if the
existing one is non-empty. Instead use the existing output directory and existing one is non-empty. Instead use the existing output directory and
potentially overwrite existing files. If False, create a new timestamped potentially overwrite existing files. If False, create a new timestamped
@ -702,6 +723,7 @@ def process_fold_input(
buckets=buckets, buckets=buckets,
ref_max_modified_date=ref_max_modified_date, ref_max_modified_date=ref_max_modified_date,
conformer_max_iterations=conformer_max_iterations, conformer_max_iterations=conformer_max_iterations,
resolve_msa_overlaps=resolve_msa_overlaps,
) )
print(f'Writing outputs with {len(fold_input.rng_seeds)} seed(s)...') print(f'Writing outputs with {len(fold_input.rng_seeds)} seed(s)...')
write_outputs( write_outputs(
@ -860,6 +882,7 @@ def main(_):
buckets=tuple(int(bucket) for bucket in _BUCKETS.value), buckets=tuple(int(bucket) for bucket in _BUCKETS.value),
ref_max_modified_date=max_template_date, ref_max_modified_date=max_template_date,
conformer_max_iterations=_CONFORMER_MAX_ITERATIONS.value, conformer_max_iterations=_CONFORMER_MAX_ITERATIONS.value,
resolve_msa_overlaps=_RESOLVE_MSA_OVERLAPS.value,
force_output_dir=_FORCE_OUTPUT_DIR.value, force_output_dir=_FORCE_OUTPUT_DIR.value,
) )
num_fold_inputs += 1 num_fold_inputs += 1

View File

@ -41,6 +41,7 @@ def featurise_input(
buckets: Sequence[int] | None, buckets: Sequence[int] | None,
ref_max_modified_date: datetime.date | None = None, ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None, conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
verbose: bool = False, verbose: bool = False,
) -> Sequence[features.BatchDict]: ) -> Sequence[features.BatchDict]:
"""Featurise the folding input. """Featurise the folding input.
@ -60,6 +61,11 @@ def featurise_input(
date the model coordinates can be used as a fallback. date the model coordinates can be used as a fallback.
conformer_max_iterations: Optional override for maximum number of iterations conformer_max_iterations: Optional override for maximum number of iterations
to run for RDKit conformer search. to run for RDKit conformer search.
resolve_msa_overlaps: Whether to deduplicate unpaired MSA against paired
MSA. The default behaviour matches the method described in the AlphaFold 3
paper. Set this to false if providing custom paired MSA using the unpaired
MSA field to keep it exactly as is as deduplication against the paired MSA
could break the manually crafted pairing between MSA sequences.
verbose: Whether to print progress messages. verbose: Whether to print progress messages.
Returns: Returns:
@ -73,6 +79,7 @@ def featurise_input(
buckets=buckets, buckets=buckets,
ref_max_modified_date=ref_max_modified_date, ref_max_modified_date=ref_max_modified_date,
conformer_max_iterations=conformer_max_iterations, conformer_max_iterations=conformer_max_iterations,
resolve_msa_overlaps=resolve_msa_overlaps,
), ),
) )

View File

@ -415,6 +415,7 @@ class MSA:
fold_input: folding_input.Input, fold_input: folding_input.Input,
logging_name: str, logging_name: str,
max_paired_sequence_per_species: int, max_paired_sequence_per_species: int,
resolve_msa_overlaps: bool = True,
) -> Self: ) -> Self:
"""Compute the msa features.""" """Compute the msa features."""
seen_entities = {} seen_entities = {}
@ -533,9 +534,10 @@ class MSA:
nonempty_chain_ids=nonempty_chain_ids, nonempty_chain_ids=nonempty_chain_ids,
max_hits_per_species=max_paired_sequence_per_species, max_hits_per_species=max_paired_sequence_per_species,
) )
np_chains_list = msa_pairing.deduplicate_unpaired_sequences( if resolve_msa_overlaps:
np_chains_list np_chains_list = msa_pairing.deduplicate_unpaired_sequences(
) np_chains_list
)
# Remove all gapped rows from all seqs. # Remove all gapped rows from all seqs.
nonempty_asym_ids = [] nonempty_asym_ids = []

View File

@ -118,6 +118,12 @@ class WholePdbPipeline:
symmetric polymer chains. symmetric polymer chains.
deterministic_frames: Whether to use fixed-seed reference positions to deterministic_frames: Whether to use fixed-seed reference positions to
construct deterministic frames. construct deterministic frames.
resolve_msa_overlaps: Whether to deduplicate unpaired MSA against paired
MSA. The default behaviour matches the method described in the AlphaFold
3 paper. Set this to false if providing custom paired MSA using the
unpaired MSA field to keep it exactly as is as deduplication against
the paired MSA could break the manually crafted pairing between MSA
sequences.
""" """
max_atoms_per_token: int = 24 max_atoms_per_token: int = 24
@ -140,6 +146,7 @@ class WholePdbPipeline:
remove_nonsymmetric_bonds: bool = False remove_nonsymmetric_bonds: bool = False
deterministic_frames: bool = True deterministic_frames: bool = True
conformer_max_iterations: int | None = None conformer_max_iterations: int | None = None
resolve_msa_overlaps: bool = True
def __init__(self, *, config: Config): def __init__(self, *, config: Config):
"""Initializes WholePdb data pipeline. """Initializes WholePdb data pipeline.
@ -338,6 +345,7 @@ class WholePdbPipeline:
fold_input=fold_input, fold_input=fold_input,
logging_name=logging_name, logging_name=logging_name,
max_paired_sequence_per_species=self._config.max_paired_sequence_per_species, max_paired_sequence_per_species=self._config.max_paired_sequence_per_species,
resolve_msa_overlaps=self._config.resolve_msa_overlaps,
) )
# Create template features # Create template features