factbot / rag_system.py
JUNGU's picture
Update rag_system.py
413ca2d verified
raw
history blame
7.06 kB
import os
from dotenv import load_dotenv
from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma # 이 쀄을 μˆ˜μ •
from langchain.chains import ConversationalRetrievalChain
from langchain_openai import ChatOpenAI
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
import pdfplumber
from concurrent.futures import ThreadPoolExecutor
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langgraph.graph import Graph
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain.prompts import PromptTemplate
# Load environment variables
load_dotenv()
# Set OpenAI API key
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("API key not found. Please set the OPENAI_API_KEY environment variable.")
os.environ["OPENAI_API_KEY"] = api_key
def load_retrieval_qa_chain():
# Load embeddings
embeddings = OpenAIEmbeddings()
# Load vector store
vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
# Initialize ChatOpenAI model
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0) # "gpt-4o-mini
# Create a compressor for re-ranking
compressor = LLMChainExtractor.from_llm(llm)
# Create a ContextualCompressionRetriever
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=vectorstore.as_retriever()
)
# Define your instruction/prompt
instruction = """당신은 RAG(Retrieval-Augmented Generation) 기반 AI μ–΄μ‹œμŠ€ν„΄νŠΈμž…λ‹ˆλ‹€. λ‹€μŒ 지침을 따라 μ‚¬μš©μž μ§ˆλ¬Έμ— λ‹΅ν•˜μ„Έμš”:
1. 검색 κ²°κ³Ό ν™œμš©: 제곡된 검색 κ²°κ³Όλ₯Ό λΆ„μ„ν•˜κ³  κ΄€λ ¨ 정보λ₯Ό μ‚¬μš©ν•΄ λ‹΅λ³€ν•˜μ„Έμš”.
2. μ •ν™•μ„± μœ μ§€: μ •λ³΄μ˜ 정확성을 ν™•μΈν•˜κ³ , λΆˆν™•μ‹€ν•œ 경우 이λ₯Ό λͺ…μ‹œν•˜μ„Έμš”.
3. κ°„κ²°ν•œ 응닡: μ§ˆλ¬Έμ— 직접 λ‹΅ν•˜κ³  핡심 λ‚΄μš©μ— μ§‘μ€‘ν•˜μ„Έμš”.
4. μΆ”κ°€ 정보 μ œμ•ˆ: κ΄€λ ¨λœ μΆ”κ°€ 정보가 μžˆλ‹€λ©΄ μ–ΈκΈ‰ν•˜μ„Έμš”.
5. μœ€λ¦¬μ„± κ³ λ €: 객관적이고 쀑립적인 νƒœλ„λ₯Ό μœ μ§€ν•˜μ„Έμš”.
6. ν•œκ³„ 인정: λ‹΅λ³€ν•  수 μ—†λŠ” 경우 μ†”μ§νžˆ μΈμ •ν•˜μ„Έμš”.
7. λŒ€ν™” μœ μ§€: μžμ—°μŠ€λŸ½κ²Œ λŒ€ν™”λ₯Ό 이어가고, ν•„μš”μ‹œ 후속 μ§ˆλ¬Έμ„ μ œμ•ˆν•˜μ„Έμš”.
항상 μ •ν™•ν•˜κ³  μœ μš©ν•œ 정보λ₯Ό μ œκ³΅ν•˜λŠ” 것을 λͺ©ν‘œλ‘œ ν•˜μ„Έμš”."""
# Create a prompt template
prompt_template = PromptTemplate(
input_variables=["context", "question"],
template=instruction + "\n\nContext: {context}\n\nQuestion: {question}\n\nAnswer:"
)
# Create ConversationalRetrievalChain with the new retriever and prompt
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=compression_retriever,
return_source_documents=True,
combine_docs_chain_kwargs={"prompt": prompt_template}
)
return qa_chain
def extract_text_from_pdf(file_path):
documents = []
with pdfplumber.open(file_path) as pdf:
for page_num, page in enumerate(pdf.pages):
text = page.extract_text()
if text:
# Split text into chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
chunks = text_splitter.split_text(text)
for chunk in chunks:
doc = Document(page_content=chunk, metadata={"source": os.path.basename(file_path), "page": page_num + 1})
documents.append(doc)
return documents
def embed_documents():
embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"))
vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
pdf_files = [f for f in os.listdir("./documents") if f.endswith('.pdf')]
documents = []
with ThreadPoolExecutor() as executor:
results = executor.map(extract_text_from_pdf, [f"./documents/{pdf_file}" for pdf_file in pdf_files])
for result in results:
documents.extend(result)
vectorstore.add_documents(documents)
def update_embeddings():
embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"))
vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
# Retrieve existing documents
existing_files = set()
for doc in vectorstore.similarity_search(""):
existing_files.add(doc.metadata["source"])
pdf_files = [f for f in os.listdir("./documents") if f.endswith('.pdf')]
new_files = [f for f in pdf_files if f not in existing_files]
documents = []
with ThreadPoolExecutor() as executor:
results = executor.map(extract_text_from_pdf, [f"./documents/{pdf_file}" for pdf_file in new_files])
for result in results:
documents.extend(result)
vectorstore.add_documents(documents)
def create_rag_graph():
qa_chain = load_retrieval_qa_chain()
def retrieve_and_generate(inputs):
question = inputs["question"]
chat_history = inputs["chat_history"]
result = qa_chain({"question": question, "chat_history": chat_history})
# Ensure source documents have the correct metadata
sources = []
for doc in result.get("source_documents", []):
if "source" in doc.metadata and "page" in doc.metadata:
sources.append(f"{os.path.basename(doc.metadata['source'])} (Page {doc.metadata['page']})")
else:
print(f"Warning: Document missing metadata: {doc.metadata}")
return {
"answer": result["answer"],
"sources": sources
}
workflow = Graph()
workflow.add_node("retrieve_and_generate", retrieve_and_generate)
workflow.set_entry_point("retrieve_and_generate")
chain = workflow.compile()
return chain
rag_chain = create_rag_graph()
def get_answer(query, chat_history):
formatted_history = [(q, a) for q, a in zip(chat_history[::2], chat_history[1::2])]
response = rag_chain.invoke({"question": query, "chat_history": formatted_history})
# Validate response format
if "answer" not in response or "sources" not in response:
print("Warning: Unexpected response format")
return {"answer": "Error in processing", "sources": []}
return {"answer": response["answer"], "sources": response["sources"]}
# Example usage
if __name__ == "__main__":
update_embeddings() # Update embeddings with new documents
question = "RAG μ‹œμŠ€ν…œμ— λŒ€ν•΄ μ„€λͺ…ν•΄μ£Όμ„Έμš”."
response = get_answer(question, [])
print(f"Question: {question}")
print(f"Answer: {response['answer']}")
print(f"Sources: {response['sources']}")
# Validate source format
for source in response['sources']:
if not (source.endswith(')') and ' (Page ' in source):
print(f"Warning: Unexpected source format: {source}")