Spaces:
Running
Running
Upload 27 files
Browse files- Dockerfile +20 -0
- README.md +14 -12
- api.py +444 -0
- background_tasks.py +197 -0
- codegen_api.py +23 -0
- configs.py +206 -0
- constants.py +449 -0
- extensions.py +252 -0
- image_to_3d_api.py +32 -0
- imagegen_api.py +33 -0
- main.py +118 -0
- model_loader.py +674 -0
- models.py +96 -0
- musicgen_api.py +35 -0
- requirements.txt +40 -0
- sadtalker_api.py +202 -0
- sadtalker_utils.py +866 -0
- sentiment_api.py +27 -0
- stt_api.py +36 -0
- summarization_api.py +29 -0
- text_generation.py +152 -0
- text_to_video_api.py +37 -0
- tokenxxx.py +161 -0
- translation_api.py +27 -0
- tts_api.py +23 -0
- utils.py +190 -0
- xxx.py +142 -0
Dockerfile
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.11-slim-buster
|
2 |
+
|
3 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
4 |
+
ENV NUMBA_DISABLE_CACHE=1
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
RUN apt-get update && apt-get upgrade -y
|
8 |
+
RUN apt-get install libgl1-mesa-glx ffmpeg -y
|
9 |
+
|
10 |
+
RUN mkdir -p /.cache/huggingface/hub && chmod -R 777 /.cache/huggingface/hub
|
11 |
+
RUN mkdir -p /.config/matplotlib && chmod -R 777 /.config/matplotlib
|
12 |
+
RUN mkdir -p /nltk_data && chmod -R 777 /nltk_data
|
13 |
+
|
14 |
+
RUN pip install --no-cache-dir accelerate retry asyncio basicsr beautifulsoup4 bs4 opencv-python deep-translator duckduckgo-search fastapi flask flask-cors facexlib ffmpeg-python gfpgan imageio imageio-ffmpeg langdetect librosa nltk numpy Pillow pydub pytorch-lightning PyYAML retry safetensors scikit-learn scipy scikit-image soundfile torch torchaudio torchvision tqdm wget yacs numba
|
15 |
+
|
16 |
+
COPY . .
|
17 |
+
|
18 |
+
EXPOSE 7860
|
19 |
+
|
20 |
+
CMD ["python", "main.py"]
|
README.md
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
-
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk: docker
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Ggggggc
|
3 |
+
emoji: 📈
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: docker
|
7 |
+
sdk_version: 5.18.0
|
8 |
+
app_file: main.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
short_description: Apache
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
api.py
ADDED
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from main import *
|
2 |
+
from tts_api import *
|
3 |
+
from stt_api import *
|
4 |
+
from sentiment_api import *
|
5 |
+
from imagegen_api import *
|
6 |
+
from musicgen_api import *
|
7 |
+
from translation_api import *
|
8 |
+
from codegen_api import *
|
9 |
+
from text_to_video_api import *
|
10 |
+
from summarization_api import *
|
11 |
+
from image_to_3d_api import *
|
12 |
+
from flask import Flask, request, jsonify, Response, send_file, stream_with_context
|
13 |
+
from flask_cors import CORS
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import torchaudio
|
18 |
+
import numpy as np
|
19 |
+
from PIL import Image
|
20 |
+
import io
|
21 |
+
import tempfile
|
22 |
+
import queue
|
23 |
+
import json
|
24 |
+
import base64
|
25 |
+
|
26 |
+
app = Flask(__name__)
|
27 |
+
CORS(app)
|
28 |
+
html_code = """<!DOCTYPE html>
|
29 |
+
<html lang="en">
|
30 |
+
<head>
|
31 |
+
<meta charset="UTF-8">
|
32 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
33 |
+
<title>AI Text Generation</title>
|
34 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/animate.css/4.1.1/animate.min.css"/>
|
35 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" integrity="sha512-9usAa10IRO0HhonpyAIVpjrylPvoDwiPUiKdWk5t3PyolY1cOd4DSE0Ga+ri4AuTroPR5aQvXU9xC6qOPnzFeg==" crossorigin="anonymous" referrerpolicy="no-referrer" />
|
36 |
+
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
37 |
+
<style>
|
38 |
+
body {
|
39 |
+
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
40 |
+
background: #f0f0f0;
|
41 |
+
color: #333;
|
42 |
+
margin: 0;
|
43 |
+
padding: 0;
|
44 |
+
display: flex;
|
45 |
+
flex-direction: column;
|
46 |
+
align-items: center;
|
47 |
+
min-height: 100vh;
|
48 |
+
}
|
49 |
+
.container {
|
50 |
+
width: 95%;
|
51 |
+
max-width: 900px;
|
52 |
+
padding: 20px;
|
53 |
+
background-color: #fff;
|
54 |
+
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
|
55 |
+
border-radius: 8px;
|
56 |
+
margin-top: 20px;
|
57 |
+
margin-bottom: 20px;
|
58 |
+
display: flex;
|
59 |
+
flex-direction: column;
|
60 |
+
}
|
61 |
+
.header {
|
62 |
+
text-align: center;
|
63 |
+
margin-bottom: 20px;
|
64 |
+
}
|
65 |
+
.header h1 {
|
66 |
+
font-size: 2em;
|
67 |
+
color: #333;
|
68 |
+
}
|
69 |
+
.form-group {
|
70 |
+
margin-bottom: 15px;
|
71 |
+
}
|
72 |
+
.form-group textarea {
|
73 |
+
width: 100%;
|
74 |
+
padding: 10px;
|
75 |
+
border: 1px solid #ccc;
|
76 |
+
border-radius: 5px;
|
77 |
+
font-size: 16px;
|
78 |
+
box-sizing: border-box;
|
79 |
+
resize: vertical;
|
80 |
+
}
|
81 |
+
button {
|
82 |
+
padding: 10px 15px;
|
83 |
+
border: none;
|
84 |
+
border-radius: 5px;
|
85 |
+
background-color: #007bff;
|
86 |
+
color: white;
|
87 |
+
font-size: 18px;
|
88 |
+
cursor: pointer;
|
89 |
+
transition: background-color 0.3s ease;
|
90 |
+
}
|
91 |
+
button:hover {
|
92 |
+
background-color: #0056b3;
|
93 |
+
}
|
94 |
+
#output {
|
95 |
+
margin-top: 20px;
|
96 |
+
padding: 15px;
|
97 |
+
border: 1px solid #ddd;
|
98 |
+
border-radius: 5px;
|
99 |
+
background-color: #f9f9f9;
|
100 |
+
white-space: pre-wrap;
|
101 |
+
word-break: break-word;
|
102 |
+
overflow-y: auto;
|
103 |
+
max-height: 100vh;
|
104 |
+
}
|
105 |
+
#output strong {
|
106 |
+
font-weight: bold;
|
107 |
+
}
|
108 |
+
.animated-text {
|
109 |
+
position: fixed;
|
110 |
+
top: 20px;
|
111 |
+
left: 20px;
|
112 |
+
font-size: 1.5em;
|
113 |
+
color: rgba(0, 0, 0, 0.1);
|
114 |
+
pointer-events: none;
|
115 |
+
z-index: -1;
|
116 |
+
}
|
117 |
+
@media (max-width: 768px) {
|
118 |
+
.container {
|
119 |
+
width: 98%;
|
120 |
+
margin-top: 10px;
|
121 |
+
margin-bottom: 10px;
|
122 |
+
padding: 15px;
|
123 |
+
}
|
124 |
+
.header h1 {
|
125 |
+
font-size: 1.8em;
|
126 |
+
}
|
127 |
+
.form-group textarea, .form-group input[type="text"] {
|
128 |
+
font-size: 14px;
|
129 |
+
padding: 8px;
|
130 |
+
}
|
131 |
+
button {
|
132 |
+
font-size: 16px;
|
133 |
+
padding: 8px 12px;
|
134 |
+
}
|
135 |
+
#output {
|
136 |
+
font-size: 14px;
|
137 |
+
padding: 10px;
|
138 |
+
margin-top: 15px;
|
139 |
+
}
|
140 |
+
}
|
141 |
+
</style>
|
142 |
+
</head>
|
143 |
+
<body>
|
144 |
+
<div class="animated-text animate__animated animate__fadeIn animate__infinite infinite">AI POWERED</div>
|
145 |
+
<div class="container">
|
146 |
+
<div class="header animate__animated animate__fadeInDown">
|
147 |
+
</div>
|
148 |
+
<div class="form-group animate__animated animate__fadeInLeft">
|
149 |
+
<textarea id="text" rows="5" placeholder="Enter text"></textarea>
|
150 |
+
</div>
|
151 |
+
<button onclick="generateText()" class="animate__animated animate__fadeInUp">Generate Reasoning</button>
|
152 |
+
<div id="output" class="animate__animated">
|
153 |
+
<strong>Response:</strong><br>
|
154 |
+
<span id="generatedText"></span>
|
155 |
+
</div>
|
156 |
+
</div>
|
157 |
+
<script>
|
158 |
+
let eventSource = null;
|
159 |
+
let accumulatedText = "";
|
160 |
+
let lastResponse = "";
|
161 |
+
async function generateText() {
|
162 |
+
const inputText = document.getElementById("text").value;
|
163 |
+
document.getElementById("generatedText").innerText = "";
|
164 |
+
accumulatedText = "";
|
165 |
+
if (eventSource) {
|
166 |
+
eventSource.close();
|
167 |
+
}
|
168 |
+
const temp = 0.7;
|
169 |
+
const top_k_val = 40;
|
170 |
+
const top_p_val = 0.0;
|
171 |
+
const repetition_penalty_val = 1.2;
|
172 |
+
const requestData = {
|
173 |
+
text: inputText,
|
174 |
+
temp: temp,
|
175 |
+
top_k: top_k_val,
|
176 |
+
top_p: top_p_val,
|
177 |
+
reppenalty: repetition_penalty_val
|
178 |
+
};
|
179 |
+
const params = new URLSearchParams(requestData).toString();
|
180 |
+
eventSource = new EventSource('/api/v1/generate_stream?' + params);
|
181 |
+
eventSource.onmessage = function(event) {
|
182 |
+
if (event.data === "<END_STREAM>") {
|
183 |
+
eventSource.close();
|
184 |
+
const currentResponse = accumulatedText.replace("<|endoftext|>", "").replace(/\s+(?=[.,,。])/g, '').trim();
|
185 |
+
if (currentResponse === lastResponse.trim()) {
|
186 |
+
accumulatedText = "**Response is repetitive. Please try again or rephrase your query.**";
|
187 |
+
} else {
|
188 |
+
lastResponse = currentResponse;
|
189 |
+
}
|
190 |
+
document.getElementById("generatedText").innerHTML = marked.parse(accumulatedText);
|
191 |
+
return;
|
192 |
+
}
|
193 |
+
accumulatedText += event.data;
|
194 |
+
let partialText = accumulatedText.replace("<|endoftext|>", "").replace(/\s+(?=[.,,。])/g, '').trim();
|
195 |
+
document.getElementById("generatedText").innerHTML = marked.parse(partialText);
|
196 |
+
};
|
197 |
+
eventSource.onerror = function(error) {
|
198 |
+
console.error("SSE error", error);
|
199 |
+
eventSource.close();
|
200 |
+
};
|
201 |
+
const outputDiv = document.getElementById("output");
|
202 |
+
outputDiv.classList.add("show");
|
203 |
+
}
|
204 |
+
function base64ToBlob(base64Data, contentType) {
|
205 |
+
contentType = contentType || '';
|
206 |
+
const sliceSize = 1024;
|
207 |
+
const byteCharacters = atob(base64Data);
|
208 |
+
const bytesLength = byteCharacters.length;
|
209 |
+
const slicesCount = Math.ceil(bytesLength / sliceSize);
|
210 |
+
const byteArrays = new Array(slicesCount);
|
211 |
+
for (let sliceIndex = 0; sliceIndex < slicesCount; ++sliceIndex) {
|
212 |
+
const begin = sliceIndex * sliceSize;
|
213 |
+
const end = Math.min(begin + sliceSize, bytesLength);
|
214 |
+
const bytes = new Array(end - begin);
|
215 |
+
for (let offset = begin, i = 0; offset < end; ++i, ++offset) {
|
216 |
+
bytes[i] = byteCharacters[offset].charCodeAt(0);
|
217 |
+
}
|
218 |
+
byteArrays[sliceIndex] = new Uint8Array(bytes);
|
219 |
+
}
|
220 |
+
return new Blob(byteArrays, { type: contentType });
|
221 |
+
}
|
222 |
+
</script>
|
223 |
+
</body>
|
224 |
+
</html>
|
225 |
+
"""
|
226 |
+
feedback_queue = queue.Queue()
|
227 |
+
|
228 |
+
class TextGenerationModel(nn.Module):
|
229 |
+
def __init__(self, vocab_size, embed_dim, hidden_dim):
|
230 |
+
super(TextGenerationModel, self).__init__()
|
231 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
232 |
+
self.rnn = nn.GRU(embed_dim, hidden_dim, batch_first=True)
|
233 |
+
self.fc = nn.Linear(hidden_dim, vocab_size)
|
234 |
+
def forward(self, x, hidden=None):
|
235 |
+
x = self.embedding(x)
|
236 |
+
out, hidden = self.rnn(x, hidden)
|
237 |
+
out = self.fc(out)
|
238 |
+
return out, hidden
|
239 |
+
|
240 |
+
vocab = ["hola", "mundo", "este", "es", "un", "ejemplo", "de", "texto", "generado", "con", "torch"]
|
241 |
+
vocab_size = len(vocab)
|
242 |
+
embed_dim = 16
|
243 |
+
hidden_dim = 32
|
244 |
+
text_model = TextGenerationModel(vocab_size, embed_dim, hidden_dim)
|
245 |
+
text_model.eval()
|
246 |
+
|
247 |
+
def tokenize(text):
|
248 |
+
tokens = text.lower().split()
|
249 |
+
indices = [vocab.index(token) if token in vocab else 0 for token in tokens]
|
250 |
+
return torch.tensor(indices, dtype=torch.long).unsqueeze(0)
|
251 |
+
|
252 |
+
def perform_reasoning_stream(text, temperature, top_k, top_p, repetition_penalty):
|
253 |
+
input_tensor = tokenize(text)
|
254 |
+
hidden = None
|
255 |
+
for _ in range(20):
|
256 |
+
outputs, hidden = text_model(input_tensor, hidden)
|
257 |
+
logits = outputs[:, -1, :] / temperature
|
258 |
+
probs = F.softmax(logits, dim=-1)
|
259 |
+
topk_probs, topk_indices = torch.topk(probs, min(top_k, logits.shape[-1]))
|
260 |
+
chosen_index = topk_indices[0, torch.multinomial(topk_probs[0], 1).item()].item()
|
261 |
+
token_str = vocab[chosen_index]
|
262 |
+
yield token_str
|
263 |
+
input_tensor = torch.cat([input_tensor, torch.tensor([[chosen_index]], dtype=torch.long)], dim=1)
|
264 |
+
yield "<END_STREAM>"
|
265 |
+
|
266 |
+
class SentimentModel(nn.Module):
|
267 |
+
def __init__(self, input_dim, hidden_dim, output_dim):
|
268 |
+
super(SentimentModel, self).__init__()
|
269 |
+
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
270 |
+
self.fc2 = nn.Linear(hidden_dim, output_dim)
|
271 |
+
def forward(self, x):
|
272 |
+
x = F.relu(self.fc1(x))
|
273 |
+
x = self.fc2(x)
|
274 |
+
return x
|
275 |
+
|
276 |
+
sentiment_model = SentimentModel(10, 16, 2)
|
277 |
+
sentiment_model.eval()
|
278 |
+
|
279 |
+
@app.route("/")
|
280 |
+
def index():
|
281 |
+
return html_code
|
282 |
+
|
283 |
+
@app.route("/api/v1/generate_stream", methods=["GET"])
|
284 |
+
def generate_stream():
|
285 |
+
text = request.args.get("text", "")
|
286 |
+
temp = float(request.args.get("temp", 0.7))
|
287 |
+
top_k = int(request.args.get("top_k", 40))
|
288 |
+
top_p = float(request.args.get("top_p", 0.0))
|
289 |
+
reppenalty = float(request.args.get("reppenalty", 1.2))
|
290 |
+
@stream_with_context
|
291 |
+
def event_stream():
|
292 |
+
try:
|
293 |
+
for token in perform_reasoning_stream(text, temperature=temp, top_k=top_k, top_p=top_p, repetition_penalty=reppenalty):
|
294 |
+
if token == "<END_STREAM>":
|
295 |
+
yield "data: <END_STREAM>\n\n"
|
296 |
+
break
|
297 |
+
yield "data: " + token + "\n\n"
|
298 |
+
except Exception as e:
|
299 |
+
yield "data: <ERROR>\n\n"
|
300 |
+
return Response(event_stream(), mimetype="text/event-stream")
|
301 |
+
|
302 |
+
@app.route("/api/v1/generate", methods=["POST"])
|
303 |
+
def generate():
|
304 |
+
data = request.get_json()
|
305 |
+
text = data.get("text", "")
|
306 |
+
temp = float(data.get("temp", 0.7))
|
307 |
+
top_k = int(data.get("top_k", 40))
|
308 |
+
top_p = float(data.get("top_p", 0.0))
|
309 |
+
reppenalty = float(data.get("reppenalty", 1.2))
|
310 |
+
result = ""
|
311 |
+
try:
|
312 |
+
for token in perform_reasoning_stream(text, temperature=temp, top_k=top_k, top_p=top_p, repetition_penalty=reppenalty):
|
313 |
+
if token == "<END_STREAM>":
|
314 |
+
break
|
315 |
+
result += token + " "
|
316 |
+
except Exception as e:
|
317 |
+
return jsonify({"error": str(e)}), 500
|
318 |
+
return jsonify({"solidity": result.strip()})
|
319 |
+
|
320 |
+
@app.route("/api/v1/feedback", methods=["POST"])
|
321 |
+
def feedback():
|
322 |
+
data = request.get_json()
|
323 |
+
feedback_text = data.get("feedback_text")
|
324 |
+
correct_category = data.get("correct_category")
|
325 |
+
if feedback_text and correct_category:
|
326 |
+
feedback_queue.put((feedback_text, correct_category))
|
327 |
+
return jsonify({"status": "feedback received"})
|
328 |
+
return jsonify({"status": "feedback failed"}), 400
|
329 |
+
|
330 |
+
@app.route("/api/v1/tts", methods=["POST"])
|
331 |
+
def tts_api():
|
332 |
+
data = request.get_json()
|
333 |
+
text = data.get("text", "")
|
334 |
+
sr = 22050
|
335 |
+
duration = 3.0
|
336 |
+
t = torch.linspace(0, duration, int(sr * duration))
|
337 |
+
frequency = 440.0
|
338 |
+
audio = 0.5 * torch.sin(2 * torch.pi * frequency * t)
|
339 |
+
audio = audio.unsqueeze(0)
|
340 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
341 |
+
torchaudio.save(tmp.name, audio, sr)
|
342 |
+
tmp_path = tmp.name
|
343 |
+
return send_file(tmp_path, mimetype="audio/wav", as_attachment=True, download_name="output.wav")
|
344 |
+
|
345 |
+
@app.route("/api/v1/stt", methods=["POST"])
|
346 |
+
def stt_api():
|
347 |
+
data = request.get_json()
|
348 |
+
audio_b64 = data.get("audio", "")
|
349 |
+
if audio_b64:
|
350 |
+
audio_bytes = base64.b64decode(audio_b64)
|
351 |
+
buf = io.BytesIO(audio_bytes)
|
352 |
+
waveform, sr = torchaudio.load(buf)
|
353 |
+
mean_amp = waveform.abs().mean().item()
|
354 |
+
recognized_text = f"Audio processed with mean amplitude {mean_amp:.3f}"
|
355 |
+
return jsonify({"text": recognized_text})
|
356 |
+
return jsonify({"text": ""})
|
357 |
+
|
358 |
+
@app.route("/api/v1/sentiment", methods=["POST"])
|
359 |
+
def sentiment_api():
|
360 |
+
data = request.get_json()
|
361 |
+
text = data.get("text", "")
|
362 |
+
if not text:
|
363 |
+
return jsonify({"sentiment": "neutral"})
|
364 |
+
ascii_vals = [ord(c) for c in text[:10]]
|
365 |
+
while len(ascii_vals) < 10:
|
366 |
+
ascii_vals.append(0)
|
367 |
+
features = torch.tensor(ascii_vals, dtype=torch.float32).unsqueeze(0)
|
368 |
+
output = sentiment_model(features)
|
369 |
+
sentiment_idx = torch.argmax(output, dim=1).item()
|
370 |
+
sentiment = "positivo" if sentiment_idx == 1 else "negativo"
|
371 |
+
return jsonify({"sentiment": sentiment})
|
372 |
+
|
373 |
+
@app.route("/api/v1/imagegen", methods=["POST"])
|
374 |
+
def imagegen_api():
|
375 |
+
data = request.get_json()
|
376 |
+
prompt = data.get("prompt", "")
|
377 |
+
image_tensor = torch.rand(3, 256, 256)
|
378 |
+
np_image = image_tensor.mul(255).clamp(0, 255).byte().numpy().transpose(1, 2, 0)
|
379 |
+
img = Image.fromarray(np_image)
|
380 |
+
buf = io.BytesIO()
|
381 |
+
img.save(buf, format="PNG")
|
382 |
+
buf.seek(0)
|
383 |
+
return send_file(buf, mimetype="image/png", as_attachment=True, download_name="image.png")
|
384 |
+
|
385 |
+
@app.route("/api/v1/musicgen", methods=["POST"])
|
386 |
+
def musicgen_api():
|
387 |
+
data = request.get_json()
|
388 |
+
prompt = data.get("prompt", "")
|
389 |
+
sr = 22050
|
390 |
+
duration = 5.0
|
391 |
+
t = torch.linspace(0, duration, int(sr * duration))
|
392 |
+
frequency = 440.0
|
393 |
+
audio = 0.5 * torch.sin(2 * torch.pi * frequency * t)
|
394 |
+
audio = audio.unsqueeze(0)
|
395 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
396 |
+
torchaudio.save(tmp.name, audio, sr)
|
397 |
+
tmp_path = tmp.name
|
398 |
+
return send_file(tmp_path, mimetype="audio/wav", as_attachment=True, download_name="music.wav")
|
399 |
+
|
400 |
+
@app.route("/api/v1/translation", methods=["POST"])
|
401 |
+
def translation_api():
|
402 |
+
data = request.get_json()
|
403 |
+
text = data.get("text", "")
|
404 |
+
translated = " ".join(text.split()[::-1])
|
405 |
+
return jsonify({"translated_text": translated})
|
406 |
+
|
407 |
+
@app.route("/api/v1/codegen", methods=["POST"])
|
408 |
+
def codegen_api():
|
409 |
+
data = request.get_json()
|
410 |
+
prompt = data.get("prompt", "")
|
411 |
+
generated_code = f"# Generated code based on prompt: {prompt}\nprint('Hello from Torch-generated code')"
|
412 |
+
return jsonify({"code": generated_code})
|
413 |
+
|
414 |
+
@app.route("/api/v1/text_to_video", methods=["POST"])
|
415 |
+
def text_to_video_api():
|
416 |
+
data = request.get_json()
|
417 |
+
prompt = data.get("prompt", "")
|
418 |
+
video_tensor = torch.randint(0, 255, (10, 3, 64, 64), dtype=torch.uint8)
|
419 |
+
video_bytes = video_tensor.numpy().tobytes()
|
420 |
+
buf = io.BytesIO(video_bytes)
|
421 |
+
return send_file(buf, mimetype="video/mp4", as_attachment=True, download_name="video.mp4")
|
422 |
+
|
423 |
+
@app.route("/api/v1/summarization", methods=["POST"])
|
424 |
+
def summarization_api():
|
425 |
+
data = request.get_json()
|
426 |
+
text = data.get("text", "")
|
427 |
+
sentences = text.split('.')
|
428 |
+
summary = sentences[0] if sentences[0] else text
|
429 |
+
return jsonify({"summary": summary})
|
430 |
+
|
431 |
+
@app.route("/api/v1/image_to_3d", methods=["POST"])
|
432 |
+
def image_to_3d_api():
|
433 |
+
data = request.get_json()
|
434 |
+
prompt = data.get("prompt", "")
|
435 |
+
obj_data = "o Cube\nv 0 0 0\nv 1 0 0\nv 1 1 0\nv 0 1 0\nf 1 2 3 4"
|
436 |
+
buf = io.BytesIO(obj_data.encode("utf-8"))
|
437 |
+
return send_file(buf, mimetype="text/plain", as_attachment=True, download_name="model.obj")
|
438 |
+
|
439 |
+
@app.route("/api/v1/sadtalker", methods=["GET"])
|
440 |
+
def sadtalker():
|
441 |
+
return jsonify({"message": "Respuesta de sadtalker"})
|
442 |
+
|
443 |
+
if __name__ == "__main__":
|
444 |
+
app.run(host="0.0.0.0", port=7860)
|
background_tasks.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import threading
|
3 |
+
import queue
|
4 |
+
import uuid
|
5 |
+
import unicodedata
|
6 |
+
import re
|
7 |
+
from deep_translator import GoogleTranslator
|
8 |
+
from duckduckgo_search import DDGS
|
9 |
+
import nltk
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import math
|
13 |
+
|
14 |
+
nltk.download('punkt')
|
15 |
+
|
16 |
+
categories = ['News', 'Sports', 'Entertainment']
|
17 |
+
TEXT_GENERATION_RATE = 10
|
18 |
+
text_queue = queue.Queue()
|
19 |
+
reasoning_queue = queue.Queue()
|
20 |
+
feedback_queue = queue.Queue()
|
21 |
+
vocabulary = ["<PAD>", "<EOS>"]
|
22 |
+
word_to_index = {word: idx for idx, word in enumerate(vocabulary)}
|
23 |
+
seen_responses = set()
|
24 |
+
news_clf = None
|
25 |
+
|
26 |
+
class SimpleClassifier(nn.Module):
|
27 |
+
def __init__(self, vocab_size, num_classes, embedding_dim=128):
|
28 |
+
super(SimpleClassifier, self).__init__()
|
29 |
+
self.embedding = nn.Embedding(vocab_size, embedding_dim)
|
30 |
+
self.fc = nn.Linear(embedding_dim, num_classes)
|
31 |
+
def forward(self, x):
|
32 |
+
embedded = self.embedding(x)
|
33 |
+
pooled = embedded.mean(dim=1)
|
34 |
+
out = self.fc(pooled)
|
35 |
+
return out
|
36 |
+
|
37 |
+
def tokenize_text(text):
|
38 |
+
return nltk.word_tokenize(text)
|
39 |
+
|
40 |
+
def update_vocabulary(tokens):
|
41 |
+
global vocabulary, word_to_index
|
42 |
+
for token in tokens:
|
43 |
+
if token not in word_to_index:
|
44 |
+
word_to_index[token] = len(vocabulary)
|
45 |
+
vocabulary.append(token)
|
46 |
+
|
47 |
+
def text_to_vector(text):
|
48 |
+
tokens = tokenize_text(text)
|
49 |
+
update_vocabulary(tokens)
|
50 |
+
indices = [word_to_index.get(token, 0) for token in tokens]
|
51 |
+
return torch.tensor(indices, dtype=torch.long)
|
52 |
+
|
53 |
+
def generate_and_queue_text(language):
|
54 |
+
global categories, text_queue
|
55 |
+
num_categories = len(categories)
|
56 |
+
num_texts_per_category = TEXT_GENERATION_RATE // (2 * num_categories)
|
57 |
+
while True:
|
58 |
+
for category in categories:
|
59 |
+
for _ in range(num_texts_per_category):
|
60 |
+
uid = uuid.uuid4()
|
61 |
+
base_text = f"Category: {category}. ID:{uid}"
|
62 |
+
try:
|
63 |
+
translator = GoogleTranslator(source='auto', target=language)
|
64 |
+
text = translator.translate(base_text)
|
65 |
+
except Exception:
|
66 |
+
text = base_text
|
67 |
+
processed_text = ''.join(c for c in unicodedata.normalize('NFKC', text) if c.isprintable())
|
68 |
+
text_queue.put((processed_text, category))
|
69 |
+
time.sleep(0)
|
70 |
+
|
71 |
+
def background_training():
|
72 |
+
global categories, news_clf, feedback_queue, vocabulary
|
73 |
+
if categories is None:
|
74 |
+
categories = ['DefaultCategory']
|
75 |
+
num_classes = len(categories)
|
76 |
+
learning_rate = 0.01
|
77 |
+
epochs = 1
|
78 |
+
if news_clf is None:
|
79 |
+
news_clf = SimpleClassifier(len(vocabulary), num_classes)
|
80 |
+
optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate)
|
81 |
+
criterion = nn.CrossEntropyLoss()
|
82 |
+
while True:
|
83 |
+
try:
|
84 |
+
feedback_item = feedback_queue.get(timeout=10)
|
85 |
+
if feedback_item:
|
86 |
+
input_text, generated_text = feedback_item
|
87 |
+
input_vector = text_to_vector(input_text)
|
88 |
+
if len(vocabulary) == 0:
|
89 |
+
vocabulary.extend(["<PAD>", "<EOS>"])
|
90 |
+
news_clf = SimpleClassifier(len(vocabulary), num_classes)
|
91 |
+
optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate)
|
92 |
+
if input_vector.size(0) != len(vocabulary) and len(vocabulary) > 0:
|
93 |
+
news_clf = SimpleClassifier(len(vocabulary), num_classes)
|
94 |
+
optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate)
|
95 |
+
input_vector = text_to_vector(input_text)
|
96 |
+
tokens = tokenize_text(input_text)
|
97 |
+
update_vocabulary(tokens)
|
98 |
+
tokens_indices = [word_to_index.get(word, 0) for word in tokens]
|
99 |
+
input_tensor = torch.tensor([tokens_indices], dtype=torch.long)
|
100 |
+
target_index = categories.index(generated_text) if generated_text in categories else 0
|
101 |
+
target_category_index = torch.tensor([target_index], dtype=torch.long)
|
102 |
+
if num_classes <= 1:
|
103 |
+
num_classes = 2
|
104 |
+
news_clf.fc = nn.Linear(128, num_classes)
|
105 |
+
for _ in range(epochs):
|
106 |
+
optimizer.zero_grad()
|
107 |
+
output = news_clf(input_tensor)
|
108 |
+
loss = criterion(output, target_category_index)
|
109 |
+
loss.backward()
|
110 |
+
optimizer.step()
|
111 |
+
feedback_queue.task_done()
|
112 |
+
except queue.Empty:
|
113 |
+
pass
|
114 |
+
except Exception:
|
115 |
+
time.sleep(5)
|
116 |
+
|
117 |
+
class ReasoningModel(nn.Module):
|
118 |
+
def __init__(self, vocab_size, embed_dim=128, hidden_dim=128):
|
119 |
+
super(ReasoningModel, self).__init__()
|
120 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
121 |
+
self.rnn = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
|
122 |
+
self.fc = nn.Linear(hidden_dim, vocab_size)
|
123 |
+
def forward(self, x, hidden=None):
|
124 |
+
emb = self.embedding(x)
|
125 |
+
output, hidden = self.rnn(emb, hidden)
|
126 |
+
logits = self.fc(output)
|
127 |
+
return logits, hidden
|
128 |
+
def generate(self, input_seq, max_length=50, temperature=1.0):
|
129 |
+
self.eval()
|
130 |
+
tokens = input_seq.copy()
|
131 |
+
hidden = None
|
132 |
+
generated = []
|
133 |
+
for _ in range(max_length):
|
134 |
+
input_tensor = torch.tensor([tokens], dtype=torch.long)
|
135 |
+
logits, hidden = self.forward(input_tensor, hidden)
|
136 |
+
next_token_logits = logits[0, -1, :] / temperature
|
137 |
+
probabilities = torch.softmax(next_token_logits, dim=0)
|
138 |
+
next_token = torch.multinomial(probabilities, 1).item()
|
139 |
+
tokens.append(next_token)
|
140 |
+
generated.append(next_token)
|
141 |
+
if next_token == word_to_index.get("<EOS>"):
|
142 |
+
break
|
143 |
+
return generated
|
144 |
+
|
145 |
+
reasoning_model = ReasoningModel(len(vocabulary))
|
146 |
+
|
147 |
+
def perform_reasoning_stream(text_input, temperature=0.7, top_k=40, top_p=0.0, repetition_penalty=1.2):
|
148 |
+
tokens = tokenize_text(text_input)
|
149 |
+
update_vocabulary(tokens)
|
150 |
+
tokens_indices = [word_to_index.get(token, 0) for token in tokens]
|
151 |
+
generated_indices = reasoning_model.generate(tokens_indices, max_length=50, temperature=temperature)
|
152 |
+
for idx in generated_indices:
|
153 |
+
yield vocabulary[idx] + " "
|
154 |
+
yield "<END_STREAM>"
|
155 |
+
|
156 |
+
def background_reasoning_queue():
|
157 |
+
global reasoning_queue, seen_responses
|
158 |
+
while True:
|
159 |
+
try:
|
160 |
+
item = reasoning_queue.get(timeout=1)
|
161 |
+
if item is None:
|
162 |
+
reasoning_queue.task_done()
|
163 |
+
continue
|
164 |
+
text_input = item.get('text_input')
|
165 |
+
temperature = item.get('temperature', 0.7)
|
166 |
+
top_k = item.get('top_k', 40)
|
167 |
+
top_p = item.get('top_p', 0.0)
|
168 |
+
repetition_penalty = item.get('repetition_penalty', 1.2)
|
169 |
+
resp_queue = item.get('response_queue', queue.Queue())
|
170 |
+
if not text_input:
|
171 |
+
resp_queue.put({"error": "Empty text input received."})
|
172 |
+
reasoning_queue.task_done()
|
173 |
+
continue
|
174 |
+
generated_text_stream = perform_reasoning_stream(text_input, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty)
|
175 |
+
full_response = ""
|
176 |
+
for chunk in generated_text_stream:
|
177 |
+
if chunk == "<END_STREAM>":
|
178 |
+
break
|
179 |
+
full_response += chunk
|
180 |
+
cleaned_response = re.sub(r'\s+(?=[.,,。])', '', full_response.replace("<|endoftext|>", "")).strip()
|
181 |
+
if cleaned_response in seen_responses:
|
182 |
+
final_response = "**Response is repetitive. Please try again or rephrase your query.**"
|
183 |
+
resp_queue.put({"text": final_response})
|
184 |
+
else:
|
185 |
+
seen_responses.add(cleaned_response)
|
186 |
+
final_response = cleaned_response
|
187 |
+
resp_queue.put({"text": final_response})
|
188 |
+
reasoning_queue.task_done()
|
189 |
+
except queue.Empty:
|
190 |
+
pass
|
191 |
+
except Exception as e:
|
192 |
+
try:
|
193 |
+
resp_queue.put({"error": str(e)})
|
194 |
+
except Exception:
|
195 |
+
pass
|
196 |
+
if reasoning_queue and not reasoning_queue.empty():
|
197 |
+
reasoning_queue.task_done()
|
codegen_api.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import jsonify, send_file, request
|
2 |
+
from main import *
|
3 |
+
#from main import import codegen_model, codegen_tokenizer, device
|
4 |
+
|
5 |
+
def generate_code(prompt, output_path="output_code.py"):
|
6 |
+
if codegen_model is None:
|
7 |
+
return "Code generation model not initialized."
|
8 |
+
input_ids = codegen_tokenizer.encode(prompt, return_tensors='pt').to(device)
|
9 |
+
output = codegen_model.generate(input_ids, max_length=512, temperature=0.7, top_p=0.9)
|
10 |
+
code = codegen_tokenizer.decode(output[0], skip_special_tokens=True)
|
11 |
+
with open(output_path, "w") as file:
|
12 |
+
file.write(code)
|
13 |
+
return output_path
|
14 |
+
|
15 |
+
def codegen_api():
|
16 |
+
data = request.get_json()
|
17 |
+
prompt = data.get('prompt')
|
18 |
+
if not prompt:
|
19 |
+
return jsonify({"error": "Prompt is required"}), 400
|
20 |
+
output_file = generate_code(prompt)
|
21 |
+
if output_file == "Code generation model not initialized.":
|
22 |
+
return jsonify({"error": "Code generation failed"}), 500
|
23 |
+
return send_file(output_file, mimetype="text/x-python", as_attachment=True, download_name="output.py")
|
configs.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from constants import *
|
2 |
+
|
3 |
+
class GPT2Config:
|
4 |
+
def __init__(self, vocab_size_or_config_json_file=50257, n_positions=MAX_LENGTH, n_ctx=MAX_LENGTH, n_embd=768, n_layer=12, n_head=12, layer_norm_epsilon=1e-05, initializer_range=0.02):
|
5 |
+
self.vocab_size = vocab_size_or_config_json_file
|
6 |
+
self.n_ctx = n_ctx
|
7 |
+
self.n_positions = n_positions
|
8 |
+
self.n_embd = n_embd
|
9 |
+
self.n_layer = n_layer
|
10 |
+
self.n_head = n_head
|
11 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
12 |
+
self.initializer_range = initializer_range
|
13 |
+
|
14 |
+
@classmethod
|
15 |
+
def from_dict(cls, config_dict):
|
16 |
+
return cls(**config_dict)
|
17 |
+
|
18 |
+
class MBartConfig:
|
19 |
+
def __init__(self, vocab_size, d_model, num_layers, num_heads, pad_token_id, eos_token_id):
|
20 |
+
self.vocab_size = vocab_size
|
21 |
+
self.d_model = d_model
|
22 |
+
self.encoder_layers = num_layers
|
23 |
+
self.decoder_layers = num_layers
|
24 |
+
self.encoder_attention_heads = num_heads
|
25 |
+
self.decoder_attention_heads = num_heads
|
26 |
+
self.encoder_ffn_dim = d_model * 4
|
27 |
+
self.decoder_ffn_dim = d_model * 4
|
28 |
+
self.dropout = 0.1
|
29 |
+
self.attention_dropout = 0.0
|
30 |
+
self.activation_dropout = 0.0
|
31 |
+
self.max_position_embeddings = 1024
|
32 |
+
self.init_std = 0.02
|
33 |
+
self.layer_norm_eps = 1e-5
|
34 |
+
self.pad_token_id = pad_token_id
|
35 |
+
self.eos_token_id = eos_token_id
|
36 |
+
self.bos_token_id = 0
|
37 |
+
self.decoder_start_token_id = 2
|
38 |
+
self.output_past = True
|
39 |
+
self.scale_embedding = True
|
40 |
+
self.use_cache = True
|
41 |
+
self.num_hidden_layers = num_layers
|
42 |
+
|
43 |
+
class CodeGenConfig:
|
44 |
+
def __init__(self, vocab_size, n_embd, n_layer, n_head):
|
45 |
+
self.vocab_size = vocab_size
|
46 |
+
self.n_embd = n_embd
|
47 |
+
self.n_layer = n_layer
|
48 |
+
self.n_head = n_head
|
49 |
+
self.n_positions = 2048
|
50 |
+
self.resid_pdrop = 0.1
|
51 |
+
self.embd_pdrop = 0.1
|
52 |
+
self.attn_pdrop = 0.1
|
53 |
+
self.activation_function = "gelu_new"
|
54 |
+
self.n_ctx = 2048
|
55 |
+
self.pad_token_id = 50256
|
56 |
+
self.eos_token_id = 50256
|
57 |
+
self.initializer_range = 0.02
|
58 |
+
|
59 |
+
class SummarizationConfig:
|
60 |
+
def __init__(self):
|
61 |
+
self.vocab_size = 10000
|
62 |
+
self.embedding_dim = 256
|
63 |
+
self.hidden_dim = 512
|
64 |
+
self.num_layers = 2
|
65 |
+
self.max_seq_len = 512
|
66 |
+
|
67 |
+
class Clip4ClipConfig:
|
68 |
+
def __init__(self, vocab_size=30522, hidden_size=512, num_hidden_layers=6, num_attention_heads=8, intermediate_size=2048, hidden_act="gelu", hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, bos_token_id=1, eos_token_id=2, **kwargs):
|
69 |
+
self.vocab_size = vocab_size
|
70 |
+
self.hidden_size = hidden_size
|
71 |
+
self.num_hidden_layers = num_hidden_layers
|
72 |
+
self.num_attention_heads = num_attention_heads
|
73 |
+
self.intermediate_size = intermediate_size
|
74 |
+
self.hidden_act = hidden_act
|
75 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
76 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
77 |
+
self.max_position_embeddings = max_position_embeddings
|
78 |
+
self.type_vocab_size = type_vocab_size
|
79 |
+
self.initializer_range = initializer_range
|
80 |
+
self.layer_norm_eps = layer_norm_eps
|
81 |
+
self.pad_token_id = pad_token_id
|
82 |
+
self.bos_token_id = bos_token_id
|
83 |
+
self.eos_token_id = eos_token_id
|
84 |
+
self.all_head_size = self.num_attention_heads * self.hidden_size
|
85 |
+
self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
|
86 |
+
for key, value in kwargs.items():
|
87 |
+
setattr(self, key, value)
|
88 |
+
|
89 |
+
@classmethod
|
90 |
+
def from_dict(cls, config_dict):
|
91 |
+
return cls(**config_dict)
|
92 |
+
|
93 |
+
class MusicGenConfig:
|
94 |
+
def __init__(self, vocab_size=2048, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, layer_norm_eps=1e-05, initializer_range=0.02, pad_token_id=0, bos_token_id=1, eos_token_id=2, n_positions=2048, n_ctx=2048, **kwargs):
|
95 |
+
self.vocab_size = vocab_size
|
96 |
+
self.hidden_size = hidden_size
|
97 |
+
self.num_hidden_layers = num_hidden_layers
|
98 |
+
self.num_attention_heads = num_attention_heads
|
99 |
+
self.intermediate_size = intermediate_size
|
100 |
+
self.hidden_act = hidden_act
|
101 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
102 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
103 |
+
self.layer_norm_eps = layer_norm_eps
|
104 |
+
self.initializer_range = initializer_range
|
105 |
+
self.pad_token_id = pad_token_id
|
106 |
+
self.bos_token_id = bos_token_id
|
107 |
+
self.eos_token_id = eos_token_id
|
108 |
+
self.n_positions = n_positions
|
109 |
+
self.n_ctx = n_ctx
|
110 |
+
self.all_head_size = self.num_attention_heads * self.hidden_size
|
111 |
+
for key, value in kwargs.items():
|
112 |
+
setattr(self, key, value)
|
113 |
+
|
114 |
+
@classmethod
|
115 |
+
def from_dict(cls, config_dict):
|
116 |
+
return cls(**config_dict)
|
117 |
+
|
118 |
+
class BartConfig:
|
119 |
+
def __init__(self, vocab_size=50265, max_position_embeddings=1024, encoder_layers=12, encoder_ffn_dim=4096, encoder_attention_heads=16, decoder_layers=12, decoder_ffn_dim=4096, decoder_attention_heads=16, encoder_layerdrop=0.0, decoder_layerdrop=0.0, activation_function="gelu", d_model=1024, dropout=0.1, attention_dropout=0.0, activation_dropout=0.0, init_std=0.02, classifier_dropout=0.0, num_labels=3, pad_token_id=1, bos_token_id=0, eos_token_id=2, layer_norm_eps=1e-05, num_beams=4, early_stopping=True, max_length=100, min_length=30, scale_embedding=False, **kwargs):
|
120 |
+
self.vocab_size = vocab_size
|
121 |
+
self.max_position_embeddings = max_position_embeddings
|
122 |
+
self.encoder_layers = encoder_layers
|
123 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
124 |
+
self.encoder_attention_heads = encoder_attention_heads
|
125 |
+
self.decoder_layers = decoder_layers
|
126 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
127 |
+
self.decoder_attention_heads = decoder_attention_heads
|
128 |
+
self.encoder_layerdrop = encoder_layerdrop
|
129 |
+
self.decoder_layerdrop = decoder_layerdrop
|
130 |
+
self.activation_function = activation_function
|
131 |
+
self.d_model = d_model
|
132 |
+
self.dropout = dropout
|
133 |
+
self.attention_dropout = attention_dropout
|
134 |
+
self.activation_dropout = activation_dropout
|
135 |
+
self.init_std = init_std
|
136 |
+
self.classifier_dropout = classifier_dropout
|
137 |
+
self.num_labels = num_labels
|
138 |
+
self.pad_token_id = pad_token_id
|
139 |
+
self.bos_token_id = bos_token_id
|
140 |
+
self.eos_token_id = eos_token_id
|
141 |
+
self.layer_norm_eps = layer_norm_eps
|
142 |
+
self.num_beams = num_beams
|
143 |
+
self.early_stopping = True
|
144 |
+
self.max_length = max_length
|
145 |
+
self.min_length = min_length
|
146 |
+
self.scale_embedding = False
|
147 |
+
for key, value in kwargs.items():
|
148 |
+
setattr(self, key, value)
|
149 |
+
|
150 |
+
@classmethod
|
151 |
+
def from_dict(cls, config_dict):
|
152 |
+
return cls(**config_dict)
|
153 |
+
|
154 |
+
class OpenLRMConfig:
|
155 |
+
def __init__(self, obj_dim=1024, hidden_dim=512, num_layers=6, num_heads=8, dropout_prob=0.1, **kwargs):
|
156 |
+
self.obj_dim = obj_dim
|
157 |
+
self.hidden_dim = hidden_dim
|
158 |
+
self.num_layers = num_layers
|
159 |
+
self.num_heads = num_heads
|
160 |
+
self.dropout_prob = dropout_prob
|
161 |
+
self.all_head_size = self.num_heads * self.hidden_dim
|
162 |
+
for key, value in kwargs.items():
|
163 |
+
setattr(self, key, value)
|
164 |
+
|
165 |
+
@classmethod
|
166 |
+
def from_dict(cls, config_dict):
|
167 |
+
return cls(**config_dict)
|
168 |
+
|
169 |
+
class UNet2DConditionModelConfig:
|
170 |
+
def __init__(self, sample_size=64, layers_per_block=2, block_out_channels=[320, 640, 1280, 1280], downsample=[2, 2, 2, 2], upsample=[2, 2, 2, 2], cross_attention_dim=768, act_fn="silu", norm_num_groups=32, num_attention_heads=8, in_channels=4, out_channels=4, attention_head_dim=64, **kwargs):
|
171 |
+
self.sample_size = sample_size
|
172 |
+
self.layers_per_block = layers_per_block
|
173 |
+
self.block_out_channels = block_out_channels
|
174 |
+
self.downsample = downsample
|
175 |
+
self.upsample = upsample
|
176 |
+
self.cross_attention_dim = cross_attention_dim
|
177 |
+
self.act_fn = act_fn
|
178 |
+
self.norm_num_groups = norm_num_groups
|
179 |
+
self.num_attention_heads = num_attention_heads
|
180 |
+
self.in_channels = in_channels
|
181 |
+
self.out_channels = out_channels
|
182 |
+
self.attention_head_dim = attention_head_dim
|
183 |
+
for key, value in kwargs.items():
|
184 |
+
setattr(self, key, value)
|
185 |
+
|
186 |
+
@classmethod
|
187 |
+
def from_dict(cls, config_dict):
|
188 |
+
return cls(**config_dict)
|
189 |
+
|
190 |
+
class AutoencoderKLConfig:
|
191 |
+
def __init__(self, **kwargs):
|
192 |
+
self.sample_size = 64
|
193 |
+
self.latent_channels = 4
|
194 |
+
self.layers_per_block = 2
|
195 |
+
self.block_out_channels = [128, 256, 512, 512]
|
196 |
+
self.downsample = [2, 2, 2, 2]
|
197 |
+
self.upsample = [2, 2, 2, 2]
|
198 |
+
self.act_fn = "silu"
|
199 |
+
self.norm_num_groups = 32
|
200 |
+
self.num_channels_every_n_layers = 2
|
201 |
+
for key, value in kwargs.items():
|
202 |
+
setattr(self, key, value)
|
203 |
+
|
204 |
+
@classmethod
|
205 |
+
def from_dict(cls, config_dict):
|
206 |
+
return cls(**config_dict)
|
constants.py
ADDED
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
TEXT_GENERATION_RATE = 40000
|
4 |
+
MAX_LENGTH = 2048
|
5 |
+
MAX_XDD = 5
|
6 |
+
END_OF_TEXT_TOKEN = "<|endoftext|>"
|
7 |
+
SYSTEM_PROMPT = """Eres un asistente experto con habilidades avanzadas en diversas áreas. Responde de manera amigable, educada y razonada. Siempre piensa cuidadosamente antes de responder para asegurar la claridad y completitud. Posees la capacidad de autoaprendizaje continuo y recuerdas interacciones pasadas para mejorar tus respuestas y evitar errores repetidos."""
|
8 |
+
XML_COT_FORMAT = """<reasoning>\n{reasoning}\n</reasoning>\n<answer>\n{answer}\n</answer>\n"""
|
9 |
+
|
10 |
+
html_code = """<!DOCTYPE html>
|
11 |
+
<html lang="en">
|
12 |
+
<head>
|
13 |
+
<meta charset="UTF-8">
|
14 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
15 |
+
<title>AI Text Generation</title>
|
16 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/animate.css/4.1.1/animate.min.css"/>
|
17 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" integrity="sha512-9usAa10IRO0HhonpyAIVpjrylPvoDwiPUiKdWk5t3PyolY1cOd4DSE0Ga+ri4AuTroPR5aQvXU9xC6qOPnzFeg==" crossorigin="anonymous" referrerpolicy="no-referrer" />
|
18 |
+
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
19 |
+
<style>
|
20 |
+
body {
|
21 |
+
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
22 |
+
background: #f0f0f0;
|
23 |
+
color: #333;
|
24 |
+
margin: 0;
|
25 |
+
padding: 0;
|
26 |
+
display: flex;
|
27 |
+
flex-direction: column;
|
28 |
+
align-items: center;
|
29 |
+
min-height: 100vh;
|
30 |
+
}
|
31 |
+
.container {
|
32 |
+
width: 95%;
|
33 |
+
max-width: 900px;
|
34 |
+
padding: 20px;
|
35 |
+
background-color: #fff;
|
36 |
+
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
|
37 |
+
border-radius: 8px;
|
38 |
+
margin-top: 20px;
|
39 |
+
margin-bottom: 20px;
|
40 |
+
display: flex;
|
41 |
+
flex-direction: column;
|
42 |
+
}
|
43 |
+
.header {
|
44 |
+
text-align: center;
|
45 |
+
margin-bottom: 20px;
|
46 |
+
}
|
47 |
+
.header h1 {
|
48 |
+
font-size: 2em;
|
49 |
+
color: #333;
|
50 |
+
}
|
51 |
+
.form-group {
|
52 |
+
margin-bottom: 15px;
|
53 |
+
}
|
54 |
+
.form-group textarea {
|
55 |
+
width: 100%;
|
56 |
+
padding: 10px;
|
57 |
+
border: 1px solid #ccc;
|
58 |
+
border-radius: 5px;
|
59 |
+
font-size: 16px;
|
60 |
+
box-sizing: border-box;
|
61 |
+
resize: vertical;
|
62 |
+
}
|
63 |
+
button {
|
64 |
+
padding: 10px 15px;
|
65 |
+
border: none;
|
66 |
+
border-radius: 5px;
|
67 |
+
background-color: #007bff;
|
68 |
+
color: white;
|
69 |
+
font-size: 18px;
|
70 |
+
cursor: pointer;
|
71 |
+
transition: background-color 0.3s ease;
|
72 |
+
}
|
73 |
+
button:hover {
|
74 |
+
background-color: #0056b3;
|
75 |
+
}
|
76 |
+
#output {
|
77 |
+
margin-top: 20px;
|
78 |
+
padding: 15px;
|
79 |
+
border: 1px solid #ddd;
|
80 |
+
border-radius: 5px;
|
81 |
+
background-color: #f9f9f9;
|
82 |
+
white-space: pre-wrap;
|
83 |
+
word-break: break-word;
|
84 |
+
overflow-y: auto;
|
85 |
+
max-height: 100vh;
|
86 |
+
}
|
87 |
+
#output strong {
|
88 |
+
font-weight: bold;
|
89 |
+
}
|
90 |
+
.animated-text {
|
91 |
+
position: fixed;
|
92 |
+
top: 20px;
|
93 |
+
left: 20px;
|
94 |
+
font-size: 1.5em;
|
95 |
+
color: rgba(0, 0, 0, 0.1);
|
96 |
+
pointer-events: none;
|
97 |
+
z-index: -1;
|
98 |
+
}
|
99 |
+
@media (max-width: 768px) {
|
100 |
+
.container {
|
101 |
+
width: 98%;
|
102 |
+
margin-top: 10px;
|
103 |
+
margin-bottom: 10px;
|
104 |
+
padding: 15px;
|
105 |
+
}
|
106 |
+
.header h1 {
|
107 |
+
font-size: 1.8em;
|
108 |
+
}
|
109 |
+
.form-group textarea, .form-group input[type="text"] {
|
110 |
+
font-size: 14px;
|
111 |
+
padding: 8px;
|
112 |
+
}
|
113 |
+
button {
|
114 |
+
font-size: 16px;
|
115 |
+
padding: 8px 12px;
|
116 |
+
}
|
117 |
+
#output {
|
118 |
+
font-size: 14px;
|
119 |
+
padding: 10px;
|
120 |
+
margin-top: 15px;
|
121 |
+
}
|
122 |
+
}
|
123 |
+
</style>
|
124 |
+
</head>
|
125 |
+
<body>
|
126 |
+
<div class="animated-text animate__animated animate__fadeIn animate__infinite infinite">AI POWERED</div>
|
127 |
+
<div class="container">
|
128 |
+
<div class="header animate__animated animate__fadeInDown">
|
129 |
+
</div>
|
130 |
+
<div class="form-group animate__animated animate__fadeInLeft">
|
131 |
+
<textarea id="text" rows="5" placeholder="Enter text"></textarea>
|
132 |
+
</div>
|
133 |
+
<button onclick="generateText()" class="animate__animated animate__fadeInUp">Generate Reasoning</button>
|
134 |
+
<div id="output" class="animate__animated">
|
135 |
+
<strong >Response:</strong><br>
|
136 |
+
<span id="generatedText"></span>
|
137 |
+
</div>
|
138 |
+
</div>
|
139 |
+
<script>
|
140 |
+
let eventSource = null;
|
141 |
+
let accumulatedText = "";
|
142 |
+
let lastResponse = "";
|
143 |
+
async function generateText() {
|
144 |
+
const inputText = document.getElementById("text").value;
|
145 |
+
document.getElementById("generatedText").innerText = "";
|
146 |
+
accumulatedText = "";
|
147 |
+
if (eventSource) {
|
148 |
+
eventSource.close();
|
149 |
+
}
|
150 |
+
const temp = 0.7;
|
151 |
+
const top_k_val = 40;
|
152 |
+
const top_p_val = 0.0;
|
153 |
+
const repetition_penalty_val = 1.2;
|
154 |
+
const requestData = {
|
155 |
+
text: inputText,
|
156 |
+
temp: temp,
|
157 |
+
top_k: top_k_val,
|
158 |
+
top_p: top_p_val,
|
159 |
+
reppenalty: repetition_penalty_val
|
160 |
+
};
|
161 |
+
eventSource = new EventSource('/generate_stream', {
|
162 |
+
headers: {
|
163 |
+
'Content-Type': 'application/json'
|
164 |
+
},
|
165 |
+
method: 'POST',
|
166 |
+
body: JSON.stringify(requestData)
|
167 |
+
});
|
168 |
+
eventSource.onmessage = function(event) {
|
169 |
+
if (event.data === "<END_STREAM>") {
|
170 |
+
eventSource.close();
|
171 |
+
const currentResponse = accumulatedText.replace("<|endoftext|>", "").replace(re.compile(r'\\s+(?=[.,,。])'), '').trim();
|
172 |
+
if (currentResponse === lastResponse.trim()) {
|
173 |
+
accumulatedText = "**Response is repetitive. Please try again or rephrase your query.**";
|
174 |
+
} else {
|
175 |
+
lastResponse = currentResponse;
|
176 |
+
}
|
177 |
+
document.getElementById("generatedText").innerHTML = marked.parse(accumulatedText);
|
178 |
+
return;
|
179 |
+
}
|
180 |
+
accumulatedText += event.data;
|
181 |
+
let partialText = accumulatedText.replace("<|endoftext|>", "").replace(re.compile(r'\\s+(?=[.,,。])'), '').trim();
|
182 |
+
document.getElementById("generatedText").innerHTML = marked.parse(partialText);
|
183 |
+
};
|
184 |
+
eventSource.onerror = function(error) {
|
185 |
+
console.error("SSE error", error);
|
186 |
+
eventSource.close();
|
187 |
+
};
|
188 |
+
const outputDiv = document.getElementById("output");
|
189 |
+
outputDiv.classList.add("show");
|
190 |
+
}
|
191 |
+
function base64ToBlob(base64Data, contentType) {
|
192 |
+
contentType = contentType || '';
|
193 |
+
const sliceSize = 1024;
|
194 |
+
const byteCharacters = atob(base64Data);
|
195 |
+
const bytesLength = byteCharacters.length;
|
196 |
+
const slicesCount = Math.ceil(bytesLength / sliceSize);
|
197 |
+
const byteArrays = new Array(slicesCount);
|
198 |
+
for (let sliceIndex = 0; sliceIndex < slicesCount; ++sliceIndex) {
|
199 |
+
const begin = sliceIndex * sliceSize;
|
200 |
+
const end = Math.min(begin + sliceSize, bytesLength);
|
201 |
+
const bytes = new Array(end - begin);
|
202 |
+
for (let offset = begin, i = 0; offset < end; ++i, ++offset) {
|
203 |
+
bytes[i] = byteCharacters[offset].charCodeAt(0);
|
204 |
+
}
|
205 |
+
byteArrays[sliceIndex] = new Uint8Array(bytes);
|
206 |
+
}
|
207 |
+
return new Blob(byteArrays, { type: contentType });
|
208 |
+
}
|
209 |
+
</script>
|
210 |
+
</body>
|
211 |
+
</html>
|
212 |
+
"""
|
213 |
+
|
214 |
+
HTML_CODE = html_code
|
215 |
+
|
216 |
+
# =============================================================================
|
217 |
+
# Constantes definidas por el usuario
|
218 |
+
# =============================================================================
|
219 |
+
|
220 |
+
# GPT-2
|
221 |
+
GPT2_FOLDER = "./GPT2"
|
222 |
+
MODEL_FILE = "gpt2-pytorch_model.bin"
|
223 |
+
ENCODER_FILE = "encoder.json"
|
224 |
+
VOCAB_FILE = "vocab.bpe"
|
225 |
+
CONFIG_FILE = "config.json"
|
226 |
+
GPT2CONFHG = "https://huggingface.co/openai-community/gpt2/resolve/main/config.json"
|
227 |
+
MODEL_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"
|
228 |
+
ENCODER_URL = "https://raw.githubusercontent.com/graykode/gpt-2-Pytorch/refs/heads/master/GPT2/encoder.json"
|
229 |
+
VOCAB_URL = "https://raw.githubusercontent.com/graykode/gpt-2-Pytorch/refs/heads/master/GPT2/vocab.bpe"
|
230 |
+
|
231 |
+
# Traducción (MBart)
|
232 |
+
TRANSLATION_FOLDER = "./TranslationModel"
|
233 |
+
TRANSLATION_MODEL_WEIGHTS_FILE = "pytorch_model.bin"
|
234 |
+
TRANSLATION_MODEL_CONFIG_FILE = "config.json"
|
235 |
+
TRANSLATION_MODEL_VOCAB_FILE = "sentencepiece.bpe.model"
|
236 |
+
TRANSLATION_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/pytorch_model.bin"
|
237 |
+
TRANSLATION_MODEL_CONFIG_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/config.json"
|
238 |
+
TRANSLATION_MODEL_VOCAB_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/sentencepiece.bpe.model"
|
239 |
+
TRANSLATION_MODEL_FILES_URLS = [
|
240 |
+
(TRANSLATION_MODEL_WEIGHTS_URL, TRANSLATION_MODEL_WEIGHTS_FILE),
|
241 |
+
(TRANSLATION_MODEL_CONFIG_URL, TRANSLATION_MODEL_CONFIG_FILE),
|
242 |
+
(TRANSLATION_MODEL_VOCAB_URL, TRANSLATION_MODEL_VOCAB_FILE),
|
243 |
+
]
|
244 |
+
|
245 |
+
# CodeGen
|
246 |
+
CODEGEN_FOLDER = "./CodeGenModel"
|
247 |
+
CODEGEN_MODEL_NAME = "codegen-350M-multi"
|
248 |
+
CODEGEN_MODEL_WEIGHTS = "pytorch_model.bin"
|
249 |
+
CODEGEN_CONFIG = "config.json"
|
250 |
+
CODEGEN_VOCAB = "vocab.json"
|
251 |
+
CODEGEN_MERGES = "merges.txt"
|
252 |
+
CODEGEN_MODEL_WEIGHTS_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/pytorch_model.bin"
|
253 |
+
CODEGEN_CONFIG_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/config.json"
|
254 |
+
CODEGEN_VOCAB_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/vocab.json"
|
255 |
+
CODEGEN_MERGES_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/merges.txt"
|
256 |
+
CODEGEN_FILES_URLS = [
|
257 |
+
(CODEGEN_MODEL_WEIGHTS_URL, CODEGEN_MODEL_WEIGHTS),
|
258 |
+
(CODEGEN_CONFIG_URL, CODEGEN_CONFIG),
|
259 |
+
(CODEGEN_VOCAB_URL, CODEGEN_VOCAB),
|
260 |
+
(CODEGEN_MERGES_URL, CODEGEN_MERGES),
|
261 |
+
]
|
262 |
+
|
263 |
+
# MusicGen
|
264 |
+
MUSICGEN_FOLDER = "./MusicGenModel"
|
265 |
+
MUSICGEN_MODEL_NAME = "melody"
|
266 |
+
MUSICGEN_MODEL_WEIGHTS = "pytorch_model.bin"
|
267 |
+
MUSICGEN_CONFIG = "config.json"
|
268 |
+
MUSICGEN_SAMPLE_RATE = 32000
|
269 |
+
MUSICGEN_DURATION = 8
|
270 |
+
MUSICGEN_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/musicgen-small/resolve/main/pytorch_model.bin"
|
271 |
+
MUSICGEN_CONFIG_URL = "https://huggingface.co/facebook/musicgen-small/resolve/main/config.json"
|
272 |
+
MUSICGEN_FILES_URLS = [
|
273 |
+
(MUSICGEN_MODEL_WEIGHTS_URL, MUSICGEN_MODEL_WEIGHTS),
|
274 |
+
(MUSICGEN_CONFIG_URL, MUSICGEN_CONFIG)
|
275 |
+
]
|
276 |
+
|
277 |
+
# Summarization (Bart)
|
278 |
+
SUMMARIZATION_FOLDER = "./SummarizationModel"
|
279 |
+
SUMMARIZATION_MODEL_WEIGHTS = "pytorch_model.bin"
|
280 |
+
SUMMARIZATION_CONFIG = "config.json"
|
281 |
+
SUMMARIZATION_VOCAB = "vocab.json"
|
282 |
+
SUMMARIZATION_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/bart-large-cnn/resolve/main/pytorch_model.bin"
|
283 |
+
SUMMARIZATION_CONFIG_URL = "https://huggingface.co/facebook/bart-large-cnn/resolve/main/config.json"
|
284 |
+
SUMMARIZATION_VOCAB_URL = "https://huggingface.co/facebook/bart-large-cnn/resolve/main/vocab.json"
|
285 |
+
SUMMARIZATION_FILES_URLS = [
|
286 |
+
(SUMMARIZATION_MODEL_WEIGHTS_URL, SUMMARIZATION_MODEL_WEIGHTS),
|
287 |
+
(SUMMARIZATION_CONFIG_URL, SUMMARIZATION_CONFIG),
|
288 |
+
(SUMMARIZATION_VOCAB_URL, SUMMARIZATION_VOCAB)
|
289 |
+
]
|
290 |
+
|
291 |
+
# TTS
|
292 |
+
TTS_FOLDER = "./TTSModel"
|
293 |
+
TTS_MODEL_NAME = "vits"
|
294 |
+
TTS_MODEL_CONFIG = "config.json"
|
295 |
+
TTS_MODEL_WEIGHTS = "pytorch_model.bin"
|
296 |
+
TTS_VOCAB = "vocab.json"
|
297 |
+
TTS_CONFIG_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/config.json"
|
298 |
+
TTS_MODEL_WEIGHTS_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/pytorch_model.bin"
|
299 |
+
TTS_VOCAB_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/vocab.json"
|
300 |
+
TTS_FILES_URLS = [
|
301 |
+
(TTS_CONFIG_URL, TTS_MODEL_CONFIG),
|
302 |
+
(TTS_MODEL_WEIGHTS_URL, TTS_MODEL_WEIGHTS),
|
303 |
+
(TTS_VOCAB_URL, TTS_VOCAB)
|
304 |
+
]
|
305 |
+
|
306 |
+
# STT
|
307 |
+
STT_FOLDER = "./STTModel"
|
308 |
+
STT_MODEL_NAME = "wav2vec2"
|
309 |
+
STT_MODEL_WEIGHTS = "pytorch_model.bin"
|
310 |
+
STT_CONFIG = "config.json"
|
311 |
+
STT_VOCAB = "vocab.json"
|
312 |
+
STT_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/pytorch_model.bin"
|
313 |
+
STT_CONFIG_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json"
|
314 |
+
STT_VOCAB_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/vocab.json"
|
315 |
+
STT_FILES_URLS = [
|
316 |
+
(STT_MODEL_WEIGHTS_URL, STT_MODEL_WEIGHTS),
|
317 |
+
(STT_CONFIG_URL, STT_CONFIG),
|
318 |
+
(STT_VOCAB_URL, STT_VOCAB)
|
319 |
+
]
|
320 |
+
|
321 |
+
# Sentiment Analysis
|
322 |
+
SENTIMENT_FOLDER = "./SentimentModel"
|
323 |
+
SENTIMENT_MODEL_WEIGHTS = "pytorch_model.bin"
|
324 |
+
SENTIMENT_VOCAB = "vocab.json"
|
325 |
+
SENTIMENT_CONFIG_FILE = "config.json"
|
326 |
+
SENTIMENT_MODEL_WEIGHTS_URL = "https://huggingface.co/climatebert/distilroberta-base-climate-sentiment/resolve/main/pytorch_model.bin"
|
327 |
+
SENTIMENT_VOCAB_URL = "https://huggingface.co/climatebert/distilroberta-base-climate-sentiment/resolve/main/vocab.json"
|
328 |
+
SENTIMENT_CONFIG_URL = "https://huggingface.co/climatebert/distilroberta-base-climate-sentiment/resolve/main/config.json"
|
329 |
+
SENTIMENT_FILES_URLS = [
|
330 |
+
(SENTIMENT_MODEL_WEIGHTS_URL, SENTIMENT_MODEL_WEIGHTS),
|
331 |
+
(SENTIMENT_VOCAB_URL, SENTIMENT_VOCAB),
|
332 |
+
(SENTIMENT_CONFIG_URL, SENTIMENT_CONFIG_FILE)
|
333 |
+
]
|
334 |
+
|
335 |
+
# Image Generation (VAE)
|
336 |
+
IMAGEGEN_FOLDER = "./ImageGenModel"
|
337 |
+
IMAGEGEN_MODEL_WEIGHTS = "diffusion_pytorch_model.bin"
|
338 |
+
IMAGEGEN_CONFIG = "config.json"
|
339 |
+
IMAGEGEN_MODEL_WEIGHTS_URL = "https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/diffusion_pytorch_model.bin"
|
340 |
+
IMAGEGEN_CONFIG_URL = "https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/config.json"
|
341 |
+
IMAGEGEN_FILES_URLS = [
|
342 |
+
(IMAGEGEN_MODEL_WEIGHTS_URL, IMAGEGEN_MODEL_WEIGHTS),
|
343 |
+
(IMAGEGEN_CONFIG_URL, IMAGEGEN_CONFIG)
|
344 |
+
]
|
345 |
+
|
346 |
+
# Image to 3D
|
347 |
+
IMAGE_TO_3D_FOLDER = "./ImageTo3DModel"
|
348 |
+
IMAGE_TO_3D_MODEL_WEIGHTS = "pytorch_model.bin"
|
349 |
+
IMAGE_TO_3D_CONFIG = "config.json"
|
350 |
+
IMAGE_TO_3D_MODEL_WEIGHTS_URL = "https://huggingface.co/zxhezexin/openlrm-obj-base-1.1/resolve/main/pytorch_model.bin"
|
351 |
+
IMAGE_TO_3D_CONFIG_URL = "https://huggingface.co/zxhezexin/openlrm-obj-base-1.1/resolve/main/config.json"
|
352 |
+
IMAGE_TO_3D_FILES_URLS = [
|
353 |
+
(IMAGE_TO_3D_MODEL_WEIGHTS_URL, IMAGE_TO_3D_MODEL_WEIGHTS),
|
354 |
+
(IMAGE_TO_3D_CONFIG_URL, IMAGE_TO_3D_CONFIG)
|
355 |
+
]
|
356 |
+
|
357 |
+
# Text to Video
|
358 |
+
TEXT_TO_VIDEO_FOLDER = "./TextToVideoModel"
|
359 |
+
TEXT_TO_VIDEO_MODEL_WEIGHTS = "diffusion_pytorch_model.bin" # Usado para ambos (Unet y VAE)
|
360 |
+
TEXT_TO_VIDEOX_MODEL_WEIGHTS = "diffusion_pytorch_model.fp16.bin" # Usado para ambos (Unet y VAE)
|
361 |
+
TEXT_TO_VIDEO_CONFIG = "config.json" # Usado para ambos (Unet y VAE)
|
362 |
+
TEXT_TO_VIDEO_VOCAB = "vocab.json"
|
363 |
+
TEXT_TO_VIDEO_MODEL_WEIGHTS_URL_UNET = "https://huggingface.co/ali-vilab/text-to-video-ms-1.7b/resolve/main/unet/diffusion_pytorch_model.fp16.bin"
|
364 |
+
TEXT_TO_VIDEO_CONFIG_URL_UNET = "https://huggingface.co/ali-vilab/text-to-video-ms-1.7b/resolve/main/unet/config.json"
|
365 |
+
TEXT_TO_VIDEO_MODEL_WEIGHTS_URL_VAE = "https://huggingface.co/ali-vilab/text-to-video-ms-1.7b/resolve/main/vae/diffusion_pytorch_model.fp16.bin"
|
366 |
+
TEXT_TO_VIDEO_CONFIG_URL_VAE = "https://huggingface.co/ali-vilab/text-to-video-ms-1.7b/resolve/main/vae/config.json"
|
367 |
+
TEXT_TO_VIDEO_VOCAB_URL = "https://huggingface.co/ali-vilab/text-to-video-ms-1.7b/resolve/main/tokenizer/vocab.json"
|
368 |
+
TEXT_TO_VIDEO_FILES_URLS = [
|
369 |
+
(TEXT_TO_VIDEO_MODEL_WEIGHTS_URL_UNET, TEXT_TO_VIDEO_MODEL_WEIGHTS),
|
370 |
+
(TEXT_TO_VIDEO_MODEL_WEIGHTS_URL_UNET, TEXT_TO_VIDEOX_MODEL_WEIGHTS),
|
371 |
+
(TEXT_TO_VIDEO_CONFIG_URL_UNET, TEXT_TO_VIDEO_CONFIG),
|
372 |
+
(TEXT_TO_VIDEO_MODEL_WEIGHTS_URL_VAE, TEXT_TO_VIDEO_MODEL_WEIGHTS),
|
373 |
+
(TEXT_TO_VIDEO_MODEL_WEIGHTS_URL_VAE, TEXT_TO_VIDEOX_MODEL_WEIGHTS),
|
374 |
+
(TEXT_TO_VIDEO_CONFIG_URL_VAE, TEXT_TO_VIDEO_CONFIG),
|
375 |
+
(TEXT_TO_VIDEO_VOCAB_URL, TEXT_TO_VIDEO_VOCAB),
|
376 |
+
]
|
377 |
+
|
378 |
+
# SadTalker
|
379 |
+
# ============================================================================
|
380 |
+
# Modelos de Restauración para SadTalker (Face Restoration / Super-Resolution)
|
381 |
+
# ============================================================================
|
382 |
+
# GFPGAN
|
383 |
+
GFPGAN_FOLDER = "./GFPGAN"
|
384 |
+
GFPGAN_MODEL_FILE = "GFPGANv1.4.pth"
|
385 |
+
GFPGAN_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
386 |
+
|
387 |
+
# RestoreFormer
|
388 |
+
RESTOREFORMER_FOLDER = "./RestoreFormer"
|
389 |
+
RESTOREFORMER_MODEL_FILE = "RestoreFormer.pth"
|
390 |
+
RESTOREFORMER_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth"
|
391 |
+
|
392 |
+
# CodeFormer
|
393 |
+
CODEFORMER_FOLDER = "./CodeFormer"
|
394 |
+
CODEFORMER_MODEL_FILE = "codeformer.pth"
|
395 |
+
CODEFORMER_URL = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
396 |
+
|
397 |
+
# RealESRGAN
|
398 |
+
REALESRGAN_FOLDER = "./RealESRGAN"
|
399 |
+
REALESRGAN_MODEL_FILE = "RealESRGAN_x2plus.pth"
|
400 |
+
REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
|
401 |
+
|
402 |
+
|
403 |
+
|
404 |
+
kp = "https://huggingface.co/usyd-community/vitpose-base-simple/resolve/main/model.safetensors"
|
405 |
+
kp_file = "kp_detector.safetensors"
|
406 |
+
aud = "https://huggingface.co/vinthony/SadTalker/resolve/main/auido2pose_00140-model.pth"
|
407 |
+
aud_file = "auido2pose_00140-model.pth"
|
408 |
+
wav = "https://huggingface.co/facebook/wav2vec2-base/resolve/main/pytorch_model.bin"
|
409 |
+
wav_file = "wav2vec2.bin"
|
410 |
+
gen = "https://huggingface.co/vinthony/SadTalker/resolve/main/wav2lip.pth"
|
411 |
+
gen_file = "generator.bin"
|
412 |
+
mapx = "https://huggingface.co/vinthony/SadTalker/resolve/main/mapping_00229-model.pth.tar"
|
413 |
+
mapx_file = "mapping.pth"
|
414 |
+
den = "https://huggingface.co/KwaiVGI/LivePortrait/resolve/main/liveportrait/base_models/motion_extractor.pth"
|
415 |
+
den_file = "dense_motion.pth"
|
416 |
+
|
417 |
+
# --- Define constants for new SadTalker models ---
|
418 |
+
SADTALKER_KP_FOLDER = "checkpoints"
|
419 |
+
SADTALKER_KP_MODEL_FILE = kp_file
|
420 |
+
SADTALKER_KP_URL = kp
|
421 |
+
|
422 |
+
SADTALKER_AUD_FOLDER = "checkpoints" # Assuming these go in the main checkpoints folder for SadTalker
|
423 |
+
SADTALKER_AUD_MODEL_FILE = aud_file
|
424 |
+
SADTALKER_AUD_URL = aud
|
425 |
+
|
426 |
+
SADTALKER_WAV_FOLDER = "checkpoints" # Assuming these go in the main checkpoints folder for SadTalker
|
427 |
+
SADTALKER_WAV_MODEL_FILE = wav_file
|
428 |
+
SADTALKER_WAV_URL = wav
|
429 |
+
|
430 |
+
SADTALKER_GEN_FOLDER = "checkpoints" # Assuming these go in the main checkpoints folder for SadTalker
|
431 |
+
SADTALKER_GEN_MODEL_FILE = gen_file
|
432 |
+
SADTALKER_GEN_URL = gen
|
433 |
+
|
434 |
+
SADTALKER_MAPX_FOLDER = "checkpoints" # Assuming these go in the main checkpoints folder for SadTalker
|
435 |
+
SADTALKER_MAPX_MODEL_FILE = mapx_file
|
436 |
+
SADTALKER_MAPX_URL = mapx
|
437 |
+
|
438 |
+
SADTALKER_DEN_FOLDER = "checkpoints" # Assuming these go in the main checkpoints folder for SadTalker
|
439 |
+
SADTALKER_DEN_MODEL_FILE = den_file
|
440 |
+
SADTALKER_DEN_URL = den
|
441 |
+
|
442 |
+
|
443 |
+
|
444 |
+
|
445 |
+
# =============================================================================
|
446 |
+
# SadTalker
|
447 |
+
# =============================================================================
|
448 |
+
SADTALKER_CHECKPOINTS_FOLDER = "./checkpoints"
|
449 |
+
SADTALKER_CONFIG_FOLDER = "./src/config"
|
extensions.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import yaml
|
5 |
+
from PIL import Image
|
6 |
+
from skimage import img_as_ubyte, transform
|
7 |
+
import safetensors
|
8 |
+
import librosa
|
9 |
+
from pydub import AudioSegment
|
10 |
+
import imageio
|
11 |
+
from scipy.io import loadmat, savemat, wavfile
|
12 |
+
import glob
|
13 |
+
import tempfile
|
14 |
+
from tqdm import tqdm
|
15 |
+
import numpy as np
|
16 |
+
import math
|
17 |
+
import torchvision
|
18 |
+
import os
|
19 |
+
import re
|
20 |
+
import shutil
|
21 |
+
from yacs.config import CfgNode as CN
|
22 |
+
import requests
|
23 |
+
import subprocess
|
24 |
+
import cv2
|
25 |
+
from collections import OrderedDict
|
26 |
+
|
27 |
+
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
28 |
+
if isinstance(imgs, np.ndarray):
|
29 |
+
if imgs.ndim == 3:
|
30 |
+
imgs = imgs[..., np.newaxis]
|
31 |
+
imgs = torch.from_numpy(imgs.transpose((2, 0, 1)))
|
32 |
+
elif isinstance(imgs, Image.Image):
|
33 |
+
imgs = torch.from_numpy(np.array(imgs)).permute(2, 0, 1)
|
34 |
+
else:
|
35 |
+
raise TypeError(f'Type `{type(imgs)}` is not suitable for img2tensor')
|
36 |
+
if bgr2rgb:
|
37 |
+
if imgs.shape[0] == 3:
|
38 |
+
imgs = imgs[[2, 1, 0], :, :]
|
39 |
+
if float32:
|
40 |
+
imgs = imgs.float() / 255.
|
41 |
+
return imgs
|
42 |
+
|
43 |
+
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
|
44 |
+
if not isinstance(tensor, torch.Tensor):
|
45 |
+
raise TypeError(f'Input tensor should be torch.Tensor, but got {type(tensor)}')
|
46 |
+
tensor = tensor.float().cpu()
|
47 |
+
tensor = tensor.clamp_(*min_max)
|
48 |
+
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
|
49 |
+
output_img = tensor.mul(255).round()
|
50 |
+
output_img = np.transpose(output_img.numpy(), (1, 2, 0))
|
51 |
+
output_img = np.clip(output_img, 0, 255).astype(np.uint8)
|
52 |
+
if rgb2bgr:
|
53 |
+
output_img = cv2.cvtColor(output_img, cv2.COLOR_RGB2BGR)
|
54 |
+
return output_img if out_type == np.uint8 else output_img.astype(out_type) / 255.
|
55 |
+
|
56 |
+
class RealESRGANer():
|
57 |
+
def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=0, half=False, device=None, gpu_id=None):
|
58 |
+
self.scale = scale
|
59 |
+
self.tile = tile
|
60 |
+
self.tile_pad = tile_pad
|
61 |
+
self.pre_pad = pre_pad
|
62 |
+
self.mod_scale = None
|
63 |
+
self.half = half
|
64 |
+
if device is None:
|
65 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
66 |
+
else:
|
67 |
+
self.device = device
|
68 |
+
if model is None:
|
69 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
|
70 |
+
if half:
|
71 |
+
model.half()
|
72 |
+
loadnet = torch.load(model_path, map_location=lambda storage, loc: storage)
|
73 |
+
if 'params' in loadnet:
|
74 |
+
model.load_state_dict(loadnet['params'], strict=True)
|
75 |
+
elif 'params_ema' in loadnet:
|
76 |
+
model.load_state_dict(loadnet['params_ema'], strict=True)
|
77 |
+
else:
|
78 |
+
model.load_state_dict(loadnet, strict=True)
|
79 |
+
model.eval()
|
80 |
+
self.model = model.to(self.device)
|
81 |
+
|
82 |
+
def enhance(self, img, outscale=None, tile=None, tile_pad=None, pre_pad=None, half=None):
|
83 |
+
h_input, w_input = img.shape[0:2]
|
84 |
+
if outscale is None:
|
85 |
+
outscale = self.scale
|
86 |
+
if tile is None:
|
87 |
+
tile = self.tile
|
88 |
+
if tile_pad is None:
|
89 |
+
tile_pad = self.tile_pad
|
90 |
+
if pre_pad is None:
|
91 |
+
pre_pad = self.pre_pad
|
92 |
+
if half is None:
|
93 |
+
half = self.half
|
94 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
95 |
+
img_tensor = img2tensor(img)
|
96 |
+
img_tensor = img_tensor.unsqueeze(0).to(self.device)
|
97 |
+
if half:
|
98 |
+
img_tensor = img_tensor.half()
|
99 |
+
mod_scale = self.mod_scale
|
100 |
+
h_pad, w_pad = 0, 0
|
101 |
+
if mod_scale is not None:
|
102 |
+
h_pad, w_pad = int(np.ceil(h_input / mod_scale) * mod_scale - h_input), int(np.ceil(w_input / mod_scale) * mod_scale - w_input)
|
103 |
+
img_tensor = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'reflect')
|
104 |
+
window_size = 256
|
105 |
+
scale = self.scale
|
106 |
+
overlap_ratio = 0.5
|
107 |
+
if w_input * h_input < window_size**2:
|
108 |
+
tile = None
|
109 |
+
if tile is not None and tile > 0:
|
110 |
+
tile_overlap = tile * overlap_ratio
|
111 |
+
sf = scale
|
112 |
+
stride_w = math.ceil(tile - tile_overlap)
|
113 |
+
stride_h = math.ceil(tile - tile_overlap)
|
114 |
+
numW = math.ceil((w_input + tile_overlap) / stride_w)
|
115 |
+
numH = math.ceil((h_input + tile_overlap) / stride_h)
|
116 |
+
paddingW = (numW - 1) * stride_w + tile - w_input
|
117 |
+
paddingH = (numH - 1) * stride_h + tile - h_input
|
118 |
+
padding_bottom = int(max(paddingH, 0))
|
119 |
+
padding_right = int(max(paddingW, 0))
|
120 |
+
padding_left, padding_top = 0, 0
|
121 |
+
img_tensor = F.pad(img_tensor, (padding_left, padding_right, padding_top, padding_bottom), mode='reflect')
|
122 |
+
output_h, output_w = padding_top + h_input * scale + padding_bottom, padding_left + w_input * scale + padding_right
|
123 |
+
output_tensor = torch.zeros([1, 3, output_h, output_w], dtype=img_tensor.dtype, device=self.device)
|
124 |
+
windows = []
|
125 |
+
for row in range(numH):
|
126 |
+
for col in range(numW):
|
127 |
+
start_x = col * stride_w
|
128 |
+
start_y = row * stride_h
|
129 |
+
end_x = min(start_x + tile, img_tensor.shape[3])
|
130 |
+
end_y = min(start_y + tile, img_tensor.shape[2])
|
131 |
+
windows.append(img_tensor[:, :, start_y:end_y, start_x:end_x])
|
132 |
+
results = []
|
133 |
+
batch_size = 8
|
134 |
+
for i in range(0, len(windows), batch_size):
|
135 |
+
batch_windows = torch.stack(windows[i:min(i + batch_size, len(windows))], dim=0)
|
136 |
+
with torch.no_grad():
|
137 |
+
results.append(self.model(batch_windows))
|
138 |
+
results = torch.cat(results, dim=0)
|
139 |
+
count = 0
|
140 |
+
for row in range(numH):
|
141 |
+
for col in range(numW):
|
142 |
+
start_x = col * stride_w
|
143 |
+
start_y = row * stride_h
|
144 |
+
end_x = min(start_x + tile, img_tensor.shape[3])
|
145 |
+
end_y = min(start_y + tile, img_tensor.shape[2])
|
146 |
+
out_start_x, out_start_y = start_x * sf, start_y * sf
|
147 |
+
out_end_x, out_end_y = end_x * sf, end_y * sf
|
148 |
+
output_tensor[:, :, out_start_y:out_end_y, out_start_x:out_end_x] += results[count][:, :, :end_y * sf - out_start_y, :end_x * sf - out_start_x]
|
149 |
+
count += 1
|
150 |
+
forward_img = output_tensor[:, :, :h_input * sf, :w_input * sf]
|
151 |
+
else:
|
152 |
+
with torch.no_grad():
|
153 |
+
forward_img = self.model(img_tensor)
|
154 |
+
if half:
|
155 |
+
forward_img = forward_img.float()
|
156 |
+
output_img = tensor2img(forward_img.squeeze(0).clamp_(0, 1))
|
157 |
+
if mod_scale is not None:
|
158 |
+
output_img = output_img[:h_input * self.scale, :w_input * self.scale, ...]
|
159 |
+
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB)
|
160 |
+
return [output_img, None]
|
161 |
+
|
162 |
+
def save_video_with_watermark(video_frames, audio_path, output_path, watermark_path='./assets/sadtalker_logo.png'):
|
163 |
+
try:
|
164 |
+
watermark = imageio.imread(watermark_path)
|
165 |
+
except FileNotFoundError:
|
166 |
+
watermark = None
|
167 |
+
writer = imageio.get_writer(output_path, fps=25)
|
168 |
+
try:
|
169 |
+
for frame in tqdm(video_frames, 'Generating video'):
|
170 |
+
if watermark is not None:
|
171 |
+
frame_h, frame_w = frame.shape[:2]
|
172 |
+
watermark_h, watermark_w = watermark.shape[:2]
|
173 |
+
if watermark_h > frame_h or watermark_w > frame_w:
|
174 |
+
watermark = transform.resize(watermark, (frame_h // 4, frame_w // 4))
|
175 |
+
watermark_h, watermark_w = watermark.shape[:2]
|
176 |
+
start_h = frame_h - watermark_h - 10
|
177 |
+
start_w = frame_w - watermark_w - 10
|
178 |
+
frame[start_h:start_h+watermark_h, start_w:start_w+watermark_w, :] = watermark
|
179 |
+
writer.append_data(img_as_ubyte(frame))
|
180 |
+
except Exception as e:
|
181 |
+
print(f"Error in video writing: {e}")
|
182 |
+
finally:
|
183 |
+
writer.close()
|
184 |
+
if audio_path is not None:
|
185 |
+
try:
|
186 |
+
command = "ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}".format(audio_path, output_path, output_path.replace('.mp4', '_with_audio.mp4'))
|
187 |
+
subprocess.call(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
188 |
+
os.remove(output_path)
|
189 |
+
os.rename(output_path.replace('.mp4', '_with_audio.mp4'), output_path)
|
190 |
+
except Exception as e:
|
191 |
+
print(f"Error adding audio to video: {e}")
|
192 |
+
|
193 |
+
def paste_pic(video_path, pic_path, crop_info, audio_path, output_path):
|
194 |
+
try:
|
195 |
+
y_start, y_end, x_start, x_end, old_size, cropped_size = crop_info[0][0], crop_info[0][1], crop_info[1][0], crop_info[1][1], crop_info[2], crop_info[3]
|
196 |
+
source_image_h, source_image_w = old_size
|
197 |
+
cropped_h, cropped_w = cropped_size
|
198 |
+
delta_h, delta_w = source_image_h - cropped_h, source_image_w - cropped_w
|
199 |
+
box = [x_start, y_start, source_image_w - x_end, source_image_h - y_end]
|
200 |
+
command = "ffmpeg -y -i {} -i {} -filter_complex \"[1]crop=w={}:h={}:x={}:y={},[s];[0][s]overlay=x={}:y={}\" -codec:a copy {}".format(video_path, pic_path, cropped_w, cropped_h, box[0], box[1], box[0], box[1], output_path)
|
201 |
+
subprocess.call(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
202 |
+
except Exception as e:
|
203 |
+
print(f"Error pasting picture to video: {e}")
|
204 |
+
|
205 |
+
def color_transfer_batch(source, target, mode='numpy'):
|
206 |
+
source_np = tensor2img(source)
|
207 |
+
target_np = tensor2img(target)
|
208 |
+
source_lab = cv2.cvtColor(source_np, cv2.COLOR_RGB2LAB).astype(np.float32)
|
209 |
+
target_lab = cv2.cvtColor(target_np, cv2.COLOR_RGB2LAB).astype(np.float32)
|
210 |
+
source_mu = np.mean(source_lab, axis=(0, 1), keepdims=True)
|
211 |
+
source_std = np.std(source_lab, axis=(0, 1), keepdims=True)
|
212 |
+
target_mu = np.mean(target_lab, axis=(0, 1), keepdims=True)
|
213 |
+
target_std = np.std(target_lab, axis=(0, 1), keepdims=True)
|
214 |
+
transfer_lab = (target_lab - target_mu) * (source_std / target_std) + source_mu
|
215 |
+
transfer_rgb = cv2.cvtColor(np.clip(transfer_lab, 0, 255).astype(np.uint8), cv2.COLOR_LAB2RGB)
|
216 |
+
transfer_rgb_tensor = img2tensor(transfer_rgb)
|
217 |
+
return transfer_rgb_tensor.unsqueeze(0).to(source.device)
|
218 |
+
|
219 |
+
def load_video_to_cv2(path, resize=None):
|
220 |
+
video = []
|
221 |
+
try:
|
222 |
+
cap = cv2.VideoCapture(path)
|
223 |
+
if not cap.isOpened():
|
224 |
+
raise Exception("Error opening video stream or file")
|
225 |
+
while(cap.isOpened()):
|
226 |
+
ret, frame = cap.read()
|
227 |
+
if ret:
|
228 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
229 |
+
if resize is not None:
|
230 |
+
frame_rgb = cv2.resize(frame_rgb, resize)
|
231 |
+
video.append(frame_rgb)
|
232 |
+
else:
|
233 |
+
break
|
234 |
+
cap.release()
|
235 |
+
except Exception as e:
|
236 |
+
print(f"Error loading video: {e}")
|
237 |
+
return video
|
238 |
+
|
239 |
+
def get_prior_from_bfm(bfm_path):
|
240 |
+
mat_path = os.path.join(bfm_path, 'BFM_prior.mat')
|
241 |
+
C = loadmat(mat_path)
|
242 |
+
pc_tex = torch.tensor(C['pc_tex'].astype(np.float32)).unsqueeze(0)
|
243 |
+
pc_exp = torch.tensor(C['pc_exp'].astype(np.float32)).unsqueeze(0)
|
244 |
+
u_tex = torch.tensor(C['u_tex'].astype(np.float32)).unsqueeze(0)
|
245 |
+
u_exp = torch.tensor(C['u_exp'].astype(np.float32)).unsqueeze(0)
|
246 |
+
prior_coeff = {
|
247 |
+
'pc_tex': pc_tex,
|
248 |
+
'pc_exp': pc_exp,
|
249 |
+
'u_tex': u_tex,
|
250 |
+
'u_exp': u_exp
|
251 |
+
}
|
252 |
+
return prior_coeff
|
image_to_3d_api.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import uuid
|
3 |
+
from flask import jsonify, send_file, request
|
4 |
+
from main import *
|
5 |
+
#from main import import image_to_3d_model, device
|
6 |
+
from PIL import Image
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
def image_to_3d_func(image_path, output_path="output_3d.obj"):
|
11 |
+
if image_to_3d_model is None:
|
12 |
+
return "Image-to-3D model not initialized."
|
13 |
+
pil_image = Image.open(image_path).convert("RGB")
|
14 |
+
image = torch.tensor(np.array(pil_image)).float().permute(2,0,1).unsqueeze(0) / 255.0
|
15 |
+
image = image.to(device)
|
16 |
+
with torch.no_grad():
|
17 |
+
mesh_obj = image_to_3d_model(image)
|
18 |
+
with open(output_path, 'w') as f:
|
19 |
+
f.write(mesh_obj)
|
20 |
+
return output_path
|
21 |
+
|
22 |
+
def image_to_3d_api():
|
23 |
+
if 'image' not in request.files:
|
24 |
+
return jsonify({"error": "Image file is required"}), 400
|
25 |
+
image_file = request.files['image']
|
26 |
+
temp_image_path = f"temp_image_{uuid.uuid4()}.png"
|
27 |
+
image_file.save(temp_image_path)
|
28 |
+
output_file = image_to_3d_func(temp_image_path)
|
29 |
+
os.remove(temp_image_path)
|
30 |
+
if output_file == "Image-to-3D model not initialized.":
|
31 |
+
return jsonify({"error": "Image to 3D failed"}), 500
|
32 |
+
return send_file(output_file, mimetype="model/obj", as_attachment=True, download_name="output_3d.obj")
|
imagegen_api.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from flask import jsonify, send_file, request
|
3 |
+
from io import BytesIO
|
4 |
+
from PIL import Image
|
5 |
+
from main import *
|
6 |
+
#from main import import imagegen_model, device
|
7 |
+
import torch
|
8 |
+
|
9 |
+
def generate_image(prompt, output_path="output_image.png"):
|
10 |
+
if imagegen_model is None:
|
11 |
+
return "Image generation model not initialized."
|
12 |
+
|
13 |
+
generator = torch.Generator(device=device).manual_seed(0)
|
14 |
+
image = imagegen_model(
|
15 |
+
prompt,
|
16 |
+
generator=generator,
|
17 |
+
).images[0]
|
18 |
+
image.save(output_path)
|
19 |
+
return output_path
|
20 |
+
|
21 |
+
def imagegen_api():
|
22 |
+
data = request.get_json()
|
23 |
+
prompt = data.get('prompt')
|
24 |
+
if not prompt:
|
25 |
+
return jsonify({"error": "Prompt is required"}), 400
|
26 |
+
output_file = generate_image(prompt)
|
27 |
+
if output_file == "Image generation model not initialized.":
|
28 |
+
return jsonify({"error": "Image generation failed"}), 500
|
29 |
+
image_io = BytesIO()
|
30 |
+
pil_image = Image.open(output_file)
|
31 |
+
pil_image.save(image_io, 'PNG')
|
32 |
+
image_io.seek(0)
|
33 |
+
return send_file(image_io, mimetype='image/png', as_attachment=True, download_name="output.png")
|
main.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
import queue
|
3 |
+
import time
|
4 |
+
import os
|
5 |
+
import nltk
|
6 |
+
import re
|
7 |
+
import json
|
8 |
+
from flask import Flask
|
9 |
+
from flask_cors import CORS
|
10 |
+
from api import *
|
11 |
+
from extensions import *
|
12 |
+
from constants import *
|
13 |
+
from configs import *
|
14 |
+
from tokenxxx import *
|
15 |
+
from models import *
|
16 |
+
from model_loader import *
|
17 |
+
from utils import *
|
18 |
+
from background_tasks import generate_and_queue_text, background_training, background_reasoning_queue
|
19 |
+
from text_generation import *
|
20 |
+
from sadtalker_utils import *
|
21 |
+
import torch
|
22 |
+
|
23 |
+
state_dict = None
|
24 |
+
enc = None
|
25 |
+
config = None
|
26 |
+
model_gpt2 = None
|
27 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
+
news_clf = None
|
29 |
+
tfidf_vectorizer = None
|
30 |
+
text_queue = queue.Queue()
|
31 |
+
categories = None
|
32 |
+
background_threads = []
|
33 |
+
feedback_queue = queue.Queue()
|
34 |
+
reasoning_queue = queue.Queue()
|
35 |
+
seen_responses = set()
|
36 |
+
dialogue_history = []
|
37 |
+
vocabulary = set()
|
38 |
+
word_to_index = {}
|
39 |
+
index_to_word = []
|
40 |
+
translation_model = None
|
41 |
+
sp = None
|
42 |
+
codegen_model = None
|
43 |
+
codegen_tokenizer = None
|
44 |
+
codegen_vocabulary = None
|
45 |
+
codegen_index_to_word = None
|
46 |
+
codegen_word_to_index = None
|
47 |
+
summarization_model = None
|
48 |
+
summarization_vocabulary = set()
|
49 |
+
summarization_word_to_index = {}
|
50 |
+
summarization_index_to_word = []
|
51 |
+
sadtalker_instance = None
|
52 |
+
imagegen_model = None
|
53 |
+
image_to_3d_model = None
|
54 |
+
text_to_video_model = None
|
55 |
+
stream_type = "text"
|
56 |
+
sentiment_model = None
|
57 |
+
stt_model = None
|
58 |
+
tts_model = None
|
59 |
+
musicgen_model = None
|
60 |
+
|
61 |
+
def load_models():
|
62 |
+
global model_gpt2, enc, translation_model, codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index, summarization_model, imagegen_model, image_to_3d_model, text_to_video_model, sadtalker_instance, sentiment_model, stt_model, tts_model, musicgen_model, checkpoint_path, gfpgan_model_file, restoreformer_model_file, codeformer_model_file, realesrgan_model_file, kp_file, aud_file, wav_file, gen_file, mapx_file, den_file
|
63 |
+
model_gpt2, enc = initialize_gpt2_model(GPT2_FOLDER, {MODEL_FILE: MODEL_URL, ENCODER_FILE: ENCODER_URL, VOCAB_FILE: VOCAB_URL, CONFIG_FILE: GPT2CONFHG})
|
64 |
+
translation_model = initialize_translation_model(TRANSLATION_FOLDER, TRANSLATION_MODEL_FILES_URLS)
|
65 |
+
codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index = initialize_codegen_model(CODEGEN_FOLDER, CODEGEN_FILES_URLS)
|
66 |
+
summarization_model, _, _, _ = initialize_summarization_model(SUMMARIZATION_FOLDER, SUMMARIZATION_FILES_URLS)
|
67 |
+
imagegen_model = initialize_imagegen_model(IMAGEGEN_FOLDER, IMAGEGEN_FILES_URLS)
|
68 |
+
image_to_3d_model = initialize_image_to_3d_model(IMAGE_TO_3D_FOLDER, IMAGE_TO_3D_FILES_URLS)
|
69 |
+
text_to_video_model = initialize_text_to_video_model(TEXT_TO_VIDEO_FOLDER, TEXT_TO_VIDEO_FILES_URLS)
|
70 |
+
sentiment_model = initialize_sentiment_model(SENTIMENT_FOLDER, SENTIMENT_FILES_URLS)
|
71 |
+
stt_model = initialize_stt_model(STT_FOLDER, STT_FILES_URLS)
|
72 |
+
tts_model = initialize_tts_model(TTS_FOLDER, TTS_FILES_URLS)
|
73 |
+
musicgen_model = initialize_musicgen_model(MUSICGEN_FOLDER, MUSICGEN_FILES_URLS)
|
74 |
+
|
75 |
+
class SimpleClassifier(torch.nn.Module):
|
76 |
+
def __init__(self, vocab_size, num_classes):
|
77 |
+
super(SimpleClassifier, self).__init__()
|
78 |
+
self.embedding = torch.nn.Embedding(vocab_size, 128)
|
79 |
+
self.linear = torch.nn.Linear(128, num_classes)
|
80 |
+
def forward(self, x):
|
81 |
+
embedded = self.embedding(x)
|
82 |
+
pooled = torch.mean(embedded, dim=1)
|
83 |
+
return self.linear(pooled)
|
84 |
+
|
85 |
+
def tokenize_text(text):
|
86 |
+
global vocabulary, word_to_index, index_to_word
|
87 |
+
tokens = text.lower().split()
|
88 |
+
for token in tokens:
|
89 |
+
if token not in vocabulary:
|
90 |
+
vocabulary.add(token)
|
91 |
+
word_to_index[token] = len(index_to_word)
|
92 |
+
index_to_word.append(token)
|
93 |
+
return tokens
|
94 |
+
|
95 |
+
def text_to_vector(text):
|
96 |
+
global vocabulary, word_to_index
|
97 |
+
tokens = tokenize_text(text)
|
98 |
+
vector = torch.zeros(len(vocabulary))
|
99 |
+
for token in tokens:
|
100 |
+
if token in word_to_index:
|
101 |
+
vector[word_to_index[token]] += 1
|
102 |
+
return vector
|
103 |
+
|
104 |
+
if __name__ == "__main__":
|
105 |
+
nltk.download('punkt')
|
106 |
+
load_models()
|
107 |
+
categories = ['Category1', 'Category2', 'Category3', 'Category4', 'Category5']
|
108 |
+
import background_tasks
|
109 |
+
background_tasks.categories = categories
|
110 |
+
background_tasks.text_queue = text_queue
|
111 |
+
background_tasks.reasoning_queue = reasoning_queue
|
112 |
+
background_threads.append(threading.Thread(target=generate_and_queue_text, args=('en',), daemon=True))
|
113 |
+
background_threads.append(threading.Thread(target=generate_and_queue_text, args=('es',), daemon=True))
|
114 |
+
background_threads.append(threading.Thread(target=background_training, daemon=True))
|
115 |
+
background_threads.append(threading.Thread(target=background_reasoning_queue, daemon=True))
|
116 |
+
for thread in background_threads:
|
117 |
+
thread.start()
|
118 |
+
app.run(host='0.0.0.0', port=7860)
|
model_loader.py
ADDED
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import urllib.request
|
4 |
+
import urllib.parse
|
5 |
+
import torch
|
6 |
+
import hashlib
|
7 |
+
from tqdm import tqdm
|
8 |
+
from skimage import img_as_ubyte
|
9 |
+
from torch import nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import inspect
|
12 |
+
|
13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
|
15 |
+
def filter_kwargs(cls, kwargs):
|
16 |
+
sig = inspect.signature(cls.__init__)
|
17 |
+
accepted = set(sig.parameters.keys()) - {"self"}
|
18 |
+
return {k: v for k, v in kwargs.items() if k in accepted}
|
19 |
+
|
20 |
+
def sanitize_filename(name, url=None):
|
21 |
+
for c in '<>:"/\\|?*':
|
22 |
+
name = name.replace(c, '')
|
23 |
+
if not name and url is not None:
|
24 |
+
name = hashlib.md5(url.encode()).hexdigest()
|
25 |
+
return name
|
26 |
+
|
27 |
+
def download_file(url, filepath):
|
28 |
+
d = os.path.dirname(filepath)
|
29 |
+
if d and not os.path.exists(d):
|
30 |
+
os.makedirs(d, exist_ok=True)
|
31 |
+
if not os.path.exists(filepath):
|
32 |
+
def prog(t):
|
33 |
+
last = [0]
|
34 |
+
def inner(n, bs, ts):
|
35 |
+
if ts > 0:
|
36 |
+
t.total = ts
|
37 |
+
t.update(n * bs - last[0])
|
38 |
+
last[0] = n * bs
|
39 |
+
return inner
|
40 |
+
with tqdm(unit='B', unit_scale=True, unit_divisor=1024, desc=os.path.basename(filepath)) as t:
|
41 |
+
urllib.request.urlretrieve(url, filepath, reporthook=prog(t))
|
42 |
+
|
43 |
+
def download_files(folder, files_spec):
|
44 |
+
if isinstance(files_spec, dict):
|
45 |
+
for fn, url in files_spec.items():
|
46 |
+
fn = sanitize_filename(fn, url)
|
47 |
+
fp = os.path.join(folder, fn)
|
48 |
+
download_file(url, fp)
|
49 |
+
elif isinstance(files_spec, list):
|
50 |
+
for item in files_spec:
|
51 |
+
if isinstance(item, str):
|
52 |
+
url = item
|
53 |
+
parsed = urllib.parse.urlparse(url)
|
54 |
+
fn = os.path.basename(parsed.path)
|
55 |
+
if not fn:
|
56 |
+
fn = hashlib.md5(url.encode()).hexdigest()
|
57 |
+
fn = sanitize_filename(fn, url)
|
58 |
+
elif isinstance(item, (list, tuple)) and len(item) == 2:
|
59 |
+
url, fn = item
|
60 |
+
fn = sanitize_filename(fn, url)
|
61 |
+
elif isinstance(item, dict) and "filename" in item and "url" in item:
|
62 |
+
fn = sanitize_filename(item["filename"], item["url"])
|
63 |
+
url = item["url"]
|
64 |
+
else:
|
65 |
+
raise ValueError("Invalid file specification")
|
66 |
+
fp = os.path.join(folder, fn)
|
67 |
+
download_file(url, fp)
|
68 |
+
else:
|
69 |
+
raise ValueError("files_spec must be dict or list")
|
70 |
+
|
71 |
+
def read_json(fp):
|
72 |
+
with open(fp, 'r', encoding='utf-8') as f:
|
73 |
+
return json.load(f)
|
74 |
+
|
75 |
+
def get_codegen_tokenizer(vocab_path, merges_path):
|
76 |
+
with open(vocab_path, 'r', encoding='utf-8') as f:
|
77 |
+
vocab = json.load(f)
|
78 |
+
with open(merges_path, 'r', encoding='utf-8') as f:
|
79 |
+
merges = f.read().splitlines()
|
80 |
+
def tokenizer(text):
|
81 |
+
toks = text.split()
|
82 |
+
return [vocab.get(t, 0) for t in toks]
|
83 |
+
return tokenizer
|
84 |
+
|
85 |
+
def simple_tokenizer(text, vocab, max_length=77):
|
86 |
+
toks = text.split()
|
87 |
+
ids = [vocab.get(t, 1) for t in toks]
|
88 |
+
if len(ids) < max_length:
|
89 |
+
ids = ids + [0]*(max_length - len(ids))
|
90 |
+
else:
|
91 |
+
ids = ids[:max_length]
|
92 |
+
return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(device)
|
93 |
+
|
94 |
+
def load_state_dict_safe(model, loaded_state_dict):
|
95 |
+
model_state = model.state_dict()
|
96 |
+
new_state = {}
|
97 |
+
for key, value in model_state.items():
|
98 |
+
if key in loaded_state_dict and loaded_state_dict[key].shape == value.shape:
|
99 |
+
new_state[key] = loaded_state_dict[key]
|
100 |
+
else:
|
101 |
+
new_state[key] = value
|
102 |
+
model.load_state_dict(new_state, strict=False)
|
103 |
+
|
104 |
+
class GPT2Config:
|
105 |
+
def __init__(self, vocab_size=50257, **kwargs):
|
106 |
+
self.vocab_size = vocab_size
|
107 |
+
self.__dict__.update(kwargs)
|
108 |
+
@classmethod
|
109 |
+
def from_dict(cls, d):
|
110 |
+
return cls(**d)
|
111 |
+
|
112 |
+
class MBartConfig:
|
113 |
+
def __init__(self, vocab_size=50265, **kwargs):
|
114 |
+
self.vocab_size = vocab_size
|
115 |
+
self.__dict__.update(kwargs)
|
116 |
+
@classmethod
|
117 |
+
def from_dict(cls, d):
|
118 |
+
return cls(**d)
|
119 |
+
|
120 |
+
class CodeGenConfig:
|
121 |
+
def __init__(self, vocab_size=50257, **kwargs):
|
122 |
+
self.vocab_size = vocab_size
|
123 |
+
self.__dict__.update(kwargs)
|
124 |
+
@classmethod
|
125 |
+
def from_dict(cls, d):
|
126 |
+
return cls(**d)
|
127 |
+
|
128 |
+
class BartConfig:
|
129 |
+
def __init__(self, vocab_size=50265, **kwargs):
|
130 |
+
self.vocab_size = vocab_size
|
131 |
+
self.__dict__.update(kwargs)
|
132 |
+
@classmethod
|
133 |
+
def from_dict(cls, d):
|
134 |
+
return cls(**d)
|
135 |
+
|
136 |
+
class AutoencoderKLConfig:
|
137 |
+
def __init__(self, **kwargs):
|
138 |
+
self.__dict__.update(kwargs)
|
139 |
+
@classmethod
|
140 |
+
def from_dict(cls, d):
|
141 |
+
return cls(**d)
|
142 |
+
|
143 |
+
class OpenLRMConfig:
|
144 |
+
def __init__(self, **kwargs):
|
145 |
+
self.__dict__.update(kwargs)
|
146 |
+
@classmethod
|
147 |
+
def from_dict(cls, d):
|
148 |
+
return cls(**d)
|
149 |
+
|
150 |
+
class UNet2DConditionModelConfig:
|
151 |
+
def __init__(self, **kwargs):
|
152 |
+
self.__dict__.update(kwargs)
|
153 |
+
@classmethod
|
154 |
+
def from_dict(cls, d):
|
155 |
+
return cls(**d)
|
156 |
+
|
157 |
+
class MusicGenConfig:
|
158 |
+
def __init__(self, **kwargs):
|
159 |
+
self.__dict__.update(kwargs)
|
160 |
+
@classmethod
|
161 |
+
def from_dict(cls, d):
|
162 |
+
return cls(**d)
|
163 |
+
|
164 |
+
class GPT2LMHeadModel(nn.Module):
|
165 |
+
def __init__(self, config):
|
166 |
+
super().__init__()
|
167 |
+
layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
|
168 |
+
self.transformer = nn.TransformerEncoder(layer, num_layers=12)
|
169 |
+
self.lm_head = nn.Linear(768, config.vocab_size)
|
170 |
+
def forward(self, x):
|
171 |
+
return self.lm_head(self.transformer(x))
|
172 |
+
|
173 |
+
class MBartForConditionalGeneration(nn.Module):
|
174 |
+
def __init__(self, config):
|
175 |
+
super().__init__()
|
176 |
+
self.config = config
|
177 |
+
layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
|
178 |
+
self.encoder = nn.TransformerEncoder(layer, num_layers=6)
|
179 |
+
dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12)
|
180 |
+
self.decoder = nn.TransformerDecoder(dlayer, num_layers=6)
|
181 |
+
self.output_layer = nn.Linear(768, config.vocab_size)
|
182 |
+
def forward(self, src, tgt):
|
183 |
+
return self.output_layer(self.decoder(tgt, self.encoder(src)))
|
184 |
+
|
185 |
+
class CodeGenForCausalLM(nn.Module):
|
186 |
+
def __init__(self, config):
|
187 |
+
super().__init__()
|
188 |
+
d_model = getattr(config, "d_model", 1024)
|
189 |
+
n_head = getattr(config, "n_head", 16)
|
190 |
+
num_layers = getattr(config, "num_layers", 12)
|
191 |
+
dlayer = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_head)
|
192 |
+
self.transformer_decoder = nn.TransformerDecoder(dlayer, num_layers=num_layers)
|
193 |
+
self.lm_head = nn.Linear(d_model, config.vocab_size)
|
194 |
+
def forward(self, tgt, memory=None):
|
195 |
+
if memory is None:
|
196 |
+
memory = torch.zeros_like(tgt)
|
197 |
+
return self.lm_head(self.transformer_decoder(tgt, memory))
|
198 |
+
|
199 |
+
class BartForConditionalGeneration(nn.Module):
|
200 |
+
def __init__(self, config):
|
201 |
+
super().__init__()
|
202 |
+
layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
|
203 |
+
self.encoder = nn.TransformerEncoder(layer, num_layers=6)
|
204 |
+
dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12)
|
205 |
+
self.decoder = nn.TransformerDecoder(dlayer, num_layers=6)
|
206 |
+
self.output_layer = nn.Linear(768, config.vocab_size)
|
207 |
+
def forward(self, src, tgt):
|
208 |
+
return self.output_layer(self.decoder(tgt, self.encoder(src)))
|
209 |
+
|
210 |
+
class ResnetBlock(nn.Module):
|
211 |
+
def __init__(self, in_ch, out_ch):
|
212 |
+
super().__init__()
|
213 |
+
self.norm1 = nn.GroupNorm(32, in_ch)
|
214 |
+
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
|
215 |
+
self.norm2 = nn.GroupNorm(32, out_ch)
|
216 |
+
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
|
217 |
+
self.conv_shortcut = nn.Conv2d(in_ch, out_ch, 1)
|
218 |
+
def forward(self, x):
|
219 |
+
sc = self.conv_shortcut(x)
|
220 |
+
h = F.silu(self.norm1(x))
|
221 |
+
h = self.conv1(h)
|
222 |
+
h = F.silu(self.norm2(h))
|
223 |
+
h = self.conv2(h)
|
224 |
+
return h + sc
|
225 |
+
|
226 |
+
class Downsample(nn.Module):
|
227 |
+
def __init__(self, in_ch, out_ch):
|
228 |
+
super().__init__()
|
229 |
+
self.conv = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1)
|
230 |
+
def forward(self, x):
|
231 |
+
return self.conv(x)
|
232 |
+
|
233 |
+
class DownBlock(nn.Module):
|
234 |
+
def __init__(self, in_ch, out_ch, num_res):
|
235 |
+
super().__init__()
|
236 |
+
self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)])
|
237 |
+
self.downsamplers = nn.ModuleList([Downsample(out_ch, out_ch)])
|
238 |
+
def forward(self, x):
|
239 |
+
for r in self.resnets:
|
240 |
+
x = r(x)
|
241 |
+
for ds in self.downsamplers:
|
242 |
+
x = ds(x)
|
243 |
+
return x
|
244 |
+
|
245 |
+
class Upsample(nn.Module):
|
246 |
+
def __init__(self, in_ch, out_ch):
|
247 |
+
super().__init__()
|
248 |
+
self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
|
249 |
+
def forward(self, x):
|
250 |
+
return self.conv(x)
|
251 |
+
|
252 |
+
class UpBlock(nn.Module):
|
253 |
+
def __init__(self, in_ch, out_ch, num_res):
|
254 |
+
super().__init__()
|
255 |
+
self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)])
|
256 |
+
self.upsampler = Upsample(out_ch, out_ch)
|
257 |
+
def forward(self, x):
|
258 |
+
for r in self.resnets:
|
259 |
+
x = r(x)
|
260 |
+
return self.upsampler(x)
|
261 |
+
|
262 |
+
class AttentionBlock(nn.Module):
|
263 |
+
def __init__(self, ch):
|
264 |
+
super().__init__()
|
265 |
+
self.norm = nn.GroupNorm(32, ch)
|
266 |
+
self.query = nn.Conv2d(ch, ch, 1)
|
267 |
+
self.key = nn.Conv2d(ch, ch, 1)
|
268 |
+
self.value = nn.Conv2d(ch, ch, 1)
|
269 |
+
self.proj_attn = nn.Conv2d(ch, ch, 1)
|
270 |
+
def forward(self, x):
|
271 |
+
b, c, h, w = x.shape
|
272 |
+
xn = self.norm(x)
|
273 |
+
q = self.query(xn).view(b, c, -1).permute(0, 2, 1)
|
274 |
+
k = self.key(xn).view(b, c, -1)
|
275 |
+
v = self.value(xn).view(b, c, -1).permute(0, 2, 1)
|
276 |
+
attn = torch.softmax(torch.bmm(q, k) / (c ** 0.5), dim=-1)
|
277 |
+
out = torch.bmm(attn, v).permute(0, 2, 1).view(b, c, h, w)
|
278 |
+
return x + self.proj_attn(out)
|
279 |
+
|
280 |
+
class Encoder(nn.Module):
|
281 |
+
def __init__(self, in_ch=3, base_ch=128, latent_ch=4):
|
282 |
+
super().__init__()
|
283 |
+
self.conv_in = nn.Conv2d(in_ch, base_ch, 3, padding=1)
|
284 |
+
self.down_blocks = nn.ModuleList([
|
285 |
+
DownBlock(base_ch, base_ch, 2),
|
286 |
+
DownBlock(base_ch, base_ch * 2, 2),
|
287 |
+
DownBlock(base_ch * 2, base_ch * 4, 2),
|
288 |
+
DownBlock(base_ch * 4, base_ch * 4, 2)
|
289 |
+
])
|
290 |
+
self.mid_block = nn.ModuleList([
|
291 |
+
ResnetBlock(base_ch * 4, base_ch * 4),
|
292 |
+
AttentionBlock(base_ch * 4),
|
293 |
+
ResnetBlock(base_ch * 4, base_ch * 4)
|
294 |
+
])
|
295 |
+
self.conv_norm_out = nn.GroupNorm(32, base_ch * 4)
|
296 |
+
self.conv_out = nn.Conv2d(base_ch * 4, latent_ch * 2, 3, padding=1)
|
297 |
+
self.quant_conv = nn.Conv2d(latent_ch * 2, latent_ch, 1)
|
298 |
+
def forward(self, x):
|
299 |
+
x = self.conv_in(x)
|
300 |
+
for blk in self.down_blocks:
|
301 |
+
x = blk(x)
|
302 |
+
for m in self.mid_block:
|
303 |
+
x = m(x)
|
304 |
+
x = self.conv_norm_out(x)
|
305 |
+
x = self.conv_out(x)
|
306 |
+
return self.quant_conv(x)
|
307 |
+
|
308 |
+
class Decoder(nn.Module):
|
309 |
+
def __init__(self, out_ch=3, base_ch=128, latent_ch=4):
|
310 |
+
super().__init__()
|
311 |
+
self.post_quant_conv = nn.Conv2d(latent_ch, latent_ch * 2, 1)
|
312 |
+
self.conv_in = nn.Conv2d(latent_ch, base_ch * 4, 3, padding=1)
|
313 |
+
self.mid_block = nn.ModuleList([
|
314 |
+
ResnetBlock(base_ch * 4, base_ch * 4),
|
315 |
+
AttentionBlock(base_ch * 4),
|
316 |
+
ResnetBlock(base_ch * 4, base_ch * 4)
|
317 |
+
])
|
318 |
+
self.up_blocks = nn.ModuleList([
|
319 |
+
UpBlock(base_ch * 4, base_ch * 4, 3),
|
320 |
+
UpBlock(base_ch * 4, base_ch * 2, 3),
|
321 |
+
UpBlock(base_ch * 2, base_ch, 3),
|
322 |
+
UpBlock(base_ch, base_ch, 3)
|
323 |
+
])
|
324 |
+
self.conv_norm_out = nn.GroupNorm(32, base_ch)
|
325 |
+
self.conv_out = nn.Conv2d(base_ch, out_ch, 3, padding=1)
|
326 |
+
def forward(self, x):
|
327 |
+
x = self.post_quant_conv(x)
|
328 |
+
x = self.conv_in(x)
|
329 |
+
for m in self.mid_block:
|
330 |
+
x = m(x)
|
331 |
+
for up in self.up_blocks:
|
332 |
+
x = up(x)
|
333 |
+
x = self.conv_norm_out(x)
|
334 |
+
return self.conv_out(x)
|
335 |
+
|
336 |
+
class AutoencoderKL(nn.Module):
|
337 |
+
def __init__(self, config):
|
338 |
+
super().__init__()
|
339 |
+
in_ch = config.get("in_channels", 3) if isinstance(config, dict) else config.__dict__.get("in_channels", 3)
|
340 |
+
out_ch = config.get("out_channels", 3) if isinstance(config, dict) else config.__dict__.get("out_channels", 3)
|
341 |
+
base_ch = config.get("base_channels", 128) if isinstance(config, dict) else config.__dict__.get("base_channels", 128)
|
342 |
+
latent_ch = config.get("latent_channels", 4) if isinstance(config, dict) else config.__dict__.get("latent_channels", 4)
|
343 |
+
self.encoder = Encoder(in_ch, base_ch, latent_ch)
|
344 |
+
self.decoder = Decoder(out_ch, base_ch, latent_ch)
|
345 |
+
def forward(self, x):
|
346 |
+
return self.decoder(self.encoder(x))
|
347 |
+
def decode(self, x):
|
348 |
+
return self.decoder(x)
|
349 |
+
|
350 |
+
class TransformerBlock(nn.Module):
|
351 |
+
def __init__(self, embed_dim, num_heads):
|
352 |
+
super().__init__()
|
353 |
+
self.norm1 = nn.LayerNorm(embed_dim)
|
354 |
+
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
355 |
+
self.norm2 = nn.LayerNorm(embed_dim)
|
356 |
+
hidden_dim = embed_dim * 4
|
357 |
+
self.mlp = nn.Sequential(
|
358 |
+
nn.Linear(embed_dim, hidden_dim),
|
359 |
+
nn.GELU(),
|
360 |
+
nn.Linear(hidden_dim, embed_dim)
|
361 |
+
)
|
362 |
+
def forward(self, x):
|
363 |
+
res = x
|
364 |
+
x = self.norm1(x)
|
365 |
+
x = x.transpose(0, 1)
|
366 |
+
attn, _ = self.attn(x, x, x)
|
367 |
+
x = attn.transpose(0, 1)
|
368 |
+
x = res + x
|
369 |
+
return x + self.mlp(self.norm2(x))
|
370 |
+
|
371 |
+
class VisionTransformer(nn.Module):
|
372 |
+
def __init__(self, config):
|
373 |
+
super().__init__()
|
374 |
+
if isinstance(config, dict):
|
375 |
+
self.img_size = config.get("img_size", 592)
|
376 |
+
self.patch_size = config.get("patch_size", 16)
|
377 |
+
self.embed_dim = config.get("hidden_size", 768)
|
378 |
+
depth = config.get("depth", 12)
|
379 |
+
num_heads = config.get("num_heads", 12)
|
380 |
+
else:
|
381 |
+
self.img_size = config.__dict__.get("img_size", 592)
|
382 |
+
self.patch_size = config.__dict__.get("patch_size", 16)
|
383 |
+
self.embed_dim = config.__dict__.get("hidden_size", 768)
|
384 |
+
depth = config.__dict__.get("depth", 12)
|
385 |
+
num_heads = config.__dict__.get("num_heads", 12)
|
386 |
+
num_patches = (self.img_size // self.patch_size) ** 2
|
387 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
388 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
|
389 |
+
self.patch_embed = nn.Conv2d(3, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
|
390 |
+
self.blocks = nn.ModuleList([TransformerBlock(self.embed_dim, num_heads) for _ in range(depth)])
|
391 |
+
self.norm = nn.LayerNorm(self.embed_dim)
|
392 |
+
self.register_tokens = nn.Parameter(torch.zeros(1, 4, self.embed_dim))
|
393 |
+
self._init_weights()
|
394 |
+
def _init_weights(self):
|
395 |
+
nn.init.normal_(self.cls_token, std=0.02)
|
396 |
+
nn.init.normal_(self.pos_embed, std=0.02)
|
397 |
+
def forward(self, x):
|
398 |
+
x = self.patch_embed(x)
|
399 |
+
x = x.flatten(2).transpose(1, 2)
|
400 |
+
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
|
401 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
402 |
+
x = x + self.pos_embed
|
403 |
+
for blk in self.blocks:
|
404 |
+
x = blk(x)
|
405 |
+
return self.norm(x)[:, 0]
|
406 |
+
|
407 |
+
class OpenLRM(nn.Module):
|
408 |
+
def __init__(self, config):
|
409 |
+
super().__init__()
|
410 |
+
self.encoder = nn.ModuleDict({"model": VisionTransformer(config)})
|
411 |
+
hidden = config.get("hidden_size", 768) if isinstance(config, dict) else config.__dict__.get("hidden_size", 768)
|
412 |
+
self.linear = nn.Linear(hidden, hidden)
|
413 |
+
def forward(self, x):
|
414 |
+
return self.linear(self.encoder["model"](x))
|
415 |
+
|
416 |
+
class VideoUNet(nn.Module):
|
417 |
+
def __init__(self, in_ch=4, out_ch=4, features=None):
|
418 |
+
super().__init__()
|
419 |
+
if features is None:
|
420 |
+
features = [64, 128, 256]
|
421 |
+
self.encoder = nn.ModuleList()
|
422 |
+
self.pool = nn.MaxPool3d(2, 2)
|
423 |
+
self.decoder = nn.ModuleList()
|
424 |
+
for f in features:
|
425 |
+
self.encoder.append(nn.Sequential(
|
426 |
+
nn.Conv3d(in_ch, f, 3, padding=1),
|
427 |
+
nn.ReLU(inplace=True),
|
428 |
+
nn.Conv3d(f, f, 3, padding=1),
|
429 |
+
nn.ReLU(inplace=True)
|
430 |
+
))
|
431 |
+
in_ch = f
|
432 |
+
for f in reversed(features):
|
433 |
+
self.decoder.append(nn.Sequential(
|
434 |
+
nn.Conv3d(f * 2, f, 3, padding=1),
|
435 |
+
nn.ReLU(inplace=True),
|
436 |
+
nn.Conv3d(f, f, 3, padding=1),
|
437 |
+
nn.ReLU(inplace=True)
|
438 |
+
))
|
439 |
+
self.final_conv = nn.Conv3d(features[0], out_ch, 1)
|
440 |
+
def forward(self, x, t, encoder_hidden_states):
|
441 |
+
skips = []
|
442 |
+
for enc in self.encoder:
|
443 |
+
x = enc(x)
|
444 |
+
skips.append(x)
|
445 |
+
x = self.pool(x)
|
446 |
+
for dec in self.decoder:
|
447 |
+
skip = skips.pop()
|
448 |
+
x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=False)
|
449 |
+
x = torch.cat([x, skip], dim=1)
|
450 |
+
x = dec(x)
|
451 |
+
return self.final_conv(x)
|
452 |
+
|
453 |
+
class SentimentClassifierModel(nn.Module):
|
454 |
+
def __init__(self, config):
|
455 |
+
super().__init__()
|
456 |
+
self.classifier = nn.Sequential(
|
457 |
+
nn.Linear(768, 256),
|
458 |
+
nn.ReLU(),
|
459 |
+
nn.Linear(256, 2)
|
460 |
+
)
|
461 |
+
def forward(self, x):
|
462 |
+
return self.classifier(x)
|
463 |
+
|
464 |
+
class STTModel(nn.Module):
|
465 |
+
def __init__(self, config):
|
466 |
+
super().__init__()
|
467 |
+
self.net = nn.Sequential(
|
468 |
+
nn.Linear(768, 512),
|
469 |
+
nn.ReLU(),
|
470 |
+
nn.Linear(512, 768)
|
471 |
+
)
|
472 |
+
def forward(self, x):
|
473 |
+
return self.net(x)
|
474 |
+
|
475 |
+
class TTSModel(nn.Module):
|
476 |
+
def __init__(self, config):
|
477 |
+
super().__init__()
|
478 |
+
self.net = nn.Sequential(
|
479 |
+
nn.Linear(768, 512),
|
480 |
+
nn.ReLU(),
|
481 |
+
nn.Linear(512, 768)
|
482 |
+
)
|
483 |
+
def forward(self, x):
|
484 |
+
return self.net(x)
|
485 |
+
|
486 |
+
class MusicGenModel(nn.Module):
|
487 |
+
def __init__(self, config):
|
488 |
+
super().__init__()
|
489 |
+
layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
|
490 |
+
self.transformer = nn.TransformerEncoder(layer, num_layers=12)
|
491 |
+
self.linear = nn.Linear(768, 768)
|
492 |
+
def forward(self, x):
|
493 |
+
return self.linear(self.transformer(x))
|
494 |
+
|
495 |
+
class SimpleTextEncoder(nn.Module):
|
496 |
+
def __init__(self, vocab_size=10000, embed_dim=768, max_length=77):
|
497 |
+
super().__init__()
|
498 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
499 |
+
self.max_length = max_length
|
500 |
+
def forward(self, text_tokens):
|
501 |
+
return self.embedding(text_tokens)
|
502 |
+
|
503 |
+
class DiffusionScheduler:
|
504 |
+
def __init__(self, steps):
|
505 |
+
self.steps = steps
|
506 |
+
self.betas = torch.linspace(0.1, 0.001, steps=steps).to(device)
|
507 |
+
def step(self, noise, t, sample):
|
508 |
+
beta = self.betas[t]
|
509 |
+
return sample - beta * noise
|
510 |
+
|
511 |
+
class VideoOutput:
|
512 |
+
def __init__(self, frames):
|
513 |
+
self.frames = [img_as_ubyte(frame) for frame in frames[0]]
|
514 |
+
|
515 |
+
class VideoPipeline(nn.Module):
|
516 |
+
def __init__(self, unet, vae, text_encoder, vocab):
|
517 |
+
super().__init__()
|
518 |
+
self.unet = unet
|
519 |
+
self.vae = vae
|
520 |
+
self.text_encoder = text_encoder
|
521 |
+
self.vocab = vocab
|
522 |
+
def forward(self, prompt: str, steps: int = 25, num_frames: int = 24):
|
523 |
+
token_ids = simple_tokenizer(prompt, self.vocab)
|
524 |
+
text_emb = self.text_encoder(token_ids)
|
525 |
+
latent = torch.randn((1, 4, num_frames, 64, 64), device=device).half()
|
526 |
+
sched = DiffusionScheduler(steps)
|
527 |
+
for t in range(steps):
|
528 |
+
noise = self.unet(latent, t, text_emb)
|
529 |
+
latent = sched.step(noise, t, latent)
|
530 |
+
frames = self.vae.decode(latent / 0.18215)
|
531 |
+
frames = frames.clamp(0, 1).float().cpu().permute(0, 2, 3, 4, 1).numpy()
|
532 |
+
return VideoOutput(frames)
|
533 |
+
|
534 |
+
def initialize_gpt2_model(folder, files):
|
535 |
+
download_files(folder, files)
|
536 |
+
config = GPT2Config()
|
537 |
+
model = GPT2LMHeadModel(config).to(device)
|
538 |
+
sd = torch.load(os.path.join(folder, sanitize_filename("gpt2-pytorch_model.bin")), map_location=device)
|
539 |
+
load_state_dict_safe(model, sd)
|
540 |
+
model.eval()
|
541 |
+
enc = read_json(os.path.join(folder, sanitize_filename("encoder.json")))
|
542 |
+
return model, enc
|
543 |
+
|
544 |
+
def initialize_translation_model(folder, files):
|
545 |
+
download_files(folder, files)
|
546 |
+
config = MBartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
547 |
+
model = MBartForConditionalGeneration(config).to(device)
|
548 |
+
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
|
549 |
+
load_state_dict_safe(model, sd)
|
550 |
+
model.eval()
|
551 |
+
vp = os.path.join(folder, "vocab.json")
|
552 |
+
if os.path.exists(vp):
|
553 |
+
vocab = read_json(vp)
|
554 |
+
model.tokenizer = lambda txt: [vocab.get(t, 0) for t in txt.split()]
|
555 |
+
else:
|
556 |
+
model.tokenizer = lambda txt: txt
|
557 |
+
model.config.lang_code_to_id = {'en_XX': 0, 'es_XX': 1}
|
558 |
+
return model
|
559 |
+
|
560 |
+
def initialize_codegen_model(folder, files):
|
561 |
+
download_files(folder, files)
|
562 |
+
config = CodeGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
563 |
+
model = CodeGenForCausalLM(config).to(device)
|
564 |
+
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
|
565 |
+
load_state_dict_safe(model, sd)
|
566 |
+
model.eval()
|
567 |
+
tok = get_codegen_tokenizer(os.path.join(folder, "vocab.json"), os.path.join(folder, "merges.txt"))
|
568 |
+
vocab = read_json(os.path.join(folder, "vocab.json"))
|
569 |
+
idx2w = {v: k for k, v in vocab.items()}
|
570 |
+
model.tokenizer = tok
|
571 |
+
return model, tok, vocab, idx2w, vocab
|
572 |
+
|
573 |
+
def initialize_summarization_model(folder, files):
|
574 |
+
download_files(folder, files)
|
575 |
+
config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
576 |
+
model = BartForConditionalGeneration(config).to(device)
|
577 |
+
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
|
578 |
+
load_state_dict_safe(model, sd)
|
579 |
+
model.eval()
|
580 |
+
vp = os.path.join(folder, "vocab.json")
|
581 |
+
if os.path.exists(vp):
|
582 |
+
vocab_json = read_json(vp)
|
583 |
+
vocab = set(vocab_json.keys())
|
584 |
+
return model, vocab, vocab_json, {v: k for k, v in vocab_json.items()}
|
585 |
+
return model, None, None, None
|
586 |
+
|
587 |
+
def initialize_imagegen_model(folder, files):
|
588 |
+
download_files(folder, files)
|
589 |
+
config = AutoencoderKLConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
590 |
+
vae = AutoencoderKL(config).to(device)
|
591 |
+
sd = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device)
|
592 |
+
load_state_dict_safe(vae, sd)
|
593 |
+
vae.eval()
|
594 |
+
return vae
|
595 |
+
|
596 |
+
def initialize_image_to_3d_model(folder, files):
|
597 |
+
download_files(folder, files)
|
598 |
+
config = OpenLRMConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
599 |
+
model3d = OpenLRM(config).to(device)
|
600 |
+
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
|
601 |
+
load_state_dict_safe(model3d, sd)
|
602 |
+
model3d.eval()
|
603 |
+
return model3d
|
604 |
+
|
605 |
+
def initialize_text_to_video_model(folder, files):
|
606 |
+
download_files(folder, files)
|
607 |
+
unet_cfg = read_json(os.path.join(folder, "config.json"))
|
608 |
+
unet_cfg = filter_kwargs(VideoUNet, unet_cfg)
|
609 |
+
unet = VideoUNet(**unet_cfg).half().to(device)
|
610 |
+
sd_unet = torch.load(os.path.join(folder, "diffusion_pytorch_model.fp16.bin"), map_location=device)
|
611 |
+
load_state_dict_safe(unet, sd_unet)
|
612 |
+
unet.eval()
|
613 |
+
vae_cfg = read_json(os.path.join(folder, "config.json"))
|
614 |
+
vae_cfg = filter_kwargs(AutoencoderKL, vae_cfg)
|
615 |
+
vae = AutoencoderKL(vae_cfg).half().to(device)
|
616 |
+
sd_vae = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device)
|
617 |
+
load_state_dict_safe(vae, sd_vae)
|
618 |
+
vae.eval()
|
619 |
+
vp = os.path.join(folder, "vocab.json")
|
620 |
+
text_vocab = read_json(vp) if os.path.exists(vp) else {}
|
621 |
+
te_path = os.path.join(folder, "text_encoder.bin")
|
622 |
+
if os.path.exists(te_path):
|
623 |
+
text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
|
624 |
+
sd_te = torch.load(te_path, map_location=device)
|
625 |
+
load_state_dict_safe(text_encoder, sd_te)
|
626 |
+
else:
|
627 |
+
text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
|
628 |
+
text_encoder.eval()
|
629 |
+
return VideoPipeline(unet, vae, text_encoder, text_vocab)
|
630 |
+
|
631 |
+
def initialize_sentiment_model(folder, files):
|
632 |
+
download_files(folder, files)
|
633 |
+
config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
634 |
+
model = SentimentClassifierModel(config).to(device)
|
635 |
+
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
|
636 |
+
load_state_dict_safe(model, sd)
|
637 |
+
model.eval()
|
638 |
+
vp = os.path.join(folder, "vocab.json")
|
639 |
+
if os.path.exists(vp):
|
640 |
+
read_json(vp)
|
641 |
+
return model
|
642 |
+
|
643 |
+
def initialize_stt_model(folder, files):
|
644 |
+
download_files(folder, files)
|
645 |
+
config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
646 |
+
model = STTModel(config).to(device)
|
647 |
+
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
|
648 |
+
load_state_dict_safe(model, sd)
|
649 |
+
model.eval()
|
650 |
+
vp = os.path.join(folder, "vocab.json")
|
651 |
+
if os.path.exists(vp):
|
652 |
+
read_json(vp)
|
653 |
+
return model
|
654 |
+
|
655 |
+
def initialize_tts_model(folder, files):
|
656 |
+
download_files(folder, files)
|
657 |
+
config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
658 |
+
model = TTSModel(config).to(device)
|
659 |
+
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
|
660 |
+
load_state_dict_safe(model, sd)
|
661 |
+
model.eval()
|
662 |
+
vp = os.path.join(folder, "vocab.json")
|
663 |
+
if os.path.exists(vp):
|
664 |
+
read_json(vp)
|
665 |
+
return model
|
666 |
+
|
667 |
+
def initialize_musicgen_model(folder, files):
|
668 |
+
download_files(folder, files)
|
669 |
+
config = MusicGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
|
670 |
+
model = MusicGenModel(config).to(device)
|
671 |
+
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
|
672 |
+
load_state_dict_safe(model, sd)
|
673 |
+
model.eval()
|
674 |
+
return model
|
models.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
import copy
|
6 |
+
#from configs import GPT2Config, MBartConfig, CodeGenConfig, SummarizationConfig, OpenLRMConfig, UNet2DConditionModelConfig, AutoencoderKLConfig, BartConfig, MusicGenConfig
|
7 |
+
from configs import *
|
8 |
+
#from extensions import gelu, LayerNorm, Conv1D, Attention, MLP, Block, GPT2Model, GPT2LMHead, MBartEncoderLayer, MBartDecoderLayer, MBartEncoder, MBartDecoder, MBartModel, MBartForConditionalGeneration, CodeGenAttention, CodeGenBlock, CodeGenModel, CodeGenForCausalLM, SummarizationModel, OpenLRM, OpenLRMLayer, OpenLRMAttention, OpenLRMFeedForward, AutoencoderKL, Encoder_, Decoder_, DownBlock, UpBlock, ResnetBlock, MidBlock, Downsample2D, Upsample2D, UNet2DConditionModel, UNetMidBlock2DConditionModel, UNetDownBlock2DConditionModel, UNetUpBlock2DConditionModel, ResnetBlock2D, CrossAttentionBlock2D, CrossAttention, SimpleClassifier
|
9 |
+
from extensions import *
|
10 |
+
|
11 |
+
class SentimentClassifierModel(nn.Module):
|
12 |
+
def __init__(self, config):
|
13 |
+
super().__init__()
|
14 |
+
self.config = config
|
15 |
+
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
|
16 |
+
self.lstm = nn.LSTM(config.d_model, config.d_model, batch_first=True, bidirectional=True)
|
17 |
+
self.fc = nn.Linear(config.d_model * 2, 3)
|
18 |
+
|
19 |
+
def forward(self, input_ids):
|
20 |
+
embedded = self.embedding(input_ids)
|
21 |
+
packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths=[input_ids.size(1)]*input_ids.size(0), batch_first=True, enforce_sorted=False)
|
22 |
+
packed_output, _ = self.lstm(packed_embedded)
|
23 |
+
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
|
24 |
+
pooled = output[:, -1, :]
|
25 |
+
logits = self.fc(pooled)
|
26 |
+
return logits
|
27 |
+
|
28 |
+
class STTModel(nn.Module):
|
29 |
+
def __init__(self, config):
|
30 |
+
super().__init__()
|
31 |
+
self.config = config
|
32 |
+
self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=2, padding=1)
|
33 |
+
self.relu1 = nn.ReLU()
|
34 |
+
self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
|
35 |
+
self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
|
36 |
+
self.relu2 = nn.ReLU()
|
37 |
+
self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
|
38 |
+
self.lstm = nn.LSTM(32 * (config.max_position_embeddings // 8), 128, batch_first=True, bidirectional=True)
|
39 |
+
self.fc = nn.Linear(128 * 2, config.vocab_size)
|
40 |
+
|
41 |
+
def forward(self, audio_data):
|
42 |
+
x = self.pool1(self.relu1(self.conv1(audio_data.unsqueeze(1))))
|
43 |
+
x = self.pool2(self.relu2(self.conv2(x)))
|
44 |
+
x = x.transpose(1, 2).contiguous()
|
45 |
+
x = x.view(x.size(0), -1, x.size(2))
|
46 |
+
packed_output = nn.utils.rnn.pack_padded_sequence(x, lengths=[x.size(1)]*x.size(0), batch_first=True, enforce_sorted=False)
|
47 |
+
packed_output, _ = self.lstm(packed_output)
|
48 |
+
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
|
49 |
+
logits = self.fc(output)
|
50 |
+
return logits
|
51 |
+
|
52 |
+
class TTSModel(nn.Module):
|
53 |
+
def __init__(self, config):
|
54 |
+
super().__init__()
|
55 |
+
self.config = config
|
56 |
+
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
|
57 |
+
self.lstm = nn.LSTM(config.d_model, config.d_model, batch_first=True, bidirectional=True)
|
58 |
+
self.fc = nn.Linear(config.d_model * 2, 1)
|
59 |
+
self.sigmoid = nn.Sigmoid()
|
60 |
+
|
61 |
+
def forward(self, input_ids):
|
62 |
+
embedded = self.embedding(input_ids)
|
63 |
+
packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths=[input_ids.size(1)]*input_ids.size(0), batch_first=True, enforce_sorted=False)
|
64 |
+
packed_output, _ = self.lstm(packed_embedded)
|
65 |
+
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
|
66 |
+
logits = self.fc(output)
|
67 |
+
audio = self.sigmoid(logits)
|
68 |
+
return audio
|
69 |
+
|
70 |
+
class MusicGenModel(nn.Module):
|
71 |
+
def __init__(self, config: MusicGenConfig):
|
72 |
+
super().__init__()
|
73 |
+
self.config = config
|
74 |
+
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
75 |
+
self.transformer_layers = nn.ModuleList([CodeGenBlock(config) for _ in range(config.num_hidden_layers)])
|
76 |
+
self.fc_out = nn.Linear(config.hidden_size, config.vocab_size)
|
77 |
+
|
78 |
+
def forward(self, input_ids):
|
79 |
+
embedded_tokens = self.embedding(input_ids)
|
80 |
+
hidden_states = embedded_tokens
|
81 |
+
for layer in self.transformer_layers:
|
82 |
+
hidden_states = layer(hidden_states)
|
83 |
+
logits = self.fc_out(hidden_states)
|
84 |
+
return logits
|
85 |
+
|
86 |
+
def sample(self, attributes, sample_rate, duration):
|
87 |
+
input_tokens = torch.randint(0, self.config.vocab_size, (1, 1), dtype=torch.long).to(device)
|
88 |
+
audio_output = []
|
89 |
+
num_steps = int(duration * sample_rate / 1024)
|
90 |
+
for _ in tqdm(range(num_steps), desc="Generating music"):
|
91 |
+
logits = self.forward(input_tokens)
|
92 |
+
predicted_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
|
93 |
+
audio_output.append(predicted_token.cpu())
|
94 |
+
input_tokens = torch.cat((input_tokens, predicted_token), dim=1)
|
95 |
+
audio_output = torch.cat(audio_output, dim=1).float()
|
96 |
+
return audio_output
|
musicgen_api.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import jsonify, send_file, request
|
2 |
+
from main import *
|
3 |
+
#from main import import musicgen_model, device
|
4 |
+
import torch
|
5 |
+
import soundfile as sf
|
6 |
+
import numpy as np
|
7 |
+
import io
|
8 |
+
|
9 |
+
def generate_music(prompt, output_path="output_music.wav"):
|
10 |
+
if musicgen_model is None:
|
11 |
+
return "Music generation model not initialized."
|
12 |
+
|
13 |
+
attributes = [prompt]
|
14 |
+
sample_rate = 32000
|
15 |
+
duration = 8
|
16 |
+
audio_values = musicgen_model.sample(
|
17 |
+
attributes=attributes,
|
18 |
+
sample_rate=sample_rate,
|
19 |
+
duration=duration,
|
20 |
+
)
|
21 |
+
output_audio = audio_values.cpu().numpy().squeeze()
|
22 |
+
sf.write(output_path, output_audio, sample_rate)
|
23 |
+
return output_path
|
24 |
+
|
25 |
+
def musicgen_api():
|
26 |
+
data = request.get_json()
|
27 |
+
prompt = data.get('prompt')
|
28 |
+
if not prompt:
|
29 |
+
return jsonify({"error": "Prompt is required"}), 400
|
30 |
+
output_file = generate_music(prompt)
|
31 |
+
if output_file == "Music generation model not initialized.":
|
32 |
+
return jsonify({"error": "Music generation failed"}), 500
|
33 |
+
with open(output_file, 'rb') as f:
|
34 |
+
audio_content = f.read()
|
35 |
+
return send_file(io.BytesIO(audio_content), mimetype="audio/wav", as_attachment=True, download_name="output.wav")
|
requirements.txt
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate
|
2 |
+
retry
|
3 |
+
asyncio
|
4 |
+
basicsr
|
5 |
+
beautifulsoup4
|
6 |
+
bs4
|
7 |
+
opencv-python
|
8 |
+
deep-translator
|
9 |
+
duckduckgo-search
|
10 |
+
fastapi
|
11 |
+
faker
|
12 |
+
flask
|
13 |
+
flask-cors
|
14 |
+
facexlib
|
15 |
+
ffmpeg-python
|
16 |
+
gfpgan
|
17 |
+
imageio
|
18 |
+
imageio-ffmpeg
|
19 |
+
langdetect
|
20 |
+
librosa
|
21 |
+
nltk
|
22 |
+
numpy
|
23 |
+
Pillow
|
24 |
+
pydub
|
25 |
+
pytorch-lightning
|
26 |
+
PyYAML
|
27 |
+
retry
|
28 |
+
safetensors
|
29 |
+
scikit-learn
|
30 |
+
scipy
|
31 |
+
scikit-image
|
32 |
+
soundfile
|
33 |
+
torch
|
34 |
+
torchaudio
|
35 |
+
torchvision
|
36 |
+
tqdm
|
37 |
+
wget
|
38 |
+
yacs
|
39 |
+
numba
|
40 |
+
librosa
|
sadtalker_api.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import uuid
|
4 |
+
import asyncio
|
5 |
+
import shutil
|
6 |
+
import requests
|
7 |
+
from urllib.parse import urlparse
|
8 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException, Form, WebSocket
|
9 |
+
from fastapi.responses import JSONResponse
|
10 |
+
#from fastapi.middleware.cors import CORSMiddleware
|
11 |
+
from fastapi import APIRouter
|
12 |
+
from extensions import *
|
13 |
+
from main import *
|
14 |
+
#from main import import sadtalker_instance
|
15 |
+
from tts_api import *
|
16 |
+
from sadtalker_utils import *
|
17 |
+
import base64
|
18 |
+
from stt_api import *
|
19 |
+
from text_generation import *
|
20 |
+
|
21 |
+
router = APIRouter()
|
22 |
+
|
23 |
+
@router.post("/sadtalker")
|
24 |
+
async def create_video(
|
25 |
+
source_image: str = Form(None),
|
26 |
+
source_image_file: UploadFile = File(None),
|
27 |
+
driven_audio: str = Form(None),
|
28 |
+
driven_audio_file: UploadFile = File(None),
|
29 |
+
preprocess: str = Form('crop'),
|
30 |
+
still_mode: bool = Form(False),
|
31 |
+
use_enhancer: bool = Form(False),
|
32 |
+
batch_size: int = Form(1),
|
33 |
+
size: int = Form(256),
|
34 |
+
pose_style: int = Form(0),
|
35 |
+
exp_scale: float = Form(1.0),
|
36 |
+
use_ref_video: bool = Form(False),
|
37 |
+
ref_video: str = Form(None),
|
38 |
+
ref_video_file: UploadFile = File(None),
|
39 |
+
ref_info: str = Form(None),
|
40 |
+
use_idle_mode: bool = Form(False),
|
41 |
+
length_of_audio: int = Form(0),
|
42 |
+
use_blink: bool = Form(True),
|
43 |
+
checkpoint_dir: str = Form('checkpoints'),
|
44 |
+
config_dir: str = Form('src/config'),
|
45 |
+
old_version: bool = Form(False),
|
46 |
+
tts_text: str = Form(None),
|
47 |
+
tts_lang: str = Form('en'),
|
48 |
+
):
|
49 |
+
if source_image_file and source_image:
|
50 |
+
raise HTTPException(status_code=400, detail="source_image and source_image_file cannot be both not None")
|
51 |
+
if driven_audio and driven_audio_file:
|
52 |
+
raise HTTPException(status_code=400, detail="driven_audio and driven_audio_file cannot be both not None")
|
53 |
+
if ref_video and ref_video_file:
|
54 |
+
raise HTTPException(status_code=400, detail="ref_video and ref_video_file cannot be both not None")
|
55 |
+
tmp_source_image = None
|
56 |
+
if source_image_file:
|
57 |
+
tmp_source_image = tempfile.NamedTemporaryFile(suffix=os.path.splitext(source_image_file.filename)[1], delete=False)
|
58 |
+
content = await source_image_file.read()
|
59 |
+
tmp_source_image.write(content)
|
60 |
+
source_image_path = tmp_source_image.name
|
61 |
+
elif source_image:
|
62 |
+
if urlparse(source_image).scheme in ["http", "https"]:
|
63 |
+
response = requests.get(source_image, stream=True)
|
64 |
+
response.raise_for_status()
|
65 |
+
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_source_image:
|
66 |
+
for chunk in response.iter_content(chunk_size=8192):
|
67 |
+
tmp_source_image.write(chunk)
|
68 |
+
source_image_path = tmp_source_image.name
|
69 |
+
else:
|
70 |
+
source_image_path = source_image
|
71 |
+
else:
|
72 |
+
raise HTTPException(status_code=400, detail="source_image not provided")
|
73 |
+
tmp_driven_audio = None
|
74 |
+
if driven_audio_file:
|
75 |
+
tmp_driven_audio = tempfile.NamedTemporaryFile(suffix=os.path.splitext(driven_audio_file.filename)[1], delete=False)
|
76 |
+
content = await driven_audio_file.read()
|
77 |
+
tmp_driven_audio.write(content)
|
78 |
+
driven_audio_path = tmp_driven_audio.name
|
79 |
+
elif driven_audio:
|
80 |
+
if urlparse(driven_audio).scheme in ["http", "https"]:
|
81 |
+
response = requests.get(driven_audio, stream=True)
|
82 |
+
response.raise_for_status()
|
83 |
+
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_driven_audio:
|
84 |
+
for chunk in response.iter_content(chunk_size=8192):
|
85 |
+
tmp_driven_audio.write(chunk)
|
86 |
+
driven_audio_path = tmp_driven_audio.name
|
87 |
+
else:
|
88 |
+
driven_audio_path = driven_audio
|
89 |
+
else:
|
90 |
+
driven_audio_path = None
|
91 |
+
tmp_ref_video = None
|
92 |
+
if ref_video_file:
|
93 |
+
tmp_ref_video = tempfile.NamedTemporaryFile(suffix=os.path.splitext(ref_video_file.filename)[1], delete=False)
|
94 |
+
content = await ref_video_file.read()
|
95 |
+
tmp_ref_video.write(content)
|
96 |
+
ref_video_path = tmp_ref_video.name
|
97 |
+
elif ref_video:
|
98 |
+
if urlparse(ref_video).scheme in ["http", "https"]:
|
99 |
+
response = requests.get(ref_video, stream=True)
|
100 |
+
response.raise_for_status()
|
101 |
+
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_ref_video:
|
102 |
+
for chunk in response.iter_content(chunk_size=8192):
|
103 |
+
tmp_ref_video.write(chunk)
|
104 |
+
ref_video_path = tmp_ref_video.name
|
105 |
+
else:
|
106 |
+
ref_video_path = ref_video
|
107 |
+
else:
|
108 |
+
ref_video_path=None
|
109 |
+
try:
|
110 |
+
loop = asyncio.get_running_loop()
|
111 |
+
output_path = await loop.run_in_executor(None, sadtalker_instance.test,
|
112 |
+
source_image_path,
|
113 |
+
driven_audio_path,
|
114 |
+
preprocess,
|
115 |
+
still_mode,
|
116 |
+
use_enhancer,
|
117 |
+
batch_size,
|
118 |
+
size,
|
119 |
+
pose_style,
|
120 |
+
exp_scale,
|
121 |
+
use_ref_video,
|
122 |
+
ref_video_path,
|
123 |
+
ref_info,
|
124 |
+
use_idle_mode,
|
125 |
+
length_of_audio,
|
126 |
+
use_blink,
|
127 |
+
'./results/',
|
128 |
+
tts_text=tts_text,
|
129 |
+
tts_lang=tts_lang,
|
130 |
+
)
|
131 |
+
return {"video_url": output_path}
|
132 |
+
except Exception as e:
|
133 |
+
raise HTTPException(status_code=500, detail=str(e))
|
134 |
+
finally:
|
135 |
+
if tmp_source_image:
|
136 |
+
os.remove(tmp_source_image.name)
|
137 |
+
if tmp_driven_audio:
|
138 |
+
os.remove(tmp_driven_audio.name)
|
139 |
+
if tmp_ref_video:
|
140 |
+
os.remove(tmp_ref_video.name)
|
141 |
+
|
142 |
+
@router.websocket("/ws")
|
143 |
+
async def websocket_endpoint(websocket: WebSocket):
|
144 |
+
await websocket.accept()
|
145 |
+
tts_model = TTSTalker()
|
146 |
+
try:
|
147 |
+
while True:
|
148 |
+
data = await websocket.receive_json()
|
149 |
+
text = data.get("text")
|
150 |
+
audio_base64 = data.get("audio")
|
151 |
+
if text:
|
152 |
+
audio_path = await asyncio.get_running_loop().run_in_executor(None, tts_model.test, text)
|
153 |
+
elif audio_base64:
|
154 |
+
try:
|
155 |
+
audio_bytes = base64.b64decode(audio_base64)
|
156 |
+
tmp_audio_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
157 |
+
tmp_audio_file.write(audio_bytes)
|
158 |
+
audio_path = tmp_audio_file.name
|
159 |
+
transcription_text_file = speech_to_text_func(tmp_audio_file.name)
|
160 |
+
with open(transcription_text_file, 'r') as f:
|
161 |
+
transcription_text = f.read()
|
162 |
+
response_stream = perform_reasoning_stream(f"respond to this sentence in 10 words or less {transcription_text}", 0.7, 40, 0.0, 1.2)
|
163 |
+
response_text = ""
|
164 |
+
for chunk in response_stream:
|
165 |
+
if chunk == "<END_STREAM>":
|
166 |
+
break
|
167 |
+
response_text += chunk
|
168 |
+
audio_path = await asyncio.get_running_loop().run_in_executor(None, tts_model.test, response_text)
|
169 |
+
|
170 |
+
except Exception as e:
|
171 |
+
await websocket.send_json({"error":str(e)})
|
172 |
+
continue
|
173 |
+
finally:
|
174 |
+
if 'tmp_audio_file' in locals() and tmp_audio_file:
|
175 |
+
os.remove(tmp_audio_file.name)
|
176 |
+
else:
|
177 |
+
continue
|
178 |
+
source_image_path = './examples/source_image/cyarh.png'
|
179 |
+
ref_video_path='./examples/driven_video/vid_xdd.mp4'
|
180 |
+
loop = asyncio.get_running_loop()
|
181 |
+
output = await loop.run_in_executor(None, sadtalker_instance.test,
|
182 |
+
source_image_path,
|
183 |
+
audio_path,
|
184 |
+
'full',
|
185 |
+
True,
|
186 |
+
True,
|
187 |
+
1,
|
188 |
+
256,
|
189 |
+
0,
|
190 |
+
1,
|
191 |
+
True,
|
192 |
+
ref_video_path,
|
193 |
+
"pose+blink",
|
194 |
+
False,
|
195 |
+
0,
|
196 |
+
True,
|
197 |
+
'./results/'
|
198 |
+
)
|
199 |
+
await websocket.send_json({"video_url": output})
|
200 |
+
except Exception as e:
|
201 |
+
print(e)
|
202 |
+
await websocket.send_json({"error":str(e)})
|
sadtalker_utils.py
ADDED
@@ -0,0 +1,866 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import uuid
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import yaml
|
10 |
+
from PIL import Image
|
11 |
+
from skimage import img_as_ubyte, transform
|
12 |
+
import safetensors
|
13 |
+
import librosa
|
14 |
+
from pydub import AudioSegment
|
15 |
+
import imageio
|
16 |
+
from scipy import signal
|
17 |
+
from scipy.io import loadmat, savemat, wavfile
|
18 |
+
import glob
|
19 |
+
import tempfile
|
20 |
+
from tqdm import tqdm
|
21 |
+
import math
|
22 |
+
import torchaudio
|
23 |
+
import urllib.request
|
24 |
+
|
25 |
+
REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
|
26 |
+
CODEFORMER_URL = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
27 |
+
RESTOREFORMER_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth"
|
28 |
+
GFPGAN_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
29 |
+
kp_url = "https://huggingface.co/usyd-community/vitpose-base-simple/resolve/main/model.safetensors"
|
30 |
+
kp_file = "kp_detector.safetensors"
|
31 |
+
aud_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/auido2pose_00140-model.pth"
|
32 |
+
aud_file = "auido2pose_00140-model.pth"
|
33 |
+
wav_url = "https://huggingface.co/facebook/wav2vec2-base/resolve/main/pytorch_model.bin"
|
34 |
+
wav_file = "wav2vec2.pth"
|
35 |
+
gen_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/wav2lip.pth"
|
36 |
+
gen_file = "generator.pth"
|
37 |
+
mapx_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/mapping_00229-model.pth.tar"
|
38 |
+
mapx_file = "mapping.pth"
|
39 |
+
den_url = "https://huggingface.co/KwaiVGI/LivePortrait/resolve/main/liveportrait/base_models/motion_extractor.pth"
|
40 |
+
den_file = "dense_motion.pth"
|
41 |
+
|
42 |
+
|
43 |
+
def download_model(url, filename, checkpoint_dir):
|
44 |
+
if not os.path.exists(os.path.join(checkpoint_dir, filename)):
|
45 |
+
print(f"Downloading {filename}...")
|
46 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
47 |
+
urllib.request.urlretrieve(url, os.path.join(checkpoint_dir, filename))
|
48 |
+
print(f"{filename} downloaded.")
|
49 |
+
else:
|
50 |
+
print(f"{filename} already exists.")
|
51 |
+
|
52 |
+
|
53 |
+
def mp3_to_wav_util(mp3_filename, wav_filename, frame_rate):
|
54 |
+
AudioSegment.from_file(mp3_filename).set_frame_rate(frame_rate).export(wav_filename, format="wav")
|
55 |
+
|
56 |
+
|
57 |
+
def load_wav_util(path, sr):
|
58 |
+
return librosa.core.load(path, sr=sr)[0]
|
59 |
+
|
60 |
+
|
61 |
+
def save_wav_util(wav, path, sr):
|
62 |
+
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
63 |
+
wavfile.write(path, sr, wav.astype(np.int16))
|
64 |
+
|
65 |
+
|
66 |
+
class OcclusionAwareKPDetector(nn.Module):
|
67 |
+
|
68 |
+
def __init__(self, kp_channels, num_kp, num_dilation_blocks, dropout_rate):
|
69 |
+
super(OcclusionAwareKPDetector, self).__init__()
|
70 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
71 |
+
self.bn1 = nn.BatchNorm2d(64)
|
72 |
+
self.relu = nn.ReLU()
|
73 |
+
self.conv2 = nn.Conv2d(64, num_kp, kernel_size=3, padding=1)
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
x = self.relu(self.bn1(self.conv1(x)))
|
77 |
+
x = self.conv2(x)
|
78 |
+
kp = {'value': x.view(x.size(0), -1)}
|
79 |
+
return kp
|
80 |
+
|
81 |
+
|
82 |
+
class Wav2Vec2Model(nn.Module):
|
83 |
+
|
84 |
+
def __init__(self):
|
85 |
+
super(Wav2Vec2Model, self).__init__()
|
86 |
+
self.conv = nn.Conv1d(1, 64, kernel_size=10, stride=5, padding=5)
|
87 |
+
self.bn = nn.BatchNorm1d(64)
|
88 |
+
self.relu = nn.ReLU()
|
89 |
+
self.fc = nn.Linear(64, 2048)
|
90 |
+
|
91 |
+
def forward(self, audio):
|
92 |
+
x = audio.unsqueeze(1)
|
93 |
+
x = self.relu(self.bn(self.conv(x)))
|
94 |
+
x = torch.mean(x, dim=-1)
|
95 |
+
x = self.fc(x)
|
96 |
+
return x
|
97 |
+
|
98 |
+
|
99 |
+
class AudioCoeffsPredictor(nn.Module):
|
100 |
+
|
101 |
+
def __init__(self, input_dim, output_dim):
|
102 |
+
super(AudioCoeffsPredictor, self).__init__()
|
103 |
+
self.linear = nn.Linear(input_dim, output_dim)
|
104 |
+
|
105 |
+
def forward(self, audio_embedding):
|
106 |
+
return self.linear(audio_embedding)
|
107 |
+
|
108 |
+
|
109 |
+
class MappingNet(nn.Module):
|
110 |
+
|
111 |
+
def __init__(self, num_coeffs, num_layers, hidden_dim):
|
112 |
+
super(MappingNet, self).__init__()
|
113 |
+
layers = []
|
114 |
+
input_dim = num_coeffs * 2
|
115 |
+
for _ in range(num_layers):
|
116 |
+
layers.append(nn.Linear(input_dim, hidden_dim))
|
117 |
+
layers.append(nn.ReLU())
|
118 |
+
input_dim = hidden_dim
|
119 |
+
layers.append(nn.Linear(hidden_dim, num_coeffs))
|
120 |
+
self.net = nn.Sequential(*layers)
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
return self.net(x)
|
124 |
+
|
125 |
+
|
126 |
+
class DenseMotionNetwork(nn.Module):
|
127 |
+
|
128 |
+
def __init__(self, num_kp, num_channels, block_expansion, num_blocks, max_features):
|
129 |
+
super(DenseMotionNetwork, self).__init__()
|
130 |
+
self.conv1 = nn.Conv2d(num_channels, max_features, kernel_size=3, padding=1)
|
131 |
+
self.relu = nn.ReLU()
|
132 |
+
self.conv2 = nn.Conv2d(max_features, num_channels, kernel_size=3, padding=1)
|
133 |
+
|
134 |
+
def forward(self, kp_source, kp_driving, jacobian):
|
135 |
+
x = self.relu(self.conv1(kp_source))
|
136 |
+
x = self.conv2(x)
|
137 |
+
sparse_motion = {'dense_motion': x}
|
138 |
+
return sparse_motion
|
139 |
+
|
140 |
+
|
141 |
+
class Hourglass(nn.Module):
|
142 |
+
|
143 |
+
def __init__(self, block_expansion, num_blocks, max_features, num_channels, kp_size, num_deform_blocks):
|
144 |
+
super(Hourglass, self).__init__()
|
145 |
+
self.encoder = nn.Sequential(nn.Conv2d(num_channels, max_features, kernel_size=7, stride=2, padding=3),
|
146 |
+
nn.BatchNorm2d(max_features), nn.ReLU())
|
147 |
+
self.decoder = nn.Sequential(
|
148 |
+
nn.ConvTranspose2d(max_features, num_channels, kernel_size=4, stride=2, padding=1), nn.Tanh())
|
149 |
+
|
150 |
+
def forward(self, source_image, kp_driving, **kwargs):
|
151 |
+
x = self.encoder(source_image)
|
152 |
+
x = self.decoder(x)
|
153 |
+
B, C, H, W = x.size()
|
154 |
+
video = []
|
155 |
+
for _ in range(10):
|
156 |
+
frame = (x[0].cpu().detach().numpy().transpose(1, 2, 0) * 127.5 + 127.5).clip(0, 255).astype(
|
157 |
+
np.uint8)
|
158 |
+
video.append(frame)
|
159 |
+
return video
|
160 |
+
|
161 |
+
|
162 |
+
class Face3DHelper:
|
163 |
+
|
164 |
+
def __init__(self, local_pca_path, device):
|
165 |
+
self.local_pca_path = local_pca_path
|
166 |
+
self.device = device
|
167 |
+
|
168 |
+
def run(self, source_image):
|
169 |
+
h, w, _ = source_image.shape
|
170 |
+
x_min = w // 4
|
171 |
+
y_min = h // 4
|
172 |
+
x_max = x_min + w // 2
|
173 |
+
y_max = y_min + h // 2
|
174 |
+
return [x_min, y_min, x_max, y_max]
|
175 |
+
|
176 |
+
|
177 |
+
class Face3DHelperOld(Face3DHelper):
|
178 |
+
|
179 |
+
def __init__(self, local_pca_path, device):
|
180 |
+
super(Face3DHelperOld, self).__init__(local_pca_path, device)
|
181 |
+
|
182 |
+
|
183 |
+
class MouthDetector:
|
184 |
+
|
185 |
+
def __init__(self):
|
186 |
+
pass
|
187 |
+
|
188 |
+
def detect(self, image):
|
189 |
+
h, w = image.shape[:2]
|
190 |
+
return (w // 2, h // 2)
|
191 |
+
|
192 |
+
|
193 |
+
class KeypointNorm(nn.Module):
|
194 |
+
|
195 |
+
def __init__(self, device):
|
196 |
+
super(KeypointNorm, self).__init__()
|
197 |
+
self.device = device
|
198 |
+
|
199 |
+
def forward(self, kp_driving):
|
200 |
+
return kp_driving
|
201 |
+
|
202 |
+
|
203 |
+
def save_video_with_watermark(video_frames, audio_path, output_path):
|
204 |
+
H, W, _ = video_frames[0].shape
|
205 |
+
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H))
|
206 |
+
for frame in video_frames:
|
207 |
+
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
208 |
+
out.release()
|
209 |
+
|
210 |
+
|
211 |
+
def paste_pic(video_path, source_image_crop, crop_info, audio_path, output_path):
|
212 |
+
shutil.copy(video_path, output_path)
|
213 |
+
|
214 |
+
|
215 |
+
class TTSTalker:
|
216 |
+
|
217 |
+
def __init__(self):
|
218 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
219 |
+
self.tts_model = None
|
220 |
+
|
221 |
+
def load_model(self):
|
222 |
+
self.tts_model = self
|
223 |
+
|
224 |
+
def tokenizer(self, text):
|
225 |
+
return [ord(c) for c in text]
|
226 |
+
|
227 |
+
def __call__(self, input_tokens):
|
228 |
+
return torch.zeros(1, 16000, device=self.device)
|
229 |
+
|
230 |
+
def test(self, text, lang='en'):
|
231 |
+
if self.tts_model is None:
|
232 |
+
self.load_model()
|
233 |
+
output_path = os.path.join('./results', str(uuid.uuid4()) + '.wav')
|
234 |
+
os.makedirs('./results', exist_ok=True)
|
235 |
+
tokens = self.tokenizer(text)
|
236 |
+
input_tokens = torch.tensor([tokens], dtype=torch.long).to(self.device)
|
237 |
+
with torch.no_grad():
|
238 |
+
audio_output = self(input_tokens)
|
239 |
+
torchaudio.save(output_path, audio_output.cpu(), 16000)
|
240 |
+
return output_path
|
241 |
+
|
242 |
+
|
243 |
+
class SadTalker:
|
244 |
+
|
245 |
+
def __init__(self, checkpoint_path='checkpoints', config_path='src/config', size=256, preprocess='crop',
|
246 |
+
old_version=False):
|
247 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
248 |
+
self.cfg = self.get_cfg_defaults()
|
249 |
+
self.merge_from_file(os.path.join(config_path, 'sadtalker_config.yaml'))
|
250 |
+
self.cfg['MODEL']['CHECKPOINTS_DIR'] = checkpoint_path
|
251 |
+
self.cfg['MODEL']['CONFIG_DIR'] = config_path
|
252 |
+
self.cfg['MODEL']['DEVICE'] = self.device
|
253 |
+
self.cfg['INPUT_IMAGE'] = {}
|
254 |
+
self.cfg['INPUT_IMAGE']['SOURCE_IMAGE'] = 'None'
|
255 |
+
self.cfg['INPUT_IMAGE']['DRIVEN_AUDIO'] = 'None'
|
256 |
+
self.cfg['INPUT_IMAGE']['PREPROCESS'] = preprocess
|
257 |
+
self.cfg['INPUT_IMAGE']['SIZE'] = size
|
258 |
+
self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version
|
259 |
+
|
260 |
+
download_model(kp_url, kp_file, checkpoint_path)
|
261 |
+
download_model(aud_url, aud_file, checkpoint_path)
|
262 |
+
download_model(wav_url, wav_file, checkpoint_path)
|
263 |
+
download_model(gen_url, gen_file, checkpoint_path)
|
264 |
+
download_model(mapx_url, mapx_file, checkpoint_path)
|
265 |
+
download_model(den_url, den_file, checkpoint_path)
|
266 |
+
download_model(GFPGAN_URL, 'GFPGANv1.4.pth', checkpoint_path)
|
267 |
+
download_model(REALESRGAN_URL, 'RealESRGAN_x2plus.pth', checkpoint_path)
|
268 |
+
|
269 |
+
self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])
|
270 |
+
|
271 |
+
def get_cfg_defaults(self):
|
272 |
+
return {
|
273 |
+
'MODEL': {
|
274 |
+
'CHECKPOINTS_DIR': '',
|
275 |
+
'CONFIG_DIR': '',
|
276 |
+
'DEVICE': self.device,
|
277 |
+
'SCALE': 64,
|
278 |
+
'NUM_VOXEL_FRAMES': 8,
|
279 |
+
'NUM_MOTION_FRAMES': 10,
|
280 |
+
'MAX_FEATURES': 256,
|
281 |
+
'DRIVEN_AUDIO_SAMPLE_RATE': 16000,
|
282 |
+
'VIDEO_FPS': 25,
|
283 |
+
'OUTPUT_VIDEO_FPS': None,
|
284 |
+
'OUTPUT_AUDIO_SAMPLE_RATE': None,
|
285 |
+
'USE_ENHANCER': False,
|
286 |
+
'ENHANCER_NAME': '',
|
287 |
+
'BG_UPSAMPLER': None,
|
288 |
+
'IS_HALF': False
|
289 |
+
},
|
290 |
+
'INPUT_IMAGE': {}
|
291 |
+
}
|
292 |
+
|
293 |
+
def merge_from_file(self, filepath):
|
294 |
+
if os.path.exists(filepath):
|
295 |
+
with open(filepath, 'r') as f:
|
296 |
+
cfg_from_file = yaml.safe_load(f)
|
297 |
+
self.cfg.update(cfg_from_file)
|
298 |
+
|
299 |
+
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
|
300 |
+
batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
|
301 |
+
ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/',
|
302 |
+
tts_text=None, tts_lang='en'):
|
303 |
+
self.sadtalker_model.test(source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size,
|
304 |
+
pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode,
|
305 |
+
length_of_audio, use_blink, result_dir, tts_text, tts_lang)
|
306 |
+
return self.sadtalker_model.save_result()
|
307 |
+
|
308 |
+
|
309 |
+
class SadTalkerModel:
|
310 |
+
|
311 |
+
def __init__(self, sadtalker_cfg, device_id=[0]):
|
312 |
+
self.cfg = sadtalker_cfg
|
313 |
+
self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
|
314 |
+
self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
|
315 |
+
self.preprocesser = self.sadtalker.preprocesser
|
316 |
+
self.kp_extractor = self.sadtalker.kp_extractor
|
317 |
+
self.generator = self.sadtalker.generator
|
318 |
+
self.mapping = self.sadtalker.mapping
|
319 |
+
self.he_estimator = self.sadtalker.he_estimator
|
320 |
+
self.audio_to_coeff = self.sadtalker.audio_to_coeff
|
321 |
+
self.animate_from_coeff = self.sadtalker.animate_from_coeff
|
322 |
+
self.face_enhancer = self.sadtalker.face_enhancer
|
323 |
+
|
324 |
+
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
|
325 |
+
batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
|
326 |
+
ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/',
|
327 |
+
tts_text=None, tts_lang='en', jitter_amount=10, jitter_source_image=False):
|
328 |
+
self.inner_test = SadTalkerInner(self, source_image, driven_audio, preprocess, still_mode, use_enhancer,
|
329 |
+
batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info,
|
330 |
+
use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang,
|
331 |
+
jitter_amount, jitter_source_image)
|
332 |
+
return self.inner_test.test()
|
333 |
+
|
334 |
+
def save_result(self):
|
335 |
+
return self.inner_test.save_result()
|
336 |
+
|
337 |
+
|
338 |
+
class SadTalkerInner:
|
339 |
+
|
340 |
+
def __init__(self, sadtalker_model, source_image, driven_audio, preprocess, still_mode, use_enhancer,
|
341 |
+
batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode,
|
342 |
+
length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image):
|
343 |
+
self.sadtalker_model = sadtalker_model
|
344 |
+
self.source_image = source_image
|
345 |
+
self.driven_audio = driven_audio
|
346 |
+
self.preprocess = preprocess
|
347 |
+
self.still_mode = still_mode
|
348 |
+
self.use_enhancer = use_enhancer
|
349 |
+
self.batch_size = batch_size
|
350 |
+
self.size = size
|
351 |
+
self.pose_style = pose_style
|
352 |
+
self.exp_scale = exp_scale
|
353 |
+
self.use_ref_video = use_ref_video
|
354 |
+
self.ref_video = ref_video
|
355 |
+
self.ref_info = ref_info
|
356 |
+
self.use_idle_mode = use_idle_mode
|
357 |
+
self.length_of_audio = length_of_audio
|
358 |
+
self.use_blink = use_blink
|
359 |
+
self.result_dir = result_dir
|
360 |
+
self.tts_text = tts_text
|
361 |
+
self.tts_lang = tts_lang
|
362 |
+
self.jitter_amount = jitter_amount
|
363 |
+
self.jitter_source_image = jitter_source_image
|
364 |
+
self.device = self.sadtalker_model.device
|
365 |
+
self.output_path = None
|
366 |
+
|
367 |
+
def get_test_data(self):
|
368 |
+
proc = self.sadtalker_model.preprocesser
|
369 |
+
if self.tts_text is not None:
|
370 |
+
temp_dir = tempfile.mkdtemp()
|
371 |
+
audio_path = os.path.join(temp_dir, 'audio.wav')
|
372 |
+
tts = TTSTalker()
|
373 |
+
tts.test(self.tts_text, self.tts_lang)
|
374 |
+
self.driven_audio = audio_path
|
375 |
+
source_image_pil = Image.open(self.source_image).convert('RGB')
|
376 |
+
if self.jitter_source_image:
|
377 |
+
jitter_dx = np.random.randint(-self.jitter_amount, self.jitter_amount + 1)
|
378 |
+
jitter_dy = np.random.randint(-self.jitter_amount, self.jitter_amount + 1)
|
379 |
+
source_image_pil = Image.fromarray(
|
380 |
+
np.roll(np.roll(np.array(source_image_pil), jitter_dx, axis=1), jitter_dy, axis=0))
|
381 |
+
source_image_tensor, crop_info, cropped_image = proc.crop(source_image_pil, self.preprocess, self.size)
|
382 |
+
if self.still_mode or self.use_idle_mode:
|
383 |
+
ref_pose_coeff = proc.generate_still_pose(self.pose_style)
|
384 |
+
ref_expression_coeff = proc.generate_still_expression(self.exp_scale)
|
385 |
+
elif self.use_idle_mode:
|
386 |
+
ref_pose_coeff = proc.generate_idles_pose(self.length_of_audio, self.pose_style)
|
387 |
+
ref_expression_coeff = proc.generate_idles_expression(self.length_of_audio)
|
388 |
+
else:
|
389 |
+
ref_pose_coeff = None
|
390 |
+
ref_expression_coeff = None
|
391 |
+
audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio,
|
392 |
+
self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'])
|
393 |
+
batch = {
|
394 |
+
'source_image': source_image_tensor.unsqueeze(0).to(self.device),
|
395 |
+
'audio': audio_tensor.unsqueeze(0).to(self.device),
|
396 |
+
'ref_pose_coeff': ref_pose_coeff,
|
397 |
+
'ref_expression_coeff': ref_expression_coeff,
|
398 |
+
'source_image_crop': cropped_image,
|
399 |
+
'crop_info': crop_info,
|
400 |
+
'use_blink': self.use_blink,
|
401 |
+
'pose_style': self.pose_style,
|
402 |
+
'exp_scale': self.exp_scale,
|
403 |
+
'ref_video': self.ref_video,
|
404 |
+
'use_ref_video': self.use_ref_video,
|
405 |
+
'ref_info': self.ref_info,
|
406 |
+
}
|
407 |
+
return batch, audio_sample_rate
|
408 |
+
|
409 |
+
def run_inference(self, batch):
|
410 |
+
kp_extractor = self.sadtalker_model.kp_extractor
|
411 |
+
generator = self.sadtalker_model.generator
|
412 |
+
mapping = self.sadtalker_model.mapping
|
413 |
+
he_estimator = self.sadtalker_model.he_estimator
|
414 |
+
audio_to_coeff = self.sadtalker_model.audio_to_coeff
|
415 |
+
animate_from_coeff = self.sadtalker_model.animate_from_coeff
|
416 |
+
proc = self.sadtalker_model.preprocesser
|
417 |
+
with torch.no_grad():
|
418 |
+
kp_source = kp_extractor(batch['source_image'])
|
419 |
+
if self.still_mode or self.use_idle_mode:
|
420 |
+
ref_pose_coeff = batch['ref_pose_coeff']
|
421 |
+
ref_expression_coeff = batch['ref_expression_coeff']
|
422 |
+
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff)
|
423 |
+
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff)
|
424 |
+
elif self.use_idle_mode:
|
425 |
+
ref_pose_coeff = batch['ref_pose_coeff']
|
426 |
+
ref_expression_coeff = batch['ref_expression_coeff']
|
427 |
+
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff)
|
428 |
+
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff)
|
429 |
+
else:
|
430 |
+
if self.use_ref_video:
|
431 |
+
kp_ref = kp_extractor(batch['source_image'])
|
432 |
+
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], kp_ref=kp_ref,
|
433 |
+
use_ref_info=batch['ref_info'])
|
434 |
+
else:
|
435 |
+
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'])
|
436 |
+
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'])
|
437 |
+
coeff = {'pose_coeff': pose_coeff, 'expression_coeff': expression_coeff}
|
438 |
+
if self.use_blink:
|
439 |
+
coeff['blink_coeff'] = audio_to_coeff.get_blink_coeff(batch['audio'])
|
440 |
+
else:
|
441 |
+
coeff['blink_coeff'] = None
|
442 |
+
kp_driving = audio_to_coeff(batch['audio'])[0]
|
443 |
+
kp_norm = animate_from_coeff.normalize_kp(kp_driving)
|
444 |
+
coeff['kp_driving'] = kp_norm
|
445 |
+
coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4
|
446 |
+
face_enhancer = self.sadtalker_model.face_enhancer if self.use_enhancer else None
|
447 |
+
output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping,
|
448 |
+
he_estimator, batch['audio'], batch['source_image_crop'],
|
449 |
+
face_enhancer=face_enhancer)
|
450 |
+
return output_video
|
451 |
+
|
452 |
+
def post_processing(self, output_video, audio_sample_rate, batch):
|
453 |
+
proc = self.sadtalker_model.preprocesser
|
454 |
+
base_name = os.path.splitext(os.path.basename(batch['source_image_crop']))[0]
|
455 |
+
audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0]
|
456 |
+
output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4')
|
457 |
+
self.output_path = output_video_path
|
458 |
+
video_fps = self.sadtalker_model.cfg['MODEL']['VIDEO_FPS'] if self.sadtalker_model.cfg['MODEL'][
|
459 |
+
'OUTPUT_VIDEO_FPS'] is None else \
|
460 |
+
self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS']
|
461 |
+
audio_output_sample_rate = self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'] if \
|
462 |
+
self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE'] is None else \
|
463 |
+
self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE']
|
464 |
+
if self.use_enhancer:
|
465 |
+
enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4')
|
466 |
+
save_video_with_watermark(output_video, self.driven_audio, enhanced_path)
|
467 |
+
paste_pic(enhanced_path, batch['source_image_crop'], batch['crop_info'], self.driven_audio,
|
468 |
+
output_video_path)
|
469 |
+
os.remove(enhanced_path)
|
470 |
+
else:
|
471 |
+
save_video_with_watermark(output_video, self.driven_audio, output_video_path)
|
472 |
+
if self.tts_text is not None:
|
473 |
+
shutil.rmtree(os.path.dirname(self.driven_audio))
|
474 |
+
|
475 |
+
def save_result(self):
|
476 |
+
return self.output_path
|
477 |
+
|
478 |
+
def __call__(self):
|
479 |
+
return self.output_path
|
480 |
+
|
481 |
+
def test(self):
|
482 |
+
batch, audio_sample_rate = self.get_test_data()
|
483 |
+
output_video = self.run_inference(batch)
|
484 |
+
self.post_processing(output_video, audio_sample_rate, batch)
|
485 |
+
return self.save_result()
|
486 |
+
|
487 |
+
|
488 |
+
class SadTalkerInnerModel:
|
489 |
+
|
490 |
+
def __init__(self, sadtalker_cfg, device_id=[0]):
|
491 |
+
self.cfg = sadtalker_cfg
|
492 |
+
self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
|
493 |
+
self.preprocesser = Preprocesser(sadtalker_cfg, self.device)
|
494 |
+
self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device)
|
495 |
+
self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device)
|
496 |
+
self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device)
|
497 |
+
self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg['MODEL'][
|
498 |
+
'USE_ENHANCER'] else None
|
499 |
+
self.generator = Generator(sadtalker_cfg, self.device)
|
500 |
+
self.mapping = Mapping(sadtalker_cfg, self.device)
|
501 |
+
self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device)
|
502 |
+
|
503 |
+
|
504 |
+
class Preprocesser:
|
505 |
+
|
506 |
+
def __init__(self, sadtalker_cfg, device):
|
507 |
+
self.cfg = sadtalker_cfg
|
508 |
+
self.device = device
|
509 |
+
if self.cfg['INPUT_IMAGE'].get('OLD_VERSION', False):
|
510 |
+
self.face3d_helper = Face3DHelperOld(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
|
511 |
+
else:
|
512 |
+
self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
|
513 |
+
self.mouth_detector = MouthDetector()
|
514 |
+
|
515 |
+
def crop(self, source_image_pil, preprocess_type, size=256):
|
516 |
+
source_image = np.array(source_image_pil)
|
517 |
+
face_info = self.face3d_helper.run(source_image)
|
518 |
+
if face_info is None:
|
519 |
+
raise Exception("No face detected")
|
520 |
+
x_min, y_min, x_max, y_max = face_info[:4]
|
521 |
+
old_size = (x_max - x_min, y_max - y_min)
|
522 |
+
x_center = (x_max + x_min) / 2
|
523 |
+
y_center = (y_max + y_min) / 2
|
524 |
+
if preprocess_type == 'crop':
|
525 |
+
face_size = max(x_max - x_min, y_max - y_min)
|
526 |
+
x_min = int(x_center - face_size / 2)
|
527 |
+
y_min = int(y_center - face_size / 2)
|
528 |
+
x_max = int(x_center + face_size / 2)
|
529 |
+
y_max = int(y_center + face_size / 2)
|
530 |
+
else:
|
531 |
+
x_min -= int((x_max - x_min) * 0.1)
|
532 |
+
y_min -= int((y_max - y_min) * 0.1)
|
533 |
+
x_max += int((x_max - x_min) * 0.1)
|
534 |
+
y_max += int((y_max - y_min) * 0.1)
|
535 |
+
h, w = source_image.shape[:2]
|
536 |
+
x_min = max(0, x_min)
|
537 |
+
y_min = max(0, y_min)
|
538 |
+
x_max = min(w, x_max)
|
539 |
+
y_max = min(h, y_max)
|
540 |
+
cropped_image = source_image[y_min:y_max, x_min:x_max]
|
541 |
+
cropped_image_pil = Image.fromarray(cropped_image)
|
542 |
+
if size is not None and size != 0:
|
543 |
+
cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS)
|
544 |
+
source_image_tensor = self.img2tensor(cropped_image_pil)
|
545 |
+
return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename(
|
546 |
+
self.cfg['INPUT_IMAGE'].get('SOURCE_IMAGE', ''))
|
547 |
+
|
548 |
+
def img2tensor(self, img):
|
549 |
+
img = np.array(img).astype(np.float32) / 255.0
|
550 |
+
img = np.transpose(img, (2, 0, 1))
|
551 |
+
return torch.FloatTensor(img)
|
552 |
+
|
553 |
+
def video_to_tensor(self, video, device):
|
554 |
+
video_tensor_list = []
|
555 |
+
import torchvision.transforms as transforms
|
556 |
+
transform_func = transforms.ToTensor()
|
557 |
+
for frame in video:
|
558 |
+
frame_pil = Image.fromarray(frame)
|
559 |
+
frame_tensor = transform_func(frame_pil).unsqueeze(0).to(device)
|
560 |
+
video_tensor_list.append(frame_tensor)
|
561 |
+
video_tensor = torch.cat(video_tensor_list, dim=0)
|
562 |
+
return video_tensor
|
563 |
+
|
564 |
+
def process_audio(self, audio_path, sample_rate):
|
565 |
+
wav = load_wav_util(audio_path, sample_rate)
|
566 |
+
wav_tensor = torch.FloatTensor(wav).unsqueeze(0)
|
567 |
+
return wav_tensor, sample_rate
|
568 |
+
|
569 |
+
def generate_still_pose(self, pose_style):
|
570 |
+
ref_pose_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device)
|
571 |
+
ref_pose_coeff[:, :3] = torch.tensor([0, 0, pose_style * 0.3], dtype=torch.float32)
|
572 |
+
return ref_pose_coeff
|
573 |
+
|
574 |
+
def generate_still_expression(self, exp_scale):
|
575 |
+
ref_expression_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device)
|
576 |
+
ref_expression_coeff[:, :3] = torch.tensor([0, 0, exp_scale * 0.3], dtype=torch.float32)
|
577 |
+
return ref_expression_coeff
|
578 |
+
|
579 |
+
def generate_idles_pose(self, length_of_audio, pose_style):
|
580 |
+
num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS'])
|
581 |
+
ref_pose_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
|
582 |
+
start_pose = self.generate_still_pose(pose_style)
|
583 |
+
end_pose = self.generate_still_pose(pose_style)
|
584 |
+
for frame_idx in range(num_frames):
|
585 |
+
alpha = frame_idx / num_frames
|
586 |
+
ref_pose_coeff[frame_idx] = (1 - alpha) * start_pose + alpha * end_pose
|
587 |
+
return ref_pose_coeff
|
588 |
+
|
589 |
+
def generate_idles_expression(self, length_of_audio):
|
590 |
+
num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS'])
|
591 |
+
ref_expression_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
|
592 |
+
start_exp = self.generate_still_expression(1.0)
|
593 |
+
end_exp = self.generate_still_expression(1.0)
|
594 |
+
for frame_idx in range(num_frames):
|
595 |
+
alpha = frame_idx / num_frames
|
596 |
+
ref_expression_coeff[frame_idx] = (1 - alpha) * start_exp + alpha * end_exp
|
597 |
+
return ref_expression_coeff
|
598 |
+
|
599 |
+
|
600 |
+
class KeyPointExtractor(nn.Module):
|
601 |
+
|
602 |
+
def __init__(self, sadtalker_cfg, device):
|
603 |
+
super(KeyPointExtractor, self).__init__()
|
604 |
+
self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'],
|
605 |
+
num_kp=10,
|
606 |
+
num_dilation_blocks=2,
|
607 |
+
dropout_rate=0.1).to(device)
|
608 |
+
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors')
|
609 |
+
self.load_kp_detector(checkpoint_path, device)
|
610 |
+
|
611 |
+
def load_kp_detector(self, checkpoint_path, device):
|
612 |
+
if os.path.exists(checkpoint_path):
|
613 |
+
if checkpoint_path.endswith('safetensors'):
|
614 |
+
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
615 |
+
else:
|
616 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
617 |
+
self.kp_extractor.load_state_dict(checkpoint.get('kp_detector', {}))
|
618 |
+
else:
|
619 |
+
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
620 |
+
|
621 |
+
def forward(self, x):
|
622 |
+
kp = self.kp_extractor(x)
|
623 |
+
return kp
|
624 |
+
|
625 |
+
|
626 |
+
class Audio2Coeff(nn.Module):
|
627 |
+
|
628 |
+
def __init__(self, sadtalker_cfg, device):
|
629 |
+
super(Audio2Coeff, self).__init__()
|
630 |
+
self.audio_model = Wav2Vec2Model().to(device)
|
631 |
+
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'wav2vec2.pth')
|
632 |
+
self.load_audio_model(checkpoint_path, device)
|
633 |
+
self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device)
|
634 |
+
self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device)
|
635 |
+
self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device)
|
636 |
+
mapping_checkpoint = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'audio2pose_00140-model.pth')
|
637 |
+
self.load_mapping_model(mapping_checkpoint, device)
|
638 |
+
|
639 |
+
def load_audio_model(self, checkpoint_path, device):
|
640 |
+
if os.path.exists(checkpoint_path):
|
641 |
+
if checkpoint_path.endswith('safetensors'):
|
642 |
+
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
643 |
+
else:
|
644 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
645 |
+
self.audio_model.load_state_dict(checkpoint.get("wav2vec2", {}))
|
646 |
+
else:
|
647 |
+
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
648 |
+
|
649 |
+
def load_mapping_model(self, checkpoint_path, device):
|
650 |
+
if os.path.exists(checkpoint_path):
|
651 |
+
if checkpoint_path.endswith('safetensors'):
|
652 |
+
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
653 |
+
else:
|
654 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
655 |
+
self.pose_mapper.load_state_dict(checkpoint.get("pose_predictor", {}))
|
656 |
+
self.exp_mapper.load_state_dict(checkpoint.get("exp_predictor", {}))
|
657 |
+
self.blink_mapper.load_state_dict(checkpoint.get("blink_predictor", {}))
|
658 |
+
else:
|
659 |
+
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
660 |
+
|
661 |
+
def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''):
|
662 |
+
audio_embedding = self.audio_model(audio_tensor)
|
663 |
+
pose_coeff = self.pose_mapper(audio_embedding)
|
664 |
+
if ref_pose_coeff is not None:
|
665 |
+
pose_coeff = ref_pose_coeff
|
666 |
+
if kp_ref is not None and use_ref_info == 'pose':
|
667 |
+
ref_pose_6d = kp_ref['value'][:, :6]
|
668 |
+
pose_coeff[:, :6] = self.mean_std_normalize(ref_pose_6d).mean(dim=1)
|
669 |
+
return pose_coeff
|
670 |
+
|
671 |
+
def get_exp_coeff(self, audio_tensor, ref_expression_coeff=None):
|
672 |
+
audio_embedding = self.audio_model(audio_tensor)
|
673 |
+
expression_coeff = self.exp_mapper(audio_embedding)
|
674 |
+
if ref_expression_coeff is not None:
|
675 |
+
expression_coeff = ref_expression_coeff
|
676 |
+
return expression_coeff
|
677 |
+
|
678 |
+
def get_blink_coeff(self, audio_tensor):
|
679 |
+
audio_embedding = self.audio_model(audio_tensor)
|
680 |
+
blink_coeff = self.blink_mapper(audio_embedding)
|
681 |
+
return blink_coeff
|
682 |
+
|
683 |
+
def forward(self, audio):
|
684 |
+
audio_embedding = self.audio_model(audio)
|
685 |
+
pose_coeff, expression_coeff, blink_coeff = self.pose_mapper(audio_embedding), self.exp_mapper(
|
686 |
+
audio_embedding), self.blink_mapper(audio_embedding)
|
687 |
+
return pose_coeff, expression_coeff, blink_coeff
|
688 |
+
|
689 |
+
def mean_std_normalize(self, coeff):
|
690 |
+
mean = coeff.mean(dim=1, keepdim=True)
|
691 |
+
std = coeff.std(dim=1, keepdim=True)
|
692 |
+
return (coeff - mean) / std
|
693 |
+
|
694 |
+
|
695 |
+
class AnimateFromCoeff(nn.Module):
|
696 |
+
|
697 |
+
def __init__(self, sadtalker_cfg, device):
|
698 |
+
super(AnimateFromCoeff, self).__init__()
|
699 |
+
self.generator = Generator(sadtalker_cfg, device)
|
700 |
+
self.mapping = Mapping(sadtalker_cfg, device)
|
701 |
+
self.kp_norm = KeypointNorm(device=device)
|
702 |
+
self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, device)
|
703 |
+
|
704 |
+
def normalize_kp(self, kp_driving):
|
705 |
+
return self.kp_norm(kp_driving)
|
706 |
+
|
707 |
+
def generate(self, source_image, kp_source, coeff, generator, mapping, he_estimator, audio, source_image_crop,
|
708 |
+
face_enhancer=None):
|
709 |
+
kp_driving = coeff['kp_driving']
|
710 |
+
jacobian = coeff['jacobian']
|
711 |
+
pose_coeff = coeff['pose_coeff']
|
712 |
+
expression_coeff = coeff['expression_coeff']
|
713 |
+
blink_coeff = coeff['blink_coeff']
|
714 |
+
with torch.no_grad():
|
715 |
+
if blink_coeff is not None:
|
716 |
+
sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
|
717 |
+
dense_motion = sparse_motion['dense_motion']
|
718 |
+
video_deocclusion = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None})
|
719 |
+
face_3d = mapping(expression_coeff, pose_coeff, blink_coeff)
|
720 |
+
video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None},
|
721 |
+
face_3d_param=face_3d)
|
722 |
+
video_output = video_deocclusion['video_no_reocclusion'] + video_3d['video_3d']
|
723 |
+
video_output = self.make_animation(video_output)
|
724 |
+
else:
|
725 |
+
sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
|
726 |
+
dense_motion = sparse_motion['dense_motion']
|
727 |
+
face_3d = mapping(expression_coeff, pose_coeff)
|
728 |
+
video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None},
|
729 |
+
face_3d_param=face_3d)
|
730 |
+
video_output = video_3d['video_3d']
|
731 |
+
video_output = self.make_animation(video_output)
|
732 |
+
if face_enhancer is not None:
|
733 |
+
video_output_enhanced = []
|
734 |
+
for frame in tqdm(video_output, 'Face enhancer running'):
|
735 |
+
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
736 |
+
enhanced_image = face_enhancer.enhance(np.array(pil_image))[0]
|
737 |
+
video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB))
|
738 |
+
video_output = video_output_enhanced
|
739 |
+
return video_output
|
740 |
+
|
741 |
+
def make_animation(self, video_array):
|
742 |
+
H, W, _ = video_array[0].shape
|
743 |
+
out = cv2.VideoWriter('./tmp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H))
|
744 |
+
for img in video_array:
|
745 |
+
out.write(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
746 |
+
out.release()
|
747 |
+
video = imageio.mimread('./tmp.mp4')
|
748 |
+
os.remove('./tmp.mp4')
|
749 |
+
return video
|
750 |
+
|
751 |
+
|
752 |
+
class Generator(nn.Module):
|
753 |
+
|
754 |
+
def __init__(self, sadtalker_cfg, device):
|
755 |
+
super(Generator, self).__init__()
|
756 |
+
self.generator = Hourglass(block_expansion=sadtalker_cfg['MODEL']['SCALE'],
|
757 |
+
num_blocks=sadtalker_cfg['MODEL']['NUM_VOXEL_FRAMES'],
|
758 |
+
max_features=sadtalker_cfg['MODEL']['MAX_FEATURES'],
|
759 |
+
num_channels=3,
|
760 |
+
kp_size=10,
|
761 |
+
num_deform_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES']).to(device)
|
762 |
+
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'generator.pth')
|
763 |
+
self.load_generator(checkpoint_path, device)
|
764 |
+
|
765 |
+
def load_generator(self, checkpoint_path, device):
|
766 |
+
if os.path.exists(checkpoint_path):
|
767 |
+
if checkpoint_path.endswith('safetensors'):
|
768 |
+
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
769 |
+
else:
|
770 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
771 |
+
self.generator.load_state_dict(checkpoint.get('generator', {}))
|
772 |
+
else:
|
773 |
+
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
774 |
+
|
775 |
+
def forward(self, source_image, dense_motion, bg_param, face_3d_param=None):
|
776 |
+
if face_3d_param is not None:
|
777 |
+
video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param,
|
778 |
+
face_3d_param=face_3d_param)
|
779 |
+
else:
|
780 |
+
video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param)
|
781 |
+
return {'video_3d': video_3d, 'video_no_reocclusion': video_3d}
|
782 |
+
|
783 |
+
|
784 |
+
class Mapping(nn.Module):
|
785 |
+
|
786 |
+
def __init__(self, sadtalker_cfg, device):
|
787 |
+
super(Mapping, self).__init__()
|
788 |
+
self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device)
|
789 |
+
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'mapping.pth')
|
790 |
+
self.load_mapping_net(checkpoint_path, device)
|
791 |
+
self.f_3d_mean = torch.zeros(1, 64, device=device)
|
792 |
+
|
793 |
+
def load_mapping_net(self, checkpoint_path, device):
|
794 |
+
if os.path.exists(checkpoint_path):
|
795 |
+
if checkpoint_path.endswith('safetensors'):
|
796 |
+
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
797 |
+
else:
|
798 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
799 |
+
self.mapping_net.load_state_dict(checkpoint.get('mapping', {}))
|
800 |
+
else:
|
801 |
+
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
802 |
+
|
803 |
+
def forward(self, expression_coeff, pose_coeff, blink_coeff=None):
|
804 |
+
coeff = torch.cat([expression_coeff, pose_coeff], dim=1)
|
805 |
+
face_3d = self.mapping_net(coeff) + self.f_3d_mean
|
806 |
+
if blink_coeff is not None:
|
807 |
+
face_3d[:, -1:] = blink_coeff
|
808 |
+
return face_3d
|
809 |
+
|
810 |
+
|
811 |
+
class OcclusionAwareDenseMotion(nn.Module):
|
812 |
+
|
813 |
+
def __init__(self, sadtalker_cfg, device):
|
814 |
+
super(OcclusionAwareDenseMotion, self).__init__()
|
815 |
+
self.dense_motion_network = DenseMotionNetwork(num_kp=10,
|
816 |
+
num_channels=3,
|
817 |
+
block_expansion=sadtalker_cfg['MODEL']['SCALE'],
|
818 |
+
num_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'] - 1,
|
819 |
+
max_features=sadtalker_cfg['MODEL']['MAX_FEATURES']).to(device)
|
820 |
+
checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'dense_motion.pth')
|
821 |
+
self.load_dense_motion_network(checkpoint_path, device)
|
822 |
+
|
823 |
+
def load_dense_motion_network(self, checkpoint_path, device):
|
824 |
+
if os.path.exists(checkpoint_path):
|
825 |
+
if checkpoint_path.endswith('safetensors'):
|
826 |
+
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
827 |
+
else:
|
828 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
829 |
+
self.dense_motion_network.load_state_dict(checkpoint.get('dense_motion', {}))
|
830 |
+
else:
|
831 |
+
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
832 |
+
|
833 |
+
def forward(self, kp_source, kp_driving, jacobian):
|
834 |
+
sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian)
|
835 |
+
return sparse_motion
|
836 |
+
|
837 |
+
|
838 |
+
class FaceEnhancer(nn.Module):
|
839 |
+
|
840 |
+
def __init__(self, sadtalker_cfg, device):
|
841 |
+
super(FaceEnhancer, self).__init__()
|
842 |
+
enhancer_name = sadtalker_cfg['MODEL']['ENHANCER_NAME']
|
843 |
+
bg_upsampler = sadtalker_cfg['MODEL']['BG_UPSAMPLER']
|
844 |
+
if enhancer_name == 'gfpgan':
|
845 |
+
from gfpgan import GFPGANer
|
846 |
+
self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'GFPGANv1.4.pth'),
|
847 |
+
upscale=1,
|
848 |
+
arch='clean',
|
849 |
+
channel_multiplier=2,
|
850 |
+
bg_upsampler=bg_upsampler)
|
851 |
+
elif enhancer_name == 'realesrgan':
|
852 |
+
from realesrgan import RealESRGANer
|
853 |
+
half = False if device == 'cpu' else sadtalker_cfg['MODEL']['IS_HALF']
|
854 |
+
self.face_enhancer = RealESRGANer(scale=2,
|
855 |
+
model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'],
|
856 |
+
'RealESRGAN_x2plus.pth'),
|
857 |
+
tile=0,
|
858 |
+
tile_pad=10,
|
859 |
+
pre_pad=0,
|
860 |
+
half=half,
|
861 |
+
device=device)
|
862 |
+
else:
|
863 |
+
self.face_enhancer = None
|
864 |
+
|
865 |
+
def forward(self, x):
|
866 |
+
return self.face_enhancer.enhance(x, outscale=1)[0]
|
sentiment_api.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import jsonify
|
2 |
+
from main import *
|
3 |
+
#from main import import sentiment_model, device
|
4 |
+
import torch
|
5 |
+
|
6 |
+
def analyze_sentiment(text, output_path="output_sentiment.json"):
|
7 |
+
if sentiment_model is None:
|
8 |
+
return "Sentiment model not initialized."
|
9 |
+
|
10 |
+
input_tokens = sentiment_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
|
11 |
+
with torch.no_grad():
|
12 |
+
sentiment_logits = sentiment_model(input_tokens['input_ids'])
|
13 |
+
predicted_class_id = torch.argmax(sentiment_logits, dim=-1).item()
|
14 |
+
sentiment_label = sentiment_model.config.id2label[predicted_class_id]
|
15 |
+
probability = torch.softmax(sentiment_logits, dim=-1)[0][predicted_class_id].item()
|
16 |
+
|
17 |
+
return {"sentiment": sentiment_label, "probability": probability}
|
18 |
+
|
19 |
+
def sentiment_api():
|
20 |
+
data = request.get_json()
|
21 |
+
text = data.get('text')
|
22 |
+
if not text:
|
23 |
+
return jsonify({"error": "Text is required"}), 400
|
24 |
+
output_file = analyze_sentiment(text)
|
25 |
+
if output_file == "Sentiment model not initialized.":
|
26 |
+
return jsonify({"error": "Sentiment analysis failed"}), 500
|
27 |
+
return jsonify(output_file)
|
stt_api.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import uuid
|
3 |
+
from flask import jsonify, send_file, request
|
4 |
+
from main import *
|
5 |
+
#from main import import stt_model, device
|
6 |
+
import torch
|
7 |
+
import torchaudio
|
8 |
+
|
9 |
+
def speech_to_text_func(audio_path, output_path="output_stt.txt"):
|
10 |
+
if stt_model is None:
|
11 |
+
return "STT model not initialized."
|
12 |
+
|
13 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
14 |
+
if waveform.ndim > 1:
|
15 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
16 |
+
waveform = waveform.to(device)
|
17 |
+
with torch.no_grad():
|
18 |
+
logits = stt_model(waveform)
|
19 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
20 |
+
transcription = stt_model.tokenizer.decode(predicted_ids[0].cpu().tolist())
|
21 |
+
|
22 |
+
with open(output_path, "w") as file:
|
23 |
+
file.write(transcription)
|
24 |
+
return output_path
|
25 |
+
|
26 |
+
def stt_api():
|
27 |
+
if 'audio' not in request.files:
|
28 |
+
return jsonify({"error": "Audio file is required"}), 400
|
29 |
+
audio_file = request.files['audio']
|
30 |
+
temp_audio_path = f"temp_audio_{uuid.uuid4()}.wav"
|
31 |
+
audio_file.save(temp_audio_path)
|
32 |
+
output_file = speech_to_text_func(temp_audio_path)
|
33 |
+
os.remove(temp_audio_path)
|
34 |
+
if output_file == "STT model not initialized.":
|
35 |
+
return jsonify({"error": "STT failed"}), 500
|
36 |
+
return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output.txt")
|
summarization_api.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import jsonify, send_file, request
|
2 |
+
from main import *
|
3 |
+
#from main import import summarization_model, summarization_word_to_index, device
|
4 |
+
import torch
|
5 |
+
|
6 |
+
def summarize_text(text, output_path="output_summary.txt"):
|
7 |
+
if summarization_model is None:
|
8 |
+
return "Summarization model not initialized."
|
9 |
+
|
10 |
+
input_tokens = [summarization_word_to_index.get(word.lower(), 1) for word in text.split()]
|
11 |
+
input_tensor = torch.tensor([input_tokens], dtype=torch.long).to(device)
|
12 |
+
|
13 |
+
with torch.no_grad():
|
14 |
+
summary_ids = summarization_model.generate(input_tensor, num_beams=4, max_length=100, early_stopping=True)
|
15 |
+
summary_text = summarization_model.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
16 |
+
|
17 |
+
with open(output_path, "w") as file:
|
18 |
+
file.write(summary_text)
|
19 |
+
return output_path
|
20 |
+
|
21 |
+
def summarization_api():
|
22 |
+
data = request.get_json()
|
23 |
+
text = data.get('text')
|
24 |
+
if not text:
|
25 |
+
return jsonify({"error": "Text is required"}), 400
|
26 |
+
output_file = summarize_text(text)
|
27 |
+
if output_file == "Summarization model not initialized.":
|
28 |
+
return jsonify({"error": "Summarization failed"}), 500
|
29 |
+
return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output_summary.txt")
|
text_generation.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from tqdm import trange
|
4 |
+
import time
|
5 |
+
from tokenxxx import *
|
6 |
+
from main import *
|
7 |
+
#from main import import model_gpt2, enc, codegen_model, codegen_tokenizer, summarization_model, device, system_prompt, MAX_LENGTH, summarize_text as summarize_func
|
8 |
+
from duckduckgo_search import DDGS
|
9 |
+
|
10 |
+
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
11 |
+
top_k = min(top_k, logits.size(-1))
|
12 |
+
if top_k > 0:
|
13 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., [-1]]
|
14 |
+
logits[indices_to_remove] = filter_value
|
15 |
+
if top_p > 0.0:
|
16 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
17 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
18 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
19 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
20 |
+
sorted_indices_to_remove[..., 0] = 0
|
21 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
22 |
+
logits[indices_to_remove] = filter_value
|
23 |
+
return logits
|
24 |
+
|
25 |
+
def sample_sequence(prompt, model, enc, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
|
26 |
+
start_time = time.time()
|
27 |
+
context_tokens = enc.encode(prompt)
|
28 |
+
context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
|
29 |
+
generated = context_tokens
|
30 |
+
past = None
|
31 |
+
text_generated_count = 0
|
32 |
+
past_key_values = past if past is not None else None
|
33 |
+
|
34 |
+
with torch.no_grad():
|
35 |
+
outputs = model(context_tokens_tensor, past_key_values=past_key_values)
|
36 |
+
next_token_logits = outputs[0][:, -1, :] / temperature
|
37 |
+
past = outputs[1]
|
38 |
+
for token_index in set(generated):
|
39 |
+
next_token_logits[0, token_index] /= repetition_penalty
|
40 |
+
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
41 |
+
if temperature == 0:
|
42 |
+
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
|
43 |
+
else:
|
44 |
+
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
45 |
+
generated += next_token.tolist()[0]
|
46 |
+
text_generated_count += 1
|
47 |
+
token = next_token.tolist()[0][0]
|
48 |
+
yield enc.decode([token])
|
49 |
+
if token == enc.encoder[END_OF_TEXT_TOKEN]:
|
50 |
+
yield "<END_STREAM>"
|
51 |
+
if text_generated_count > length:
|
52 |
+
yield "<END_STREAM>"
|
53 |
+
if (time.time() - start_time) * 1000 > 5000:
|
54 |
+
yield "<END_STREAM>"
|
55 |
+
|
56 |
+
def sample_sequence_codegen(prompt, model, tokenizer, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
|
57 |
+
start_time = time.time()
|
58 |
+
context_tokens = tokenizer.encode(prompt)
|
59 |
+
context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device).unsqueeze(0)
|
60 |
+
generated = context_tokens
|
61 |
+
past = None
|
62 |
+
text_generated_count = 0
|
63 |
+
with torch.no_grad():
|
64 |
+
outputs = model(input_ids=context_tokens_tensor, past_key_values=past, labels=None)
|
65 |
+
next_token_logits = outputs[0][:, -1, :] / temperature
|
66 |
+
past = outputs[1]
|
67 |
+
for token_index in set(generated):
|
68 |
+
next_token_logits[0, token_index] /= repetition_penalty
|
69 |
+
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
70 |
+
if temperature == 0:
|
71 |
+
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
|
72 |
+
else:
|
73 |
+
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
74 |
+
generated.append(next_token.tolist()[0][0])
|
75 |
+
text_generated_count += 1
|
76 |
+
token = next_token.tolist()[0][0]
|
77 |
+
yield tokenizer.decode([token])
|
78 |
+
if token == 50256:
|
79 |
+
yield "<END_STREAM>"
|
80 |
+
if text_generated_count > length:
|
81 |
+
yield "<END_STREAM>"
|
82 |
+
if (time.time() - start_time) * 1000 > 5000:
|
83 |
+
yield "<END_STREAM>"
|
84 |
+
|
85 |
+
def perform_reasoning_stream(text_input, temperature, top_k, top_p, repetition_penalty):
|
86 |
+
try:
|
87 |
+
prompt_text = system_prompt + "\n\n"
|
88 |
+
prompt_text += "User: " + text_input + "\nCyrah: "
|
89 |
+
reasoning_prompt = prompt_text
|
90 |
+
|
91 |
+
ddgs = DDGS()
|
92 |
+
search_results = [r for r in ddgs.text(text_input, max_results=MAX_XDD)]
|
93 |
+
if search_results:
|
94 |
+
prompt_text += "\nWeb Search Results:\n"
|
95 |
+
for result in search_results:
|
96 |
+
prompt_text += f"- {result['body']}\n"
|
97 |
+
prompt_text += "\n"
|
98 |
+
|
99 |
+
generated_text_stream = []
|
100 |
+
stream_type = "text"
|
101 |
+
|
102 |
+
if "code" in text_input.lower() or "program" in text_input.lower():
|
103 |
+
if codegen_model and codegen_tokenizer:
|
104 |
+
generated_text_stream = sample_sequence_codegen(
|
105 |
+
prompt=reasoning_prompt,
|
106 |
+
model=codegen_model,
|
107 |
+
tokenizer=codegen_tokenizer,
|
108 |
+
length=MAX_LENGTH,
|
109 |
+
temperature=temperature,
|
110 |
+
top_k=top_k,
|
111 |
+
top_p=top_p,
|
112 |
+
repetition_penalty=repetition_penalty,
|
113 |
+
device=device
|
114 |
+
)
|
115 |
+
stream_type = "text"
|
116 |
+
elif "summarize" in text_input.lower() or "summary" in text_input.lower():
|
117 |
+
if summarization_model:
|
118 |
+
summary = summarize_func(text_input)
|
119 |
+
yield f"SUMMARY_TEXT:{summary}"
|
120 |
+
yield "<END_STREAM>"
|
121 |
+
stream_type = "summary"
|
122 |
+
else:
|
123 |
+
if model_gpt2 and enc:
|
124 |
+
generated_text_stream = sample_sequence(
|
125 |
+
prompt=reasoning_prompt,
|
126 |
+
model=model_gpt2,
|
127 |
+
enc=enc,
|
128 |
+
length=MAX_LENGTH,
|
129 |
+
temperature=temperature,
|
130 |
+
top_k=top_k,
|
131 |
+
top_p=top_p,
|
132 |
+
repetition_penalty=repetition_penalty,
|
133 |
+
device=device
|
134 |
+
)
|
135 |
+
stream_type = "text"
|
136 |
+
|
137 |
+
accumulated_text = ""
|
138 |
+
if stream_type == "text":
|
139 |
+
for token in generated_text_stream:
|
140 |
+
if token == "<END_STREAM>":
|
141 |
+
yield accumulated_text
|
142 |
+
yield "<END_STREAM>"
|
143 |
+
return
|
144 |
+
if token == END_OF_TEXT_TOKEN:
|
145 |
+
accumulated_text += END_OF_TEXT_TOKEN
|
146 |
+
continue
|
147 |
+
if token:
|
148 |
+
accumulated_text += token
|
149 |
+
except Exception as e:
|
150 |
+
print(f"Reasoning Error: {e}")
|
151 |
+
yield "Error during reasoning. Please try again."
|
152 |
+
yield "<END_STREAM>"
|
text_to_video_api.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import uuid
|
3 |
+
from flask import jsonify, send_file, request
|
4 |
+
from main import *
|
5 |
+
#from main import import text_to_video_model
|
6 |
+
import torch
|
7 |
+
import io
|
8 |
+
from skimage import img_as_ubyte
|
9 |
+
import imageio
|
10 |
+
|
11 |
+
def text_to_video_func(prompt, output_path="output_video.mp4"):
|
12 |
+
if text_to_video_model is None:
|
13 |
+
return "Text-to-Video model not initialized."
|
14 |
+
video_frames_list = text_to_video_model(prompt)
|
15 |
+
if video_frames_list and hasattr(video_frames_list, 'frames'):
|
16 |
+
video_frames = video_frames_list.frames
|
17 |
+
export_to_video_pure(video_frames, output_video=output_path)
|
18 |
+
return output_path
|
19 |
+
return "Video generation failed."
|
20 |
+
|
21 |
+
def export_to_video_pure(video_frames, output_video="output_video.mp4", fps=25):
|
22 |
+
writer = imageio.get_writer(output_video, fps=fps)
|
23 |
+
for frame in video_frames:
|
24 |
+
writer.append_data(img_as_ubyte(frame))
|
25 |
+
writer.close()
|
26 |
+
|
27 |
+
def text_to_video_api():
|
28 |
+
data = request.get_json()
|
29 |
+
prompt = data.get('prompt')
|
30 |
+
if not prompt:
|
31 |
+
return jsonify({"error": "Prompt is required"}), 400
|
32 |
+
output_file = text_to_video_func(prompt)
|
33 |
+
if output_file == "Text-to-Video model not initialized." or output_file == "Video generation failed.":
|
34 |
+
return jsonify({"error": "Text to video failed"}), 500
|
35 |
+
with open(output_file, 'rb') as f:
|
36 |
+
video_content = f.read()
|
37 |
+
return send_file(io.BytesIO(video_content), mimetype='video/mp4', as_attachment=True, download_name="output_video.mp4")
|
tokenxxx.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
import unicodedata
|
4 |
+
from functools import lru_cache
|
5 |
+
import wget
|
6 |
+
import os
|
7 |
+
from constants import *
|
8 |
+
import nltk
|
9 |
+
|
10 |
+
@lru_cache()
|
11 |
+
def bytes_to_unicode():
|
12 |
+
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
13 |
+
cs = bs[:]
|
14 |
+
n = 0
|
15 |
+
for b in range(2**8):
|
16 |
+
if b not in bs:
|
17 |
+
bs.append(b)
|
18 |
+
cs.append(2**8 + n)
|
19 |
+
n += 1
|
20 |
+
cs = [chr(n) for n in cs]
|
21 |
+
return dict(zip(bs, cs))
|
22 |
+
|
23 |
+
def get_pairs(word):
|
24 |
+
pairs = set()
|
25 |
+
prev_char = word[0]
|
26 |
+
for char in word[1:]:
|
27 |
+
pairs.add((prev_char, char))
|
28 |
+
prev_char = char
|
29 |
+
return pairs
|
30 |
+
|
31 |
+
class Encoder:
|
32 |
+
def __init__(self, encoder, bpe_merges, errors='replace', tokenize=None):
|
33 |
+
self.encoder = encoder
|
34 |
+
self.decoder = {v:k for k,v in self.encoder.items()}
|
35 |
+
self.errors = errors
|
36 |
+
self.byte_encoder = bytes_to_unicode()
|
37 |
+
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
|
38 |
+
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
39 |
+
self.cache = {}
|
40 |
+
if tokenize is None:
|
41 |
+
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?[^\s\w]+|\s+(?!\S)|\s+""", re.UNICODE)
|
42 |
+
self.tokenize = lambda text: re.findall(self.pat, text)
|
43 |
+
else:
|
44 |
+
self.tokenize = tokenize
|
45 |
+
|
46 |
+
def bpe(self, token):
|
47 |
+
if token in self.cache:
|
48 |
+
return self.cache[token]
|
49 |
+
word = tuple(token)
|
50 |
+
pairs = get_pairs(word)
|
51 |
+
if not pairs:
|
52 |
+
return token
|
53 |
+
while True:
|
54 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
55 |
+
if bigram not in self.bpe_ranks:
|
56 |
+
break
|
57 |
+
first, second = bigram
|
58 |
+
new_word = []
|
59 |
+
i = 0
|
60 |
+
while i < len(word):
|
61 |
+
try:
|
62 |
+
j = word.index(first, i)
|
63 |
+
new_word.extend(word[i:j])
|
64 |
+
i = j
|
65 |
+
except ValueError:
|
66 |
+
new_word.extend(word[i:])
|
67 |
+
break
|
68 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
69 |
+
new_word.append(first+second)
|
70 |
+
i += 2
|
71 |
+
else:
|
72 |
+
new_word.append(word[i])
|
73 |
+
i += 1
|
74 |
+
new_word = tuple(new_word)
|
75 |
+
word = new_word
|
76 |
+
if len(word) == 1:
|
77 |
+
break
|
78 |
+
else:
|
79 |
+
pairs = get_pairs(word)
|
80 |
+
word = ' '.join(word)
|
81 |
+
self.cache[token] = word
|
82 |
+
return word
|
83 |
+
|
84 |
+
def encode(self, text):
|
85 |
+
bpe_tokens = []
|
86 |
+
normalized_text = unicodedata.normalize('NFKC', text)
|
87 |
+
normalized_text = ''.join(c for c in normalized_text if c.isascii() and c != '\t')
|
88 |
+
normalized_text = ''.join(c for c in normalized_text if not unicodedata.category(c).startswith('C'))
|
89 |
+
for token in self.tokenize(normalized_text):
|
90 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8', errors='ignore'))
|
91 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
92 |
+
return bpe_tokens
|
93 |
+
|
94 |
+
def decode(self, tokens):
|
95 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
96 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
|
97 |
+
decoded_text = text.replace(" .", ".").replace(" ,", ",").replace(" '", "'").replace(" ?", "?").replace(" !", "!").replace(" :", ":").replace('\n', '<br>')
|
98 |
+
sentences = nltk.sent_tokenize(decoded_text)
|
99 |
+
return ' '.join(sentences).replace("<br>", "<br>\n")
|
100 |
+
|
101 |
+
def get_encoder_gpt2():
|
102 |
+
encoder_path = os.path.join(GPT2_FOLDER, ENCODER_FILE)
|
103 |
+
vocab_path = os.path.join(GPT2_FOLDER, VOCAB_FILE)
|
104 |
+
if not os.path.exists(GPT2_FOLDER):
|
105 |
+
os.makedirs(GPT2_FOLDER)
|
106 |
+
if not os.path.exists(encoder_path):
|
107 |
+
wget.download(ENCODER_URL, out=encoder_path)
|
108 |
+
if not os.path.exists(vocab_path):
|
109 |
+
wget.download(VOCAB_URL, out=vocab_path)
|
110 |
+
|
111 |
+
with open(encoder_path, 'r') as f:
|
112 |
+
encoder = json.load(f)
|
113 |
+
with open(vocab_path, 'r', encoding="utf-8") as f:
|
114 |
+
bpe_data = f.read()
|
115 |
+
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
|
116 |
+
encoder_obj = Encoder(encoder=encoder, bpe_merges=bpe_merges)
|
117 |
+
encoder_obj.encoder[END_OF_TEXT_TOKEN] = len(encoder_obj.encoder)
|
118 |
+
encoder_obj.decoder[len(encoder_obj.decoder)] = END_OF_TEXT_TOKEN
|
119 |
+
return encoder_obj
|
120 |
+
|
121 |
+
def get_codegen_tokenizer_pure(vocab_file, merges_file):
|
122 |
+
vocab = json.load(open(vocab_file))
|
123 |
+
merges = open(merges_file, 'r', encoding="utf-8").read().split('\n')[1:-1]
|
124 |
+
bpe_merges = [tuple(m.split()) for m in merges]
|
125 |
+
byte_encoder = bytes_to_unicode()
|
126 |
+
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
127 |
+
tokenizer_regex = re.compile(r'''<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+''')
|
128 |
+
tokenize = lambda text: re.findall(tokenizer_regex, text)
|
129 |
+
encoder_obj = Encoder(
|
130 |
+
encoder=vocab,
|
131 |
+
bpe_merges=bpe_merges,
|
132 |
+
byte_encoder=byte_encoder,
|
133 |
+
byte_decoder=byte_decoder,
|
134 |
+
tokenize=tokenize
|
135 |
+
)
|
136 |
+
return encoder_obj
|
137 |
+
|
138 |
+
def codegen_tokenize(text, tokenizer):
|
139 |
+
return tokenizer.encode(text)
|
140 |
+
|
141 |
+
def codegen_decode(tokens, tokenizer):
|
142 |
+
return tokenizer.decode(tokens)
|
143 |
+
|
144 |
+
def tokenize_text(text):
|
145 |
+
global vocabulary, word_to_index, index_to_word
|
146 |
+
tokens = text.lower().split()
|
147 |
+
for token in tokens:
|
148 |
+
if token not in vocabulary:
|
149 |
+
vocabulary.add(token)
|
150 |
+
word_to_index[token] = len(index_to_word)
|
151 |
+
index_to_word.append(token)
|
152 |
+
return tokens
|
153 |
+
|
154 |
+
def text_to_vector(text):
|
155 |
+
global vocabulary, word_to_index
|
156 |
+
tokens = tokenize_text(text)
|
157 |
+
vector = torch.zeros(len(vocabulary))
|
158 |
+
for token in tokens:
|
159 |
+
if token in word_to_index:
|
160 |
+
vector[word_to_index[token]] += 1
|
161 |
+
return vector
|
translation_api.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import jsonify, send_file, request
|
2 |
+
from main import *
|
3 |
+
#from main import import translation_model, device
|
4 |
+
|
5 |
+
def perform_translation(text, target_language_code='es_XX', source_language_code='en_XX', output_path="output_translation.txt"):
|
6 |
+
if translation_model is None:
|
7 |
+
return "Translation model not initialized."
|
8 |
+
|
9 |
+
encoded_text = translation_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
|
10 |
+
generated_tokens = translation_model.generate(input_ids=encoded_text['input_ids'], attention_mask=encoded_text['attention_mask'], forced_bos_token_id=translation_model.config.lang_code_to_id[target_language_code])
|
11 |
+
translation = translation_model.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
12 |
+
|
13 |
+
with open(output_path, "w") as file:
|
14 |
+
file.write(translation)
|
15 |
+
return output_path
|
16 |
+
|
17 |
+
def translation_api():
|
18 |
+
data = request.get_json()
|
19 |
+
text = data.get('text')
|
20 |
+
target_lang = data.get('target_lang', 'es')
|
21 |
+
source_lang = data.get('source_lang', 'en')
|
22 |
+
if not text:
|
23 |
+
return jsonify({"error": "Text is required"}), 400
|
24 |
+
output_file = perform_translation(text, target_language_code=f'{target_lang}_XX', source_language_code=f'{source_lang}_XX')
|
25 |
+
if output_file == "Translation model not initialized.":
|
26 |
+
return jsonify({"error": "Translation failed"}), 500
|
27 |
+
return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output_translation.txt")
|
tts_api.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from flask import jsonify, send_file, request
|
3 |
+
from main import *
|
4 |
+
#from main import import tts_model, device
|
5 |
+
|
6 |
+
def text_to_speech_func(text, output_path="output_tts.wav"):
|
7 |
+
if tts_model is None:
|
8 |
+
return "TTS model not initialized."
|
9 |
+
input_tokens = tts_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
|
10 |
+
with torch.no_grad():
|
11 |
+
audio_output = tts_model(input_tokens['input_ids'])
|
12 |
+
torchaudio.save(output_path, audio_output.cpu(), 16000)
|
13 |
+
return output_path
|
14 |
+
|
15 |
+
def tts_api():
|
16 |
+
data = request.get_json()
|
17 |
+
text = data.get('text')
|
18 |
+
if not text:
|
19 |
+
return jsonify({"error": "Text is required"}), 400
|
20 |
+
output_file = text_to_speech_func(text)
|
21 |
+
if output_file == "TTS model not initialized.":
|
22 |
+
return jsonify({"error": "TTS generation failed"}), 500
|
23 |
+
return send_file(output_file, mimetype="audio/wav", as_attachment=True, download_name="output.wav")
|
utils.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
from faker import Faker
|
4 |
+
from urllib.request import urlretrieve
|
5 |
+
import urllib.request
|
6 |
+
from urllib3.util.retry import Retry
|
7 |
+
import time
|
8 |
+
import os
|
9 |
+
import wget
|
10 |
+
import json
|
11 |
+
import unicodedata
|
12 |
+
import nltk
|
13 |
+
from sklearn.datasets import fetch_20newsgroups
|
14 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
15 |
+
from sklearn.linear_model import LogisticRegression
|
16 |
+
from sklearn.multiclass import OneVsRestClassifier
|
17 |
+
import warnings
|
18 |
+
from requests.adapters import HTTPAdapter
|
19 |
+
from constants import *
|
20 |
+
|
21 |
+
MAX_XDD = 5
|
22 |
+
use_google_search = True
|
23 |
+
use_20newsgroup = True
|
24 |
+
fake = Faker()
|
25 |
+
|
26 |
+
def create_retry_session():
|
27 |
+
retry_strategy = Retry(
|
28 |
+
total=5,
|
29 |
+
status_forcelist=[429, 500, 502, 503, 504],
|
30 |
+
method_whitelist=["GET"],
|
31 |
+
backoff_factor=1,
|
32 |
+
)
|
33 |
+
adapter = HTTPAdapter(max_retries=retry_strategy)
|
34 |
+
http = requests.Session()
|
35 |
+
http.mount("https://", adapter)
|
36 |
+
http.mount("http://", adapter)
|
37 |
+
return http
|
38 |
+
|
39 |
+
def get_google_search_results(query, retry_session):
|
40 |
+
if not use_google_search:
|
41 |
+
return []
|
42 |
+
headers = {"User-Agent": fake.user_agent()}
|
43 |
+
search_url = f"https://www.google.com/search?q={query}"
|
44 |
+
try:
|
45 |
+
response = retry_session.get(search_url, headers=headers, timeout=10)
|
46 |
+
response.raise_for_status()
|
47 |
+
except requests.exceptions.RequestException as e:
|
48 |
+
return []
|
49 |
+
soup = BeautifulSoup(response.text, "html.parser")
|
50 |
+
search_results = []
|
51 |
+
for a_tag in soup.find_all('a', href=True):
|
52 |
+
if 'url?q=' in a_tag['href'] and not a_tag['href'].startswith("https://accounts.google.com"):
|
53 |
+
search_results.append(a_tag['href'].split('url?q=')[1].split('&')[0])
|
54 |
+
return search_results
|
55 |
+
|
56 |
+
def fetch_20newsgroup_data():
|
57 |
+
if not use_20newsgroup:
|
58 |
+
return []
|
59 |
+
try:
|
60 |
+
newsgroups_train = fetch_20newsgroups(subset='train', categories=['talk.trivia', 'rec.sport.baseball', 'sci.med', 'comp.sys.ibm.pc.hardware', 'soc.religion.christian'])
|
61 |
+
data = newsgroups_train.data
|
62 |
+
return data
|
63 |
+
except Exception as e:
|
64 |
+
return []
|
65 |
+
|
66 |
+
def download_file(url, filename, folder, retries=3):
|
67 |
+
filepath = os.path.join(folder, filename)
|
68 |
+
if os.path.exists(filepath):
|
69 |
+
return True
|
70 |
+
os.makedirs(folder, exist_ok=True)
|
71 |
+
for attempt in range(retries):
|
72 |
+
try:
|
73 |
+
wget.download(url, out=filepath)
|
74 |
+
return True
|
75 |
+
except Exception as e:
|
76 |
+
if attempt < retries - 1:
|
77 |
+
time.sleep(2)
|
78 |
+
else:
|
79 |
+
return False
|
80 |
+
return False
|
81 |
+
|
82 |
+
def download_gpt2_files(folder, model_url, model_file, encoder_url, encoder_file, vocab_url, vocab_file):
|
83 |
+
if not os.path.exists(folder):
|
84 |
+
os.makedirs(folder)
|
85 |
+
if not os.path.exists(os.path.join(folder, model_file)):
|
86 |
+
download_file(model_url, model_file, folder)
|
87 |
+
if not os.path.exists(os.path.join(folder, encoder_file)):
|
88 |
+
download_file(encoder_url, encoder_file, folder)
|
89 |
+
if not os.path.exists(os.path.join(folder, vocab_file)):
|
90 |
+
download_file(vocab_url, vocab_file, folder)
|
91 |
+
|
92 |
+
def download_translation_files(folder, model_files_urls):
|
93 |
+
if not os.path.exists(folder):
|
94 |
+
os.makedirs(folder)
|
95 |
+
for url, filename in model_files_urls:
|
96 |
+
if not os.path.exists(os.path.join(folder, filename)):
|
97 |
+
download_file(url, filename, folder)
|
98 |
+
|
99 |
+
def download_codegen_files(folder, model_files_urls):
|
100 |
+
if not os.path.exists(folder):
|
101 |
+
os.makedirs(folder)
|
102 |
+
for url, filename in model_files_urls:
|
103 |
+
if not os.path.exists(os.path.join(folder, filename)):
|
104 |
+
download_file(url, filename, folder)
|
105 |
+
|
106 |
+
def download_summarization_files(folder, model_files_urls):
|
107 |
+
if not os.path.exists(folder):
|
108 |
+
os.makedirs(folder)
|
109 |
+
for url, filename in model_files_urls:
|
110 |
+
if not os.path.exists(os.path.join(folder, filename)):
|
111 |
+
download_file(url, filename, folder)
|
112 |
+
|
113 |
+
def download_imagegen_files(folder, model_files_urls):
|
114 |
+
if not os.path.exists(folder):
|
115 |
+
os.makedirs(folder)
|
116 |
+
for url, filename in model_files_urls:
|
117 |
+
if not os.path.exists(os.path.join(folder, filename)):
|
118 |
+
download_file(url, filename, folder)
|
119 |
+
|
120 |
+
def download_image_to_3d_files(folder, model_files_urls):
|
121 |
+
if not os.path.exists(folder):
|
122 |
+
os.makedirs(folder)
|
123 |
+
for url, filename in model_files_urls:
|
124 |
+
if not os.path.exists(os.path.join(folder, filename)):
|
125 |
+
download_file(url, filename, folder)
|
126 |
+
|
127 |
+
def download_text_to_video_files(folder, model_files_urls):
|
128 |
+
if not os.path.exists(folder):
|
129 |
+
os.makedirs(folder)
|
130 |
+
for url, filename in model_files_urls:
|
131 |
+
if not os.path.exists(os.path.join(folder, filename)):
|
132 |
+
download_file(url, filename, folder)
|
133 |
+
|
134 |
+
def download_sentiment_files(folder, model_files_urls):
|
135 |
+
if not os.path.exists(folder):
|
136 |
+
os.makedirs(folder)
|
137 |
+
for url, filename in model_files_urls:
|
138 |
+
if not os.path.exists(os.path.join(folder, filename)):
|
139 |
+
download_file(url, filename, folder)
|
140 |
+
|
141 |
+
def download_stt_files(folder, model_files_urls):
|
142 |
+
if not os.path.exists(folder):
|
143 |
+
os.makedirs(folder)
|
144 |
+
for url, filename in model_files_urls:
|
145 |
+
if not os.path.exists(os.path.join(folder, filename)):
|
146 |
+
download_file(url, filename, folder)
|
147 |
+
|
148 |
+
def download_tts_files(folder, model_files_urls):
|
149 |
+
if not os.path.exists(folder):
|
150 |
+
os.makedirs(folder)
|
151 |
+
for url, filename in model_files_urls:
|
152 |
+
if not os.path.exists(os.path.join(folder, filename)):
|
153 |
+
download_file(url, filename, folder)
|
154 |
+
|
155 |
+
def download_musicgen_files(folder, model_files_urls):
|
156 |
+
if not os.path.exists(folder):
|
157 |
+
os.makedirs(folder)
|
158 |
+
for url, filename in model_files_urls:
|
159 |
+
if not os.path.exists(os.path.join(folder, filename)):
|
160 |
+
download_file(url, filename, folder)
|
161 |
+
|
162 |
+
def bytes_to_unicode_gpt2():
|
163 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
164 |
+
cs = bs[:]
|
165 |
+
n = 0
|
166 |
+
for b in range(2**8):
|
167 |
+
if b not in bs:
|
168 |
+
bs.append(b)
|
169 |
+
cs.append(2**8+n)
|
170 |
+
n = n+1
|
171 |
+
cs = [chr(n) for n in cs]
|
172 |
+
return dict(zip(bs, cs))
|
173 |
+
|
174 |
+
def get_codegen_tokenizer_pure(vocab_file, merges_file):
|
175 |
+
vocab = json.load(open(vocab_file))
|
176 |
+
merges = open(merges_file, 'r', encoding="utf-8").read().split('\n')[1:-1]
|
177 |
+
bpe_ranks = dict(zip(merges, range(len(merges))))
|
178 |
+
byte_encoder = bytes_to_unicode()
|
179 |
+
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
180 |
+
tokenizer_regex = re.compile(r'''<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+''')
|
181 |
+
tokenize = lambda text: re.findall(tokenizer_regex, text)
|
182 |
+
encoder_obj = Encoder(
|
183 |
+
encoder=vocab,
|
184 |
+
decoder={v: u for u, v in vocab.items()},
|
185 |
+
bpe_ranks=bpe_ranks,
|
186 |
+
byte_encoder=byte_encoder,
|
187 |
+
byte_decoder=byte_decoder,
|
188 |
+
tokenize=tokenize
|
189 |
+
)
|
190 |
+
return encoder_obj
|
xxx.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
import unicodedata
|
4 |
+
from functools import lru_cache
|
5 |
+
import wget
|
6 |
+
import os
|
7 |
+
from constants import GPT2_FOLDER, ENCODER_FILE, VOCAB_FILE, END_OF_TEXT_TOKEN
|
8 |
+
import nltk
|
9 |
+
|
10 |
+
@lru_cache()
|
11 |
+
def bytes_to_unicode():
|
12 |
+
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
13 |
+
cs = bs[:]
|
14 |
+
n = 0
|
15 |
+
for b in range(2**8):
|
16 |
+
if b not in bs:
|
17 |
+
bs.append(b)
|
18 |
+
cs.append(2**8 + n)
|
19 |
+
n += 1
|
20 |
+
cs = [chr(n) for n in cs]
|
21 |
+
return dict(zip(bs, cs))
|
22 |
+
|
23 |
+
def get_pairs(word):
|
24 |
+
pairs = set()
|
25 |
+
prev_char = word[0]
|
26 |
+
for char in word[1:]:
|
27 |
+
pairs.add((prev_char, char))
|
28 |
+
prev_char = char
|
29 |
+
return pairs
|
30 |
+
|
31 |
+
class Encoder:
|
32 |
+
def __init__(self, encoder, bpe_merges, errors='replace', tokenize=None):
|
33 |
+
self.encoder = encoder
|
34 |
+
self.decoder = {v:k for k,v in self.encoder.items()}
|
35 |
+
self.errors = errors
|
36 |
+
self.byte_encoder = bytes_to_unicode()
|
37 |
+
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
|
38 |
+
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
39 |
+
self.cache = {}
|
40 |
+
if tokenize is None:
|
41 |
+
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?[^\s\w]+|\s+(?!\S)|\s+""", re.UNICODE)
|
42 |
+
self.tokenize = lambda text: re.findall(self.pat, text)
|
43 |
+
else:
|
44 |
+
self.tokenize = tokenize
|
45 |
+
|
46 |
+
def bpe(self, token):
|
47 |
+
if token in self.cache:
|
48 |
+
return self.cache[token]
|
49 |
+
word = tuple(token)
|
50 |
+
pairs = get_pairs(word)
|
51 |
+
if not pairs:
|
52 |
+
return token
|
53 |
+
while True:
|
54 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
55 |
+
if bigram not in self.bpe_ranks:
|
56 |
+
break
|
57 |
+
first, second = bigram
|
58 |
+
new_word = []
|
59 |
+
i = 0
|
60 |
+
while i < len(word):
|
61 |
+
try:
|
62 |
+
j = word.index(first, i)
|
63 |
+
new_word.extend(word[i:j])
|
64 |
+
i = j
|
65 |
+
except ValueError:
|
66 |
+
new_word.extend(word[i:])
|
67 |
+
break
|
68 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
69 |
+
new_word.append(first+second)
|
70 |
+
i += 2
|
71 |
+
else:
|
72 |
+
new_word.append(word[i])
|
73 |
+
i += 1
|
74 |
+
new_word = tuple(new_word)
|
75 |
+
word = new_word
|
76 |
+
if len(word) == 1:
|
77 |
+
break
|
78 |
+
else:
|
79 |
+
pairs = get_pairs(word)
|
80 |
+
word = ' '.join(word)
|
81 |
+
self.cache[token] = word
|
82 |
+
return word
|
83 |
+
|
84 |
+
def encode(self, text):
|
85 |
+
bpe_tokens = []
|
86 |
+
normalized_text = unicodedata.normalize('NFKC', text)
|
87 |
+
normalized_text = ''.join(c for c in normalized_text if c.isascii() and c != '\t')
|
88 |
+
normalized_text = ''.join(c for c in normalized_text if not unicodedata.category(c).startswith('C'))
|
89 |
+
for token in self.tokenize(normalized_text):
|
90 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8', errors='ignore'))
|
91 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
92 |
+
return bpe_tokens
|
93 |
+
|
94 |
+
def decode(self, tokens):
|
95 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
96 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
|
97 |
+
decoded_text = text.replace(" .", ".").replace(" ,", ",").replace(" '", "'").replace(" ?", "?").replace(" !", "!").replace(" :", ":").replace('\n', '<br>')
|
98 |
+
sentences = nltk.sent_tokenize(decoded_text)
|
99 |
+
return ' '.join(sentences).replace("<br>", "<br>\n")
|
100 |
+
|
101 |
+
def get_encoder_gpt2():
|
102 |
+
encoder_path = os.path.join(GPT2_FOLDER, ENCODER_FILE)
|
103 |
+
vocab_path = os.path.join(GPT2_FOLDER, VOCAB_FILE)
|
104 |
+
if not os.path.exists(GPT2_FOLDER):
|
105 |
+
os.makedirs(GPT2_FOLDER)
|
106 |
+
if not os.path.exists(encoder_path):
|
107 |
+
wget.download(ENCODER_URL, out=encoder_path)
|
108 |
+
if not os.path.exists(vocab_path):
|
109 |
+
wget.download(VOCAB_URL, out=vocab_path)
|
110 |
+
|
111 |
+
with open(encoder_path, 'r') as f:
|
112 |
+
encoder = json.load(f)
|
113 |
+
with open(vocab_path, 'r', encoding="utf-8") as f:
|
114 |
+
bpe_data = f.read()
|
115 |
+
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
|
116 |
+
encoder_obj = Encoder(encoder=encoder, bpe_merges=bpe_merges)
|
117 |
+
encoder_obj.encoder[END_OF_TEXT_TOKEN] = len(encoder_obj.encoder)
|
118 |
+
encoder_obj.decoder[len(encoder_obj.decoder)] = END_OF_TEXT_TOKEN
|
119 |
+
return encoder_obj
|
120 |
+
|
121 |
+
def get_codegen_tokenizer_pure(vocab_file, merges_file):
|
122 |
+
vocab = json.load(open(vocab_file))
|
123 |
+
merges = open(merges_file, 'r', encoding="utf-8").read().split('\n')[1:-1]
|
124 |
+
bpe_merges = [tuple(m.split()) for m in merges]
|
125 |
+
byte_encoder = bytes_to_unicode()
|
126 |
+
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
127 |
+
tokenizer_regex = re.compile(r'''<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+''')
|
128 |
+
tokenize = lambda text: re.findall(tokenizer_regex, text)
|
129 |
+
encoder_obj = Encoder(
|
130 |
+
encoder=vocab,
|
131 |
+
bpe_merges=bpe_merges,
|
132 |
+
byte_encoder=byte_encoder,
|
133 |
+
byte_decoder=byte_decoder,
|
134 |
+
tokenize=tokenize
|
135 |
+
)
|
136 |
+
return encoder_obj
|
137 |
+
|
138 |
+
def codegen_tokenize(text, tokenizer):
|
139 |
+
return tokenizer.encode(text)
|
140 |
+
|
141 |
+
def codegen_decode(tokens, tokenizer):
|
142 |
+
return tokenizer.decode(tokens)
|