-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
87 lines (65 loc) · 2 KB
/
main.py
File metadata and controls
87 lines (65 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import argparse
import os
import sys
from load import load_data
from fea import generate_features
import tfsuppress
from tensorflow.keras.models import load_model
import pandas as pd
def main():
parser = argparse.ArgumentParser(
description="Predict triplex-forming potential from FASTA file"
)
parser.add_argument(
"--input",
required=True,
help="Full path to input FASTA file"
)
parser.add_argument(
"--type",
required=True,
choices=["RNA", "DNA"],
help="Specify sequence type: RNA or DNA"
)
parser.add_argument(
"--output",
required=True,
help="Full path to output CSV file"
)
args = parser.parse_args()
fasta_path = args.input
seq_type = args.type.upper()
output_path = args.output
if not os.path.exists(fasta_path):
print("Error: FASTA file not found.")
sys.exit(1)
if not output_path.endswith(".csv"):
output_path += ".csv"
output_dir = os.path.dirname(output_path)
if not os.path.exists(output_dir):
print("Error: Output directory does not exist.")
sys.exit(1)
print(f"\nInput file: {fasta_path}")
print(f"Sequence type: {seq_type}")
print(f"Output file: {output_path}")
metadata, sequences = load_data(fasta_path)
input_vector = generate_features(sequences, seq_type)
if seq_type == "RNA":
model_path = "models/rna.keras"
elif seq_type == "DNA":
model_path = "models/dna.keras"
if not os.path.exists(model_path):
print(f"Error: Model file '{model_path}' not found.")
sys.exit(1)
model = load_model(model_path)
predictions = model.predict(input_vector)
predictions = predictions.flatten()
df = pd.DataFrame({
"Metadata": metadata,
"TriplexCL": predictions
})
df.to_csv(output_path, index=False)
print("\nPrediction complete.")
print(f"Results saved to: {output_path}")
if __name__ == "__main__":
main()