Return b factors and occupancy from to_res_arrays()

Change to_res_arrays() to return a ResArrays dataclass which contains [num_res, num_atoms, *] shaped arrays for all atom-level data in a Structure.

PiperOrigin-RevId: 814245949
Change-Id: I38abeec96c8d4ed97a06964e6c7b7fc74f37f97d
This commit is contained in:
DeepMind
2025-10-02 08:31:01 -07:00
committed by Copybara-Service
parent a8ecdb2d7a
commit 03a6c295e5
2 changed files with 34 additions and 8 deletions

View File

@ -862,9 +862,11 @@ def get_polymer_features(
chain_sequence = chain.chain_single_letter_sequence()[label_chain_id]
polymer = _POLYMERS[chain_poly_type]
positions, positions_mask = chain.to_res_arrays(
res_arrays = chain.to_res_arrays(
include_missing_residues=True, atom_order=polymer.atom_order
)
positions = res_arrays.atom_positions
positions_mask = res_arrays.atom_mask
template_all_atom_positions = np.zeros(
(query_sequence_length, polymer.num_atom_types, 3), dtype=np.float64
)

View File

@ -275,6 +275,24 @@ class StructureTables:
bonds: structure_tables.Bonds
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class ResArrays:
"""Atom-level data arrays with a residue dimension.
Attributes:
atom_positions: float32 of shape [num_res, num_atom_type, 3] coordinates.
atom_mask: float32 of shape [num_res, num_atom_type] indicating if an atom
is present.
atom_b_factor: float32 of shape [num_res, num_atom_type] b_factors.
atom_occupancy: float32 of shape [num_res, num_atom_type] occupancies.
"""
atom_positions: np.ndarray
atom_mask: np.ndarray
atom_b_factor: np.ndarray
atom_occupancy: np.ndarray
class Structure(table.Database):
"""Structure class for representing and processing molecular structures."""
@ -2438,8 +2456,8 @@ class Structure(table.Database):
*,
include_missing_residues: bool,
atom_order: Mapping[str, int] = atom_types.ATOM37_ORDER,
) -> tuple[np.ndarray, np.ndarray]:
"""Returns an atom position and atom mask array with a num_res dimension.
) -> ResArrays:
"""Returns atom-level information in arrays containing a num_res dimension.
NB: All residues in the structure will appear in the residue dimension but
atoms will only have a True (1.0) mask value if the residue + atom
@ -2455,15 +2473,14 @@ class Structure(table.Database):
choose atom_types.ATOM29_ORDER for nucleics.
Returns:
A pair of arrays:
* atom_positions: [num_res, atom_type_num, 3] float32 array of coords.
* atom_mask: [num_res, atom_type_num] float32 atom mask denoting
which atoms are present in this Structure.
A ResArrays object.
"""
num_res = self.num_residues(count_unresolved=include_missing_residues)
atom_type_num = len(atom_order)
atom_positions = np.zeros((num_res, atom_type_num, 3), dtype=np.float32)
atom_mask = np.zeros((num_res, atom_type_num), dtype=np.float32)
atom_b_factor = np.zeros((num_res, atom_type_num), dtype=np.float32)
atom_occupancy = np.zeros((num_res, atom_type_num), dtype=np.float32)
all_residues = None if not include_missing_residues else self.all_residues
for i, atom in enumerate_residues(self.iter_atoms(), all_residues):
@ -2473,8 +2490,15 @@ class Structure(table.Database):
atom_positions[i, atom_idx, 1] = atom['atom_y']
atom_positions[i, atom_idx, 2] = atom['atom_z']
atom_mask[i, atom_idx] = 1.0
atom_b_factor[i, atom_idx] = atom['atom_b_factor']
atom_occupancy[i, atom_idx] = atom['atom_occupancy']
return atom_positions, atom_mask
return ResArrays(
atom_positions=atom_positions,
atom_mask=atom_mask,
atom_b_factor=atom_b_factor,
atom_occupancy=atom_occupancy,
)
def to_res_atom_lists(
self, *, include_missing_residues: bool