ifire commited on
Commit
4d040cc
·
1 Parent(s): 6f75f83

Add example space.

Browse files
Files changed (2) hide show
  1. app.py +57 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import AutoModelForCausalLM, AutoProcessor
5
+ from starvector.data.util import process_and_rasterize_svg
6
+
7
+ # Load model and processor
8
+ model = AutoModelForCausalLM.from_pretrained(
9
+ "starvector/starvector-8b-im2svg",
10
+ trust_remote_code=True,
11
+ torch_dtype=torch.float16
12
+ ).cuda()
13
+ processor = AutoProcessor.from_pretrained("starvector/starvector-8b-im2svg")
14
+
15
+ def generate_svg(input_data, input_type):
16
+ if input_type == "image":
17
+ # Process image input
18
+ image = processor(input_data, return_tensors="pt")['pixel_values'].cuda()
19
+ raw_svg = model.generate_im2svg({"image": image}, max_length=4000)[0]
20
+ else:
21
+ # Process text input
22
+ raw_svg = model.generate_text2svg(input_data, max_length=4000)[0]
23
+
24
+ svg_code, raster_image = process_and_rasterize_svg(raw_svg)
25
+ return svg_code, raster_image
26
+
27
+ with gr.Blocks() as demo:
28
+ gr.Markdown("# 💫 StarVector SVG Generator")
29
+
30
+ with gr.Tab("Image to SVG"):
31
+ gr.Markdown("Upload an image to convert to SVG")
32
+ with gr.Row():
33
+ image_input = gr.Image(type="pil", label="Input Image")
34
+ image_output = gr.Image(label="SVG Preview")
35
+ svg_code = gr.Code(label="Generated SVG Code")
36
+ image_button = gr.Button("Convert to SVG")
37
+
38
+ with gr.Tab("Text to SVG"):
39
+ gr.Markdown("Enter text to generate SVG")
40
+ with gr.Row():
41
+ text_input = gr.Textbox(label="Text Prompt")
42
+ text_output = gr.Image(label="SVG Preview")
43
+ text_svg_code = gr.Code(label="Generated SVG Code")
44
+ text_button = gr.Button("Generate SVG")
45
+
46
+ image_button.click(
47
+ fn=lambda x: generate_svg(x, "image"),
48
+ inputs=image_input,
49
+ outputs=[svg_code, image_output]
50
+ )
51
+ text_button.click(
52
+ fn=lambda x: generate_svg(x, "text"),
53
+ inputs=text_input,
54
+ outputs=[text_svg_code, text_output]
55
+ )
56
+
57
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ torch
4
+ accelerate
5
+ starvector