mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2025-10-20 13:23:47 +08:00
Add a flag to enable skipping unpaired MSA deduplication against paired MSA
Addresses an issue reported in https://github.com/google-deepmind/alphafold3/issues/404. PiperOrigin-RevId: 761917258 Change-Id: I26592af0f40481a40ba611130739a40a62e4a6dc
This commit is contained in:
committed by
Copybara-Service
parent
fbabfe25b0
commit
e64804c49f
@ -486,6 +486,12 @@ opposed to relying on name-matching post-processing heuristics used for
|
|||||||
When setting `unpairedMsa` manually, the `pairedMsa` must be explicitly set to
|
When setting `unpairedMsa` manually, the `pairedMsa` must be explicitly set to
|
||||||
an empty string (`""`).
|
an empty string (`""`).
|
||||||
|
|
||||||
|
Make sure to run with `--resolve_msa_overlaps=false`. This prevents
|
||||||
|
deduplication of the unpaired MSA within each chain against the paired MSA
|
||||||
|
sequences. Even if you set `pairedMsa` to an empty string, the query sequence(s)
|
||||||
|
will still be added in there and the deduplication procedure could destroy the
|
||||||
|
carefully crafted sequence positioning in the unpaired MSA.
|
||||||
|
|
||||||
For instance, if there are two chains `DEEP` and `MIND` which we want to be
|
For instance, if there are two chains `DEEP` and `MIND` which we want to be
|
||||||
paired on organism A and C, we can achieve it as follows:
|
paired on organism A and C, we can achieve it as follows:
|
||||||
|
|
||||||
|
@ -179,17 +179,28 @@ _SEQRES_DATABASE_PATH = flags.DEFINE_string(
|
|||||||
_JACKHMMER_N_CPU = flags.DEFINE_integer(
|
_JACKHMMER_N_CPU = flags.DEFINE_integer(
|
||||||
'jackhmmer_n_cpu',
|
'jackhmmer_n_cpu',
|
||||||
min(multiprocessing.cpu_count(), 8),
|
min(multiprocessing.cpu_count(), 8),
|
||||||
'Number of CPUs to use for Jackhmmer. Default to min(cpu_count, 8). Going'
|
'Number of CPUs to use for Jackhmmer. Defaults to min(cpu_count, 8). Going'
|
||||||
' beyond 8 CPUs provides very little additional speedup.',
|
' above 8 CPUs provides very little additional speedup.',
|
||||||
|
lower_bound=0,
|
||||||
)
|
)
|
||||||
_NHMMER_N_CPU = flags.DEFINE_integer(
|
_NHMMER_N_CPU = flags.DEFINE_integer(
|
||||||
'nhmmer_n_cpu',
|
'nhmmer_n_cpu',
|
||||||
min(multiprocessing.cpu_count(), 8),
|
min(multiprocessing.cpu_count(), 8),
|
||||||
'Number of CPUs to use for Nhmmer. Default to min(cpu_count, 8). Going'
|
'Number of CPUs to use for Nhmmer. Defaults to min(cpu_count, 8). Going'
|
||||||
' beyond 8 CPUs provides very little additional speedup.',
|
' above 8 CPUs provides very little additional speedup.',
|
||||||
|
lower_bound=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Template search configuration.
|
# Data pipeline configuration.
|
||||||
|
_RESOLVE_MSA_OVERLAPS = flags.DEFINE_bool(
|
||||||
|
'resolve_msa_overlaps',
|
||||||
|
True,
|
||||||
|
'Whether to deduplicate unpaired MSA against paired MSA. The default'
|
||||||
|
' behaviour matches the method described in the AlphaFold 3 paper. Set this'
|
||||||
|
' to false if providing custom paired MSA using the unpaired MSA field to'
|
||||||
|
' keep it exactly as is as deduplication against the paired MSA could break'
|
||||||
|
' the manually crafted pairing between MSA sequences.',
|
||||||
|
)
|
||||||
_MAX_TEMPLATE_DATE = flags.DEFINE_string(
|
_MAX_TEMPLATE_DATE = flags.DEFINE_string(
|
||||||
'max_template_date',
|
'max_template_date',
|
||||||
'2021-09-30', # By default, use the date from the AlphaFold 3 paper.
|
'2021-09-30', # By default, use the date from the AlphaFold 3 paper.
|
||||||
@ -200,12 +211,12 @@ _MAX_TEMPLATE_DATE = flags.DEFINE_string(
|
|||||||
' coordinates set. Only for components that have been released before this'
|
' coordinates set. Only for components that have been released before this'
|
||||||
' date the model coordinates can be used as a fallback.',
|
' date the model coordinates can be used as a fallback.',
|
||||||
)
|
)
|
||||||
|
|
||||||
_CONFORMER_MAX_ITERATIONS = flags.DEFINE_integer(
|
_CONFORMER_MAX_ITERATIONS = flags.DEFINE_integer(
|
||||||
'conformer_max_iterations',
|
'conformer_max_iterations',
|
||||||
None, # Default to RDKit default parameters value.
|
None, # Default to RDKit default parameters value.
|
||||||
'Optional override for maximum number of iterations to run for RDKit '
|
'Optional override for maximum number of iterations to run for RDKit '
|
||||||
'conformer search.',
|
'conformer search.',
|
||||||
|
lower_bound=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# JAX inference performance tuning.
|
# JAX inference performance tuning.
|
||||||
@ -429,6 +440,7 @@ def predict_structure(
|
|||||||
buckets: Sequence[int] | None = None,
|
buckets: Sequence[int] | None = None,
|
||||||
ref_max_modified_date: datetime.date | None = None,
|
ref_max_modified_date: datetime.date | None = None,
|
||||||
conformer_max_iterations: int | None = None,
|
conformer_max_iterations: int | None = None,
|
||||||
|
resolve_msa_overlaps: bool = True,
|
||||||
) -> Sequence[ResultsForSeed]:
|
) -> Sequence[ResultsForSeed]:
|
||||||
"""Runs the full inference pipeline to predict structures for each seed."""
|
"""Runs the full inference pipeline to predict structures for each seed."""
|
||||||
|
|
||||||
@ -442,6 +454,7 @@ def predict_structure(
|
|||||||
verbose=True,
|
verbose=True,
|
||||||
ref_max_modified_date=ref_max_modified_date,
|
ref_max_modified_date=ref_max_modified_date,
|
||||||
conformer_max_iterations=conformer_max_iterations,
|
conformer_max_iterations=conformer_max_iterations,
|
||||||
|
resolve_msa_overlaps=resolve_msa_overlaps,
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f'Featurising data with {len(fold_input.rng_seeds)} seed(s) took'
|
f'Featurising data with {len(fold_input.rng_seeds)} seed(s) took'
|
||||||
@ -600,6 +613,7 @@ def process_fold_input(
|
|||||||
buckets: Sequence[int] | None = None,
|
buckets: Sequence[int] | None = None,
|
||||||
ref_max_modified_date: datetime.date | None = None,
|
ref_max_modified_date: datetime.date | None = None,
|
||||||
conformer_max_iterations: int | None = None,
|
conformer_max_iterations: int | None = None,
|
||||||
|
resolve_msa_overlaps: bool = True,
|
||||||
force_output_dir: bool = False,
|
force_output_dir: bool = False,
|
||||||
) -> folding_input.Input:
|
) -> folding_input.Input:
|
||||||
...
|
...
|
||||||
@ -614,6 +628,7 @@ def process_fold_input(
|
|||||||
buckets: Sequence[int] | None = None,
|
buckets: Sequence[int] | None = None,
|
||||||
ref_max_modified_date: datetime.date | None = None,
|
ref_max_modified_date: datetime.date | None = None,
|
||||||
conformer_max_iterations: int | None = None,
|
conformer_max_iterations: int | None = None,
|
||||||
|
resolve_msa_overlaps: bool = True,
|
||||||
force_output_dir: bool = False,
|
force_output_dir: bool = False,
|
||||||
) -> Sequence[ResultsForSeed]:
|
) -> Sequence[ResultsForSeed]:
|
||||||
...
|
...
|
||||||
@ -627,6 +642,7 @@ def process_fold_input(
|
|||||||
buckets: Sequence[int] | None = None,
|
buckets: Sequence[int] | None = None,
|
||||||
ref_max_modified_date: datetime.date | None = None,
|
ref_max_modified_date: datetime.date | None = None,
|
||||||
conformer_max_iterations: int | None = None,
|
conformer_max_iterations: int | None = None,
|
||||||
|
resolve_msa_overlaps: bool = True,
|
||||||
force_output_dir: bool = False,
|
force_output_dir: bool = False,
|
||||||
) -> folding_input.Input | Sequence[ResultsForSeed]:
|
) -> folding_input.Input | Sequence[ResultsForSeed]:
|
||||||
"""Runs data pipeline and/or inference on a single fold input.
|
"""Runs data pipeline and/or inference on a single fold input.
|
||||||
@ -649,6 +665,11 @@ def process_fold_input(
|
|||||||
date the model coordinates can be used as a fallback.
|
date the model coordinates can be used as a fallback.
|
||||||
conformer_max_iterations: Optional override for maximum number of iterations
|
conformer_max_iterations: Optional override for maximum number of iterations
|
||||||
to run for RDKit conformer search.
|
to run for RDKit conformer search.
|
||||||
|
resolve_msa_overlaps: Whether to deduplicate unpaired MSA against paired
|
||||||
|
MSA. The default behaviour matches the method described in the AlphaFold 3
|
||||||
|
paper. Set this to false if providing custom paired MSA using the unpaired
|
||||||
|
MSA field to keep it exactly as is as deduplication against the paired MSA
|
||||||
|
could break the manually crafted pairing between MSA sequences.
|
||||||
force_output_dir: If True, do not create a new output directory even if the
|
force_output_dir: If True, do not create a new output directory even if the
|
||||||
existing one is non-empty. Instead use the existing output directory and
|
existing one is non-empty. Instead use the existing output directory and
|
||||||
potentially overwrite existing files. If False, create a new timestamped
|
potentially overwrite existing files. If False, create a new timestamped
|
||||||
@ -702,6 +723,7 @@ def process_fold_input(
|
|||||||
buckets=buckets,
|
buckets=buckets,
|
||||||
ref_max_modified_date=ref_max_modified_date,
|
ref_max_modified_date=ref_max_modified_date,
|
||||||
conformer_max_iterations=conformer_max_iterations,
|
conformer_max_iterations=conformer_max_iterations,
|
||||||
|
resolve_msa_overlaps=resolve_msa_overlaps,
|
||||||
)
|
)
|
||||||
print(f'Writing outputs with {len(fold_input.rng_seeds)} seed(s)...')
|
print(f'Writing outputs with {len(fold_input.rng_seeds)} seed(s)...')
|
||||||
write_outputs(
|
write_outputs(
|
||||||
@ -860,6 +882,7 @@ def main(_):
|
|||||||
buckets=tuple(int(bucket) for bucket in _BUCKETS.value),
|
buckets=tuple(int(bucket) for bucket in _BUCKETS.value),
|
||||||
ref_max_modified_date=max_template_date,
|
ref_max_modified_date=max_template_date,
|
||||||
conformer_max_iterations=_CONFORMER_MAX_ITERATIONS.value,
|
conformer_max_iterations=_CONFORMER_MAX_ITERATIONS.value,
|
||||||
|
resolve_msa_overlaps=_RESOLVE_MSA_OVERLAPS.value,
|
||||||
force_output_dir=_FORCE_OUTPUT_DIR.value,
|
force_output_dir=_FORCE_OUTPUT_DIR.value,
|
||||||
)
|
)
|
||||||
num_fold_inputs += 1
|
num_fold_inputs += 1
|
||||||
|
@ -41,6 +41,7 @@ def featurise_input(
|
|||||||
buckets: Sequence[int] | None,
|
buckets: Sequence[int] | None,
|
||||||
ref_max_modified_date: datetime.date | None = None,
|
ref_max_modified_date: datetime.date | None = None,
|
||||||
conformer_max_iterations: int | None = None,
|
conformer_max_iterations: int | None = None,
|
||||||
|
resolve_msa_overlaps: bool = True,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
) -> Sequence[features.BatchDict]:
|
) -> Sequence[features.BatchDict]:
|
||||||
"""Featurise the folding input.
|
"""Featurise the folding input.
|
||||||
@ -60,6 +61,11 @@ def featurise_input(
|
|||||||
date the model coordinates can be used as a fallback.
|
date the model coordinates can be used as a fallback.
|
||||||
conformer_max_iterations: Optional override for maximum number of iterations
|
conformer_max_iterations: Optional override for maximum number of iterations
|
||||||
to run for RDKit conformer search.
|
to run for RDKit conformer search.
|
||||||
|
resolve_msa_overlaps: Whether to deduplicate unpaired MSA against paired
|
||||||
|
MSA. The default behaviour matches the method described in the AlphaFold 3
|
||||||
|
paper. Set this to false if providing custom paired MSA using the unpaired
|
||||||
|
MSA field to keep it exactly as is as deduplication against the paired MSA
|
||||||
|
could break the manually crafted pairing between MSA sequences.
|
||||||
verbose: Whether to print progress messages.
|
verbose: Whether to print progress messages.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -73,6 +79,7 @@ def featurise_input(
|
|||||||
buckets=buckets,
|
buckets=buckets,
|
||||||
ref_max_modified_date=ref_max_modified_date,
|
ref_max_modified_date=ref_max_modified_date,
|
||||||
conformer_max_iterations=conformer_max_iterations,
|
conformer_max_iterations=conformer_max_iterations,
|
||||||
|
resolve_msa_overlaps=resolve_msa_overlaps,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -415,6 +415,7 @@ class MSA:
|
|||||||
fold_input: folding_input.Input,
|
fold_input: folding_input.Input,
|
||||||
logging_name: str,
|
logging_name: str,
|
||||||
max_paired_sequence_per_species: int,
|
max_paired_sequence_per_species: int,
|
||||||
|
resolve_msa_overlaps: bool = True,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
"""Compute the msa features."""
|
"""Compute the msa features."""
|
||||||
seen_entities = {}
|
seen_entities = {}
|
||||||
@ -533,9 +534,10 @@ class MSA:
|
|||||||
nonempty_chain_ids=nonempty_chain_ids,
|
nonempty_chain_ids=nonempty_chain_ids,
|
||||||
max_hits_per_species=max_paired_sequence_per_species,
|
max_hits_per_species=max_paired_sequence_per_species,
|
||||||
)
|
)
|
||||||
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(
|
if resolve_msa_overlaps:
|
||||||
np_chains_list
|
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(
|
||||||
)
|
np_chains_list
|
||||||
|
)
|
||||||
|
|
||||||
# Remove all gapped rows from all seqs.
|
# Remove all gapped rows from all seqs.
|
||||||
nonempty_asym_ids = []
|
nonempty_asym_ids = []
|
||||||
|
@ -118,6 +118,12 @@ class WholePdbPipeline:
|
|||||||
symmetric polymer chains.
|
symmetric polymer chains.
|
||||||
deterministic_frames: Whether to use fixed-seed reference positions to
|
deterministic_frames: Whether to use fixed-seed reference positions to
|
||||||
construct deterministic frames.
|
construct deterministic frames.
|
||||||
|
resolve_msa_overlaps: Whether to deduplicate unpaired MSA against paired
|
||||||
|
MSA. The default behaviour matches the method described in the AlphaFold
|
||||||
|
3 paper. Set this to false if providing custom paired MSA using the
|
||||||
|
unpaired MSA field to keep it exactly as is as deduplication against
|
||||||
|
the paired MSA could break the manually crafted pairing between MSA
|
||||||
|
sequences.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_atoms_per_token: int = 24
|
max_atoms_per_token: int = 24
|
||||||
@ -140,6 +146,7 @@ class WholePdbPipeline:
|
|||||||
remove_nonsymmetric_bonds: bool = False
|
remove_nonsymmetric_bonds: bool = False
|
||||||
deterministic_frames: bool = True
|
deterministic_frames: bool = True
|
||||||
conformer_max_iterations: int | None = None
|
conformer_max_iterations: int | None = None
|
||||||
|
resolve_msa_overlaps: bool = True
|
||||||
|
|
||||||
def __init__(self, *, config: Config):
|
def __init__(self, *, config: Config):
|
||||||
"""Initializes WholePdb data pipeline.
|
"""Initializes WholePdb data pipeline.
|
||||||
@ -338,6 +345,7 @@ class WholePdbPipeline:
|
|||||||
fold_input=fold_input,
|
fold_input=fold_input,
|
||||||
logging_name=logging_name,
|
logging_name=logging_name,
|
||||||
max_paired_sequence_per_species=self._config.max_paired_sequence_per_species,
|
max_paired_sequence_per_species=self._config.max_paired_sequence_per_species,
|
||||||
|
resolve_msa_overlaps=self._config.resolve_msa_overlaps,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create template features
|
# Create template features
|
||||||
|
Reference in New Issue
Block a user