Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
Oufan75 authored Jul 12, 2023
2 parents d0e0ba5 + 1794566 commit 005fb50
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 30 deletions.
14 changes: 13 additions & 1 deletion src/mcsce/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ def read_structure_and_check(file_name, retain_idx=[], verbose=False):
"""
s = Structure(file_name)
s.build()
his_id = [str(resid) for resid, res_type in zip(s.residues, s.residue_types) if res_type == "HIS"]
if len(his_id) > 0:
message = f"WARNING! Undefined histidine protonation status found for structure [{file_name}]:\n"
message += " ".join(his_id)
message += "\nThese residues have been modified to HIP\n"
res_labels = s.res_labels
res_labels[res_labels == "HIS"] = "HIP"
s.res_labels = res_labels
if verbose:
print(message)
else:
print("HIS residues modified to HIP\n")
missing_backbone_atoms = s.check_backbone_atom_completeness()
if len(missing_backbone_atoms) > 0:
message = f"WARNING! These atoms are missing from the current backbone structure [{file_name}]:"
Expand Down Expand Up @@ -133,7 +145,7 @@ def main(input_structure, n_conf, n_worker, output_dir, logfile, mode, fix, batc
return
# remove added sidechains in sections to be processed
s = s.remove_side_chains(fix_idxs)

initialize_func_calc(partial(prepare_energy_function, batch_size=batch_size,
forcefield=ff_obj, terms=["lj", "clash", "coulomb"]),
structure=s, retain_idxs=fix_idxs)
Expand Down
14 changes: 11 additions & 3 deletions src/mcsce/core/side_chain_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,23 +74,31 @@ def initialize_func_calc(efunc_creator, aa_seq=None, structure=None, retain_idxs
# extract amino acid sequence from structure object
aa_seq = structure.residue_types
structure = deepcopy(structure)
chain_ids = structure.residue_chains

sidechain_placeholders.append(deepcopy(structure))
n_terms, c_terms = structure.get_terminal_res_atom_arr()
energy_calculators.append(efunc_creator(structure.atom_labels,
structure.res_nums,
structure.res_labels))
for idx, resname in tqdm(enumerate(aa_seq), total=len(aa_seq)):
structure.res_labels,
n_terms,
c_terms))

for idx, resname, chain_id in tqdm(zip(range(len(aa_seq)), aa_seq, chain_ids), total=len(aa_seq)):
if idx + structure.res_nums[0] not in retain_idxs:
template = sidechain_templates[resname]
structure.add_side_chain(idx + structure.res_nums[0], template)
structure.add_side_chain(idx + structure.res_nums[0], template, chain_id)
sidechain_placeholders.append(deepcopy(structure))

if resname not in ["GLY", "ALA"] and idx + structure.res_nums[0] not in retain_idxs:
n_sidechain_atoms = len(template[1])
all_indices = np.arange(len(structure.atom_labels))
n_terms, c_terms = structure.get_terminal_res_atom_arr()
energy_func = efunc_creator(structure.atom_labels,
structure.res_nums,
structure.res_labels,
n_terms,
c_terms,
partial_indices=[all_indices[-n_sidechain_atoms:],
all_indices[:-n_sidechain_atoms]])
energy_calculators.append(energy_func)
Expand Down
26 changes: 17 additions & 9 deletions src/mcsce/libs/libenergy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def prepare_energy_function(
atom_labels,
residue_numbers,
residue_labels,
N_terminal_indicators,
C_terminal_indicators,
forcefield,
batch_size=16,
partial_indices=None,
Expand Down Expand Up @@ -162,6 +164,8 @@ def prepare_energy_function(
new_indices,
old_indices,
forcefield.forcefield,
N_terminal_indicators,
C_terminal_indicators,
)

# 0.2 as 0.4
Expand Down Expand Up @@ -387,7 +391,7 @@ def create_bonds_apart_mask_for_ij_pairs_old(
residue_labels_ij_gen,
residue_numbers_ij_gen,
atom_labels_ij_gen,
bonds_intra,
bonds_atom_labelsintra,
bonds_inter,
)

Expand All @@ -407,15 +411,18 @@ def create_LJ_params_raw(
new_indices,
old_indices,
force_field,
n_terminal_indicators,
c_terminal_indicators,
):
"""Create ACOEFF and BCOEFF parameters.
Borrowed from IDP Conformer Generator package (https://github.com/julie-forman-kay-lab/IDPConformerGenerator) developed by Joao M. C. Teixeira"""

sigmas_ii_new = extract_ff_params_for_seq(
atom_labels[new_indices],
residue_numbers[new_indices],
residue_labels[new_indices],
min(residue_numbers),
max(residue_numbers),
n_terminal_indicators[new_indices],
c_terminal_indicators[new_indices],
force_field,
'sigma',
)
Expand All @@ -424,8 +431,8 @@ def create_LJ_params_raw(
atom_labels[old_indices],
residue_numbers[old_indices],
residue_labels[old_indices],
min(residue_numbers),
max(residue_numbers),
n_terminal_indicators[old_indices],
c_terminal_indicators[old_indices],
force_field,
'sigma',
)
Expand All @@ -434,8 +441,8 @@ def create_LJ_params_raw(
atom_labels[new_indices],
residue_numbers[new_indices],
residue_labels[new_indices],
min(residue_numbers),
max(residue_numbers),
n_terminal_indicators[new_indices],
c_terminal_indicators[new_indices],
force_field,
'epsilon',
)
Expand All @@ -444,8 +451,8 @@ def create_LJ_params_raw(
atom_labels[old_indices],
residue_numbers[old_indices],
residue_labels[old_indices],
min(residue_numbers),
max(residue_numbers),
n_terminal_indicators[old_indices],
c_terminal_indicators[old_indices],
force_field,
'epsilon',
)
Expand All @@ -463,6 +470,7 @@ def create_LJ_params_raw(

# mixing rules
epsilons_ij = epsilons_ij_pre ** 0.5

# mixing + nm to Angstrom converstion
# / 2 and * 10
sigmas_ij = sigmas_ij_pre * 5
Expand Down
11 changes: 6 additions & 5 deletions src/mcsce/libs/libparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,15 +415,16 @@ def extract_ff_params_for_seq(
params_l = []
params_append = params_l.append

zipit = zip(atom_labels, residue_numbers, residue_labels)
for atom_name, res_num, res_label in zipit:
#print(atom_name, res_num, res_label)
zipit = zip(atom_labels, residue_numbers, residue_labels, n_terminal_idx, c_terminal_idx)
for atom_name, res_num, res_label, is_nterm, is_cterm in zipit:

# adds N and C to the terminal residues
if res_num == n_terminal_idx:

if is_nterm:
res = 'N' + res_label
assert res.isupper() and len(res) == 4, res

elif res_num == c_terminal_idx:
elif is_cterm:
res = 'C' + res_label
assert res.isupper() and len(res) == 4, res
else:
Expand Down
70 changes: 58 additions & 12 deletions src/mcsce/libs/libstructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,29 @@ def coords(self):
def coords(self, coords):
self.data_array[:, cols_coords] = \
np.round(coords, decimals=3).astype('<U8')


def get_terminal_res_atom_arr(self):
"""Return two boolean arrays, each one equal to the total length of the structure,
indicating whether atoms are in N/C terminals."""
c, rs, rn = col_chainID, col_resSeq, col_resName

chains = defaultdict(dict)
for row in self.filtered_atoms:
chains[row[c]].setdefault(row[rs], row[rn])

N_term_indicator = np.zeros(len(self.filtered_atoms), dtype=bool)
C_term_indicator = np.zeros(len(self.filtered_atoms), dtype=bool)
for chain_name, chain_residues in chains.items():
N_res_id = list(chain_residues.keys())[0]
N_res_chain_indicator = (self.filtered_atoms[:, c] == chain_name) & (self.filtered_atoms[:, rs] == N_res_id)
N_term_indicator = N_term_indicator | N_res_chain_indicator

C_res_id = list(chain_residues.keys())[-1]
C_res_chain_indicator = (self.filtered_atoms[:, c] == chain_name) & (self.filtered_atoms[:, rs] == C_res_id)
C_term_indicator = C_term_indicator | C_res_chain_indicator

return N_term_indicator, C_term_indicator



Expand Down Expand Up @@ -196,7 +219,23 @@ def residue_types(self):
chains = defaultdict(dict)
for row in self.filtered_atoms:
chains[row[c]].setdefault(row[rs], row[rn])
return list(list(chains.values())[0].values())
restypes = []
for chain_residues in chains.values():
restypes.extend(list(chain_residues.values()))
return restypes

@property
def residue_chains(self):
c, rs, rn = col_chainID, col_resSeq, col_resName

chains = defaultdict(dict)
for row in self.filtered_atoms:
chains[row[c]].setdefault(row[rs], row[rn])

chain_ids = []
for chain_id in chains:
chain_ids.extend([chain_id] * len(chains[chain_id]))
return chain_ids

@property
def filtered_residues(self):
Expand All @@ -215,15 +254,15 @@ def residues(self):

@property
def atom_labels(self):
return self.data_array[:, col_name]
return self.filtered_atoms[:, col_name]

@property
def res_nums(self):
return self.data_array[:, col_resSeq].astype(int)
return self.filtered_atoms[:, col_resSeq].astype(int)

@property
def res_labels(self):
return self.data_array[:, col_resName]
return self.filtered_atoms[:, col_resName]

@residues.setter
def residues(self, residue_idx):
Expand All @@ -235,6 +274,10 @@ def residues(self, residue_idx):
"""
self.data_array[:, col_resSeq] = str(residue_idx)

@res_labels.setter
def res_labels(self, res_labels):
self.data_array[:, col_resName] = res_labels

def pop_last_filter(self):
"""Pop last filter."""
self._filters.pop()
Expand Down Expand Up @@ -412,24 +455,25 @@ def check_backbone_atom_completeness(self):
Run check of backbone atom completeness and return a list containing all missing atoms from expected backbone atom list
'''
all_residue_atoms = {}
for atom_label, res_num, res_label in zip(self.atom_labels, self.res_nums, self.res_labels):
N_res_indicators, C_res_indicators = self.get_terminal_res_atom_arr()
for atom_label, res_num, res_label, is_N_res, is_C_res in \
zip(self.atom_labels, self.res_nums, self.res_labels, N_res_indicators, C_res_indicators):
if res_num not in all_residue_atoms:
all_residue_atoms[res_num] = {"label": res_label, "atoms": [atom_label]}
all_residue_atoms[res_num] = {"label": res_label, "atoms": [atom_label], "is_N_res": is_N_res, "is_C_res": is_C_res}
else:
all_residue_atoms[res_num]["atoms"].append(atom_label)
n_term_idx = min(all_residue_atoms)
c_term_idx = max(all_residue_atoms)

missing_atoms = []
for idx in all_residue_atoms:
if idx == n_term_idx:
if all_residue_atoms[idx]["is_N_res"]:
expected_atoms = ["N", "CA", "C", "O", "H1", "H2"]
if all_residue_atoms[idx]["label"] not in ["PRO", "HYP"]:
expected_atoms.append("H3")
else:
expected_atoms = ["N", "CA", "C", "O"]
if all_residue_atoms[idx]["label"] not in ["PRO", "HYP"]:
expected_atoms.append("H")
if idx == c_term_idx:
if all_residue_atoms[idx]["is_C_res"]:
expected_atoms.append("OXT")
residue_missing_atom = [item for item in expected_atoms \
if item not in all_residue_atoms[idx]["atoms"]]
Expand All @@ -451,19 +495,21 @@ def remove_side_chains(self, retain_idxs=[]):
copied_structure._data_array = copied_structure.data_array[retained_atoms_filter]
return copied_structure #, None if np.all(retained_atoms_filter) else retained_atoms_filter

def add_side_chain(self, res_idx, sidechain_template):
def add_side_chain(self, res_idx, sidechain_template, chain_id='A'):
template_structure, sc_atoms = sidechain_template
self.add_filter_resnum(res_idx)
N_CA_C_coords = self.get_sorted_minimal_backbone_coords(filtered=True)
sc_all_atom_coords = place_sidechain_template(N_CA_C_coords, template_structure.coords)
sidechain_data_arr = template_structure.data_array.copy()
sidechain_data_arr[:, cols_coords] = np.round(sc_all_atom_coords, decimals=3).astype('<U8')
sidechain_data_arr[:, col_resSeq] = str(res_idx)

# conform to backbone residue labels but conform to sidechain records
res_mask = (self.data_array[:, col_resSeq].astype(int) == res_idx)
self.data_array[res_mask, col_record] = sidechain_data_arr[0, col_record]
sidechain_data_arr[:, col_segid] = str(self.filtered_atoms[0, col_segid])
sidechain_data_arr[:, col_chainID] = str(self.filtered_atoms[0, col_chainID])
sidechain_data_arr[:, col_chainID] = chain_id

self.pop_last_filter()
self._data_array = np.concatenate([self.data_array, sidechain_data_arr[sc_atoms]])

Expand Down

0 comments on commit 005fb50

Please sign in to comment.