# app.py for Hugging Face Space # Make sure to add 'gradio', 'transformers', 'torch' (or 'tensorflow'/'flax'), # and 'huggingface_hub' to your requirements.txt file in the Hugging Face Space repository. # Using gr.DataFrame does not require adding pandas if using list-of-lists format. from huggingface_hub import login import gradio as gr import torch # Or tensorflow/flax depending on backend from transformers import AutoModelForCausalLM, AutoTokenizer from huggingface_hub import hf_hub_download # Import hub download function import json # Import json library import os # Import os library for path joining # --- Configuration --- MODEL_NAME = "google/txgemma-2b-predict" #MODEL_NAME = "google/txgemma-9b-predict" PROMPT_FILENAME = "tdc_prompts.json" MODEL_CACHE = "model_cache" # Optional: define a cache directory # MAX_EXAMPLES is no longer strictly limiting the display, but can be used if needed later MAX_EXAMPLES = 600 # Keep variable definition, but DataFrame handles scrolling EXAMPLE_SMILES = "C1=CC=CC=C1" # Default SMILES for examples (Benzene) DATAFRAME_HEADERS = ["Task Name", "Prompt Template"] DATAFRAME_ROW_COUNT = 8 # Number of rows to display initially in the DataFrame hf_token = os.getenv("HF_TOKEN") login(token=hf_token) # --- Load Model, Tokenizer, and Prompts --- print(f"Loading model: {MODEL_NAME}...") tdc_prompts_data = None # Initialize as None dataframe_data = [] # Initialize empty list for DataFrame content try: # Check if GPU is available and use it, otherwise use CPU device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=MODEL_CACHE) print("Tokenizer loaded.") # Load the model model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, cache_dir=MODEL_CACHE, device_map="auto" # Automatically distribute model across available devices (GPU/CPU) ) print("Model loaded.") # Download and load the prompts JSON file print(f"Downloading {PROMPT_FILENAME}...") prompts_file_path = hf_hub_download( repo_id=MODEL_NAME, filename=PROMPT_FILENAME, cache_dir=MODEL_CACHE, ) print(f"{PROMPT_FILENAME} downloaded to: {prompts_file_path}") # Load the JSON data with open(prompts_file_path, 'r') as f: tdc_prompts_data = json.load(f) print(f"Loaded prompts data from {PROMPT_FILENAME}.") # --- Prepare data for Gradio DataFrame --- # Updated logic: Parse the dictionary format from tdc_prompts.json # Create a list of lists for the DataFrame: [[task_name, prompt_template], ...] if isinstance(tdc_prompts_data, dict): print(f"Processing {len(tdc_prompts_data)} prompts from dictionary for DataFrame...") for task_name, prompt_template in tdc_prompts_data.items(): if isinstance(prompt_template, str) and isinstance(task_name, str): # Add task name and the raw template to the list dataframe_data.append([task_name, prompt_template]) else: print(f"Warning: Skipping invalid item in prompts dictionary: key={task_name}, value_type={type(prompt_template)}") print(f"Prepared {len(dataframe_data)} rows for DataFrame.") else: print(f"Warning: Expected {PROMPT_FILENAME} to contain a dictionary, but found {type(tdc_prompts_data)}. Cannot load examples.") # dataframe_data remains empty except Exception as e: print(f"Error loading model, tokenizer, or prompts: {e}") # Ensure dataframe_data is empty on error during setup dataframe_data = [] raise gr.Error(f"Failed during setup. Check logs for details. Error: {e}") # --- Prediction Function --- def predict(prompt, max_new_tokens=100, temperature=0.7): """ Generates text based on the input prompt using the loaded model. (Function remains the same as before) """ print(f"Received prompt: {prompt}") print(f"Generation parameters: max_new_tokens={max_new_tokens}, temperature={temperature}") try: # Prepare the input for the model inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Move inputs to the model's device # Generate text with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=int(max_new_tokens), # Ensure it's an integer temperature=float(temperature), # Ensure it's a float do_sample=True if float(temperature) > 0 else False, # Only sample if temp > 0 pad_token_id=tokenizer.eos_token_id # Set pad token id ) # Decode the generated tokens generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) print(f"Generated text (raw): {generated_text}") # Remove the prompt from the beginning of the generated text if generated_text.startswith(prompt): prompt_length = len(prompt) result_text = generated_text[prompt_length:].lstrip() else: common_prefix = os.path.commonprefix([prompt, generated_text]) if len(prompt) > 0 and len(common_prefix) / len(prompt) > 0.8: result_text = generated_text[len(common_prefix):].lstrip() else: result_text = generated_text print(f"Generated text (processed): {result_text}") return result_text except Exception as e: print(f"Error during prediction: {e}") return f"An error occurred during generation: {e}" # --- Function to handle DataFrame selection --- def select_prompt_from_df(evt: gr.SelectData): """ Triggered when a row is selected in the DataFrame. Updates the main prompt input with the selected template, replacing the placeholder. """ if evt.index is None or evt.index[0] >= len(dataframe_data): print("Invalid selection event or index out of bounds.") return gr.update() # No change selected_row_index = evt.index[0] # Get the prompt template from the second column (index 1) of the selected row prompt_template = dataframe_data[selected_row_index][1] # Replace the placeholder with the example SMILES string selected_prompt = prompt_template.replace("{Drug SMILES}", EXAMPLE_SMILES) print(f"Selected prompt template from row {selected_row_index}, updated input.") # Return the processed prompt to update the prompt_input textbox return selected_prompt # --- Gradio Interface --- print("Creating Gradio interface...") with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( f""" # 🤖 TXGemma-2B-Predict Property Prediction Enter a prompt below, or select a task from the table to load its template, and the model ({MODEL_NAME}) will generate text. Adjust the parameters for different results. Prompt templates loaded from `{PROMPT_FILENAME}`. Selected templates will use the SMILES string `{EXAMPLE_SMILES}` (Benzene) as a placeholder. """ ) with gr.Row(): with gr.Column(scale=2): prompt_input = gr.Textbox( label="Your Prompt", placeholder="Enter your text prompt here, or select a template from the table below...", lines=5, elem_id="prompt_input_box" # Add elem_id for clarity if needed ) with gr.Row(): max_tokens_slider = gr.Slider( minimum=10, maximum=500, value=100, step=10, label="Max New Tokens", info="Maximum number of tokens to generate after the prompt." ) temperature_slider = gr.Slider( minimum=0.0, maximum=1.5, value=0.7, step=0.05, label="Temperature", info="Controls randomness (0=deterministic, >0=random)." ) submit_button = gr.Button("Generate Text", variant="primary") with gr.Column(scale=3): output_text = gr.Textbox( label="Generated Text", lines=10, # Adjust height if needed interactive=False ) # --- Add DataFrame for Prompt Templates --- gr.Markdown("### Select a Prompt Template") prompt_df = gr.DataFrame( value=dataframe_data, headers=DATAFRAME_HEADERS, row_count=(DATAFRAME_ROW_COUNT, "dynamic"), # Show fixed rows initially, allow scrolling col_count=(len(DATAFRAME_HEADERS), "fixed"), # Fixed number of columns wrap=True, # Wrap text in cells label="Prompt Templates" ) # --- Connect Components --- # Connect submit button to prediction function submit_button.click( fn=predict, inputs=[prompt_input, max_tokens_slider, temperature_slider], outputs=output_text, api_name="predict" ) # Connect DataFrame selection to update prompt input # The `select` event triggers the `select_prompt_from_df` function. # The event data (evt: gr.SelectData) is implicitly passed to the function. # The function returns the value to update the `prompt_input` component. prompt_df.select( fn=select_prompt_from_df, inputs=None, # No explicit inputs needed, event data is passed automatically outputs=prompt_input, show_progress="hidden" # Hide progress bar for this quick update ) # --- Launch the App --- print("Launching Gradio app...") demo.queue().launch(debug=True) # Set debug=False for production