SigmaTriple / app.py
feras-vbrl's picture
Upload app.py
d847350 verified
raw
history blame contribute delete
17.9 kB
import streamlit as st
import json
import torch
import os
import tempfile
import networkx as nx
from pyvis.network import Network
import markdown
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
# Try to import vllm, but don't fail if it's not available
try:
from vllm import LLM, SamplingParams
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
# Set page configuration
st.set_page_config(
page_title="SigmaTriple - Knowledge Graph Extractor",
page_icon="πŸ”",
layout="wide"
)
# Cache the model loading to avoid reloading on each interaction
@st.cache_resource
def load_model():
with st.spinner("Loading model..."):
# Check if GPU is available
gpu_available = torch.cuda.is_available()
st.info(f"GPU available: {gpu_available}")
# Optimized for T4 GPU with vllm
if gpu_available and VLLM_AVAILABLE:
try:
# Configure vllm for T4 GPU
model = LLM(
model="sciphi/triplex",
trust_remote_code=True,
tensor_parallel_size=1,
gpu_memory_utilization=0.9, # Higher utilization for T4
max_model_len=8192, # Increased context length
)
tokenizer = AutoTokenizer.from_pretrained("sciphi/triplex", trust_remote_code=True)
st.success("βœ… Successfully loaded model with vllm on T4 GPU")
return model, tokenizer, True # True indicates vllm is used
except Exception as e:
st.warning(f"Failed to load model with vllm: {e}. Falling back to standard transformers.")
else:
if not VLLM_AVAILABLE:
st.warning("vllm is not available. Using standard transformers.")
elif not gpu_available:
st.warning("No GPU available. vllm requires a GPU. Using standard transformers.")
# Fallback to standard transformers
device = "cuda" if gpu_available else "cpu"
st.info(f"Loading model on {device} using standard transformers.")
# Load with standard transformers
if device == "cuda":
# Optimized for GPU
model = AutoModelForCausalLM.from_pretrained(
"sciphi/triplex",
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.float16 # Use half precision for better GPU performance
)
else:
# CPU fallback with quantization
try:
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
"sciphi/triplex",
trust_remote_code=True,
device_map=None,
quantization_config=quantization_config
)
except Exception as e:
st.warning(f"Failed to load 8-bit model: {e}. Using standard model.")
model = AutoModelForCausalLM.from_pretrained(
"sciphi/triplex",
trust_remote_code=True,
device_map=None
)
# Move model to appropriate device if needed
# Check if the model has a device_map attribute and if it's not None
# If it has a device_map, it's already distributed across devices and shouldn't be moved
if not hasattr(model, 'device_map') or model.device_map is None:
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained("sciphi/triplex", trust_remote_code=True)
return model, tokenizer, False # False indicates standard transformers is used
def triplextract(model, tokenizer, text, entity_types, predicates, use_vllm=True):
input_format = """Perform Named Entity Recognition (NER) and extract knowledge graph triplets from the text. NER identifies named entities of given entity types, and triple extraction identifies relationships between entities using specified predicates.
**Entity Types:**
{entity_types}
**Predicates:**
{predicates}
**Text:**
{text}
"""
message = input_format.format(
entity_types = json.dumps({"entity_types": entity_types}),
predicates = json.dumps({"predicates": predicates}),
text = text)
start_time = time.time()
if use_vllm and VLLM_AVAILABLE:
# Use vllm for inference
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=2048,
)
outputs = model.generate([message], sampling_params)
output = outputs[0].outputs[0].text
else:
# Use standard transformers
messages = [{'role': 'user', 'content': message}]
# Handle device mapping differently based on model configuration
if hasattr(model, 'device_map') and model.device_map is not None:
# Model already has device mapping, don't need to specify device for input_ids
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
else:
# Get the device the model is on
device = next(model.parameters()).device
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(device)
output = tokenizer.decode(model.generate(input_ids=input_ids, max_length=2048)[0], skip_special_tokens=True)
processing_time = time.time() - start_time
st.info(f"Processing time: {processing_time:.2f} seconds")
return output
def batch_process_markdown(model, tokenizer, markdown_text, entity_types, predicates, use_vllm=True, chunk_size=1000, overlap=100):
"""Process large markdown text in batches"""
# Convert markdown to plain text
html = markdown.markdown(markdown_text)
from bs4 import BeautifulSoup
text = BeautifulSoup(html, features="html.parser").get_text()
# Split text into chunks with overlap
chunks = []
for i in range(0, len(text), chunk_size - overlap):
chunk = text[i:i + chunk_size]
chunks.append(chunk)
# If there are too many chunks, inform the user
if len(chunks) > 20:
st.info(f"πŸ“Š Your text will be processed in {len(chunks)} chunks.")
# Process each chunk with progress bar
all_results = []
progress_bar = st.progress(0)
status_text = st.empty()
time_estimate = st.empty()
# Process first chunk to estimate time
start_time = time.time()
for i, chunk in enumerate(chunks):
# Update progress
progress = (i + 1) / len(chunks)
progress_bar.progress(progress)
status_text.text(f"Processing chunk {i+1}/{len(chunks)} ({int(progress*100)}%)")
# Process chunk with timeout protection
try:
with st.spinner(f"Processing chunk {i+1}/{len(chunks)}..."):
chunk_start_time = time.time()
result = triplextract(model, tokenizer, chunk, entity_types, predicates, use_vllm)
chunk_time = time.time() - chunk_start_time
# After first chunk, estimate total time
if i == 0:
estimated_total_time = chunk_time * len(chunks)
time_estimate.info(f"⏱️ Estimated total processing time: {estimated_total_time:.1f} seconds ({estimated_total_time/60:.1f} minutes)")
all_results.append(result)
# Show time taken for this chunk
st.success(f"βœ… Chunk {i+1}/{len(chunks)} processed in {chunk_time:.1f} seconds")
except Exception as e:
st.error(f"Error processing chunk {i+1}: {e}")
all_results.append(f"Error processing this chunk: {e}")
# Show total time taken
total_time = time.time() - start_time
st.info(f"Total processing time: {total_time:.1f} seconds ({total_time/60:.1f} minutes)")
# Clear progress indicators
progress_bar.empty()
status_text.empty()
time_estimate.empty()
# Combine results
combined_result = "\n\n".join(all_results)
return combined_result
def parse_triplets(output):
"""Parse the model output to extract triplets"""
try:
# Find the JSON part in the output
start_idx = output.find('{')
end_idx = output.rfind('}') + 1
if start_idx != -1 and end_idx != -1:
json_str = output[start_idx:end_idx]
data = json.loads(json_str)
return data
else:
# If no JSON found, try to parse the text format
triplets = []
lines = output.split('\n')
for line in lines:
if '->' in line and '<-' in line:
parts = line.split('->')
if len(parts) >= 2:
subject = parts[0].strip()
rest = parts[1].split('<-')
if len(rest) >= 2:
predicate = rest[0].strip()
object_ = rest[1].strip()
triplets.append({
"subject": subject,
"predicate": predicate,
"object": object_
})
if triplets:
return {"triplets": triplets}
# If still no triplets found, return empty result
return {"triplets": []}
except Exception as e:
st.error(f"Error parsing triplets: {e}")
return {"triplets": []}
def visualize_knowledge_graph(triplets):
"""Create a network visualization of the knowledge graph"""
G = nx.DiGraph()
# Add nodes and edges
for triplet in triplets:
subject = triplet.get("subject", "")
predicate = triplet.get("predicate", "")
object_ = triplet.get("object", "")
if subject and object_:
G.add_node(subject)
G.add_node(object_)
G.add_edge(subject, object_, title=predicate, label=predicate)
# Create pyvis network
net = Network(notebook=True, height="600px", width="100%", directed=True)
# Add nodes with different colors based on type if available
for node in G.nodes():
net.add_node(node, label=node, title=node)
# Add edges
for edge in G.edges(data=True):
net.add_edge(edge[0], edge[1], title=edge[2].get('title', ''), label=edge[2].get('label', ''))
# Generate HTML file
with tempfile.NamedTemporaryFile(delete=False, suffix='.html') as tmp:
net.save_graph(tmp.name)
return tmp.name
def main():
st.title("πŸ” SigmaTriple - Knowledge Graph Extractor")
st.markdown("""
Extract knowledge graphs from markdown text using the SciPhi/Triplex model.
""")
# Load model (spinner is inside the load_model function)
model, tokenizer, use_vllm = load_model()
# Add a note about performance
if torch.cuda.is_available():
if use_vllm:
st.success("""
πŸš€ Running on GPU with vllm for optimal performance!
""")
else:
st.success("""
πŸš€ Running on GPU for improved performance!
""")
else:
st.warning("""
⚠️ You are running on CPU which can be very slow for the SciPhi/Triplex model.
Processing may take 10+ minutes for even small texts.
""")
# Sidebar for configuration
st.sidebar.title("Configuration")
# Entity types and predicates input
st.sidebar.subheader("Entity Types")
entity_types_default = ["PERSON", "ORGANIZATION", "LOCATION", "DATE", "EVENT", "PRODUCT", "TECHNOLOGY"]
entity_types_input = st.sidebar.text_area("Enter entity types (one per line)",
"\n".join(entity_types_default),
height=150)
entity_types = [et.strip() for et in entity_types_input.split("\n") if et.strip()]
st.sidebar.subheader("Predicates")
predicates_default = ["WORKS_AT", "LOCATED_IN", "FOUNDED", "DEVELOPED", "USES", "RELATED_TO", "PART_OF", "CREATED", "MEMBER_OF"]
predicates_input = st.sidebar.text_area("Enter predicates (one per line)",
"\n".join(predicates_default),
height=150)
predicates = [p.strip() for p in predicates_input.split("\n") if p.strip()]
# Add option to adjust chunk size
st.sidebar.subheader("Performance Settings")
chunk_size = st.sidebar.slider("Chunk Size", 500, 2000, 1000,
help="Larger chunks capture more context but may take longer to process")
# Input method selection
input_method = st.radio("Select input method:", ["Text Input", "File Upload"])
if input_method == "Text Input":
markdown_text = st.text_area("Enter markdown text:", height=300)
process_button = st.button("Extract Knowledge Graph")
if process_button and markdown_text:
with st.spinner("Processing text..."):
result = batch_process_markdown(model, tokenizer, markdown_text, entity_types, predicates, use_vllm, chunk_size=chunk_size)
# Display raw output in an expandable section
with st.expander("Raw Model Output"):
st.text(result)
# Parse and visualize triplets
parsed_data = parse_triplets(result)
triplets = parsed_data.get("triplets", [])
if triplets:
st.subheader(f"Extracted {len(triplets)} Knowledge Graph Triplets:")
# Display triplets in a table
triplet_data = []
for t in triplets:
triplet_data.append({
"Subject": t.get("subject", ""),
"Predicate": t.get("predicate", ""),
"Object": t.get("object", "")
})
st.table(triplet_data)
# Visualize the knowledge graph
if len(triplets) > 0:
html_file = visualize_knowledge_graph(triplets)
st.subheader("Knowledge Graph Visualization:")
st.components.v1.html(open(html_file, 'r').read(), height=600)
os.unlink(html_file) # Clean up the temporary file
else:
st.warning("No triplets were extracted from the text.")
else: # File Upload
uploaded_file = st.file_uploader("Upload a markdown file", type=["md", "markdown", "txt"])
if uploaded_file is not None:
markdown_text = uploaded_file.read().decode("utf-8")
st.subheader("File Preview:")
with st.expander("Show file content"):
st.markdown(markdown_text)
process_button = st.button("Extract Knowledge Graph")
if process_button:
with st.spinner("Processing file..."):
result = batch_process_markdown(model, tokenizer, markdown_text, entity_types, predicates, use_vllm, chunk_size=chunk_size)
# Display raw output in an expandable section
with st.expander("Raw Model Output"):
st.text(result)
# Parse and visualize triplets
parsed_data = parse_triplets(result)
triplets = parsed_data.get("triplets", [])
if triplets:
st.subheader(f"Extracted {len(triplets)} Knowledge Graph Triplets:")
# Display triplets in a table
triplet_data = []
for t in triplets:
triplet_data.append({
"Subject": t.get("subject", ""),
"Predicate": t.get("predicate", ""),
"Object": t.get("object", "")
})
st.table(triplet_data)
# Visualize the knowledge graph
if len(triplets) > 0:
html_file = visualize_knowledge_graph(triplets)
st.subheader("Knowledge Graph Visualization:")
st.components.v1.html(open(html_file, 'r').read(), height=600)
os.unlink(html_file) # Clean up the temporary file
else:
st.warning("No triplets were extracted from the file.")
# Add information about the model
st.sidebar.markdown("---")
st.sidebar.subheader("About")
st.sidebar.info("""
This app uses the SciPhi/Triplex model to extract knowledge graphs from text.
The model performs Named Entity Recognition (NER) and extracts relationships between entities.
Using vllm: {}
""".format("Yes" if use_vllm else "No (using standard transformers)"))
if __name__ == "__main__":
main()