From 7ba30060f8e2174d5d424a7b4461cf5561729463 Mon Sep 17 00:00:00 2001 From: Robert Baldwin Date: Fri, 22 Mar 2024 17:22:50 +0000 Subject: [PATCH] fixed paths --- pfaster.py | 14 ++++++++------ predict/call_features.py | 3 ++- predict/threshold.py | 3 ++- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/pfaster.py b/pfaster.py index 4a9bdb6..fa49dc7 100644 --- a/pfaster.py +++ b/pfaster.py @@ -4,20 +4,23 @@ 12 September 2022 ''' +from pathlib import Path + from predict import Mash, call_features, threshold, model_predict from tools import cmd_parser from tools import exporters as exp def screen_fasta(fasta, db = 'sketch_k70.pkl'): - screener = Mash.MashScreen('ref/sketch/{}'.format(db)) + db_path = Path(__file__).parent / "ref/sketch" / db + screener = Mash.MashScreen(db_path) screener.screen(fasta) matches = screener.ref_counts return list(matches.values()) def run_model(mash_results, model = 'model.rfm'): - model = 'model/{}'.format(model) - st = model_predict.RFClassify(mash_results, model).serotype + model_path = Path(__file__).parent / "model" / model + st = model_predict.RFClassify(mash_results, model_path).serotype return st # tuple - serotype, probability def call_serotype(fasta, outdir): @@ -43,9 +46,8 @@ def call_serotype(fasta, outdir): prediction[3] = prediction[3] + feature_check.flag prediction[1] = pred_sero prediction[2] = prob - except: - prediction = [fasta, 'not typed', 'N/A', 'failed to predict serotype'] - return prediction + except Exception as e: + raise RuntimeError("calling serotype") from e #probability thresholding high_confidence = threshold.ThresholdCall(prediction[1], prediction[2]).valid if not high_confidence: diff --git a/predict/call_features.py b/predict/call_features.py index 28a0531..b2211a3 100644 --- a/predict/call_features.py +++ b/predict/call_features.py @@ -5,6 +5,7 @@ ''' import os +from pathlib import Path import re import xml.etree.ElementTree as ET @@ -22,7 +23,7 @@ def __init__(self, fasta, serogroup): # run blast search for causal gene def blast(self): - ref = 'ref/blast/causal_genes.fasta' + ref = Path(__file__).parent.parent / 'ref/blast/causal_genes.fasta' self.blst = self.fasta.split('.f')[0] + '_blast.xml' cmd = 'blastn -db {0} -query {1} -out {2} -outfmt 5'.format(ref, self.fasta, self.blst) os.system(cmd) diff --git a/predict/threshold.py b/predict/threshold.py index 00fc6f6..e92a0ed 100644 --- a/predict/threshold.py +++ b/predict/threshold.py @@ -5,13 +5,14 @@ ''' import csv +from pathlib import Path class ThresholdCall: def __init__(self, serotype, prob): self.serotype = serotype self.prob = float(prob) - self.threshold_file = 'ref/threshold/prob_thresholds.csv' + self.threshold_file = Path(__file__).parent.parent / "ref/threshold/prob_thresholds.csv" self.thresholds = None self.valid = True self.run_check()