import gradio as gr
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
import PyPDF2
from io import BytesIO
import torch
# Set environment variables
HF_TOKEN = os.environ.get("HF_TOKEN", None)
DESCRIPTION = '''
Academic Paper Improver
This Space helps you improve a selected content of your academic paper using the XtraGPT model series, ensuring controllability on criteria following and in-context ability.
Upload your PDF paper, select a section of text you want to improve, and specify your requirements.
'''
CITATION = """
@misc{XtraGPT,
title = {XtraGPT},
url = {https://huggingface.co/Xtra-Computing/XtraGPT-7B},
author = {Nuo Chen, Andre Lin HuiKai, Junyi Hou, Zining Zhang, Qian Wang, Xidong Wang, Bingsheng He},
month = {March},
year = {2025}
}
"""
LICENSE = """
---
Built with XtraGPT models
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
# Default paper content
default_paper_content = """
The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.
"""
# Available models
AVAILABLE_MODELS = {
"XtraGPT-1.5B": "Xtra-Computing/XtraGPT-1.5B",
"XtraGPT-3B": "Xtra-Computing/XtraGPT-3B",
"XtraGPT-7B": "Xtra-Computing/XtraGPT-7B",
"XtraGPT-14B": "Xtra-Computing/XtraGPT-14B"
}
# Global variables for model and tokenizer
current_model = None
current_tokenizer = None
current_model_name = None
def extract_text_from_pdf(pdf_bytes):
"""Extract text from uploaded PDF file"""
if pdf_bytes is None:
return default_paper_content
try:
# Ensure pdf_bytes is bytes type
if isinstance(pdf_bytes, str):
return pdf_bytes # If already a string, return directly
# Use bytes object directly
pdf_reader = PyPDF2.PdfReader(BytesIO(pdf_bytes))
# Extract text from all pages
text = ""
for page_num in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_num]
text += page.extract_text() + "\n\n"
return text
except Exception as e:
print(f"PDF extraction error: {str(e)}")
return default_paper_content
def load_model(model_name):
"""Load model and tokenizer on demand"""
global current_model, current_tokenizer, current_model_name
# If the requested model is already loaded, return it
if current_model_name == model_name and current_model is not None and current_tokenizer is not None:
return current_tokenizer, current_model
# Clear GPU memory if a model is already loaded
if current_model is not None:
del current_model
del current_tokenizer
torch.cuda.empty_cache()
# Load the requested model
model_path = AVAILABLE_MODELS[model_name]
current_tokenizer = AutoTokenizer.from_pretrained(model_path)
current_model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
current_model_name = model_name
return current_tokenizer, current_model
@spaces.GPU(duration=200)
def improve_paper_section(model_name, paper_content, selected_content, improvement_prompt, temperature=0.1, max_new_tokens=512, progress=gr.Progress()):
"""
Improve a section of an academic paper - non-streaming generation
"""
# Check inputs
if not selected_content or not improvement_prompt:
return "Please provide both text to improve and improvement requirements."
try:
progress(0.1, desc="Loading model...")
# Load the selected model
tokenizer, model = load_model(model_name)
progress(0.3, desc="Processing input...")
# Build prompt
content = f"""
Please improve the selected content based on the following. Act as an expert model for improving articles **PAPER_CONTENT**.
The output needs to answer the **QUESTION** on **SELECTED_CONTENT** in the input. Avoid adding unnecessary length, unrelated details, overclaims, or vague statements.
Focus on clear, concise, and evidence-based improvements that align with the overall context of the paper.
{paper_content}
{selected_content}
{improvement_prompt}
"""
# Prepare input
messages = [
{"role": "user", "content": content}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Check input length and truncate to first 10k tokens
input_tokens = tokenizer.encode(text)
if len(input_tokens) > 10000: # Limit to 10k tokens as requested
input_tokens = input_tokens[:10000]
text = tokenizer.decode(input_tokens)
print(f"Input truncated to 10000 tokens")
progress(0.5, desc="Generating improved text...")
# Generate non-streaming
input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=(temperature > 0),
temperature=temperature if temperature > 0 else 1.0,
pad_token_id=tokenizer.eos_token_id
)
# Only keep the newly generated part
generated_ids = output_ids[0, len(input_ids[0]):]
response = tokenizer.decode(generated_ids, skip_special_tokens=True)
progress(1.0, desc="Complete!")
return response
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Generation error: {str(e)}\n{error_details}")
return f"Error generating text: {str(e)}\n\nPlease try with different parameters or input."
# Create Gradio interface
with gr.Blocks(fill_height=True, css=css) as demo:
# Store extracted PDF text
extracted_pdf_text = gr.State(default_paper_content)
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
# Step 1: Upload PDF
with gr.Group():
gr.Markdown("### Step 1: Upload your academic paper")
pdf_file = gr.File(
label="Upload PDF",
file_types=[".pdf"],
type="binary" # Get binary data directly
)
# Model selection
with gr.Group():
gr.Markdown("### Select Model")
model_dropdown = gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
value="XtraGPT-7B", # Default selection
label="Select XtraGPT Model"
)
# Step 2: Extract and select text
with gr.Group():
gr.Markdown("### Step 2: Enter the text section to improve")
selected_content = gr.Textbox(
label="Text to improve",
placeholder="Paste the section of text you want to improve...",
lines=5,
value="The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration."
)
# Step 3: Specify improvement requirements
with gr.Group():
gr.Markdown("### Step 3: Specify your improvement requirements")
improvement_prompt = gr.Textbox(
label="Improvement requirements",
placeholder="e.g., 'Make this more concise', 'Add more technical details', 'Redefine this concept'...",
lines=3,
value="help me make it more concise."
)
with gr.Accordion("⚙️ Parameters", open=False):
temperature = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.1, label="Temperature")
max_tokens = gr.Slider(minimum=128, maximum=1024, step=32, value=512, label="Max Tokens")
submit_btn = gr.Button("Improve Text")
with gr.Column():
# Output
output = gr.Textbox(label="Improved Text", lines=20)
# Display extracted PDF text (collapsible)
with gr.Accordion("Extracted PDF Content (for reference)", open=False):
pdf_content_display = gr.Textbox(
label="Paper Content",
lines=10,
value=default_paper_content
)
# Automatically extract text when PDF is uploaded
def update_pdf_content(pdf_bytes):
if pdf_bytes is not None:
content = extract_text_from_pdf(pdf_bytes)
return content, content
return default_paper_content, default_paper_content
pdf_file.change(
fn=update_pdf_content,
inputs=[pdf_file],
outputs=[extracted_pdf_text, pdf_content_display]
)
# Process text improvement
submit_btn.click(
fn=improve_paper_section,
inputs=[model_dropdown, extracted_pdf_text, selected_content, improvement_prompt, temperature, max_tokens],
outputs=[output]
)
gr.HTML(CITATION)
if __name__ == "__main__":
demo.launch()