mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2025-10-20 13:23:47 +08:00
Add a flag to enable skipping unpaired MSA deduplication against paired MSA
Addresses an issue reported in https://github.com/google-deepmind/alphafold3/issues/404. PiperOrigin-RevId: 761917258 Change-Id: I26592af0f40481a40ba611130739a40a62e4a6dc
This commit is contained in:
committed by
Copybara-Service
parent
fbabfe25b0
commit
e64804c49f
@ -486,6 +486,12 @@ opposed to relying on name-matching post-processing heuristics used for
|
||||
When setting `unpairedMsa` manually, the `pairedMsa` must be explicitly set to
|
||||
an empty string (`""`).
|
||||
|
||||
Make sure to run with `--resolve_msa_overlaps=false`. This prevents
|
||||
deduplication of the unpaired MSA within each chain against the paired MSA
|
||||
sequences. Even if you set `pairedMsa` to an empty string, the query sequence(s)
|
||||
will still be added in there and the deduplication procedure could destroy the
|
||||
carefully crafted sequence positioning in the unpaired MSA.
|
||||
|
||||
For instance, if there are two chains `DEEP` and `MIND` which we want to be
|
||||
paired on organism A and C, we can achieve it as follows:
|
||||
|
||||
|
@ -179,17 +179,28 @@ _SEQRES_DATABASE_PATH = flags.DEFINE_string(
|
||||
_JACKHMMER_N_CPU = flags.DEFINE_integer(
|
||||
'jackhmmer_n_cpu',
|
||||
min(multiprocessing.cpu_count(), 8),
|
||||
'Number of CPUs to use for Jackhmmer. Default to min(cpu_count, 8). Going'
|
||||
' beyond 8 CPUs provides very little additional speedup.',
|
||||
'Number of CPUs to use for Jackhmmer. Defaults to min(cpu_count, 8). Going'
|
||||
' above 8 CPUs provides very little additional speedup.',
|
||||
lower_bound=0,
|
||||
)
|
||||
_NHMMER_N_CPU = flags.DEFINE_integer(
|
||||
'nhmmer_n_cpu',
|
||||
min(multiprocessing.cpu_count(), 8),
|
||||
'Number of CPUs to use for Nhmmer. Default to min(cpu_count, 8). Going'
|
||||
' beyond 8 CPUs provides very little additional speedup.',
|
||||
'Number of CPUs to use for Nhmmer. Defaults to min(cpu_count, 8). Going'
|
||||
' above 8 CPUs provides very little additional speedup.',
|
||||
lower_bound=0,
|
||||
)
|
||||
|
||||
# Template search configuration.
|
||||
# Data pipeline configuration.
|
||||
_RESOLVE_MSA_OVERLAPS = flags.DEFINE_bool(
|
||||
'resolve_msa_overlaps',
|
||||
True,
|
||||
'Whether to deduplicate unpaired MSA against paired MSA. The default'
|
||||
' behaviour matches the method described in the AlphaFold 3 paper. Set this'
|
||||
' to false if providing custom paired MSA using the unpaired MSA field to'
|
||||
' keep it exactly as is as deduplication against the paired MSA could break'
|
||||
' the manually crafted pairing between MSA sequences.',
|
||||
)
|
||||
_MAX_TEMPLATE_DATE = flags.DEFINE_string(
|
||||
'max_template_date',
|
||||
'2021-09-30', # By default, use the date from the AlphaFold 3 paper.
|
||||
@ -200,12 +211,12 @@ _MAX_TEMPLATE_DATE = flags.DEFINE_string(
|
||||
' coordinates set. Only for components that have been released before this'
|
||||
' date the model coordinates can be used as a fallback.',
|
||||
)
|
||||
|
||||
_CONFORMER_MAX_ITERATIONS = flags.DEFINE_integer(
|
||||
'conformer_max_iterations',
|
||||
None, # Default to RDKit default parameters value.
|
||||
'Optional override for maximum number of iterations to run for RDKit '
|
||||
'conformer search.',
|
||||
lower_bound=0,
|
||||
)
|
||||
|
||||
# JAX inference performance tuning.
|
||||
@ -429,6 +440,7 @@ def predict_structure(
|
||||
buckets: Sequence[int] | None = None,
|
||||
ref_max_modified_date: datetime.date | None = None,
|
||||
conformer_max_iterations: int | None = None,
|
||||
resolve_msa_overlaps: bool = True,
|
||||
) -> Sequence[ResultsForSeed]:
|
||||
"""Runs the full inference pipeline to predict structures for each seed."""
|
||||
|
||||
@ -442,6 +454,7 @@ def predict_structure(
|
||||
verbose=True,
|
||||
ref_max_modified_date=ref_max_modified_date,
|
||||
conformer_max_iterations=conformer_max_iterations,
|
||||
resolve_msa_overlaps=resolve_msa_overlaps,
|
||||
)
|
||||
print(
|
||||
f'Featurising data with {len(fold_input.rng_seeds)} seed(s) took'
|
||||
@ -600,6 +613,7 @@ def process_fold_input(
|
||||
buckets: Sequence[int] | None = None,
|
||||
ref_max_modified_date: datetime.date | None = None,
|
||||
conformer_max_iterations: int | None = None,
|
||||
resolve_msa_overlaps: bool = True,
|
||||
force_output_dir: bool = False,
|
||||
) -> folding_input.Input:
|
||||
...
|
||||
@ -614,6 +628,7 @@ def process_fold_input(
|
||||
buckets: Sequence[int] | None = None,
|
||||
ref_max_modified_date: datetime.date | None = None,
|
||||
conformer_max_iterations: int | None = None,
|
||||
resolve_msa_overlaps: bool = True,
|
||||
force_output_dir: bool = False,
|
||||
) -> Sequence[ResultsForSeed]:
|
||||
...
|
||||
@ -627,6 +642,7 @@ def process_fold_input(
|
||||
buckets: Sequence[int] | None = None,
|
||||
ref_max_modified_date: datetime.date | None = None,
|
||||
conformer_max_iterations: int | None = None,
|
||||
resolve_msa_overlaps: bool = True,
|
||||
force_output_dir: bool = False,
|
||||
) -> folding_input.Input | Sequence[ResultsForSeed]:
|
||||
"""Runs data pipeline and/or inference on a single fold input.
|
||||
@ -649,6 +665,11 @@ def process_fold_input(
|
||||
date the model coordinates can be used as a fallback.
|
||||
conformer_max_iterations: Optional override for maximum number of iterations
|
||||
to run for RDKit conformer search.
|
||||
resolve_msa_overlaps: Whether to deduplicate unpaired MSA against paired
|
||||
MSA. The default behaviour matches the method described in the AlphaFold 3
|
||||
paper. Set this to false if providing custom paired MSA using the unpaired
|
||||
MSA field to keep it exactly as is as deduplication against the paired MSA
|
||||
could break the manually crafted pairing between MSA sequences.
|
||||
force_output_dir: If True, do not create a new output directory even if the
|
||||
existing one is non-empty. Instead use the existing output directory and
|
||||
potentially overwrite existing files. If False, create a new timestamped
|
||||
@ -702,6 +723,7 @@ def process_fold_input(
|
||||
buckets=buckets,
|
||||
ref_max_modified_date=ref_max_modified_date,
|
||||
conformer_max_iterations=conformer_max_iterations,
|
||||
resolve_msa_overlaps=resolve_msa_overlaps,
|
||||
)
|
||||
print(f'Writing outputs with {len(fold_input.rng_seeds)} seed(s)...')
|
||||
write_outputs(
|
||||
@ -860,6 +882,7 @@ def main(_):
|
||||
buckets=tuple(int(bucket) for bucket in _BUCKETS.value),
|
||||
ref_max_modified_date=max_template_date,
|
||||
conformer_max_iterations=_CONFORMER_MAX_ITERATIONS.value,
|
||||
resolve_msa_overlaps=_RESOLVE_MSA_OVERLAPS.value,
|
||||
force_output_dir=_FORCE_OUTPUT_DIR.value,
|
||||
)
|
||||
num_fold_inputs += 1
|
||||
|
@ -41,6 +41,7 @@ def featurise_input(
|
||||
buckets: Sequence[int] | None,
|
||||
ref_max_modified_date: datetime.date | None = None,
|
||||
conformer_max_iterations: int | None = None,
|
||||
resolve_msa_overlaps: bool = True,
|
||||
verbose: bool = False,
|
||||
) -> Sequence[features.BatchDict]:
|
||||
"""Featurise the folding input.
|
||||
@ -60,6 +61,11 @@ def featurise_input(
|
||||
date the model coordinates can be used as a fallback.
|
||||
conformer_max_iterations: Optional override for maximum number of iterations
|
||||
to run for RDKit conformer search.
|
||||
resolve_msa_overlaps: Whether to deduplicate unpaired MSA against paired
|
||||
MSA. The default behaviour matches the method described in the AlphaFold 3
|
||||
paper. Set this to false if providing custom paired MSA using the unpaired
|
||||
MSA field to keep it exactly as is as deduplication against the paired MSA
|
||||
could break the manually crafted pairing between MSA sequences.
|
||||
verbose: Whether to print progress messages.
|
||||
|
||||
Returns:
|
||||
@ -73,6 +79,7 @@ def featurise_input(
|
||||
buckets=buckets,
|
||||
ref_max_modified_date=ref_max_modified_date,
|
||||
conformer_max_iterations=conformer_max_iterations,
|
||||
resolve_msa_overlaps=resolve_msa_overlaps,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -415,6 +415,7 @@ class MSA:
|
||||
fold_input: folding_input.Input,
|
||||
logging_name: str,
|
||||
max_paired_sequence_per_species: int,
|
||||
resolve_msa_overlaps: bool = True,
|
||||
) -> Self:
|
||||
"""Compute the msa features."""
|
||||
seen_entities = {}
|
||||
@ -533,9 +534,10 @@ class MSA:
|
||||
nonempty_chain_ids=nonempty_chain_ids,
|
||||
max_hits_per_species=max_paired_sequence_per_species,
|
||||
)
|
||||
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(
|
||||
np_chains_list
|
||||
)
|
||||
if resolve_msa_overlaps:
|
||||
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(
|
||||
np_chains_list
|
||||
)
|
||||
|
||||
# Remove all gapped rows from all seqs.
|
||||
nonempty_asym_ids = []
|
||||
|
@ -118,6 +118,12 @@ class WholePdbPipeline:
|
||||
symmetric polymer chains.
|
||||
deterministic_frames: Whether to use fixed-seed reference positions to
|
||||
construct deterministic frames.
|
||||
resolve_msa_overlaps: Whether to deduplicate unpaired MSA against paired
|
||||
MSA. The default behaviour matches the method described in the AlphaFold
|
||||
3 paper. Set this to false if providing custom paired MSA using the
|
||||
unpaired MSA field to keep it exactly as is as deduplication against
|
||||
the paired MSA could break the manually crafted pairing between MSA
|
||||
sequences.
|
||||
"""
|
||||
|
||||
max_atoms_per_token: int = 24
|
||||
@ -140,6 +146,7 @@ class WholePdbPipeline:
|
||||
remove_nonsymmetric_bonds: bool = False
|
||||
deterministic_frames: bool = True
|
||||
conformer_max_iterations: int | None = None
|
||||
resolve_msa_overlaps: bool = True
|
||||
|
||||
def __init__(self, *, config: Config):
|
||||
"""Initializes WholePdb data pipeline.
|
||||
@ -338,6 +345,7 @@ class WholePdbPipeline:
|
||||
fold_input=fold_input,
|
||||
logging_name=logging_name,
|
||||
max_paired_sequence_per_species=self._config.max_paired_sequence_per_species,
|
||||
resolve_msa_overlaps=self._config.resolve_msa_overlaps,
|
||||
)
|
||||
|
||||
# Create template features
|
||||
|
Reference in New Issue
Block a user