File size: 7,057 Bytes
aa8e01a
 
 
 
 
 
 
 
 
 
14bbe59
 
 
 
413ca2d
aa8e01a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14bbe59
 
 
 
 
 
 
 
 
413ca2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa8e01a
 
14bbe59
413ca2d
 
aa8e01a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14bbe59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa8e01a
 
14bbe59
aa8e01a
14bbe59
 
 
 
aa8e01a
14bbe59
aa8e01a
 
 
 
14bbe59
 
aa8e01a
 
14bbe59
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os
from dotenv import load_dotenv
from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma  # 이 쀄을 μˆ˜μ •
from langchain.chains import ConversationalRetrievalChain
from langchain_openai import ChatOpenAI
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
import pdfplumber
from concurrent.futures import ThreadPoolExecutor
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langgraph.graph import Graph
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain.prompts import PromptTemplate

# Load environment variables
load_dotenv()

# Set OpenAI API key
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
    raise ValueError("API key not found. Please set the OPENAI_API_KEY environment variable.")
os.environ["OPENAI_API_KEY"] = api_key

def load_retrieval_qa_chain():
    # Load embeddings
    embeddings = OpenAIEmbeddings()
    
    # Load vector store
    vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)

    # Initialize ChatOpenAI model
    llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)  # "gpt-4o-mini

    # Create a compressor for re-ranking
    compressor = LLMChainExtractor.from_llm(llm)
    
    # Create a ContextualCompressionRetriever
    compression_retriever = ContextualCompressionRetriever(
        base_compressor=compressor,
        base_retriever=vectorstore.as_retriever()
    )

    # Define your instruction/prompt
    instruction = """당신은 RAG(Retrieval-Augmented Generation) 기반 AI μ–΄μ‹œμŠ€ν„΄νŠΈμž…λ‹ˆλ‹€. λ‹€μŒ 지침을 따라 μ‚¬μš©μž μ§ˆλ¬Έμ— λ‹΅ν•˜μ„Έμš”:

1. 검색 κ²°κ³Ό ν™œμš©: 제곡된 검색 κ²°κ³Όλ₯Ό λΆ„μ„ν•˜κ³  κ΄€λ ¨ 정보λ₯Ό μ‚¬μš©ν•΄ λ‹΅λ³€ν•˜μ„Έμš”.
2. μ •ν™•μ„± μœ μ§€: μ •λ³΄μ˜ 정확성을 ν™•μΈν•˜κ³ , λΆˆν™•μ‹€ν•œ 경우 이λ₯Ό λͺ…μ‹œν•˜μ„Έμš”.
3. κ°„κ²°ν•œ 응닡: μ§ˆλ¬Έμ— 직접 λ‹΅ν•˜κ³  핡심 λ‚΄μš©μ— μ§‘μ€‘ν•˜μ„Έμš”.
4. μΆ”κ°€ 정보 μ œμ•ˆ: κ΄€λ ¨λœ μΆ”κ°€ 정보가 μžˆλ‹€λ©΄ μ–ΈκΈ‰ν•˜μ„Έμš”.
5. μœ€λ¦¬μ„± κ³ λ €: 객관적이고 쀑립적인 νƒœλ„λ₯Ό μœ μ§€ν•˜μ„Έμš”.
6. ν•œκ³„ 인정: λ‹΅λ³€ν•  수 μ—†λŠ” 경우 μ†”μ§νžˆ μΈμ •ν•˜μ„Έμš”.
7. λŒ€ν™” μœ μ§€: μžμ—°μŠ€λŸ½κ²Œ λŒ€ν™”λ₯Ό 이어가고, ν•„μš”μ‹œ 후속 μ§ˆλ¬Έμ„ μ œμ•ˆν•˜μ„Έμš”.

항상 μ •ν™•ν•˜κ³  μœ μš©ν•œ 정보λ₯Ό μ œκ³΅ν•˜λŠ” 것을 λͺ©ν‘œλ‘œ ν•˜μ„Έμš”."""

    # Create a prompt template
    prompt_template = PromptTemplate(
        input_variables=["context", "question"],
        template=instruction + "\n\nContext: {context}\n\nQuestion: {question}\n\nAnswer:"
    )

    # Create ConversationalRetrievalChain with the new retriever and prompt
    qa_chain = ConversationalRetrievalChain.from_llm(
        llm,
        retriever=compression_retriever,
        return_source_documents=True,
        combine_docs_chain_kwargs={"prompt": prompt_template}
    )

    return qa_chain

def extract_text_from_pdf(file_path):
    documents = []
    with pdfplumber.open(file_path) as pdf:
        for page_num, page in enumerate(pdf.pages):
            text = page.extract_text()
            if text:
                # Split text into chunks
                text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
                chunks = text_splitter.split_text(text)
                for chunk in chunks:
                    doc = Document(page_content=chunk, metadata={"source": os.path.basename(file_path), "page": page_num + 1})
                    documents.append(doc)
    return documents

def embed_documents():
    embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"))
    vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
    
    pdf_files = [f for f in os.listdir("./documents") if f.endswith('.pdf')]
    documents = []
    with ThreadPoolExecutor() as executor:
        results = executor.map(extract_text_from_pdf, [f"./documents/{pdf_file}" for pdf_file in pdf_files])
        for result in results:
            documents.extend(result)
    vectorstore.add_documents(documents)

def update_embeddings():
    embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"))
    vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
    
    # Retrieve existing documents
    existing_files = set()
    for doc in vectorstore.similarity_search(""):
        existing_files.add(doc.metadata["source"])
    
    pdf_files = [f for f in os.listdir("./documents") if f.endswith('.pdf')]
    
    new_files = [f for f in pdf_files if f not in existing_files]
    documents = []
    with ThreadPoolExecutor() as executor:
        results = executor.map(extract_text_from_pdf, [f"./documents/{pdf_file}" for pdf_file in new_files])
        for result in results:
            documents.extend(result)
    vectorstore.add_documents(documents)

def create_rag_graph():
    qa_chain = load_retrieval_qa_chain()

    def retrieve_and_generate(inputs):
        question = inputs["question"]
        chat_history = inputs["chat_history"]
        result = qa_chain({"question": question, "chat_history": chat_history})
        
        # Ensure source documents have the correct metadata
        sources = []
        for doc in result.get("source_documents", []):
            if "source" in doc.metadata and "page" in doc.metadata:
                sources.append(f"{os.path.basename(doc.metadata['source'])} (Page {doc.metadata['page']})")
            else:
                print(f"Warning: Document missing metadata: {doc.metadata}")
        
        return {
            "answer": result["answer"],
            "sources": sources
        }

    workflow = Graph()
    workflow.add_node("retrieve_and_generate", retrieve_and_generate)
    workflow.set_entry_point("retrieve_and_generate")

    chain = workflow.compile()
    return chain

rag_chain = create_rag_graph()

def get_answer(query, chat_history):
    formatted_history = [(q, a) for q, a in zip(chat_history[::2], chat_history[1::2])]
    
    response = rag_chain.invoke({"question": query, "chat_history": formatted_history})
    
    # Validate response format
    if "answer" not in response or "sources" not in response:
        print("Warning: Unexpected response format")
        return {"answer": "Error in processing", "sources": []}
    
    return {"answer": response["answer"], "sources": response["sources"]}

# Example usage
if __name__ == "__main__":
    update_embeddings()  # Update embeddings with new documents
    question = "RAG μ‹œμŠ€ν…œμ— λŒ€ν•΄ μ„€λͺ…ν•΄μ£Όμ„Έμš”."
    response = get_answer(question, [])
    print(f"Question: {question}")
    print(f"Answer: {response['answer']}")
    print(f"Sources: {response['sources']}")
    
    # Validate source format
    for source in response['sources']:
        if not (source.endswith(')') and ' (Page ' in source):
            print(f"Warning: Unexpected source format: {source}")