File size: 7,057 Bytes
aa8e01a 14bbe59 413ca2d aa8e01a 14bbe59 413ca2d aa8e01a 14bbe59 413ca2d aa8e01a 14bbe59 aa8e01a 14bbe59 aa8e01a 14bbe59 aa8e01a 14bbe59 aa8e01a 14bbe59 aa8e01a 14bbe59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import os
from dotenv import load_dotenv
from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma # μ΄ μ€μ μμ
from langchain.chains import ConversationalRetrievalChain
from langchain_openai import ChatOpenAI
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
import pdfplumber
from concurrent.futures import ThreadPoolExecutor
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langgraph.graph import Graph
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain.prompts import PromptTemplate
# Load environment variables
load_dotenv()
# Set OpenAI API key
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("API key not found. Please set the OPENAI_API_KEY environment variable.")
os.environ["OPENAI_API_KEY"] = api_key
def load_retrieval_qa_chain():
# Load embeddings
embeddings = OpenAIEmbeddings()
# Load vector store
vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
# Initialize ChatOpenAI model
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0) # "gpt-4o-mini
# Create a compressor for re-ranking
compressor = LLMChainExtractor.from_llm(llm)
# Create a ContextualCompressionRetriever
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=vectorstore.as_retriever()
)
# Define your instruction/prompt
instruction = """λΉμ μ RAG(Retrieval-Augmented Generation) κΈ°λ° AI μ΄μμ€ν΄νΈμ
λλ€. λ€μ μ§μΉ¨μ λ°λΌ μ¬μ©μ μ§λ¬Έμ λ΅νμΈμ:
1. κ²μ κ²°κ³Ό νμ©: μ 곡λ κ²μ κ²°κ³Όλ₯Ό λΆμνκ³ κ΄λ ¨ μ 보λ₯Ό μ¬μ©ν΄ λ΅λ³νμΈμ.
2. μ νμ± μ μ§: μ 보μ μ νμ±μ νμΈνκ³ , λΆνμ€ν κ²½μ° μ΄λ₯Ό λͺ
μνμΈμ.
3. κ°κ²°ν μλ΅: μ§λ¬Έμ μ§μ λ΅νκ³ ν΅μ¬ λ΄μ©μ μ§μ€νμΈμ.
4. μΆκ° μ 보 μ μ: κ΄λ ¨λ μΆκ° μ λ³΄κ° μλ€λ©΄ μΈκΈνμΈμ.
5. μ€λ¦¬μ± κ³ λ €: κ°κ΄μ μ΄κ³ μ€λ¦½μ μΈ νλλ₯Ό μ μ§νμΈμ.
6. νκ³ μΈμ : λ΅λ³ν μ μλ κ²½μ° μμ§ν μΈμ νμΈμ.
7. λν μ μ§: μμ°μ€λ½κ² λνλ₯Ό μ΄μ΄κ°κ³ , νμμ νμ μ§λ¬Έμ μ μνμΈμ.
νμ μ ννκ³ μ μ©ν μ 보λ₯Ό μ 곡νλ κ²μ λͺ©νλ‘ νμΈμ."""
# Create a prompt template
prompt_template = PromptTemplate(
input_variables=["context", "question"],
template=instruction + "\n\nContext: {context}\n\nQuestion: {question}\n\nAnswer:"
)
# Create ConversationalRetrievalChain with the new retriever and prompt
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=compression_retriever,
return_source_documents=True,
combine_docs_chain_kwargs={"prompt": prompt_template}
)
return qa_chain
def extract_text_from_pdf(file_path):
documents = []
with pdfplumber.open(file_path) as pdf:
for page_num, page in enumerate(pdf.pages):
text = page.extract_text()
if text:
# Split text into chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
chunks = text_splitter.split_text(text)
for chunk in chunks:
doc = Document(page_content=chunk, metadata={"source": os.path.basename(file_path), "page": page_num + 1})
documents.append(doc)
return documents
def embed_documents():
embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"))
vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
pdf_files = [f for f in os.listdir("./documents") if f.endswith('.pdf')]
documents = []
with ThreadPoolExecutor() as executor:
results = executor.map(extract_text_from_pdf, [f"./documents/{pdf_file}" for pdf_file in pdf_files])
for result in results:
documents.extend(result)
vectorstore.add_documents(documents)
def update_embeddings():
embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"))
vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
# Retrieve existing documents
existing_files = set()
for doc in vectorstore.similarity_search(""):
existing_files.add(doc.metadata["source"])
pdf_files = [f for f in os.listdir("./documents") if f.endswith('.pdf')]
new_files = [f for f in pdf_files if f not in existing_files]
documents = []
with ThreadPoolExecutor() as executor:
results = executor.map(extract_text_from_pdf, [f"./documents/{pdf_file}" for pdf_file in new_files])
for result in results:
documents.extend(result)
vectorstore.add_documents(documents)
def create_rag_graph():
qa_chain = load_retrieval_qa_chain()
def retrieve_and_generate(inputs):
question = inputs["question"]
chat_history = inputs["chat_history"]
result = qa_chain({"question": question, "chat_history": chat_history})
# Ensure source documents have the correct metadata
sources = []
for doc in result.get("source_documents", []):
if "source" in doc.metadata and "page" in doc.metadata:
sources.append(f"{os.path.basename(doc.metadata['source'])} (Page {doc.metadata['page']})")
else:
print(f"Warning: Document missing metadata: {doc.metadata}")
return {
"answer": result["answer"],
"sources": sources
}
workflow = Graph()
workflow.add_node("retrieve_and_generate", retrieve_and_generate)
workflow.set_entry_point("retrieve_and_generate")
chain = workflow.compile()
return chain
rag_chain = create_rag_graph()
def get_answer(query, chat_history):
formatted_history = [(q, a) for q, a in zip(chat_history[::2], chat_history[1::2])]
response = rag_chain.invoke({"question": query, "chat_history": formatted_history})
# Validate response format
if "answer" not in response or "sources" not in response:
print("Warning: Unexpected response format")
return {"answer": "Error in processing", "sources": []}
return {"answer": response["answer"], "sources": response["sources"]}
# Example usage
if __name__ == "__main__":
update_embeddings() # Update embeddings with new documents
question = "RAG μμ€ν
μ λν΄ μ€λͺ
ν΄μ£ΌμΈμ."
response = get_answer(question, [])
print(f"Question: {question}")
print(f"Answer: {response['answer']}")
print(f"Sources: {response['sources']}")
# Validate source format
for source in response['sources']:
if not (source.endswith(')') and ' (Page ' in source):
print(f"Warning: Unexpected source format: {source}") |