Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
Enhanced Gradio UI for the Salesforce/codet5-large model using the Hugging Face Inference API. | |
Adheres to best practices, PEP8, flake8, and the Zen of Python. | |
""" | |
import gradio as gr | |
MODEL_ID = "Salesforce/codet5-large" | |
def prepare_payload(prompt: str, max_tokens: int) -> dict: | |
""" | |
Prepare the payload dictionary for the Hugging Face inference call. | |
Args: | |
prompt (str): The input code containing `<extra_id_0>`. | |
max_tokens (int): Maximum number of tokens for generation. | |
Returns: | |
dict: Payload for the model API call. | |
""" | |
return {"inputs": prompt, "parameters": {"max_length": max_tokens}} | |
def extract_generated_text(api_response: dict) -> str: | |
""" | |
Extract generated text from the API response. | |
Args: | |
api_response (dict): The response dictionary from the model API call. | |
Returns: | |
str: The generated text, or string representation of the response. | |
""" | |
return api_response.get("generated_text", str(api_response)) | |
def main(): | |
with gr.Blocks(fill_height=True, theme=gr.themes.Soft()) as demo: | |
with gr.Sidebar(): | |
gr.Markdown("## 🤖 Inference Provider") | |
gr.Markdown( | |
( | |
"This Space showcases the `{}` model, served via the Hugging Face Inference API.\n\n" | |
"Sign in with your Hugging Face account to access the model." | |
).format(MODEL_ID) | |
) | |
login_button = gr.LoginButton("🔐 Sign in") | |
gr.Markdown("---") | |
gr.Markdown(f"**Model:** `{MODEL_ID}`") | |
gr.Markdown("[📄 View Model Card](https://huggingface.co/Salesforce/codet5-large)") | |
gr.Markdown("# 🧠 CodeT5 Inference UI") | |
gr.Markdown("Enter your Python code snippet with `<extra_id_0>` as the mask token.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
code_input = gr.Code( | |
label="Input Code", | |
language="python", | |
value="def greet(user): print(f'hello <extra_id_0>!')", | |
lines=10, | |
autofocus=True, | |
) | |
max_tokens = gr.Slider( | |
minimum=8, maximum=128, value=32, step=8, label="Max Tokens" | |
) | |
submit_btn = gr.Button("🚀 Run Inference") | |
with gr.Column(scale=1): | |
output_text = gr.Textbox( | |
label="Inference Output", | |
lines=10, | |
interactive=False, | |
placeholder="Model output will appear here...", | |
) | |
# Load the model from Hugging Face Inference API. | |
model_iface = gr.load( | |
f"models/{MODEL_ID}", | |
accept_token=login_button, | |
provider="hf-inference", | |
) | |
# Chain click events: prepare payload -> API call -> extract output. | |
submit_btn.click( | |
fn=prepare_payload, | |
inputs=[code_input, max_tokens], | |
outputs=model_iface, | |
api_name="prepare_payload", | |
).then( | |
fn=extract_generated_text, | |
inputs=model_iface, | |
outputs=output_text, | |
api_name="extract_output", | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |