import json import os import random import urllib.parse from pathlib import Path from typing import Optional import gradio as gr import numpy as np import pandas as pd from dotenv import load_dotenv from fastembed import SparseEmbedding, SparseTextEmbedding from google import genai from google.genai import types from pydantic import BaseModel, Field from qdrant_client import QdrantClient from qdrant_client import models as qmodels from sentence_transformers import CrossEncoder, SentenceTransformer from vllm import LLM, SamplingParams from vllm.sampling_params import GuidedDecodingParams load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN") VLLM_MODEL_NAME = os.getenv("VLLM_MODEL_NAME") VLLM_GPU_MEMORY_UTILIZATION = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION")) VLLM_MAX_SEQ_LEN = int(os.getenv("VLLM_MAX_SEQ_LEN")) VLLM_DTYPE = os.getenv("VLLM_DTYPE") GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") DATA_PATH = Path(os.getenv("DATA_PATH")) DB_PATH = DATA_PATH / "db" client = QdrantClient(path=str(DB_PATH)) collection_name = "knowledge_cards" num_chunks_base = 500 alpha = 0.5 top_k = 5 # we only want top 5 genres youtube_url_template = "{genre} music playlist" # -------------------------------- HELPERS ------------------------------------- def load_text_resource(path: Path) -> str: with path.open("r") as file: resource = file.read() return resource def youtube_search_link_for_genre(genre: str) -> str: base_url = "https://www.youtube.com/results" params = { "search_query": youtube_url_template.format( genre=genre.replace("_", " ").lower() ) } return f"{base_url}?{urllib.parse.urlencode(params)}" def generate_recommendation_string(ranking: dict[str, float]) -> str: recommendation_string = "## Recommendations for You\n\n" for idx, (genre, score) in enumerate(ranking.items(), start=1): youtube_link = youtube_search_link_for_genre(genre=genre) recommendation_string += ( f"{idx}. **{genre.replace('_', ' ').capitalize()}** ({score:.2f}); " f"[YouTube link]({youtube_link})\n" ) return recommendation_string # -------------------------------- Data Models ------------------------------- class StructuredQueryRewriteResponse(BaseModel): general: str | None subjective: str | None purpose: str | None technical: str | None curiosity: str | None class QueryRewrite(BaseModel): rewrites: list[str] | None = None structured: StructuredQueryRewriteResponse | None = None class APIGenreRecommendation(BaseModel): name: str = Field(description="Name of the music genre.") score: float = Field( description="Score you assign to the genre (from 0 to 1).", ge=0, le=1 ) class APIGenreRecommendationResponse(BaseModel): genres: list[APIGenreRecommendation] class RetrievalResult(BaseModel): chunk: str genre: str score: float class RerankingResult(BaseModel): query: str genre: str chunk: str score: float class Recommendation(BaseModel): name: str rank: int score: Optional[float] = None class PipelineResult(BaseModel): query: str rewrite: Optional[QueryRewrite] = None retrieval_result: Optional[list[RetrievalResult]] = None reranking_result: Optional[list[RerankingResult]] = None recommendations: Optional[dict[str, Recommendation]] = None def to_ranking(self) -> dict[str, float]: if not self.recommendations: return {} return { genre: recommendation.score for genre, recommendation in self.recommendations.items() } # -------------------------------- VLLM -------------------------------------- local_llm = LLM( model=VLLM_MODEL_NAME, max_model_len=VLLM_MAX_SEQ_LEN, gpu_memory_utilization=VLLM_GPU_MEMORY_UTILIZATION, hf_token=HF_TOKEN, enforce_eager=True, dtype=VLLM_DTYPE, ) json_schema = StructuredQueryRewriteResponse.model_json_schema() guided_decoding_params_json = GuidedDecodingParams(json=json_schema) sampling_params_json = SamplingParams( guided_decoding=guided_decoding_params_json, temperature=0.7, top_p=0.8, repetition_penalty=1.05, max_tokens=1024, ) vllm_system_prompt = ( "You are a search query optimization assistant built into" " music genre search engine, helping users discover novel music genres." ) vllm_prompt = load_text_resource(Path("./resources/prompt_vllm.md")) # -------------------------------- GEMINI ------------------------------------ gemini_config = types.GenerateContentConfig( response_mime_type="application/json", response_schema=APIGenreRecommendationResponse, temperature=0.7, max_output_tokens=1024, system_instruction=( "You are a helpful music genre recommendation assistant built into" " music genre search engine, helping users discover novel music genres." ), ) gemini_llm = genai.Client( api_key=GEMINI_API_KEY, http_options={"api_version": "v1alpha"}, ) gemini_prompt = load_text_resource(Path("./resources/prompt_api.md")) # ---------------------------- EMBEDDING MODELS -------------------------------- dense_encoder = SentenceTransformer( model_name_or_path="mixedbread-ai/mxbai-embed-large-v1", device="cuda", model_kwargs={"torch_dtype": VLLM_DTYPE}, ) sparse_encoder = SparseTextEmbedding(model_name="Qdrant/bm25", cuda=True) reranker = CrossEncoder( model_name_or_path="BAAI/bge-reranker-v2-m3", max_length=1024, device="cuda", model_kwargs={"torch_dtype": VLLM_DTYPE}, ) reranker_batch_size = 128 # ---------------------------- RETRIEVAL --------------------------------------- def run_query_rewrite(query: str) -> QueryRewrite: prompt = vllm_prompt.format(query=query) messages = [ {"role": "system", "content": vllm_system_prompt}, {"role": "user", "content": prompt}, ] outputs = local_llm.chat( messages=messages, sampling_params=sampling_params_json, ) rewrite_json = json.loads(outputs[0].outputs[0].text) rewrite = QueryRewrite( rewrites=[x for x in list(rewrite_json.values()) if x is not None], structured=rewrite_json, ) return rewrite def prepare_queries_for_retrieval( query: str, rewrite: QueryRewrite ) -> list[dict[str, str | None]]: queries_to_retrieve = [{"text": query, "topic": None}] for cat, rewrite in rewrite.structured.model_dump().items(): if rewrite is None: continue topic = cat if cat not in ["subjective", "purpose", "technical"]: topic = None queries_to_retrieve.append({"text": rewrite, "topic": topic}) return queries_to_retrieve def run_retrieval( queries: list[dict[str, str]], ) -> RetrievalResult: queries_to_embed = [query["text"] for query in queries] dense_queries = list( dense_encoder.encode( queries_to_embed, convert_to_numpy=True, normalize_embeddings=True ) ) sparse_queries = list(sparse_encoder.query_embed(queries_to_embed)) prefetches: list[qmodels.Prefetch] = [] for query, dense_query, sparse_query in zip(queries, dense_queries, sparse_queries): assert dense_query is not None and sparse_query is not None assert isinstance(dense_query, np.ndarray) and isinstance( sparse_query, SparseEmbedding ) topic = query.get("topic", None) prefetch = [ qmodels.Prefetch( query=dense_query, using="dense", filter=qmodels.Filter( must=[ qmodels.FieldCondition( key="topic", match=qmodels.MatchValue(value=topic) ) ] ) if topic is not None else None, limit=num_chunks_base, ), qmodels.Prefetch( query=qmodels.SparseVector(**sparse_query.as_object()), using="sparse", filter=qmodels.Filter( must=[ qmodels.FieldCondition( key="topic", match=qmodels.MatchValue(value=topic) ) ] ) if topic is not None else None, limit=num_chunks_base, ), ] prefetches.extend(prefetch) retrieval_results = client.query_points( collection_name=collection_name, prefetch=prefetches, query=qmodels.FusionQuery(fusion=qmodels.Fusion.RRF), limit=num_chunks_base, ) final_hits: list[RetrievalResult] = [ RetrievalResult( chunk=hit.payload["text"], genre=hit.payload["genre"], score=hit.score ) for hit in retrieval_results.points ] return final_hits def run_reranking( query: str, retrieval_result: list[RetrievalResult] ) -> list[RerankingResult]: hit_texts: list[str] = [result.chunk for result in retrieval_result] hit_genres: list[str] = [result.genre for result in retrieval_result] hit_rerank = reranker.rank( query=query, documents=hit_texts, batch_size=reranker_batch_size, ) ranking = [ RerankingResult( query=query, genre=hit_genres[hit["corpus_id"]], chunk=hit_texts[hit["corpus_id"]], score=hit["score"], ) for hit in hit_rerank ] ranking.sort(key=lambda x: x.score, reverse=True) return ranking def get_top_genres( df: pd.DataFrame, column: str, alpha: float = 1.0, # beta: float = 1.0, top_k: int | None = None, ) -> pd.Series: assert 0 <= alpha <= 1.0 # Min-max normalization of re-ranker scores before aggregation task_scores = df[column] min_score = task_scores.min() max_score = task_scores.max() if max_score > min_score: # Avoid division by zero df.loc[:, column] = (task_scores - min_score) / (max_score - min_score) tg_df = df.groupby("genre").agg(size=("chunk", "size"), score=(column, "sum")) tg_df["weighted_score"] = alpha * (tg_df["size"] / tg_df["size"].max()) + ( 1 - alpha ) * (tg_df["score"] / tg_df["score"].max()) tg = tg_df.sort_values("weighted_score", ascending=False)["weighted_score"] if top_k: tg = tg.head(top_k) return tg def get_recommendations( reranking_result: list[RerankingResult], ) -> dict[str, Recommendation]: ranking_df = pd.DataFrame([x.model_dump(mode="python") for x in reranking_result]) top_genres_series = get_top_genres( df=ranking_df, column="score", alpha=alpha, top_k=top_k ) recommendations = { genre: Recommendation(name=genre, rank=rank, score=score) for rank, (genre, score) in enumerate( top_genres_series.to_dict().items(), start=1 ) } return recommendations # ----------------------- GENERATE RECOMMENDATIONS ----------------------------- def recommend_sadaimrec(query: str): result = PipelineResult(query=query) print("Running query processing...", flush=True) result.rewrite = run_query_rewrite(query=query) queries_to_retrieve = prepare_queries_for_retrieval( query=query, rewrite=result.rewrite ) print("Running retrieval...", flush=True) result.retrieval_result = run_retrieval(queries_to_retrieve) print("Running re-ranking...", flush=True) result.reranking_result = run_reranking( query=query, retrieval_result=result.retrieval_result ) print("Aggregating recommendations...", flush=True) result.recommendations = get_recommendations(result.reranking_result) recommendation_string = generate_recommendation_string(result.to_ranking()) return f"{recommendation_string}" def recommend_gemini(query: str): print("Generating recommendations using Gemini...", flush=True) prompt = gemini_prompt.format(query=query) response = gemini_llm.models.generate_content( model="gemini-2.0-flash", contents=prompt, config=gemini_config, ) parsed_content: APIGenreRecommendationResponse = response.parsed parsed_content.genres.sort(key=lambda x: x.score, reverse=True) ranking = {x.name.lower(): x.score for x in parsed_content.genres} recommendation_string = generate_recommendation_string(ranking) return f"{recommendation_string}" # -------------------------------------- INTERFACE ----------------------------- pipelines = { "sadaimrec": recommend_sadaimrec, "chatgpt": recommend_gemini, } def generate_responses(query): # Randomize model order pipeline_names = list(pipelines.keys()) random.shuffle(pipeline_names) # Generate responses resp1 = pipelines[pipeline_names[0]](query) resp2 = pipelines[pipeline_names[1]](query) # Return texts and hidden labels return resp1, resp2, pipeline_names[0], pipeline_names[1] # Callback to capture vote def handle_vote(selected, label1, label2, resp1, resp2): chosen_name = label1 if selected == "Option 1" else label2 chosen_resp = resp1 if selected == "Option 1" else resp2 print(f"User voted for {chosen_name}: '{chosen_resp}'") return ( "Thank you for your vote! Restarting in 2 seconds...", gr.update(active=True), ) def reset_ui(): return ( gr.update(value="", visible=False), # hide row gr.update(value=""), # clear query gr.update(visible=False), # hide radio gr.update(visible=False), # hide vote button gr.update(value="**Generating...**"), # clear Option 1 text gr.update(value="**Generating...**"), # clear Option 2 text gr.update(value=""), # clear result gr.update(active=False), ) app_description = load_text_resource(Path("./resources/description.md")) app_instructions = load_text_resource(Path("./resources/instructions.md")) with gr.Blocks( title="sadai-mrec", theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) ) as demo: gr.Markdown(app_description) with gr.Accordion("Detailed usage instructions", open=False): gr.Markdown(app_instructions) query = gr.Textbox( label="Your Query", placeholder="Calming, music for deep relaxation with echoing sounds and deep bass", ) submit_btn = gr.Button("Submit") # timer that resets ui after feedback is sent reset_timer = gr.Timer(value=2.0, active=False) # Hidden components to store model responses and names with gr.Row(visible=False) as response_row: response_1 = gr.Markdown(value="**Generating...**", label="Option 1") response_2 = gr.Markdown(value="**Generating...**", label="Option 2") model_label_1 = gr.Textbox(visible=False) model_label_2 = gr.Textbox(visible=False) # Feedback vote = gr.Radio( ["Option 1 (left)", "Option 2 (right)"], label="Select Best Response", visible=False, ) vote_btn = gr.Button("Vote", visible=False) result = gr.Textbox(label="Console", interactive=False) # On submit submit_btn.click( # generate fn=generate_responses, inputs=[query], outputs=[response_1, response_2, model_label_1, model_label_2], show_progress="full", ) submit_btn.click( # update ui fn=lambda: ( gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), ), inputs=None, outputs=[response_row, vote, vote_btn], ) # Feedback handling vote_btn.click( fn=handle_vote, inputs=[vote, model_label_1, model_label_2, response_1, response_2], outputs=[result, reset_timer], ) reset_timer.tick( fn=reset_ui, inputs=None, outputs=[ response_row, query, vote, vote_btn, response_1, response_2, result, reset_timer, ], trigger_mode="once", ) if __name__ == "__main__": demo.queue(max_size=10, default_concurrency_limit=1).launch( server_name="0.0.0.0", server_port=7860 )