-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathanalyze.py
More file actions
128 lines (104 loc) · 3.58 KB
/
analyze.py
File metadata and controls
128 lines (104 loc) · 3.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from db import db, Post, Comment
from vars import (
REDDIT_CLIENT_ID,
REDDIT_CLIENT_SECRET,
REDDIT_USERNAME,
REDDIT_PASSWORD,
SUBREDDITS_FILE,
SCRAPE_LIMIT,
OPENROUTER_API_KEY,
OPENROUTER_MODEL,
HUGGINGFACE_TOKEN,
HUGGINGFACE_MODEL,
)
import requests
import json
import os
import re
from utils import HuggingFaceClient, OpenRouterClient
huggingface_client = HuggingFaceClient(
hf_token=HUGGINGFACE_TOKEN,
model=HUGGINGFACE_MODEL,
)
openrouter_client = OpenRouterClient(model=OPENROUTER_MODEL, api_key=OPENROUTER_API_KEY)
def strip_json_code_block(llm_output_string):
"""
Strips Markdown JSON code block delimiters from an LLM's output string
and returns the pure JSON string.
"""
# Regex to match ```json or ``` at the beginning and ``` at the end,
# and capture the content in between.
match = re.search(r"```(?:json)?\s*(.*?)\s*```", llm_output_string, re.DOTALL)
if match:
json_string = match.group(1)
return json_string
else:
# If no code block is found, return the original string or handle as an error
return llm_output_string
def initialize_database():
db.connect()
db.create_tables([Post, Comment])
db.close()
def main():
initialize_database()
openrouter_api_key = os.environ["OPENROUTER_API_KEY"]
# post = Post.select().order_by(Post.score.desc()).limit(100)
"""comments = (
Comment.select(Comment.comment_id, Comment.content)
.order_by(Comment.score.desc())
.limit(100)
.dicts()
)
comments_list = []
for comment in comments:
comments_list.append(
{"comment_id": comment["comment_id"], "content": comment["content"]}
)
json_output = json.dumps(comments_list)
print(json_output)"""
question = "What problems are people experiencing with AI?"
# stripped = strip_json_code_block(client.chat(prompt))
# analysis = json.loads(stripped)
# print(analysis)
#
prompt = f"""Analyze the intent behind this user question: {question}.
Generate a list of keywords that will be used to do semantic search using a vector database for relevant comments. Do not use words from the question itself.
Provide nothing but the keywords separated by commas."""
# meta-llama/llama-3.3-70b-instruct:free
response = openrouter_client.chat(prompt)
print(response)
embedding = huggingface_client.generate_embedding(response)
DISTANCE_THRESHOLD = 0.7
sql_query = """
SELECT * FROM (
SELECT
*,
(embedding <=> %s::vector) AS distance
FROM Comment
) AS subquery
WHERE distance < %s
ORDER BY distance ASC
LIMIT 100
"""
# Pass both the embedding and the threshold as parameters
comments = Comment.raw(sql_query, embedding, DISTANCE_THRESHOLD).dicts()
relevant_comments = []
for comment in comments:
relevant_comments.append(
{
"comment_id": comment["comment_id"],
"content": comment["content"],
"distance": comment["distance"],
}
)
if not relevant_comments:
print("No relevant comments found that meet the threshold.")
json_output = json.dumps(relevant_comments)
print(json_output)
prompt = f"""Please answer the question: `{question}` using the JSON array of these comments from Reddit:
{json_output}
Please provide your response in Markdown."""
response = openrouter_client.chat(prompt)
print(response)
if __name__ == "__main__":
main()