timeki's picture
log to huggingface
bc61879
raw
history blame
11.5 kB
import gradio as gr
from typing import TypedDict, List, Optional
import os
import pandas as pd
from climateqa.engine.talk_to_data.main import ask_drias
from climateqa.engine.talk_to_data.config import DRIAS_MODELS, DRIAS_UI_TEXT
class DriasUIElements(TypedDict):
tab: gr.Tab
details_accordion: gr.Accordion
examples_hidden: gr.Textbox
examples: gr.Examples
drias_direct_question: gr.Textbox
result_text: gr.Textbox
table_names_display: gr.DataFrame
query_accordion: gr.Accordion
drias_sql_query: gr.Textbox
chart_accordion: gr.Accordion
model_selection: gr.Dropdown
drias_display: gr.Plot
table_accordion: gr.Accordion
drias_table: gr.DataFrame
pagination_display: gr.Markdown
prev_button: gr.Button
next_button: gr.Button
async def ask_drias_query(query: str, index_state: int, user_id: str):
result = await ask_drias(query, index_state, user_id)
return result
def show_results(sql_queries_state, dataframes_state, plots_state):
if not sql_queries_state or not dataframes_state or not plots_state:
# If all results are empty, show "No result"
return (
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
)
else:
# Show the appropriate components with their data
return (
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
)
def filter_by_model(dataframes, figures, index_state, model_selection):
df = dataframes[index_state]
if df.empty:
return df, None
if "model" not in df.columns:
return df, figures[index_state](df)
if model_selection != "ALL":
df = df[df["model"] == model_selection]
if df.empty:
return df, None
figure = figures[index_state](df)
return df, figure
def update_pagination(index, sql_queries):
pagination = f"{index + 1}/{len(sql_queries)}" if sql_queries else ""
return pagination
def show_previous(index, sql_queries, dataframes, plots):
if index > 0:
index -= 1
return (
sql_queries[index],
dataframes[index],
plots[index](dataframes[index]),
index,
)
def show_next(index, sql_queries, dataframes, plots):
if index < len(sql_queries) - 1:
index += 1
return (
sql_queries[index],
dataframes[index],
plots[index](dataframes[index]),
index,
)
def display_table_names(table_names):
return [table_names]
def on_table_click(evt: gr.SelectData, table_names, sql_queries, dataframes, plots):
index = evt.index[1]
figure = plots[index](dataframes[index])
return (
sql_queries[index],
dataframes[index],
figure,
index,
)
def create_drias_ui() -> DriasUIElements:
"""Create and return all UI elements for the DRIAS tab."""
with gr.Tab("France - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab:
with gr.Accordion(label="Details") as details_accordion:
gr.Markdown(DRIAS_UI_TEXT)
# Add examples for common questions
examples_hidden = gr.Textbox(visible=False, elem_id="drias-examples-hidden")
examples = gr.Examples(
examples=[
["What will the temperature be like in Paris?"],
["What will be the total rainfall in France in 2030?"],
["How frequent will extreme events be in Lyon?"],
["Comment va évoluer la température en France entre 2030 et 2050 ?"]
],
label="Example Questions",
inputs=[examples_hidden],
outputs=[examples_hidden],
)
with gr.Row():
drias_direct_question = gr.Textbox(
label="Direct Question",
placeholder="You can write direct question here",
elem_id="direct-question",
interactive=True,
)
result_text = gr.Textbox(
label="", elem_id="no-result-label", interactive=False, visible=True
)
table_names_display = gr.DataFrame(
[], label="List of relevant indicators", headers=None, interactive=False, elem_id="table-names", visible=False
)
with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
drias_sql_query = gr.Textbox(
label="", elem_id="sql-query", interactive=False
)
with gr.Accordion(label="Chart", visible=False) as chart_accordion:
model_selection = gr.Dropdown(
label="Model", choices=DRIAS_MODELS, value="ALL", interactive=True
)
drias_display = gr.Plot(elem_id="vanna-plot")
with gr.Accordion(
label="Data used", open=False, visible=False
) as table_accordion:
drias_table = gr.DataFrame([], elem_id="vanna-table")
pagination_display = gr.Markdown(
value="", visible=False, elem_id="pagination-display"
)
with gr.Row():
prev_button = gr.Button("Previous", visible=False)
next_button = gr.Button("Next", visible=False)
return DriasUIElements(
tab=tab,
details_accordion=details_accordion,
examples_hidden=examples_hidden,
examples=examples,
drias_direct_question=drias_direct_question,
result_text=result_text,
table_names_display=table_names_display,
query_accordion=query_accordion,
drias_sql_query=drias_sql_query,
chart_accordion=chart_accordion,
model_selection=model_selection,
drias_display=drias_display,
table_accordion=table_accordion,
drias_table=drias_table,
pagination_display=pagination_display,
prev_button=prev_button,
next_button=next_button
)
def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=None) -> None:
"""Set up all event handlers for the DRIAS tab."""
# Create state variables
sql_queries_state = gr.State([])
dataframes_state = gr.State([])
plots_state = gr.State([])
index_state = gr.State(0)
table_names_list = gr.State([])
user_id = gr.State(user_id)
# Handle example selection
ui_elements["examples_hidden"].change(
lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
inputs=[ui_elements["examples_hidden"]],
outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]]
).then(
ask_drias_query,
inputs=[ui_elements["examples_hidden"], index_state, user_id],
outputs=[
ui_elements["drias_sql_query"],
ui_elements["drias_table"],
ui_elements["drias_display"],
sql_queries_state,
dataframes_state,
plots_state,
index_state,
table_names_list,
ui_elements["result_text"],
],
).then(
show_results,
inputs=[sql_queries_state, dataframes_state, plots_state],
outputs=[
ui_elements["result_text"],
ui_elements["query_accordion"],
ui_elements["table_accordion"],
ui_elements["chart_accordion"],
ui_elements["prev_button"],
ui_elements["next_button"],
ui_elements["pagination_display"],
ui_elements["table_names_display"],
],
).then(
update_pagination,
inputs=[index_state, sql_queries_state],
outputs=[ui_elements["pagination_display"]],
).then(
display_table_names,
inputs=[table_names_list],
outputs=[ui_elements["table_names_display"]],
)
# Handle direct question submission
ui_elements["drias_direct_question"].submit(
lambda: gr.Accordion(open=False),
inputs=None,
outputs=[ui_elements["details_accordion"]]
).then(
ask_drias_query,
inputs=[ui_elements["drias_direct_question"], index_state, user_id],
outputs=[
ui_elements["drias_sql_query"],
ui_elements["drias_table"],
ui_elements["drias_display"],
sql_queries_state,
dataframes_state,
plots_state,
index_state,
table_names_list,
ui_elements["result_text"],
],
).then(
show_results,
inputs=[sql_queries_state, dataframes_state, plots_state],
outputs=[
ui_elements["result_text"],
ui_elements["query_accordion"],
ui_elements["table_accordion"],
ui_elements["chart_accordion"],
ui_elements["prev_button"],
ui_elements["next_button"],
ui_elements["pagination_display"],
ui_elements["table_names_display"],
],
).then(
update_pagination,
inputs=[index_state, sql_queries_state],
outputs=[ui_elements["pagination_display"]],
).then(
display_table_names,
inputs=[table_names_list],
outputs=[ui_elements["table_names_display"]],
)
# Handle model selection change
ui_elements["model_selection"].change(
filter_by_model,
inputs=[dataframes_state, plots_state, index_state, ui_elements["model_selection"]],
outputs=[ui_elements["drias_table"], ui_elements["drias_display"]],
)
# Handle pagination buttons
ui_elements["prev_button"].click(
show_previous,
inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
).then(
update_pagination,
inputs=[index_state, sql_queries_state],
outputs=[ui_elements["pagination_display"]],
)
ui_elements["next_button"].click(
show_next,
inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
).then(
update_pagination,
inputs=[index_state, sql_queries_state],
outputs=[ui_elements["pagination_display"]],
)
# Handle table selection
ui_elements["table_names_display"].select(
fn=on_table_click,
inputs=[table_names_list, sql_queries_state, dataframes_state, plots_state],
outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
).then(
update_pagination,
inputs=[index_state, sql_queries_state],
outputs=[ui_elements["pagination_display"]],
)
def create_drias_tab(share_client=None, user_id=None):
"""Create the DRIAS tab with all its components and event handlers."""
ui_elements = create_drias_ui()
setup_drias_events(ui_elements, share_client=share_client, user_id=user_id)