|
import gradio as gr |
|
from typing import TypedDict, List, Optional |
|
import os |
|
import pandas as pd |
|
|
|
from climateqa.engine.talk_to_data.main import ask_drias |
|
from climateqa.engine.talk_to_data.config import DRIAS_MODELS, DRIAS_UI_TEXT |
|
|
|
class DriasUIElements(TypedDict): |
|
tab: gr.Tab |
|
details_accordion: gr.Accordion |
|
examples_hidden: gr.Textbox |
|
examples: gr.Examples |
|
drias_direct_question: gr.Textbox |
|
result_text: gr.Textbox |
|
table_names_display: gr.DataFrame |
|
query_accordion: gr.Accordion |
|
drias_sql_query: gr.Textbox |
|
chart_accordion: gr.Accordion |
|
model_selection: gr.Dropdown |
|
drias_display: gr.Plot |
|
table_accordion: gr.Accordion |
|
drias_table: gr.DataFrame |
|
pagination_display: gr.Markdown |
|
prev_button: gr.Button |
|
next_button: gr.Button |
|
|
|
|
|
async def ask_drias_query(query: str, index_state: int, user_id: str): |
|
result = await ask_drias(query, index_state, user_id) |
|
return result |
|
|
|
|
|
def show_results(sql_queries_state, dataframes_state, plots_state): |
|
if not sql_queries_state or not dataframes_state or not plots_state: |
|
|
|
return ( |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
else: |
|
|
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
) |
|
|
|
|
|
def filter_by_model(dataframes, figures, index_state, model_selection): |
|
df = dataframes[index_state] |
|
if df.empty: |
|
return df, None |
|
if "model" not in df.columns: |
|
return df, figures[index_state](df) |
|
if model_selection != "ALL": |
|
df = df[df["model"] == model_selection] |
|
if df.empty: |
|
return df, None |
|
figure = figures[index_state](df) |
|
return df, figure |
|
|
|
|
|
def update_pagination(index, sql_queries): |
|
pagination = f"{index + 1}/{len(sql_queries)}" if sql_queries else "" |
|
return pagination |
|
|
|
|
|
def show_previous(index, sql_queries, dataframes, plots): |
|
if index > 0: |
|
index -= 1 |
|
return ( |
|
sql_queries[index], |
|
dataframes[index], |
|
plots[index](dataframes[index]), |
|
index, |
|
) |
|
|
|
|
|
def show_next(index, sql_queries, dataframes, plots): |
|
if index < len(sql_queries) - 1: |
|
index += 1 |
|
return ( |
|
sql_queries[index], |
|
dataframes[index], |
|
plots[index](dataframes[index]), |
|
index, |
|
) |
|
|
|
|
|
def display_table_names(table_names): |
|
return [table_names] |
|
|
|
|
|
def on_table_click(evt: gr.SelectData, table_names, sql_queries, dataframes, plots): |
|
index = evt.index[1] |
|
figure = plots[index](dataframes[index]) |
|
return ( |
|
sql_queries[index], |
|
dataframes[index], |
|
figure, |
|
index, |
|
) |
|
|
|
|
|
def create_drias_ui() -> DriasUIElements: |
|
"""Create and return all UI elements for the DRIAS tab.""" |
|
with gr.Tab("France - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab: |
|
with gr.Accordion(label="Details") as details_accordion: |
|
gr.Markdown(DRIAS_UI_TEXT) |
|
|
|
|
|
examples_hidden = gr.Textbox(visible=False, elem_id="drias-examples-hidden") |
|
examples = gr.Examples( |
|
examples=[ |
|
["What will the temperature be like in Paris?"], |
|
["What will be the total rainfall in France in 2030?"], |
|
["How frequent will extreme events be in Lyon?"], |
|
["Comment va évoluer la température en France entre 2030 et 2050 ?"] |
|
], |
|
label="Example Questions", |
|
inputs=[examples_hidden], |
|
outputs=[examples_hidden], |
|
) |
|
|
|
with gr.Row(): |
|
drias_direct_question = gr.Textbox( |
|
label="Direct Question", |
|
placeholder="You can write direct question here", |
|
elem_id="direct-question", |
|
interactive=True, |
|
) |
|
|
|
result_text = gr.Textbox( |
|
label="", elem_id="no-result-label", interactive=False, visible=True |
|
) |
|
|
|
table_names_display = gr.DataFrame( |
|
[], label="List of relevant indicators", headers=None, interactive=False, elem_id="table-names", visible=False |
|
) |
|
|
|
with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion: |
|
drias_sql_query = gr.Textbox( |
|
label="", elem_id="sql-query", interactive=False |
|
) |
|
|
|
with gr.Accordion(label="Chart", visible=False) as chart_accordion: |
|
model_selection = gr.Dropdown( |
|
label="Model", choices=DRIAS_MODELS, value="ALL", interactive=True |
|
) |
|
drias_display = gr.Plot(elem_id="vanna-plot") |
|
|
|
with gr.Accordion( |
|
label="Data used", open=False, visible=False |
|
) as table_accordion: |
|
drias_table = gr.DataFrame([], elem_id="vanna-table") |
|
|
|
pagination_display = gr.Markdown( |
|
value="", visible=False, elem_id="pagination-display" |
|
) |
|
|
|
with gr.Row(): |
|
prev_button = gr.Button("Previous", visible=False) |
|
next_button = gr.Button("Next", visible=False) |
|
|
|
return DriasUIElements( |
|
tab=tab, |
|
details_accordion=details_accordion, |
|
examples_hidden=examples_hidden, |
|
examples=examples, |
|
drias_direct_question=drias_direct_question, |
|
result_text=result_text, |
|
table_names_display=table_names_display, |
|
query_accordion=query_accordion, |
|
drias_sql_query=drias_sql_query, |
|
chart_accordion=chart_accordion, |
|
model_selection=model_selection, |
|
drias_display=drias_display, |
|
table_accordion=table_accordion, |
|
drias_table=drias_table, |
|
pagination_display=pagination_display, |
|
prev_button=prev_button, |
|
next_button=next_button |
|
) |
|
|
|
|
|
|
|
def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=None) -> None: |
|
"""Set up all event handlers for the DRIAS tab.""" |
|
|
|
sql_queries_state = gr.State([]) |
|
dataframes_state = gr.State([]) |
|
plots_state = gr.State([]) |
|
index_state = gr.State(0) |
|
table_names_list = gr.State([]) |
|
user_id = gr.State(user_id) |
|
|
|
|
|
ui_elements["examples_hidden"].change( |
|
lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)), |
|
inputs=[ui_elements["examples_hidden"]], |
|
outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]] |
|
).then( |
|
ask_drias_query, |
|
inputs=[ui_elements["examples_hidden"], index_state, user_id], |
|
outputs=[ |
|
ui_elements["drias_sql_query"], |
|
ui_elements["drias_table"], |
|
ui_elements["drias_display"], |
|
sql_queries_state, |
|
dataframes_state, |
|
plots_state, |
|
index_state, |
|
table_names_list, |
|
ui_elements["result_text"], |
|
], |
|
).then( |
|
show_results, |
|
inputs=[sql_queries_state, dataframes_state, plots_state], |
|
outputs=[ |
|
ui_elements["result_text"], |
|
ui_elements["query_accordion"], |
|
ui_elements["table_accordion"], |
|
ui_elements["chart_accordion"], |
|
ui_elements["prev_button"], |
|
ui_elements["next_button"], |
|
ui_elements["pagination_display"], |
|
ui_elements["table_names_display"], |
|
], |
|
).then( |
|
update_pagination, |
|
inputs=[index_state, sql_queries_state], |
|
outputs=[ui_elements["pagination_display"]], |
|
).then( |
|
display_table_names, |
|
inputs=[table_names_list], |
|
outputs=[ui_elements["table_names_display"]], |
|
) |
|
|
|
|
|
ui_elements["drias_direct_question"].submit( |
|
lambda: gr.Accordion(open=False), |
|
inputs=None, |
|
outputs=[ui_elements["details_accordion"]] |
|
).then( |
|
ask_drias_query, |
|
inputs=[ui_elements["drias_direct_question"], index_state, user_id], |
|
outputs=[ |
|
ui_elements["drias_sql_query"], |
|
ui_elements["drias_table"], |
|
ui_elements["drias_display"], |
|
sql_queries_state, |
|
dataframes_state, |
|
plots_state, |
|
index_state, |
|
table_names_list, |
|
ui_elements["result_text"], |
|
], |
|
).then( |
|
show_results, |
|
inputs=[sql_queries_state, dataframes_state, plots_state], |
|
outputs=[ |
|
ui_elements["result_text"], |
|
ui_elements["query_accordion"], |
|
ui_elements["table_accordion"], |
|
ui_elements["chart_accordion"], |
|
ui_elements["prev_button"], |
|
ui_elements["next_button"], |
|
ui_elements["pagination_display"], |
|
ui_elements["table_names_display"], |
|
], |
|
).then( |
|
update_pagination, |
|
inputs=[index_state, sql_queries_state], |
|
outputs=[ui_elements["pagination_display"]], |
|
).then( |
|
display_table_names, |
|
inputs=[table_names_list], |
|
outputs=[ui_elements["table_names_display"]], |
|
) |
|
|
|
|
|
ui_elements["model_selection"].change( |
|
filter_by_model, |
|
inputs=[dataframes_state, plots_state, index_state, ui_elements["model_selection"]], |
|
outputs=[ui_elements["drias_table"], ui_elements["drias_display"]], |
|
) |
|
|
|
|
|
ui_elements["prev_button"].click( |
|
show_previous, |
|
inputs=[index_state, sql_queries_state, dataframes_state, plots_state], |
|
outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state], |
|
).then( |
|
update_pagination, |
|
inputs=[index_state, sql_queries_state], |
|
outputs=[ui_elements["pagination_display"]], |
|
) |
|
|
|
ui_elements["next_button"].click( |
|
show_next, |
|
inputs=[index_state, sql_queries_state, dataframes_state, plots_state], |
|
outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state], |
|
).then( |
|
update_pagination, |
|
inputs=[index_state, sql_queries_state], |
|
outputs=[ui_elements["pagination_display"]], |
|
) |
|
|
|
|
|
ui_elements["table_names_display"].select( |
|
fn=on_table_click, |
|
inputs=[table_names_list, sql_queries_state, dataframes_state, plots_state], |
|
outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state], |
|
).then( |
|
update_pagination, |
|
inputs=[index_state, sql_queries_state], |
|
outputs=[ui_elements["pagination_display"]], |
|
) |
|
|
|
def create_drias_tab(share_client=None, user_id=None): |
|
"""Create the DRIAS tab with all its components and event handlers.""" |
|
ui_elements = create_drias_ui() |
|
setup_drias_events(ui_elements, share_client=share_client, user_id=user_id) |
|
|
|
|