Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions pfaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,23 @@
12 September 2022
'''

from pathlib import Path

from predict import Mash, call_features, threshold, model_predict
from tools import cmd_parser
from tools import exporters as exp


def screen_fasta(fasta, db = 'sketch_k70.pkl'):
screener = Mash.MashScreen('ref/sketch/{}'.format(db))
db_path = Path(__file__).parent / "ref/sketch" / db
screener = Mash.MashScreen(db_path)
screener.screen(fasta)
matches = screener.ref_counts
return list(matches.values())

def run_model(mash_results, model = 'model.rfm'):
model = 'model/{}'.format(model)
st = model_predict.RFClassify(mash_results, model).serotype
model_path = Path(__file__).parent / "model" / model
st = model_predict.RFClassify(mash_results, model_path).serotype
return st # tuple - serotype, probability

def call_serotype(fasta, outdir):
Expand All @@ -43,9 +46,8 @@ def call_serotype(fasta, outdir):
prediction[3] = prediction[3] + feature_check.flag
prediction[1] = pred_sero
prediction[2] = prob
except:
prediction = [fasta, 'not typed', 'N/A', 'failed to predict serotype']
return prediction
except Exception as e:
raise RuntimeError("calling serotype") from e
#probability thresholding
high_confidence = threshold.ThresholdCall(prediction[1], prediction[2]).valid
if not high_confidence:
Expand Down
3 changes: 2 additions & 1 deletion predict/call_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
'''

import os
from pathlib import Path
import re
import xml.etree.ElementTree as ET

Expand All @@ -22,7 +23,7 @@ def __init__(self, fasta, serogroup):

# run blast search for causal gene
def blast(self):
ref = 'ref/blast/causal_genes.fasta'
ref = Path(__file__).parent.parent / 'ref/blast/causal_genes.fasta'
self.blst = self.fasta.split('.f')[0] + '_blast.xml'
cmd = 'blastn -db {0} -query {1} -out {2} -outfmt 5'.format(ref, self.fasta, self.blst)
os.system(cmd)
Expand Down
3 changes: 2 additions & 1 deletion predict/threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
'''

import csv
from pathlib import Path

class ThresholdCall:

def __init__(self, serotype, prob):
self.serotype = serotype
self.prob = float(prob)
self.threshold_file = 'ref/threshold/prob_thresholds.csv'
self.threshold_file = Path(__file__).parent.parent / "ref/threshold/prob_thresholds.csv"
self.thresholds = None
self.valid = True
self.run_check()
Expand Down