Skip to content

Commit

Permalink
Write predictions in cif format (#45)
Browse files Browse the repository at this point in the history
* write cif

* headers + fix

* fix stubs

* remove reduntant manipulation

* delete unused code

* restict modelcif from below

* force-check all tensors are on CPU

* update docstring

* use entity_name for agreement with rest of codebase

---------

Co-authored-by: Alex Rogozhnikov <[email protected]>
  • Loading branch information
jacquesboitreaud and arogozhnikov authored Sep 13, 2024
1 parent 76f0f26 commit ecd62ff
Show file tree
Hide file tree
Showing 8 changed files with 305 additions and 48 deletions.
20 changes: 12 additions & 8 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
from chai_lab.data.features.generators.token_pair_pocket_restraint import (
TokenPairPocketRestraint,
)
from chai_lab.data.io.pdb_utils import write_pdbs_from_outputs
from chai_lab.data.io.cif_utils import outputs_to_cif
from chai_lab.model.diffusion_schedules import InferenceNoiseSchedule
from chai_lab.model.utils import center_random_augmentation
from chai_lab.ranking.frames import get_frames_and_mask
Expand Down Expand Up @@ -271,7 +271,7 @@ def run_inference(
constraint_context=constraint_context,
)

output_pdb_paths, _, _, _ = run_folding_on_context(
output_cif_paths, _, _, _ = run_folding_on_context(
feature_context,
output_dir=output_dir,
num_trunk_recycles=num_trunk_recycles,
Expand All @@ -280,7 +280,7 @@ def run_inference(
device=device,
)

return output_pdb_paths
return output_cif_paths


def _bin_centers(min_bin: float, max_bin: float, no_bins: int) -> Tensor:
Expand Down Expand Up @@ -661,20 +661,24 @@ def avg_per_token_1d(x):
## Write output files
##

pdb_out_path = output_dir.joinpath(f"pred.model_idx_{idx}.pdb")
cif_out_path = output_dir.joinpath(f"pred.model_idx_{idx}.cif")

print(f"Writing output to {pdb_out_path}")
print(f"Writing output to {cif_out_path}")

# use 0-100 scale for pLDDT in pdb outputs
scaled_plddt_scores_per_atom = 100 * plddt_scores_atom[idx : idx + 1]

write_pdbs_from_outputs(
outputs_to_cif(
coords=atom_pos[idx : idx + 1],
bfactors=scaled_plddt_scores_per_atom,
output_batch=move_data_to_device(inputs, torch.device("cpu")),
write_path=pdb_out_path,
write_path=cif_out_path,
entity_names={
c.entity_data.entity_id: c.entity_data.entity_name
for c in feature_context.chains
},
)
output_paths.append(pdb_out_path)
output_paths.append(cif_out_path)

scores_basename = f"scores.model_idx_{idx}.npz"
scores_out_path = output_dir / scores_basename
Expand Down
7 changes: 5 additions & 2 deletions chai_lab/data/dataset/inference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
class Input:
sequence: str
entity_type: int
entity_name: str


def get_lig_residues(
Expand Down Expand Up @@ -131,7 +132,7 @@ def raw_inputs_to_entitites_data(
release_datetime=datetime.now(),
pdb_id=identifier,
source_pdb_chain_id=_synth_subchain_id(i),
entity_name=f"entity_{i}_{entity_type.name}",
entity_name=input.entity_name,
entity_id=entity_id,
method="none",
entity_type=entity_type,
Expand Down Expand Up @@ -204,6 +205,8 @@ def read_inputs(fasta_file: str | Path, length_limit: int | None = None) -> list
logger.info(f"[fasta] [{fasta_file}] {desc} {len(sequence)}")
# get the type of the sequence
entity_str = desc.split("|")[0].strip().lower()
entity_name = desc.split("|")[1].strip().lower()

match entity_str:
case "protein":
entity_type = EntityType.PROTEIN
Expand All @@ -225,7 +228,7 @@ def read_inputs(fasta_file: str | Path, length_limit: int | None = None) -> list
f"Provided {sequence=} is likely {types_fmt}, not {entity_type.name}"
)

retval.append(Input(sequence, entity_type.value))
retval.append(Input(sequence, entity_type.value, entity_name))
total_length += len(sequence)

if length_limit is not None and total_length > length_limit:
Expand Down
200 changes: 200 additions & 0 deletions chai_lab/data/io/cif_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import logging
from pathlib import Path

import gemmi
import modelcif
import torch
from ihm import ChemComp, DNAChemComp, LPeptideChemComp, RNAChemComp
from modelcif import Assembly, AsymUnit, Entity, dumper, model
from torch import Tensor

from chai_lab.data.io.pdb_utils import (
PDBContext,
entity_to_pdb_atoms,
get_pdb_chain_name,
pdb_context_from_batch,
)
from chai_lab.data.parsing.structure.entity_type import EntityType
from chai_lab.data.residue_constants import restype_3to1
from chai_lab.utils.typing import Float

logger = logging.getLogger(__name__)


class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT):
name = "pLDDT"
software = None
description = "Predicted lddt"


def token_centre_plddts(
context: PDBContext,
asym_id: int,
) -> tuple[list[float], list[int]]:
assert context.atom_bfactor_or_plddt is not None
mask = context.token_asym_id == asym_id

atom_idces = context.token_centre_atom_index[mask]

# residue indices for tokens
residue_indices = context.token_residue_index[mask]

return context.atom_bfactor_or_plddt[atom_idces].tolist(), residue_indices.tolist()


def get_chains_metadata(context: PDBContext) -> list[dict]:
# for each chain, get chain id, entity id, full sequence
token_res_names = context.token_res_names_to_string

records = []

for asym_id in torch.unique(context.token_asym_id):
if asym_id == 0: # padding
continue

token_indices = torch.where(context.token_asym_id == asym_id)[0]
chain_token_res_names = [token_res_names[i] for i in token_indices]

residue_indices = context.token_residue_index[token_indices]

# check no missing residues
max_res = torch.max(residue_indices).item()
tokens_in_residue = (
residue_indices == torch.arange(max_res + 1)[..., None]
) # (residues, tokens)
assert tokens_in_residue.any(dim=-1).all()

first_token_in_resi = torch.argmax(tokens_in_residue.int(), dim=-1)

sequence = [chain_token_res_names[i] for i in first_token_in_resi]
entity_id = context.token_entity_id[token_indices][0]

entity_type: int = context.get_chain_entity_type(asym_id)

records.append(
{
"sequence": sequence,
"entity_id": entity_id.item(),
"asym_id": asym_id.item(),
"entity_type": entity_type,
}
)

return records


def _to_chem_component(res_name_3: str, entity_type: int):
match entity_type:
case EntityType.LIGAND.value:
code = res_name_3
return ChemComp(res_name_3, code, code_canonical=code)
case EntityType.PROTEIN.value:
code = restype_3to1.get(res_name_3, res_name_3)
one_letter_code = gemmi.find_tabulated_residue(res_name_3).one_letter_code
return LPeptideChemComp(res_name_3, code, code_canonical=one_letter_code)
case EntityType.DNA.value:
code = res_name_3
# canonical code is DA -> A for DNA in cif files from wwpdb
return DNAChemComp(res_name_3, code, code_canonical=code[-1])
case EntityType.RNA.value:
code = res_name_3
return RNAChemComp(res_name_3, code, code_canonical=code)
case _:
raise NotImplementedError


def sequence_to_chem_comps(sequence: list[str], entity_type: int) -> list[ChemComp]:
return [_to_chem_component(resi, entity_type) for resi in sequence]


def context_to_cif(context: PDBContext, outpath: Path, entity_names: dict[int, str]):
records = get_chains_metadata(context)

entities_map = {}
for record in records:
entity_id = record["entity_id"]
if entity_id in entities_map:
continue

chem_components = sequence_to_chem_comps(
record["sequence"], record["entity_type"]
)
cif_entity = Entity(chem_components, description=entity_names[entity_id])

entities_map[entity_id] = cif_entity

chains_map = {}

def _make_chain(record: dict) -> AsymUnit:
asym_id = record["asym_id"]
chain_id_str = get_pdb_chain_name(asym_id)
entity_id = record["entity_id"]

return AsymUnit(
entities_map[entity_id],
details=f"Chain {chain_id_str}",
id=chain_id_str,
)

chains_map = {r["asym_id"]: _make_chain(r) for r in records}

pdb_atoms: list[list] = entity_to_pdb_atoms(context)

_assembly = Assembly(chains_map.values(), name="Assembly 1")

class PredModel(model.AbInitioModel):
def get_atoms(self):
for atoms_list in pdb_atoms:
for a in atoms_list:
yield model.Atom(
asym_unit=chains_map[a.asym_id],
type_symbol=a.element,
seq_id=int(a.residue_index),
atom_id=a.atom_name,
x=a.pos[0],
y=a.pos[1],
z=a.pos[2],
het=False,
biso=a.b_factor,
occupancy=1.00,
)

_model = PredModel(_assembly, name="pred_model_1")

for asym_id, cif_asym_unit in chains_map.items():
entity_type = context.get_chain_entity_type(asym_id)

if entity_type != EntityType.LIGAND.value:
# token centres are Ca or C1' for nucleic acids
chain_plddts, residue_indices = token_centre_plddts(
context,
asym_id,
)

for residue_idx, plddt in zip(
residue_indices,
chain_plddts,
):
_model.qa_metrics.append(
_LocalPLDDT(cif_asym_unit.residue(residue_idx + 1), plddt)
)

system = modelcif.System(title="Chai-1 predicted structure")
system.authors = ["Chai team"]
model_group = model.ModelGroup([_model], name="pred")
system.model_groups.append(model_group)

dumper.write(open(outpath, "w"), systems=[system])


def outputs_to_cif(
coords: Float[Tensor, "1 n_atoms 3"],
output_batch: dict,
write_path: Path,
entity_names: dict[int, str],
bfactors: Float[Tensor, "1 n_atoms"] | None = None,
):
context = pdb_context_from_batch(output_batch, coords, plddt=bfactors)
write_path.parent.mkdir(parents=True, exist_ok=True)
context_to_cif(context, write_path, entity_names)
logger.info(f"saved cif file to {write_path}")
Loading

0 comments on commit ecd62ff

Please sign in to comment.