Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import faiss | |
import torch | |
from transformers import AutoTokenizer, AutoModel, pipeline | |
# Load FAISS index and product data | |
index = faiss.read_index("deberta_faiss.index") | |
text_data = pd.read_csv("deberta_text_data.csv")["Retrieved Text"].tolist() | |
# Load DeBERTa model and tokenizer for embedding | |
deberta_model_name = "microsoft/deberta-v3-base" | |
deberta_tokenizer = AutoTokenizer.from_pretrained(deberta_model_name) | |
deberta_model = AutoModel.from_pretrained(deberta_model_name).to("cpu") | |
# Load LLaMA-2 tokenizer and pipeline | |
llama_model_name = "meta-llama/Llama-2-7b-chat-hf" | |
llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_name) | |
llama_pipeline = pipeline("text-generation", model=llama_model_name, tokenizer=llama_tokenizer, device=-1) | |
# Function to generate embeddings from DeBERTa | |
def generate_embeddings(queries): | |
tokens = deberta_tokenizer(queries, return_tensors="pt", padding=True, truncation=True).to("cpu") | |
with torch.no_grad(): | |
outputs = deberta_model(**tokens).last_hidden_state.mean(dim=1).cpu().numpy().astype("float32") | |
return outputs | |
# Define the RAG response logic | |
def generate_response(user_query): | |
query_embedding = generate_embeddings([user_query]) | |
faiss.normalize_L2(query_embedding) | |
distances, indices = index.search(query_embedding, k=5) | |
retrieved_docs = [text_data[idx] for idx in indices[0]] | |
context = ", ".join(set(retrieved_docs)) | |
prompt = f""" | |
Using the following product descriptions: | |
{context} | |
Carefully craft a well-structured response to the following question: | |
**Question:** {user_query} | |
**Instructions:** | |
1. Incorporate **all** retrieved product descriptions. | |
2. Use a **formal yet engaging** tone. | |
3. Provide **practical & creative** luxury decor ideas. | |
4. Ensure a **cohesive & detailed response.** | |
**Your response:** | |
""" | |
result = llama_pipeline(prompt, max_length=512, truncation=True, do_sample=True)[0]["generated_text"] | |
return result | |
# Gradio UI | |
interface = gr.Interface( | |
fn=generate_response, | |
inputs=gr.Textbox(lines=2, placeholder="Ask a question about luxury home decor..."), | |
outputs="text", | |
title="Luxury Decor Assistant (RAG)", | |
description="Ask your luxury decor questions based on real product descriptions!" | |
) | |
interface.launch() | |