Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions conkit/command_line/conkit_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def create_argument_parser():
help="Number of iterations")
parser.add_argument("--moltype", dest="moltype", default="Protein", type=str,
help="Type of molecule")
parser.add_argument("--run_svm", dest="RUN_SVM", default='yes', type=str,
parser.add_argument("--run_svm", dest="RUN_SVM", default='yes if prediction not pdb or mmcif', type=str,
help="Whether to run the support vector machine validation")
parser.add_argument("--run_map_align", dest="RUN_MAP_ALIGN", default='yes', type=str,
help="Whether to run the contactmap alignment validation")
Expand Down Expand Up @@ -275,7 +275,8 @@ def main():

logger.info(os.linesep + "Working directory: %s", os.getcwd())
logger.info("Reading input sequence: %s", args.seqfile)
sequence = conkit.io.read(args.seqfile, args.seqformat).top
sequencefile = conkit.io.read(args.seqfile, args.seqformat)
sequence = sequencefile.top

if len(sequence) < 5:
raise ValueError('Cannot validate model with less than 5 residues')
Expand All @@ -290,6 +291,7 @@ def main():
else:
prediction_file = conkit.io.read(args.distfile, args.distformat)
prediction = prediction_file.top

logger.info("Reading input PDB model: %s", args.pdbfile)
model = conkit.io.read(args.pdbfile, args.pdbformat, distance_cutoff=cutoff, atom_type=rep_atom).top

Expand All @@ -300,6 +302,12 @@ def main():

validation = conkit.plot.ModelValidationFigure(model, prediction, sequence)

if args.RUN_SVM=='yes if prediction not pdb or mmcif': #don't run the svm if prediction is a structure by default
if args.distformat in ['pdb', 'mmcif']:
args.RUN_SVM='no'
else:
args.RUN_SVM='yes'

if args.RUN_SVM=='yes':
logger.info(os.linesep + "Running Support Vector Machine.")

Expand All @@ -320,7 +328,7 @@ def main():
if args.RUN_FILTERS=='yes':
logger.info(os.linesep + "Running Filters.")

validation.count_contacts()
validation.count_contacts(cutoff=cutoff)

if (prediction.plddt != None) and (args.PLDDT_IN_DISTFILE == 'yes'): ##turn into check for plddt

Expand Down
31 changes: 23 additions & 8 deletions conkit/io/pdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def _build_plddts(self, chain):
for residue in chain:
for atom in residue.get_atoms():
plddts[residue.get_id()[1]] = atom.get_bfactor()

return plddts


Expand Down Expand Up @@ -112,14 +113,28 @@ def _chain_contacts(self, chain1, chain2):

def _remove_atom(self, chain, type):
"""Tidy up a chain removing all HETATM entries"""
for residue in chain.copy():
for atom in residue.copy():
if atom.is_disordered():
chain[residue.id].detach_child(atom.id)
elif residue.resname == "GLY" and type == "CB" and atom.id == "CA":
continue
elif atom.id != type:
chain[residue.id].detach_child(atom.id)

if type == 'BASEPAIRING':
#handle special request for contacts/distances based on basepairing atoms in NA rather than backbone atoms
#this could be improved to handle hoogsteen pairs
for residue in chain.copy():
for atom in residue.copy():
if atom.is_disordered():
chain[residue.id].detach_child(atom.id)
else:
atom_needed = (atom.id == 'N1' and residue.resname in ['A', 'G', 'DA', 'DG'])
atom_needed = atom_needed or (atom.id == 'N3' and residue.resname in ['C', 'T', 'U', 'DC', 'DT','DU'])
if not atom_needed:
chain[residue.id].detach_child(atom.id)
else:
for residue in chain.copy():
for atom in residue.copy():
if atom.is_disordered():
chain[residue.id].detach_child(atom.id)
elif residue.resname == "GLY" and type == "CB" and atom.id == "CA":
continue
elif atom.id != type:
chain[residue.id].detach_child(atom.id)

def _remove_hetatm(self, chain):
"""Tidy up a chain removing all HETATM entries"""
Expand Down
9 changes: 5 additions & 4 deletions conkit/plot/modelvalidation.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def _parse_data(self, predicted_dict, *metrics):
self.data['SCORE'] = 0
self.data['CONTACTS'] = 0
self.data['PLDDT'] = 0
self.data['Q_IN_ERROR'] = ''
self.data['Q_IN_ERROR'] = ''



Expand Down Expand Up @@ -383,9 +383,9 @@ def map_align(self,map_align_exe=None):
else:
self.data['MISALIGNED'] = False

def count_contacts(self):
def count_contacts(self,cutoff):

cmap = self.prediction.as_contactmap()
cmap = self.prediction.as_contactmap(distance_cutoff=cutoff)
cmap_dict = cmap.as_dict()
self.data['CONTACTS'] = self.data['RESNUM'].apply(lambda x: len(cmap_dict[int(x)]))

Expand Down Expand Up @@ -428,9 +428,11 @@ def Run_gesamt_filter(self, experimentfile, predictionfile, gesamt_exe, moltype=
chain_experiment = chain.get_id()

for region in flagged_regions:

Q_region = tools.Gesamt_Q_score(predictionfile,experimentfile,region,gesamt_exe=gesamt_exe, chain_experiment = chain_experiment, chain_prediction = 'A', moltype=moltype)
self.data.loc[ (self.data['RESNUM'] <= region[1]) & (self.data['RESNUM'] >= region[0]), 'Q_IN_ERROR'] = Q_region


return 0


Expand Down Expand Up @@ -487,7 +489,6 @@ def draw(self,RUN_SVM=True,RUN_MAP_ALIGN=True,RUN_FILTERS=True,n_contacts_per_re

if 'Q_IN_ERROR' in self.data.columns:
Qs = self.data.set_index('RESNUM')['Q_IN_ERROR'].to_dict()

color_scheme = tools.ColorDefinitions.Q_COLORS
thresholds = list(color_scheme.keys())
thresholds.sort(reverse=True)
Expand Down
Loading