Skip to content

Commit

Permalink
streamline computation of clashes (#43)
Browse files Browse the repository at this point in the history
* result of unique is already sorted

* streamline computation of clashes

* account for 1-based indexing of asym_id
  • Loading branch information
arogozhnikov authored Sep 13, 2024
1 parent 5b0d770 commit 76f0f26
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 173 deletions.
228 changes: 61 additions & 167 deletions chai_lab/ranking/clashes.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from dataclasses import dataclass
from itertools import combinations

import torch
from einops import rearrange, reduce, repeat
from torch import Tensor

import chai_lab.ranking.utils as rutils
from chai_lab.ranking.utils import (
get_chain_masks_and_asyms,
)
from chai_lab.utils.tensor_utils import cdist, und_self
from chai_lab.utils.typing import Bool, Float, Int, typecheck

Expand All @@ -24,11 +20,11 @@ class ClashScores:
per_chain_pair_clashes: number of inter-chain clashes for each chain pair in the complex
"""

total_clashes: Float[Tensor, "..."]
total_inter_chain_clashes: Float[Tensor, "..."]
per_chain_intra_clashes: Float[Tensor, "... n_chains"]
per_chain_pair_clashes: Float[Tensor, "... n_chains n_chains"]
has_clashes: Bool[Tensor, "..."]
total_clashes: Int[Tensor, "..."]
total_inter_chain_clashes: Int[Tensor, "..."]
chain_intra_clashes: Int[Tensor, "... n_chains"]
chain_chain_inter_clashes: Int[Tensor, "... n_chains n_chains"]
has_inter_chain_clashes: Bool[Tensor, "..."]


@typecheck
Expand All @@ -46,123 +42,11 @@ def _compute_clashes(


@typecheck
def maybe_compute_clashes(
atom_coords: Float[Tensor, "... a 3"],
atom_mask: Bool[Tensor, "... a"],
clash_matrix: Bool[Tensor, "... a a"] | None = None,
clash_threshold: float = 1.1,
) -> Bool[Tensor, "... a a"]:
if clash_matrix is None:
return _compute_clashes(atom_coords, atom_mask, clash_threshold)
else:
return clash_matrix


@typecheck
def total_clashes(
atom_coords: Float[Tensor, "... a 3"],
atom_mask: Bool[Tensor, "... a"],
clash_matrix: Bool[Tensor, "... a a"] | None = None,
clash_threshold: float = 1.1,
) -> Float[Tensor, "..."]:
"""
Computes the total number of clashes in the complex
"""
clash_matrix = maybe_compute_clashes(
atom_coords, atom_mask, clash_matrix, clash_threshold
)
# clash matrix is symmetric
return reduce(clash_matrix, "... a1 a2 -> ...", "sum") / 2


@typecheck
def total_inter_chain_clashes(
atom_coords: Float[Tensor, "... a 3"],
atom_mask: Bool[Tensor, "... a"],
asym_id: Int[Tensor, "... a"],
clash_matrix: Bool[Tensor, "... a a"] | None = None,
clash_threshold: float = 1.1,
) -> Float[Tensor, "..."]:
"""Compute total number of inter-chain clashes in the complex"""
clash_matrix = maybe_compute_clashes(
atom_coords, atom_mask, clash_matrix, clash_threshold
).clone() # don't overwrite an input
# clash matrix is symmetric
clash_matrix &= rearrange(asym_id, "... a -> ... a 1") != rearrange(
asym_id, "... a -> ... 1 a"
)
# account for double counting
return reduce(clash_matrix, "... a1 a2 -> ...", "sum") / 2


@typecheck
def per_chain_intra_clashes(
atom_coords: Float[Tensor, "... a 3"],
atom_mask: Bool[Tensor, "... a"],
asym_id: Int[Tensor, "... a"],
clash_matrix: Bool[Tensor, "... a a"] | None = None,
clash_threshold: float = 1.1,
) -> tuple[Float[Tensor, "... n_chains"], Int[Tensor, "n_chains"]]:
clash_matrix = maybe_compute_clashes(
atom_coords, atom_mask, clash_matrix, clash_threshold
).clone() # don't overwrite an input
# clash matrix is symmetric
clash_matrix &= rearrange(asym_id, "... a -> ... a 1") == rearrange(
asym_id, "... a -> ... 1 a"
)
per_atom_clashes = reduce(clash_matrix, "... a -> ...", "sum") / 2
# add dimension for chains
per_atom_clashes = rearrange(per_atom_clashes, "... a -> ... 1 a")
chain_masks, asyms = get_chain_masks_and_asyms(asym_id, atom_mask)
return reduce(per_atom_clashes * chain_masks, "... c a -> ... c", "sum"), asyms


@typecheck
def per_chain_pair_clashes(
atom_coords: Float[Tensor, "... a 3"],
atom_mask: Bool[Tensor, "... a"],
asym_id: Int[Tensor, "... a"],
clash_matrix: Bool[Tensor, "... a a"] | None = None,
clash_threshold: float = 1.1,
) -> tuple[Float[Tensor, "... n_chains n_chains"], Int[Tensor, "n_chains"]]:
"""
Compute the number of inter-chain clashes for each chain in the complex
"""
clash_matrix = maybe_compute_clashes(
atom_coords, atom_mask, clash_matrix, clash_threshold
).clone() # don't overwrite an input
clash_matrix &= rearrange(asym_id, "... a -> ... a 1") != rearrange(
asym_id, "... a -> ... 1 a"
)
chain_masks, asyms = get_chain_masks_and_asyms(asym_id, atom_mask)
per_chain_clashes = torch.zeros(
*chain_masks.shape[:-2],
len(asyms),
len(asyms),
device=atom_coords.device,
dtype=torch.float32,
)
# compute in loop to minimize peak memory
for i, j in combinations(range(len(asyms)), 2):
chain_pair_mask = torch.einsum(
"...i,...j->...ij", chain_masks[..., i, :], chain_masks[..., j, :]
)
# chain_pair_mask is triangular, so don't need to account for double counting
per_chain_clashes[..., i, j] = reduce(
clash_matrix * chain_pair_mask, "... i j -> ...", "sum"
)
symm_clashes = per_chain_clashes + rearrange(
per_chain_clashes, "... i j -> ... j i"
)
return symm_clashes, asyms


@typecheck
def has_clashes(
def has_inter_chain_clashes(
atom_mask: Bool[Tensor, "... a"],
atom_asym_id: Int[Tensor, "... a"],
atom_entity_type: Int[Tensor, "... a"],
per_chain_pair_clashes: Float[Tensor, "... n_chains n_chains"],
per_chain_pair_clashes: Int[Tensor, "... n_chains n_chains"],
max_clashes: int = 100,
max_clash_ratio: float = 0.5,
) -> Bool[Tensor, "..."]:
Expand All @@ -186,32 +70,25 @@ def has_clashes(
# if a chain pair has less than max_clashes clashes, butmore than
# max_clash_ratio of the smaller chain's total atoms, then also
# consider it a clash
c = atoms_per_chain.shape[-1]
atoms_per_chain_row = repeat(atoms_per_chain, "... c -> ... (c k)", k=c)
atoms_per_chain_col = repeat(atoms_per_chain, "... c -> ... (k c)", k=c)
min_atoms_per_chain_pair, _ = torch.min(
torch.stack([atoms_per_chain_row, atoms_per_chain_col], dim=-1), dim=-1
)
min_atoms_per_chain_pair = rearrange(
min_atoms_per_chain_pair,
"... (c_row c_col) -> ... c_row c_col",
c_row=c,
)
has_clashes |= (
per_chain_pair_clashes / torch.clamp(min_atoms_per_chain_pair, min=1)
) >= max_clash_ratio
per_chain_pair_clashes
/ rearrange(atoms_per_chain, "... c -> ... c 1").clamp(min=1)
).ge(max_clash_ratio)

has_clashes |= (
per_chain_pair_clashes / rearrange(atoms_per_chain, "b c -> b 1 c").clamp(min=1)
).ge(max_clash_ratio)

# only consider clashes between pairs of polymer chains
polymer_chains = rutils.chain_is_polymer(
asym_id=atom_asym_id,
mask=atom_mask,
entity_type=atom_entity_type,
)
is_polymer_pair = rearrange(polymer_chains, "... c -> ... c 1") & rearrange(
polymer_chains, "... c -> ... 1 c"
)
is_polymer_pair = und_self(polymer_chains, "... c1, ... c2 -> ... c1 c2")

# reduce over all chain pairs
return torch.any(has_clashes & is_polymer_pair, dim=(-1, -2))
return reduce(has_clashes & is_polymer_pair, "... c1 c2 -> ...", torch.any)


@typecheck
Expand All @@ -224,37 +101,54 @@ def get_scores(
max_clashes: int = 100,
max_clash_ratio: float = 0.5,
) -> ClashScores:
clash_matrix = _compute_clashes(atom_coords, atom_mask, clash_threshold)
_per_chain_pair_clashes = per_chain_pair_clashes(
atom_coords, atom_mask, atom_asym_id, clash_matrix, clash_threshold
)[0]
# shift asym_id from 1-based to 0-based
assert atom_asym_id.dtype in (torch.int32, torch.int64)
atom_asym_id = (atom_asym_id - 1).to(torch.int64)
assert torch.amin(atom_asym_id) >= 0

# dimensions
n_chains = atom_asym_id.amax().add(1).item()
assert isinstance(n_chains, int)
*b, a = atom_mask.shape

clashes_a_a = _compute_clashes(atom_coords, atom_mask, clash_threshold)
clashes_a_a = clashes_a_a.to(torch.int32) # b a a

clashes_a_chain = clashes_a_a.new_zeros(*b, a, n_chains)
clashes_a_chain.scatter_add_(
dim=-1,
src=clashes_a_a,
index=repeat(atom_asym_id, f"b a -> b {a} a"),
)

clashes_chain_chain = clashes_a_a.new_zeros(*b, n_chains, n_chains)
clashes_chain_chain.scatter_add_(
dim=-2,
src=clashes_a_chain,
index=repeat(atom_asym_id, f"b a -> b a {n_chains}"),
)
# i, j enumerate chains
total_clashes = reduce(clashes_chain_chain, "... i j -> ...", "sum") // 2
# NB: diagonal term (self-interaction of chain), contains doubled self-interaction
per_chain_intra_clashes = torch.einsum("... i i -> ... i", clashes_chain_chain) // 2
# delete self-interaction for simplicity
non_diag = 1 - torch.diag(clashes_a_a.new_ones(n_chains))
inter_chain_chain = non_diag * clashes_chain_chain

inter_chain_clashes = (
reduce(inter_chain_chain, "... i j -> ... ", "sum") // 2
) # div by 2 to compensate for symmetricity of matrix

return ClashScores(
total_clashes=total_clashes(
atom_coords=atom_coords,
atom_mask=atom_mask,
clash_matrix=clash_matrix,
clash_threshold=clash_threshold,
),
total_inter_chain_clashes=total_inter_chain_clashes(
atom_coords=atom_coords,
atom_mask=atom_mask,
asym_id=atom_asym_id,
clash_matrix=clash_matrix,
clash_threshold=clash_threshold,
),
per_chain_intra_clashes=per_chain_intra_clashes(
atom_coords=atom_coords,
atom_mask=atom_mask,
asym_id=atom_asym_id,
clash_matrix=clash_matrix,
clash_threshold=clash_threshold,
)[0],
per_chain_pair_clashes=_per_chain_pair_clashes,
has_clashes=has_clashes(
total_clashes=total_clashes,
total_inter_chain_clashes=inter_chain_clashes,
chain_intra_clashes=per_chain_intra_clashes,
chain_chain_inter_clashes=inter_chain_chain,
has_inter_chain_clashes=has_inter_chain_clashes(
atom_mask=atom_mask,
atom_asym_id=atom_asym_id,
atom_entity_type=atom_entity_type,
per_chain_pair_clashes=_per_chain_pair_clashes,
per_chain_pair_clashes=inter_chain_chain,
max_clashes=max_clashes,
max_clash_ratio=max_clash_ratio,
),
Expand Down
9 changes: 5 additions & 4 deletions chai_lab/ranking/rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def rank(
aggregate_score = (
0.2 * ptm_scores.complex_ptm
+ 0.8 * ptm_scores.interface_ptm
- 100 * clash_scores.has_clashes.float()
- 100 * clash_scores.has_inter_chain_clashes.float()
)

_, asyms = rutils.get_chain_masks_and_asyms(
Expand All @@ -122,8 +122,9 @@ def get_scores(ranking_data: SampleRanking) -> dict[str, np.ndarray]:
"iptm": ranking_data.ptm_scores.interface_ptm,
"per_chain_ptm": ranking_data.ptm_scores.per_chain_ptm,
"per_chain_pair_iptm": ranking_data.ptm_scores.per_chain_pair_iptm,
"has_clashes": ranking_data.clash_scores.total_clashes,
"per_chain_intra_clashes": ranking_data.clash_scores.per_chain_intra_clashes,
"per_chain_pair_inter_clashes": ranking_data.clash_scores.per_chain_pair_clashes,
"has_inter_chain_clashes": ranking_data.clash_scores.has_inter_chain_clashes,
# TODO replace with just one tensor that contains both
"chain_intra_clashes": ranking_data.clash_scores.chain_intra_clashes,
"chain_chain_inter_clashes": ranking_data.clash_scores.chain_chain_inter_clashes,
}
return {k: v.cpu().numpy() for k, v in scores.items()}
3 changes: 1 addition & 2 deletions chai_lab/ranking/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ def get_chain_masks_and_asyms(
"""
Returns a mask for each chain and the unique asym ids
"""
unique_asyms = torch.unique(asym_id[mask])
sorted_unique_asyms, _ = torch.sort(unique_asyms)
sorted_unique_asyms = torch.unique(asym_id[mask])
# shape: (..., max_num_chains, n)
chain_masks = rearrange(asym_id, "... n -> ... 1 n") == rearrange(
sorted_unique_asyms, "nc -> nc 1"
Expand Down

0 comments on commit 76f0f26

Please sign in to comment.