diff --git a/conkit/command_line/conkit_validate.py b/conkit/command_line/conkit_validate.py index 83eb201a..71b4e63c 100644 --- a/conkit/command_line/conkit_validate.py +++ b/conkit/command_line/conkit_validate.py @@ -94,7 +94,7 @@ def create_argument_parser(): help="Number of iterations") parser.add_argument("--moltype", dest="moltype", default="Protein", type=str, help="Type of molecule") - parser.add_argument("--run_svm", dest="RUN_SVM", default='yes', type=str, + parser.add_argument("--run_svm", dest="RUN_SVM", default='yes if prediction not pdb or mmcif', type=str, help="Whether to run the support vector machine validation") parser.add_argument("--run_map_align", dest="RUN_MAP_ALIGN", default='yes', type=str, help="Whether to run the contactmap alignment validation") @@ -275,7 +275,8 @@ def main(): logger.info(os.linesep + "Working directory: %s", os.getcwd()) logger.info("Reading input sequence: %s", args.seqfile) - sequence = conkit.io.read(args.seqfile, args.seqformat).top + sequencefile = conkit.io.read(args.seqfile, args.seqformat) + sequence = sequencefile.top if len(sequence) < 5: raise ValueError('Cannot validate model with less than 5 residues') @@ -290,6 +291,7 @@ def main(): else: prediction_file = conkit.io.read(args.distfile, args.distformat) prediction = prediction_file.top + logger.info("Reading input PDB model: %s", args.pdbfile) model = conkit.io.read(args.pdbfile, args.pdbformat, distance_cutoff=cutoff, atom_type=rep_atom).top @@ -300,6 +302,12 @@ def main(): validation = conkit.plot.ModelValidationFigure(model, prediction, sequence) + if args.RUN_SVM=='yes if prediction not pdb or mmcif': #don't run the svm if prediction is a structure by default + if args.distformat in ['pdb', 'mmcif']: + args.RUN_SVM='no' + else: + args.RUN_SVM='yes' + if args.RUN_SVM=='yes': logger.info(os.linesep + "Running Support Vector Machine.") @@ -320,7 +328,7 @@ def main(): if args.RUN_FILTERS=='yes': logger.info(os.linesep + "Running Filters.") - validation.count_contacts() + validation.count_contacts(cutoff=cutoff) if (prediction.plddt != None) and (args.PLDDT_IN_DISTFILE == 'yes'): ##turn into check for plddt diff --git a/conkit/io/pdb.py b/conkit/io/pdb.py index 7799c472..ca2965d7 100644 --- a/conkit/io/pdb.py +++ b/conkit/io/pdb.py @@ -71,6 +71,7 @@ def _build_plddts(self, chain): for residue in chain: for atom in residue.get_atoms(): plddts[residue.get_id()[1]] = atom.get_bfactor() + return plddts @@ -112,14 +113,28 @@ def _chain_contacts(self, chain1, chain2): def _remove_atom(self, chain, type): """Tidy up a chain removing all HETATM entries""" - for residue in chain.copy(): - for atom in residue.copy(): - if atom.is_disordered(): - chain[residue.id].detach_child(atom.id) - elif residue.resname == "GLY" and type == "CB" and atom.id == "CA": - continue - elif atom.id != type: - chain[residue.id].detach_child(atom.id) + + if type == 'BASEPAIRING': + #handle special request for contacts/distances based on basepairing atoms in NA rather than backbone atoms + #this could be improved to handle hoogsteen pairs + for residue in chain.copy(): + for atom in residue.copy(): + if atom.is_disordered(): + chain[residue.id].detach_child(atom.id) + else: + atom_needed = (atom.id == 'N1' and residue.resname in ['A', 'G', 'DA', 'DG']) + atom_needed = atom_needed or (atom.id == 'N3' and residue.resname in ['C', 'T', 'U', 'DC', 'DT','DU']) + if not atom_needed: + chain[residue.id].detach_child(atom.id) + else: + for residue in chain.copy(): + for atom in residue.copy(): + if atom.is_disordered(): + chain[residue.id].detach_child(atom.id) + elif residue.resname == "GLY" and type == "CB" and atom.id == "CA": + continue + elif atom.id != type: + chain[residue.id].detach_child(atom.id) def _remove_hetatm(self, chain): """Tidy up a chain removing all HETATM entries""" diff --git a/conkit/plot/modelvalidation.py b/conkit/plot/modelvalidation.py index 1bbc4d0b..5b0bd642 100644 --- a/conkit/plot/modelvalidation.py +++ b/conkit/plot/modelvalidation.py @@ -303,7 +303,7 @@ def _parse_data(self, predicted_dict, *metrics): self.data['SCORE'] = 0 self.data['CONTACTS'] = 0 self.data['PLDDT'] = 0 - self.data['Q_IN_ERROR'] = '' + self.data['Q_IN_ERROR'] = '' @@ -383,9 +383,9 @@ def map_align(self,map_align_exe=None): else: self.data['MISALIGNED'] = False - def count_contacts(self): + def count_contacts(self,cutoff): - cmap = self.prediction.as_contactmap() + cmap = self.prediction.as_contactmap(distance_cutoff=cutoff) cmap_dict = cmap.as_dict() self.data['CONTACTS'] = self.data['RESNUM'].apply(lambda x: len(cmap_dict[int(x)])) @@ -428,9 +428,11 @@ def Run_gesamt_filter(self, experimentfile, predictionfile, gesamt_exe, moltype= chain_experiment = chain.get_id() for region in flagged_regions: + Q_region = tools.Gesamt_Q_score(predictionfile,experimentfile,region,gesamt_exe=gesamt_exe, chain_experiment = chain_experiment, chain_prediction = 'A', moltype=moltype) self.data.loc[ (self.data['RESNUM'] <= region[1]) & (self.data['RESNUM'] >= region[0]), 'Q_IN_ERROR'] = Q_region + return 0 @@ -487,7 +489,6 @@ def draw(self,RUN_SVM=True,RUN_MAP_ALIGN=True,RUN_FILTERS=True,n_contacts_per_re if 'Q_IN_ERROR' in self.data.columns: Qs = self.data.set_index('RESNUM')['Q_IN_ERROR'].to_dict() - color_scheme = tools.ColorDefinitions.Q_COLORS thresholds = list(color_scheme.keys()) thresholds.sort(reverse=True)