-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathapp.py
More file actions
56 lines (44 loc) · 1.88 KB
/
app.py
File metadata and controls
56 lines (44 loc) · 1.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import streamlit as st
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from dotenv import load_dotenv
from langchain_chroma import Chroma
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.chains import create_retrieval_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
load_dotenv()
st.title("RAG Application with Gemini API")
loader = PyPDFLoader("attentionisallyouneed.pdf")
data = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000)
docs = text_splitter.split_documents(data)
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
vectorstore = Chroma.from_documents(documents=docs,embedding=embeddings,persist_directory="./chroma_db")
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs= {"k" : 10})
llm = ChatGoogleGenerativeAI(
model="gemini-1.5-pro",
temperature=0.3,
max_tokens=500
)
query=st.chat_input("Say something:")
prompt = query
system_prompt = ("You are an assistant for question-answering tasks"
"Use the following pieces of retrieved context to answer"
"If you don't know the answer, say that you don't know"
"Use three sentences maximum and keep the answer corrects."
"\n\n"
"{context}"
)
prompt = ChatPromptTemplate.from_messages(
[
("system",system_prompt),
("human","{input}")
]
)
if query:
question_answer_chain = create_stuff_documents_chain(llm,prompt)
rag_chain = create_retrieval_chain(retriever,question_answer_chain)
response = rag_chain.invoke({"input":"What is encoder?"})
st.write(response["answer"])