S-Dreamer's picture
Update app.py
4c513c9 verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Enhanced Gradio UI for the Salesforce/codet5-large model using the Hugging Face Inference API.
Adheres to best practices, PEP8, flake8, and the Zen of Python.
"""
import gradio as gr
MODEL_ID = "Salesforce/codet5-large"
def prepare_payload(prompt: str, max_tokens: int) -> dict:
"""
Prepare the payload dictionary for the Hugging Face inference call.
Args:
prompt (str): The input code containing `<extra_id_0>`.
max_tokens (int): Maximum number of tokens for generation.
Returns:
dict: Payload for the model API call.
"""
return {"inputs": prompt, "parameters": {"max_length": max_tokens}}
def extract_generated_text(api_response: dict) -> str:
"""
Extract generated text from the API response.
Args:
api_response (dict): The response dictionary from the model API call.
Returns:
str: The generated text, or string representation of the response.
"""
return api_response.get("generated_text", str(api_response))
def main():
with gr.Blocks(fill_height=True, theme=gr.themes.Soft()) as demo:
with gr.Sidebar():
gr.Markdown("## 🤖 Inference Provider")
gr.Markdown(
(
"This Space showcases the `{}` model, served via the Hugging Face Inference API.\n\n"
"Sign in with your Hugging Face account to access the model."
).format(MODEL_ID)
)
login_button = gr.LoginButton("🔐 Sign in")
gr.Markdown("---")
gr.Markdown(f"**Model:** `{MODEL_ID}`")
gr.Markdown("[📄 View Model Card](https://huggingface.co/Salesforce/codet5-large)")
gr.Markdown("# 🧠 CodeT5 Inference UI")
gr.Markdown("Enter your Python code snippet with `<extra_id_0>` as the mask token.")
with gr.Row():
with gr.Column(scale=1):
code_input = gr.Code(
label="Input Code",
language="python",
value="def greet(user): print(f'hello <extra_id_0>!')",
lines=10,
autofocus=True,
)
max_tokens = gr.Slider(
minimum=8, maximum=128, value=32, step=8, label="Max Tokens"
)
submit_btn = gr.Button("🚀 Run Inference")
with gr.Column(scale=1):
output_text = gr.Textbox(
label="Inference Output",
lines=10,
interactive=False,
placeholder="Model output will appear here...",
)
# Load the model from Hugging Face Inference API.
model_iface = gr.load(
f"models/{MODEL_ID}",
accept_token=login_button,
provider="hf-inference",
)
# Chain click events: prepare payload -> API call -> extract output.
submit_btn.click(
fn=prepare_payload,
inputs=[code_input, max_tokens],
outputs=model_iface,
api_name="prepare_payload",
).then(
fn=extract_generated_text,
inputs=model_iface,
outputs=output_text,
api_name="extract_output",
)
demo.launch()
if __name__ == "__main__":
main()