import gradio as gr from typing import TypedDict, List, Optional import os import pandas as pd from climateqa.engine.talk_to_data.main import ask_drias from climateqa.engine.talk_to_data.config import DRIAS_MODELS, DRIAS_UI_TEXT class DriasUIElements(TypedDict): tab: gr.Tab details_accordion: gr.Accordion examples_hidden: gr.Textbox examples: gr.Examples drias_direct_question: gr.Textbox result_text: gr.Textbox table_names_display: gr.DataFrame query_accordion: gr.Accordion drias_sql_query: gr.Textbox chart_accordion: gr.Accordion model_selection: gr.Dropdown drias_display: gr.Plot table_accordion: gr.Accordion drias_table: gr.DataFrame pagination_display: gr.Markdown prev_button: gr.Button next_button: gr.Button async def ask_drias_query(query: str, index_state: int, user_id: str): result = await ask_drias(query, index_state, user_id) return result def show_results(sql_queries_state, dataframes_state, plots_state): if not sql_queries_state or not dataframes_state or not plots_state: # If all results are empty, show "No result" return ( gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), ) else: # Show the appropriate components with their data return ( gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), ) def filter_by_model(dataframes, figures, index_state, model_selection): df = dataframes[index_state] if df.empty: return df, None if "model" not in df.columns: return df, figures[index_state](df) if model_selection != "ALL": df = df[df["model"] == model_selection] if df.empty: return df, None figure = figures[index_state](df) return df, figure def update_pagination(index, sql_queries): pagination = f"{index + 1}/{len(sql_queries)}" if sql_queries else "" return pagination def show_previous(index, sql_queries, dataframes, plots): if index > 0: index -= 1 return ( sql_queries[index], dataframes[index], plots[index](dataframes[index]), index, ) def show_next(index, sql_queries, dataframes, plots): if index < len(sql_queries) - 1: index += 1 return ( sql_queries[index], dataframes[index], plots[index](dataframes[index]), index, ) def display_table_names(table_names): return [table_names] def on_table_click(evt: gr.SelectData, table_names, sql_queries, dataframes, plots): index = evt.index[1] figure = plots[index](dataframes[index]) return ( sql_queries[index], dataframes[index], figure, index, ) def create_drias_ui() -> DriasUIElements: """Create and return all UI elements for the DRIAS tab.""" with gr.Tab("France - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab: with gr.Accordion(label="Details") as details_accordion: gr.Markdown(DRIAS_UI_TEXT) # Add examples for common questions examples_hidden = gr.Textbox(visible=False, elem_id="drias-examples-hidden") examples = gr.Examples( examples=[ ["What will the temperature be like in Paris?"], ["What will be the total rainfall in France in 2030?"], ["How frequent will extreme events be in Lyon?"], ["Comment va évoluer la température en France entre 2030 et 2050 ?"] ], label="Example Questions", inputs=[examples_hidden], outputs=[examples_hidden], ) with gr.Row(): drias_direct_question = gr.Textbox( label="Direct Question", placeholder="You can write direct question here", elem_id="direct-question", interactive=True, ) result_text = gr.Textbox( label="", elem_id="no-result-label", interactive=False, visible=True ) table_names_display = gr.DataFrame( [], label="List of relevant indicators", headers=None, interactive=False, elem_id="table-names", visible=False ) with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion: drias_sql_query = gr.Textbox( label="", elem_id="sql-query", interactive=False ) with gr.Accordion(label="Chart", visible=False) as chart_accordion: model_selection = gr.Dropdown( label="Model", choices=DRIAS_MODELS, value="ALL", interactive=True ) drias_display = gr.Plot(elem_id="vanna-plot") with gr.Accordion( label="Data used", open=False, visible=False ) as table_accordion: drias_table = gr.DataFrame([], elem_id="vanna-table") pagination_display = gr.Markdown( value="", visible=False, elem_id="pagination-display" ) with gr.Row(): prev_button = gr.Button("Previous", visible=False) next_button = gr.Button("Next", visible=False) return DriasUIElements( tab=tab, details_accordion=details_accordion, examples_hidden=examples_hidden, examples=examples, drias_direct_question=drias_direct_question, result_text=result_text, table_names_display=table_names_display, query_accordion=query_accordion, drias_sql_query=drias_sql_query, chart_accordion=chart_accordion, model_selection=model_selection, drias_display=drias_display, table_accordion=table_accordion, drias_table=drias_table, pagination_display=pagination_display, prev_button=prev_button, next_button=next_button ) def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=None) -> None: """Set up all event handlers for the DRIAS tab.""" # Create state variables sql_queries_state = gr.State([]) dataframes_state = gr.State([]) plots_state = gr.State([]) index_state = gr.State(0) table_names_list = gr.State([]) user_id = gr.State(user_id) # Handle example selection ui_elements["examples_hidden"].change( lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)), inputs=[ui_elements["examples_hidden"]], outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]] ).then( ask_drias_query, inputs=[ui_elements["examples_hidden"], index_state, user_id], outputs=[ ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], sql_queries_state, dataframes_state, plots_state, index_state, table_names_list, ui_elements["result_text"], ], ).then( show_results, inputs=[sql_queries_state, dataframes_state, plots_state], outputs=[ ui_elements["result_text"], ui_elements["query_accordion"], ui_elements["table_accordion"], ui_elements["chart_accordion"], ui_elements["prev_button"], ui_elements["next_button"], ui_elements["pagination_display"], ui_elements["table_names_display"], ], ).then( update_pagination, inputs=[index_state, sql_queries_state], outputs=[ui_elements["pagination_display"]], ).then( display_table_names, inputs=[table_names_list], outputs=[ui_elements["table_names_display"]], ) # Handle direct question submission ui_elements["drias_direct_question"].submit( lambda: gr.Accordion(open=False), inputs=None, outputs=[ui_elements["details_accordion"]] ).then( ask_drias_query, inputs=[ui_elements["drias_direct_question"], index_state, user_id], outputs=[ ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], sql_queries_state, dataframes_state, plots_state, index_state, table_names_list, ui_elements["result_text"], ], ).then( show_results, inputs=[sql_queries_state, dataframes_state, plots_state], outputs=[ ui_elements["result_text"], ui_elements["query_accordion"], ui_elements["table_accordion"], ui_elements["chart_accordion"], ui_elements["prev_button"], ui_elements["next_button"], ui_elements["pagination_display"], ui_elements["table_names_display"], ], ).then( update_pagination, inputs=[index_state, sql_queries_state], outputs=[ui_elements["pagination_display"]], ).then( display_table_names, inputs=[table_names_list], outputs=[ui_elements["table_names_display"]], ) # Handle model selection change ui_elements["model_selection"].change( filter_by_model, inputs=[dataframes_state, plots_state, index_state, ui_elements["model_selection"]], outputs=[ui_elements["drias_table"], ui_elements["drias_display"]], ) # Handle pagination buttons ui_elements["prev_button"].click( show_previous, inputs=[index_state, sql_queries_state, dataframes_state, plots_state], outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state], ).then( update_pagination, inputs=[index_state, sql_queries_state], outputs=[ui_elements["pagination_display"]], ) ui_elements["next_button"].click( show_next, inputs=[index_state, sql_queries_state, dataframes_state, plots_state], outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state], ).then( update_pagination, inputs=[index_state, sql_queries_state], outputs=[ui_elements["pagination_display"]], ) # Handle table selection ui_elements["table_names_display"].select( fn=on_table_click, inputs=[table_names_list, sql_queries_state, dataframes_state, plots_state], outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state], ).then( update_pagination, inputs=[index_state, sql_queries_state], outputs=[ui_elements["pagination_display"]], ) def create_drias_tab(share_client=None, user_id=None): """Create the DRIAS tab with all its components and event handlers.""" ui_elements = create_drias_ui() setup_drias_events(ui_elements, share_client=share_client, user_id=user_id)