Hjgugugjhuhjggg commited on
Commit
1c817fd
·
verified ·
1 Parent(s): bb7e407

Upload 27 files

Browse files
Files changed (27) hide show
  1. Dockerfile +20 -0
  2. README.md +14 -12
  3. api.py +444 -0
  4. background_tasks.py +197 -0
  5. codegen_api.py +23 -0
  6. configs.py +206 -0
  7. constants.py +449 -0
  8. extensions.py +252 -0
  9. image_to_3d_api.py +32 -0
  10. imagegen_api.py +33 -0
  11. main.py +118 -0
  12. model_loader.py +674 -0
  13. models.py +96 -0
  14. musicgen_api.py +35 -0
  15. requirements.txt +40 -0
  16. sadtalker_api.py +202 -0
  17. sadtalker_utils.py +866 -0
  18. sentiment_api.py +27 -0
  19. stt_api.py +36 -0
  20. summarization_api.py +29 -0
  21. text_generation.py +152 -0
  22. text_to_video_api.py +37 -0
  23. tokenxxx.py +161 -0
  24. translation_api.py +27 -0
  25. tts_api.py +23 -0
  26. utils.py +190 -0
  27. xxx.py +142 -0
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim-buster
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive
4
+ ENV NUMBA_DISABLE_CACHE=1
5
+ WORKDIR /app
6
+
7
+ RUN apt-get update && apt-get upgrade -y
8
+ RUN apt-get install libgl1-mesa-glx ffmpeg -y
9
+
10
+ RUN mkdir -p /.cache/huggingface/hub && chmod -R 777 /.cache/huggingface/hub
11
+ RUN mkdir -p /.config/matplotlib && chmod -R 777 /.config/matplotlib
12
+ RUN mkdir -p /nltk_data && chmod -R 777 /nltk_data
13
+
14
+ RUN pip install --no-cache-dir accelerate retry asyncio basicsr beautifulsoup4 bs4 opencv-python deep-translator duckduckgo-search fastapi flask flask-cors facexlib ffmpeg-python gfpgan imageio imageio-ffmpeg langdetect librosa nltk numpy Pillow pydub pytorch-lightning PyYAML retry safetensors scikit-learn scipy scikit-image soundfile torch torchaudio torchvision tqdm wget yacs numba
15
+
16
+ COPY . .
17
+
18
+ EXPOSE 7860
19
+
20
+ CMD ["python", "main.py"]
README.md CHANGED
@@ -1,12 +1,14 @@
1
- ---
2
- title: Hhhh
3
- emoji:
4
- colorFrom: indigo
5
- colorTo: red
6
- sdk: docker
7
- pinned: false
8
- license: apache-2.0
9
- short_description: Apache2
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
+ ---
2
+ title: Ggggggc
3
+ emoji: 📈
4
+ colorFrom: yellow
5
+ colorTo: indigo
6
+ sdk: docker
7
+ sdk_version: 5.18.0
8
+ app_file: main.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: Apache
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
api.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from main import *
2
+ from tts_api import *
3
+ from stt_api import *
4
+ from sentiment_api import *
5
+ from imagegen_api import *
6
+ from musicgen_api import *
7
+ from translation_api import *
8
+ from codegen_api import *
9
+ from text_to_video_api import *
10
+ from summarization_api import *
11
+ from image_to_3d_api import *
12
+ from flask import Flask, request, jsonify, Response, send_file, stream_with_context
13
+ from flask_cors import CORS
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import torchaudio
18
+ import numpy as np
19
+ from PIL import Image
20
+ import io
21
+ import tempfile
22
+ import queue
23
+ import json
24
+ import base64
25
+
26
+ app = Flask(__name__)
27
+ CORS(app)
28
+ html_code = """<!DOCTYPE html>
29
+ <html lang="en">
30
+ <head>
31
+ <meta charset="UTF-8">
32
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
33
+ <title>AI Text Generation</title>
34
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/animate.css/4.1.1/animate.min.css"/>
35
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" integrity="sha512-9usAa10IRO0HhonpyAIVpjrylPvoDwiPUiKdWk5t3PyolY1cOd4DSE0Ga+ri4AuTroPR5aQvXU9xC6qOPnzFeg==" crossorigin="anonymous" referrerpolicy="no-referrer" />
36
+ <script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
37
+ <style>
38
+ body {
39
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
40
+ background: #f0f0f0;
41
+ color: #333;
42
+ margin: 0;
43
+ padding: 0;
44
+ display: flex;
45
+ flex-direction: column;
46
+ align-items: center;
47
+ min-height: 100vh;
48
+ }
49
+ .container {
50
+ width: 95%;
51
+ max-width: 900px;
52
+ padding: 20px;
53
+ background-color: #fff;
54
+ box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
55
+ border-radius: 8px;
56
+ margin-top: 20px;
57
+ margin-bottom: 20px;
58
+ display: flex;
59
+ flex-direction: column;
60
+ }
61
+ .header {
62
+ text-align: center;
63
+ margin-bottom: 20px;
64
+ }
65
+ .header h1 {
66
+ font-size: 2em;
67
+ color: #333;
68
+ }
69
+ .form-group {
70
+ margin-bottom: 15px;
71
+ }
72
+ .form-group textarea {
73
+ width: 100%;
74
+ padding: 10px;
75
+ border: 1px solid #ccc;
76
+ border-radius: 5px;
77
+ font-size: 16px;
78
+ box-sizing: border-box;
79
+ resize: vertical;
80
+ }
81
+ button {
82
+ padding: 10px 15px;
83
+ border: none;
84
+ border-radius: 5px;
85
+ background-color: #007bff;
86
+ color: white;
87
+ font-size: 18px;
88
+ cursor: pointer;
89
+ transition: background-color 0.3s ease;
90
+ }
91
+ button:hover {
92
+ background-color: #0056b3;
93
+ }
94
+ #output {
95
+ margin-top: 20px;
96
+ padding: 15px;
97
+ border: 1px solid #ddd;
98
+ border-radius: 5px;
99
+ background-color: #f9f9f9;
100
+ white-space: pre-wrap;
101
+ word-break: break-word;
102
+ overflow-y: auto;
103
+ max-height: 100vh;
104
+ }
105
+ #output strong {
106
+ font-weight: bold;
107
+ }
108
+ .animated-text {
109
+ position: fixed;
110
+ top: 20px;
111
+ left: 20px;
112
+ font-size: 1.5em;
113
+ color: rgba(0, 0, 0, 0.1);
114
+ pointer-events: none;
115
+ z-index: -1;
116
+ }
117
+ @media (max-width: 768px) {
118
+ .container {
119
+ width: 98%;
120
+ margin-top: 10px;
121
+ margin-bottom: 10px;
122
+ padding: 15px;
123
+ }
124
+ .header h1 {
125
+ font-size: 1.8em;
126
+ }
127
+ .form-group textarea, .form-group input[type="text"] {
128
+ font-size: 14px;
129
+ padding: 8px;
130
+ }
131
+ button {
132
+ font-size: 16px;
133
+ padding: 8px 12px;
134
+ }
135
+ #output {
136
+ font-size: 14px;
137
+ padding: 10px;
138
+ margin-top: 15px;
139
+ }
140
+ }
141
+ </style>
142
+ </head>
143
+ <body>
144
+ <div class="animated-text animate__animated animate__fadeIn animate__infinite infinite">AI POWERED</div>
145
+ <div class="container">
146
+ <div class="header animate__animated animate__fadeInDown">
147
+ </div>
148
+ <div class="form-group animate__animated animate__fadeInLeft">
149
+ <textarea id="text" rows="5" placeholder="Enter text"></textarea>
150
+ </div>
151
+ <button onclick="generateText()" class="animate__animated animate__fadeInUp">Generate Reasoning</button>
152
+ <div id="output" class="animate__animated">
153
+ <strong>Response:</strong><br>
154
+ <span id="generatedText"></span>
155
+ </div>
156
+ </div>
157
+ <script>
158
+ let eventSource = null;
159
+ let accumulatedText = "";
160
+ let lastResponse = "";
161
+ async function generateText() {
162
+ const inputText = document.getElementById("text").value;
163
+ document.getElementById("generatedText").innerText = "";
164
+ accumulatedText = "";
165
+ if (eventSource) {
166
+ eventSource.close();
167
+ }
168
+ const temp = 0.7;
169
+ const top_k_val = 40;
170
+ const top_p_val = 0.0;
171
+ const repetition_penalty_val = 1.2;
172
+ const requestData = {
173
+ text: inputText,
174
+ temp: temp,
175
+ top_k: top_k_val,
176
+ top_p: top_p_val,
177
+ reppenalty: repetition_penalty_val
178
+ };
179
+ const params = new URLSearchParams(requestData).toString();
180
+ eventSource = new EventSource('/api/v1/generate_stream?' + params);
181
+ eventSource.onmessage = function(event) {
182
+ if (event.data === "<END_STREAM>") {
183
+ eventSource.close();
184
+ const currentResponse = accumulatedText.replace("<|endoftext|>", "").replace(/\s+(?=[.,,。])/g, '').trim();
185
+ if (currentResponse === lastResponse.trim()) {
186
+ accumulatedText = "**Response is repetitive. Please try again or rephrase your query.**";
187
+ } else {
188
+ lastResponse = currentResponse;
189
+ }
190
+ document.getElementById("generatedText").innerHTML = marked.parse(accumulatedText);
191
+ return;
192
+ }
193
+ accumulatedText += event.data;
194
+ let partialText = accumulatedText.replace("<|endoftext|>", "").replace(/\s+(?=[.,,。])/g, '').trim();
195
+ document.getElementById("generatedText").innerHTML = marked.parse(partialText);
196
+ };
197
+ eventSource.onerror = function(error) {
198
+ console.error("SSE error", error);
199
+ eventSource.close();
200
+ };
201
+ const outputDiv = document.getElementById("output");
202
+ outputDiv.classList.add("show");
203
+ }
204
+ function base64ToBlob(base64Data, contentType) {
205
+ contentType = contentType || '';
206
+ const sliceSize = 1024;
207
+ const byteCharacters = atob(base64Data);
208
+ const bytesLength = byteCharacters.length;
209
+ const slicesCount = Math.ceil(bytesLength / sliceSize);
210
+ const byteArrays = new Array(slicesCount);
211
+ for (let sliceIndex = 0; sliceIndex < slicesCount; ++sliceIndex) {
212
+ const begin = sliceIndex * sliceSize;
213
+ const end = Math.min(begin + sliceSize, bytesLength);
214
+ const bytes = new Array(end - begin);
215
+ for (let offset = begin, i = 0; offset < end; ++i, ++offset) {
216
+ bytes[i] = byteCharacters[offset].charCodeAt(0);
217
+ }
218
+ byteArrays[sliceIndex] = new Uint8Array(bytes);
219
+ }
220
+ return new Blob(byteArrays, { type: contentType });
221
+ }
222
+ </script>
223
+ </body>
224
+ </html>
225
+ """
226
+ feedback_queue = queue.Queue()
227
+
228
+ class TextGenerationModel(nn.Module):
229
+ def __init__(self, vocab_size, embed_dim, hidden_dim):
230
+ super(TextGenerationModel, self).__init__()
231
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
232
+ self.rnn = nn.GRU(embed_dim, hidden_dim, batch_first=True)
233
+ self.fc = nn.Linear(hidden_dim, vocab_size)
234
+ def forward(self, x, hidden=None):
235
+ x = self.embedding(x)
236
+ out, hidden = self.rnn(x, hidden)
237
+ out = self.fc(out)
238
+ return out, hidden
239
+
240
+ vocab = ["hola", "mundo", "este", "es", "un", "ejemplo", "de", "texto", "generado", "con", "torch"]
241
+ vocab_size = len(vocab)
242
+ embed_dim = 16
243
+ hidden_dim = 32
244
+ text_model = TextGenerationModel(vocab_size, embed_dim, hidden_dim)
245
+ text_model.eval()
246
+
247
+ def tokenize(text):
248
+ tokens = text.lower().split()
249
+ indices = [vocab.index(token) if token in vocab else 0 for token in tokens]
250
+ return torch.tensor(indices, dtype=torch.long).unsqueeze(0)
251
+
252
+ def perform_reasoning_stream(text, temperature, top_k, top_p, repetition_penalty):
253
+ input_tensor = tokenize(text)
254
+ hidden = None
255
+ for _ in range(20):
256
+ outputs, hidden = text_model(input_tensor, hidden)
257
+ logits = outputs[:, -1, :] / temperature
258
+ probs = F.softmax(logits, dim=-1)
259
+ topk_probs, topk_indices = torch.topk(probs, min(top_k, logits.shape[-1]))
260
+ chosen_index = topk_indices[0, torch.multinomial(topk_probs[0], 1).item()].item()
261
+ token_str = vocab[chosen_index]
262
+ yield token_str
263
+ input_tensor = torch.cat([input_tensor, torch.tensor([[chosen_index]], dtype=torch.long)], dim=1)
264
+ yield "<END_STREAM>"
265
+
266
+ class SentimentModel(nn.Module):
267
+ def __init__(self, input_dim, hidden_dim, output_dim):
268
+ super(SentimentModel, self).__init__()
269
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
270
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
271
+ def forward(self, x):
272
+ x = F.relu(self.fc1(x))
273
+ x = self.fc2(x)
274
+ return x
275
+
276
+ sentiment_model = SentimentModel(10, 16, 2)
277
+ sentiment_model.eval()
278
+
279
+ @app.route("/")
280
+ def index():
281
+ return html_code
282
+
283
+ @app.route("/api/v1/generate_stream", methods=["GET"])
284
+ def generate_stream():
285
+ text = request.args.get("text", "")
286
+ temp = float(request.args.get("temp", 0.7))
287
+ top_k = int(request.args.get("top_k", 40))
288
+ top_p = float(request.args.get("top_p", 0.0))
289
+ reppenalty = float(request.args.get("reppenalty", 1.2))
290
+ @stream_with_context
291
+ def event_stream():
292
+ try:
293
+ for token in perform_reasoning_stream(text, temperature=temp, top_k=top_k, top_p=top_p, repetition_penalty=reppenalty):
294
+ if token == "<END_STREAM>":
295
+ yield "data: <END_STREAM>\n\n"
296
+ break
297
+ yield "data: " + token + "\n\n"
298
+ except Exception as e:
299
+ yield "data: <ERROR>\n\n"
300
+ return Response(event_stream(), mimetype="text/event-stream")
301
+
302
+ @app.route("/api/v1/generate", methods=["POST"])
303
+ def generate():
304
+ data = request.get_json()
305
+ text = data.get("text", "")
306
+ temp = float(data.get("temp", 0.7))
307
+ top_k = int(data.get("top_k", 40))
308
+ top_p = float(data.get("top_p", 0.0))
309
+ reppenalty = float(data.get("reppenalty", 1.2))
310
+ result = ""
311
+ try:
312
+ for token in perform_reasoning_stream(text, temperature=temp, top_k=top_k, top_p=top_p, repetition_penalty=reppenalty):
313
+ if token == "<END_STREAM>":
314
+ break
315
+ result += token + " "
316
+ except Exception as e:
317
+ return jsonify({"error": str(e)}), 500
318
+ return jsonify({"solidity": result.strip()})
319
+
320
+ @app.route("/api/v1/feedback", methods=["POST"])
321
+ def feedback():
322
+ data = request.get_json()
323
+ feedback_text = data.get("feedback_text")
324
+ correct_category = data.get("correct_category")
325
+ if feedback_text and correct_category:
326
+ feedback_queue.put((feedback_text, correct_category))
327
+ return jsonify({"status": "feedback received"})
328
+ return jsonify({"status": "feedback failed"}), 400
329
+
330
+ @app.route("/api/v1/tts", methods=["POST"])
331
+ def tts_api():
332
+ data = request.get_json()
333
+ text = data.get("text", "")
334
+ sr = 22050
335
+ duration = 3.0
336
+ t = torch.linspace(0, duration, int(sr * duration))
337
+ frequency = 440.0
338
+ audio = 0.5 * torch.sin(2 * torch.pi * frequency * t)
339
+ audio = audio.unsqueeze(0)
340
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
341
+ torchaudio.save(tmp.name, audio, sr)
342
+ tmp_path = tmp.name
343
+ return send_file(tmp_path, mimetype="audio/wav", as_attachment=True, download_name="output.wav")
344
+
345
+ @app.route("/api/v1/stt", methods=["POST"])
346
+ def stt_api():
347
+ data = request.get_json()
348
+ audio_b64 = data.get("audio", "")
349
+ if audio_b64:
350
+ audio_bytes = base64.b64decode(audio_b64)
351
+ buf = io.BytesIO(audio_bytes)
352
+ waveform, sr = torchaudio.load(buf)
353
+ mean_amp = waveform.abs().mean().item()
354
+ recognized_text = f"Audio processed with mean amplitude {mean_amp:.3f}"
355
+ return jsonify({"text": recognized_text})
356
+ return jsonify({"text": ""})
357
+
358
+ @app.route("/api/v1/sentiment", methods=["POST"])
359
+ def sentiment_api():
360
+ data = request.get_json()
361
+ text = data.get("text", "")
362
+ if not text:
363
+ return jsonify({"sentiment": "neutral"})
364
+ ascii_vals = [ord(c) for c in text[:10]]
365
+ while len(ascii_vals) < 10:
366
+ ascii_vals.append(0)
367
+ features = torch.tensor(ascii_vals, dtype=torch.float32).unsqueeze(0)
368
+ output = sentiment_model(features)
369
+ sentiment_idx = torch.argmax(output, dim=1).item()
370
+ sentiment = "positivo" if sentiment_idx == 1 else "negativo"
371
+ return jsonify({"sentiment": sentiment})
372
+
373
+ @app.route("/api/v1/imagegen", methods=["POST"])
374
+ def imagegen_api():
375
+ data = request.get_json()
376
+ prompt = data.get("prompt", "")
377
+ image_tensor = torch.rand(3, 256, 256)
378
+ np_image = image_tensor.mul(255).clamp(0, 255).byte().numpy().transpose(1, 2, 0)
379
+ img = Image.fromarray(np_image)
380
+ buf = io.BytesIO()
381
+ img.save(buf, format="PNG")
382
+ buf.seek(0)
383
+ return send_file(buf, mimetype="image/png", as_attachment=True, download_name="image.png")
384
+
385
+ @app.route("/api/v1/musicgen", methods=["POST"])
386
+ def musicgen_api():
387
+ data = request.get_json()
388
+ prompt = data.get("prompt", "")
389
+ sr = 22050
390
+ duration = 5.0
391
+ t = torch.linspace(0, duration, int(sr * duration))
392
+ frequency = 440.0
393
+ audio = 0.5 * torch.sin(2 * torch.pi * frequency * t)
394
+ audio = audio.unsqueeze(0)
395
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
396
+ torchaudio.save(tmp.name, audio, sr)
397
+ tmp_path = tmp.name
398
+ return send_file(tmp_path, mimetype="audio/wav", as_attachment=True, download_name="music.wav")
399
+
400
+ @app.route("/api/v1/translation", methods=["POST"])
401
+ def translation_api():
402
+ data = request.get_json()
403
+ text = data.get("text", "")
404
+ translated = " ".join(text.split()[::-1])
405
+ return jsonify({"translated_text": translated})
406
+
407
+ @app.route("/api/v1/codegen", methods=["POST"])
408
+ def codegen_api():
409
+ data = request.get_json()
410
+ prompt = data.get("prompt", "")
411
+ generated_code = f"# Generated code based on prompt: {prompt}\nprint('Hello from Torch-generated code')"
412
+ return jsonify({"code": generated_code})
413
+
414
+ @app.route("/api/v1/text_to_video", methods=["POST"])
415
+ def text_to_video_api():
416
+ data = request.get_json()
417
+ prompt = data.get("prompt", "")
418
+ video_tensor = torch.randint(0, 255, (10, 3, 64, 64), dtype=torch.uint8)
419
+ video_bytes = video_tensor.numpy().tobytes()
420
+ buf = io.BytesIO(video_bytes)
421
+ return send_file(buf, mimetype="video/mp4", as_attachment=True, download_name="video.mp4")
422
+
423
+ @app.route("/api/v1/summarization", methods=["POST"])
424
+ def summarization_api():
425
+ data = request.get_json()
426
+ text = data.get("text", "")
427
+ sentences = text.split('.')
428
+ summary = sentences[0] if sentences[0] else text
429
+ return jsonify({"summary": summary})
430
+
431
+ @app.route("/api/v1/image_to_3d", methods=["POST"])
432
+ def image_to_3d_api():
433
+ data = request.get_json()
434
+ prompt = data.get("prompt", "")
435
+ obj_data = "o Cube\nv 0 0 0\nv 1 0 0\nv 1 1 0\nv 0 1 0\nf 1 2 3 4"
436
+ buf = io.BytesIO(obj_data.encode("utf-8"))
437
+ return send_file(buf, mimetype="text/plain", as_attachment=True, download_name="model.obj")
438
+
439
+ @app.route("/api/v1/sadtalker", methods=["GET"])
440
+ def sadtalker():
441
+ return jsonify({"message": "Respuesta de sadtalker"})
442
+
443
+ if __name__ == "__main__":
444
+ app.run(host="0.0.0.0", port=7860)
background_tasks.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import threading
3
+ import queue
4
+ import uuid
5
+ import unicodedata
6
+ import re
7
+ from deep_translator import GoogleTranslator
8
+ from duckduckgo_search import DDGS
9
+ import nltk
10
+ import torch
11
+ import torch.nn as nn
12
+ import math
13
+
14
+ nltk.download('punkt')
15
+
16
+ categories = ['News', 'Sports', 'Entertainment']
17
+ TEXT_GENERATION_RATE = 10
18
+ text_queue = queue.Queue()
19
+ reasoning_queue = queue.Queue()
20
+ feedback_queue = queue.Queue()
21
+ vocabulary = ["<PAD>", "<EOS>"]
22
+ word_to_index = {word: idx for idx, word in enumerate(vocabulary)}
23
+ seen_responses = set()
24
+ news_clf = None
25
+
26
+ class SimpleClassifier(nn.Module):
27
+ def __init__(self, vocab_size, num_classes, embedding_dim=128):
28
+ super(SimpleClassifier, self).__init__()
29
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
30
+ self.fc = nn.Linear(embedding_dim, num_classes)
31
+ def forward(self, x):
32
+ embedded = self.embedding(x)
33
+ pooled = embedded.mean(dim=1)
34
+ out = self.fc(pooled)
35
+ return out
36
+
37
+ def tokenize_text(text):
38
+ return nltk.word_tokenize(text)
39
+
40
+ def update_vocabulary(tokens):
41
+ global vocabulary, word_to_index
42
+ for token in tokens:
43
+ if token not in word_to_index:
44
+ word_to_index[token] = len(vocabulary)
45
+ vocabulary.append(token)
46
+
47
+ def text_to_vector(text):
48
+ tokens = tokenize_text(text)
49
+ update_vocabulary(tokens)
50
+ indices = [word_to_index.get(token, 0) for token in tokens]
51
+ return torch.tensor(indices, dtype=torch.long)
52
+
53
+ def generate_and_queue_text(language):
54
+ global categories, text_queue
55
+ num_categories = len(categories)
56
+ num_texts_per_category = TEXT_GENERATION_RATE // (2 * num_categories)
57
+ while True:
58
+ for category in categories:
59
+ for _ in range(num_texts_per_category):
60
+ uid = uuid.uuid4()
61
+ base_text = f"Category: {category}. ID:{uid}"
62
+ try:
63
+ translator = GoogleTranslator(source='auto', target=language)
64
+ text = translator.translate(base_text)
65
+ except Exception:
66
+ text = base_text
67
+ processed_text = ''.join(c for c in unicodedata.normalize('NFKC', text) if c.isprintable())
68
+ text_queue.put((processed_text, category))
69
+ time.sleep(0)
70
+
71
+ def background_training():
72
+ global categories, news_clf, feedback_queue, vocabulary
73
+ if categories is None:
74
+ categories = ['DefaultCategory']
75
+ num_classes = len(categories)
76
+ learning_rate = 0.01
77
+ epochs = 1
78
+ if news_clf is None:
79
+ news_clf = SimpleClassifier(len(vocabulary), num_classes)
80
+ optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate)
81
+ criterion = nn.CrossEntropyLoss()
82
+ while True:
83
+ try:
84
+ feedback_item = feedback_queue.get(timeout=10)
85
+ if feedback_item:
86
+ input_text, generated_text = feedback_item
87
+ input_vector = text_to_vector(input_text)
88
+ if len(vocabulary) == 0:
89
+ vocabulary.extend(["<PAD>", "<EOS>"])
90
+ news_clf = SimpleClassifier(len(vocabulary), num_classes)
91
+ optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate)
92
+ if input_vector.size(0) != len(vocabulary) and len(vocabulary) > 0:
93
+ news_clf = SimpleClassifier(len(vocabulary), num_classes)
94
+ optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate)
95
+ input_vector = text_to_vector(input_text)
96
+ tokens = tokenize_text(input_text)
97
+ update_vocabulary(tokens)
98
+ tokens_indices = [word_to_index.get(word, 0) for word in tokens]
99
+ input_tensor = torch.tensor([tokens_indices], dtype=torch.long)
100
+ target_index = categories.index(generated_text) if generated_text in categories else 0
101
+ target_category_index = torch.tensor([target_index], dtype=torch.long)
102
+ if num_classes <= 1:
103
+ num_classes = 2
104
+ news_clf.fc = nn.Linear(128, num_classes)
105
+ for _ in range(epochs):
106
+ optimizer.zero_grad()
107
+ output = news_clf(input_tensor)
108
+ loss = criterion(output, target_category_index)
109
+ loss.backward()
110
+ optimizer.step()
111
+ feedback_queue.task_done()
112
+ except queue.Empty:
113
+ pass
114
+ except Exception:
115
+ time.sleep(5)
116
+
117
+ class ReasoningModel(nn.Module):
118
+ def __init__(self, vocab_size, embed_dim=128, hidden_dim=128):
119
+ super(ReasoningModel, self).__init__()
120
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
121
+ self.rnn = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
122
+ self.fc = nn.Linear(hidden_dim, vocab_size)
123
+ def forward(self, x, hidden=None):
124
+ emb = self.embedding(x)
125
+ output, hidden = self.rnn(emb, hidden)
126
+ logits = self.fc(output)
127
+ return logits, hidden
128
+ def generate(self, input_seq, max_length=50, temperature=1.0):
129
+ self.eval()
130
+ tokens = input_seq.copy()
131
+ hidden = None
132
+ generated = []
133
+ for _ in range(max_length):
134
+ input_tensor = torch.tensor([tokens], dtype=torch.long)
135
+ logits, hidden = self.forward(input_tensor, hidden)
136
+ next_token_logits = logits[0, -1, :] / temperature
137
+ probabilities = torch.softmax(next_token_logits, dim=0)
138
+ next_token = torch.multinomial(probabilities, 1).item()
139
+ tokens.append(next_token)
140
+ generated.append(next_token)
141
+ if next_token == word_to_index.get("<EOS>"):
142
+ break
143
+ return generated
144
+
145
+ reasoning_model = ReasoningModel(len(vocabulary))
146
+
147
+ def perform_reasoning_stream(text_input, temperature=0.7, top_k=40, top_p=0.0, repetition_penalty=1.2):
148
+ tokens = tokenize_text(text_input)
149
+ update_vocabulary(tokens)
150
+ tokens_indices = [word_to_index.get(token, 0) for token in tokens]
151
+ generated_indices = reasoning_model.generate(tokens_indices, max_length=50, temperature=temperature)
152
+ for idx in generated_indices:
153
+ yield vocabulary[idx] + " "
154
+ yield "<END_STREAM>"
155
+
156
+ def background_reasoning_queue():
157
+ global reasoning_queue, seen_responses
158
+ while True:
159
+ try:
160
+ item = reasoning_queue.get(timeout=1)
161
+ if item is None:
162
+ reasoning_queue.task_done()
163
+ continue
164
+ text_input = item.get('text_input')
165
+ temperature = item.get('temperature', 0.7)
166
+ top_k = item.get('top_k', 40)
167
+ top_p = item.get('top_p', 0.0)
168
+ repetition_penalty = item.get('repetition_penalty', 1.2)
169
+ resp_queue = item.get('response_queue', queue.Queue())
170
+ if not text_input:
171
+ resp_queue.put({"error": "Empty text input received."})
172
+ reasoning_queue.task_done()
173
+ continue
174
+ generated_text_stream = perform_reasoning_stream(text_input, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty)
175
+ full_response = ""
176
+ for chunk in generated_text_stream:
177
+ if chunk == "<END_STREAM>":
178
+ break
179
+ full_response += chunk
180
+ cleaned_response = re.sub(r'\s+(?=[.,,。])', '', full_response.replace("<|endoftext|>", "")).strip()
181
+ if cleaned_response in seen_responses:
182
+ final_response = "**Response is repetitive. Please try again or rephrase your query.**"
183
+ resp_queue.put({"text": final_response})
184
+ else:
185
+ seen_responses.add(cleaned_response)
186
+ final_response = cleaned_response
187
+ resp_queue.put({"text": final_response})
188
+ reasoning_queue.task_done()
189
+ except queue.Empty:
190
+ pass
191
+ except Exception as e:
192
+ try:
193
+ resp_queue.put({"error": str(e)})
194
+ except Exception:
195
+ pass
196
+ if reasoning_queue and not reasoning_queue.empty():
197
+ reasoning_queue.task_done()
codegen_api.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import jsonify, send_file, request
2
+ from main import *
3
+ #from main import import codegen_model, codegen_tokenizer, device
4
+
5
+ def generate_code(prompt, output_path="output_code.py"):
6
+ if codegen_model is None:
7
+ return "Code generation model not initialized."
8
+ input_ids = codegen_tokenizer.encode(prompt, return_tensors='pt').to(device)
9
+ output = codegen_model.generate(input_ids, max_length=512, temperature=0.7, top_p=0.9)
10
+ code = codegen_tokenizer.decode(output[0], skip_special_tokens=True)
11
+ with open(output_path, "w") as file:
12
+ file.write(code)
13
+ return output_path
14
+
15
+ def codegen_api():
16
+ data = request.get_json()
17
+ prompt = data.get('prompt')
18
+ if not prompt:
19
+ return jsonify({"error": "Prompt is required"}), 400
20
+ output_file = generate_code(prompt)
21
+ if output_file == "Code generation model not initialized.":
22
+ return jsonify({"error": "Code generation failed"}), 500
23
+ return send_file(output_file, mimetype="text/x-python", as_attachment=True, download_name="output.py")
configs.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from constants import *
2
+
3
+ class GPT2Config:
4
+ def __init__(self, vocab_size_or_config_json_file=50257, n_positions=MAX_LENGTH, n_ctx=MAX_LENGTH, n_embd=768, n_layer=12, n_head=12, layer_norm_epsilon=1e-05, initializer_range=0.02):
5
+ self.vocab_size = vocab_size_or_config_json_file
6
+ self.n_ctx = n_ctx
7
+ self.n_positions = n_positions
8
+ self.n_embd = n_embd
9
+ self.n_layer = n_layer
10
+ self.n_head = n_head
11
+ self.layer_norm_epsilon = layer_norm_epsilon
12
+ self.initializer_range = initializer_range
13
+
14
+ @classmethod
15
+ def from_dict(cls, config_dict):
16
+ return cls(**config_dict)
17
+
18
+ class MBartConfig:
19
+ def __init__(self, vocab_size, d_model, num_layers, num_heads, pad_token_id, eos_token_id):
20
+ self.vocab_size = vocab_size
21
+ self.d_model = d_model
22
+ self.encoder_layers = num_layers
23
+ self.decoder_layers = num_layers
24
+ self.encoder_attention_heads = num_heads
25
+ self.decoder_attention_heads = num_heads
26
+ self.encoder_ffn_dim = d_model * 4
27
+ self.decoder_ffn_dim = d_model * 4
28
+ self.dropout = 0.1
29
+ self.attention_dropout = 0.0
30
+ self.activation_dropout = 0.0
31
+ self.max_position_embeddings = 1024
32
+ self.init_std = 0.02
33
+ self.layer_norm_eps = 1e-5
34
+ self.pad_token_id = pad_token_id
35
+ self.eos_token_id = eos_token_id
36
+ self.bos_token_id = 0
37
+ self.decoder_start_token_id = 2
38
+ self.output_past = True
39
+ self.scale_embedding = True
40
+ self.use_cache = True
41
+ self.num_hidden_layers = num_layers
42
+
43
+ class CodeGenConfig:
44
+ def __init__(self, vocab_size, n_embd, n_layer, n_head):
45
+ self.vocab_size = vocab_size
46
+ self.n_embd = n_embd
47
+ self.n_layer = n_layer
48
+ self.n_head = n_head
49
+ self.n_positions = 2048
50
+ self.resid_pdrop = 0.1
51
+ self.embd_pdrop = 0.1
52
+ self.attn_pdrop = 0.1
53
+ self.activation_function = "gelu_new"
54
+ self.n_ctx = 2048
55
+ self.pad_token_id = 50256
56
+ self.eos_token_id = 50256
57
+ self.initializer_range = 0.02
58
+
59
+ class SummarizationConfig:
60
+ def __init__(self):
61
+ self.vocab_size = 10000
62
+ self.embedding_dim = 256
63
+ self.hidden_dim = 512
64
+ self.num_layers = 2
65
+ self.max_seq_len = 512
66
+
67
+ class Clip4ClipConfig:
68
+ def __init__(self, vocab_size=30522, hidden_size=512, num_hidden_layers=6, num_attention_heads=8, intermediate_size=2048, hidden_act="gelu", hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, bos_token_id=1, eos_token_id=2, **kwargs):
69
+ self.vocab_size = vocab_size
70
+ self.hidden_size = hidden_size
71
+ self.num_hidden_layers = num_hidden_layers
72
+ self.num_attention_heads = num_attention_heads
73
+ self.intermediate_size = intermediate_size
74
+ self.hidden_act = hidden_act
75
+ self.hidden_dropout_prob = hidden_dropout_prob
76
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
77
+ self.max_position_embeddings = max_position_embeddings
78
+ self.type_vocab_size = type_vocab_size
79
+ self.initializer_range = initializer_range
80
+ self.layer_norm_eps = layer_norm_eps
81
+ self.pad_token_id = pad_token_id
82
+ self.bos_token_id = bos_token_id
83
+ self.eos_token_id = eos_token_id
84
+ self.all_head_size = self.num_attention_heads * self.hidden_size
85
+ self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
86
+ for key, value in kwargs.items():
87
+ setattr(self, key, value)
88
+
89
+ @classmethod
90
+ def from_dict(cls, config_dict):
91
+ return cls(**config_dict)
92
+
93
+ class MusicGenConfig:
94
+ def __init__(self, vocab_size=2048, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, layer_norm_eps=1e-05, initializer_range=0.02, pad_token_id=0, bos_token_id=1, eos_token_id=2, n_positions=2048, n_ctx=2048, **kwargs):
95
+ self.vocab_size = vocab_size
96
+ self.hidden_size = hidden_size
97
+ self.num_hidden_layers = num_hidden_layers
98
+ self.num_attention_heads = num_attention_heads
99
+ self.intermediate_size = intermediate_size
100
+ self.hidden_act = hidden_act
101
+ self.hidden_dropout_prob = hidden_dropout_prob
102
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
103
+ self.layer_norm_eps = layer_norm_eps
104
+ self.initializer_range = initializer_range
105
+ self.pad_token_id = pad_token_id
106
+ self.bos_token_id = bos_token_id
107
+ self.eos_token_id = eos_token_id
108
+ self.n_positions = n_positions
109
+ self.n_ctx = n_ctx
110
+ self.all_head_size = self.num_attention_heads * self.hidden_size
111
+ for key, value in kwargs.items():
112
+ setattr(self, key, value)
113
+
114
+ @classmethod
115
+ def from_dict(cls, config_dict):
116
+ return cls(**config_dict)
117
+
118
+ class BartConfig:
119
+ def __init__(self, vocab_size=50265, max_position_embeddings=1024, encoder_layers=12, encoder_ffn_dim=4096, encoder_attention_heads=16, decoder_layers=12, decoder_ffn_dim=4096, decoder_attention_heads=16, encoder_layerdrop=0.0, decoder_layerdrop=0.0, activation_function="gelu", d_model=1024, dropout=0.1, attention_dropout=0.0, activation_dropout=0.0, init_std=0.02, classifier_dropout=0.0, num_labels=3, pad_token_id=1, bos_token_id=0, eos_token_id=2, layer_norm_eps=1e-05, num_beams=4, early_stopping=True, max_length=100, min_length=30, scale_embedding=False, **kwargs):
120
+ self.vocab_size = vocab_size
121
+ self.max_position_embeddings = max_position_embeddings
122
+ self.encoder_layers = encoder_layers
123
+ self.encoder_ffn_dim = encoder_ffn_dim
124
+ self.encoder_attention_heads = encoder_attention_heads
125
+ self.decoder_layers = decoder_layers
126
+ self.decoder_ffn_dim = decoder_ffn_dim
127
+ self.decoder_attention_heads = decoder_attention_heads
128
+ self.encoder_layerdrop = encoder_layerdrop
129
+ self.decoder_layerdrop = decoder_layerdrop
130
+ self.activation_function = activation_function
131
+ self.d_model = d_model
132
+ self.dropout = dropout
133
+ self.attention_dropout = attention_dropout
134
+ self.activation_dropout = activation_dropout
135
+ self.init_std = init_std
136
+ self.classifier_dropout = classifier_dropout
137
+ self.num_labels = num_labels
138
+ self.pad_token_id = pad_token_id
139
+ self.bos_token_id = bos_token_id
140
+ self.eos_token_id = eos_token_id
141
+ self.layer_norm_eps = layer_norm_eps
142
+ self.num_beams = num_beams
143
+ self.early_stopping = True
144
+ self.max_length = max_length
145
+ self.min_length = min_length
146
+ self.scale_embedding = False
147
+ for key, value in kwargs.items():
148
+ setattr(self, key, value)
149
+
150
+ @classmethod
151
+ def from_dict(cls, config_dict):
152
+ return cls(**config_dict)
153
+
154
+ class OpenLRMConfig:
155
+ def __init__(self, obj_dim=1024, hidden_dim=512, num_layers=6, num_heads=8, dropout_prob=0.1, **kwargs):
156
+ self.obj_dim = obj_dim
157
+ self.hidden_dim = hidden_dim
158
+ self.num_layers = num_layers
159
+ self.num_heads = num_heads
160
+ self.dropout_prob = dropout_prob
161
+ self.all_head_size = self.num_heads * self.hidden_dim
162
+ for key, value in kwargs.items():
163
+ setattr(self, key, value)
164
+
165
+ @classmethod
166
+ def from_dict(cls, config_dict):
167
+ return cls(**config_dict)
168
+
169
+ class UNet2DConditionModelConfig:
170
+ def __init__(self, sample_size=64, layers_per_block=2, block_out_channels=[320, 640, 1280, 1280], downsample=[2, 2, 2, 2], upsample=[2, 2, 2, 2], cross_attention_dim=768, act_fn="silu", norm_num_groups=32, num_attention_heads=8, in_channels=4, out_channels=4, attention_head_dim=64, **kwargs):
171
+ self.sample_size = sample_size
172
+ self.layers_per_block = layers_per_block
173
+ self.block_out_channels = block_out_channels
174
+ self.downsample = downsample
175
+ self.upsample = upsample
176
+ self.cross_attention_dim = cross_attention_dim
177
+ self.act_fn = act_fn
178
+ self.norm_num_groups = norm_num_groups
179
+ self.num_attention_heads = num_attention_heads
180
+ self.in_channels = in_channels
181
+ self.out_channels = out_channels
182
+ self.attention_head_dim = attention_head_dim
183
+ for key, value in kwargs.items():
184
+ setattr(self, key, value)
185
+
186
+ @classmethod
187
+ def from_dict(cls, config_dict):
188
+ return cls(**config_dict)
189
+
190
+ class AutoencoderKLConfig:
191
+ def __init__(self, **kwargs):
192
+ self.sample_size = 64
193
+ self.latent_channels = 4
194
+ self.layers_per_block = 2
195
+ self.block_out_channels = [128, 256, 512, 512]
196
+ self.downsample = [2, 2, 2, 2]
197
+ self.upsample = [2, 2, 2, 2]
198
+ self.act_fn = "silu"
199
+ self.norm_num_groups = 32
200
+ self.num_channels_every_n_layers = 2
201
+ for key, value in kwargs.items():
202
+ setattr(self, key, value)
203
+
204
+ @classmethod
205
+ def from_dict(cls, config_dict):
206
+ return cls(**config_dict)
constants.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ TEXT_GENERATION_RATE = 40000
4
+ MAX_LENGTH = 2048
5
+ MAX_XDD = 5
6
+ END_OF_TEXT_TOKEN = "<|endoftext|>"
7
+ SYSTEM_PROMPT = """Eres un asistente experto con habilidades avanzadas en diversas áreas. Responde de manera amigable, educada y razonada. Siempre piensa cuidadosamente antes de responder para asegurar la claridad y completitud. Posees la capacidad de autoaprendizaje continuo y recuerdas interacciones pasadas para mejorar tus respuestas y evitar errores repetidos."""
8
+ XML_COT_FORMAT = """<reasoning>\n{reasoning}\n</reasoning>\n<answer>\n{answer}\n</answer>\n"""
9
+
10
+ html_code = """<!DOCTYPE html>
11
+ <html lang="en">
12
+ <head>
13
+ <meta charset="UTF-8">
14
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
15
+ <title>AI Text Generation</title>
16
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/animate.css/4.1.1/animate.min.css"/>
17
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" integrity="sha512-9usAa10IRO0HhonpyAIVpjrylPvoDwiPUiKdWk5t3PyolY1cOd4DSE0Ga+ri4AuTroPR5aQvXU9xC6qOPnzFeg==" crossorigin="anonymous" referrerpolicy="no-referrer" />
18
+ <script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
19
+ <style>
20
+ body {
21
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
22
+ background: #f0f0f0;
23
+ color: #333;
24
+ margin: 0;
25
+ padding: 0;
26
+ display: flex;
27
+ flex-direction: column;
28
+ align-items: center;
29
+ min-height: 100vh;
30
+ }
31
+ .container {
32
+ width: 95%;
33
+ max-width: 900px;
34
+ padding: 20px;
35
+ background-color: #fff;
36
+ box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
37
+ border-radius: 8px;
38
+ margin-top: 20px;
39
+ margin-bottom: 20px;
40
+ display: flex;
41
+ flex-direction: column;
42
+ }
43
+ .header {
44
+ text-align: center;
45
+ margin-bottom: 20px;
46
+ }
47
+ .header h1 {
48
+ font-size: 2em;
49
+ color: #333;
50
+ }
51
+ .form-group {
52
+ margin-bottom: 15px;
53
+ }
54
+ .form-group textarea {
55
+ width: 100%;
56
+ padding: 10px;
57
+ border: 1px solid #ccc;
58
+ border-radius: 5px;
59
+ font-size: 16px;
60
+ box-sizing: border-box;
61
+ resize: vertical;
62
+ }
63
+ button {
64
+ padding: 10px 15px;
65
+ border: none;
66
+ border-radius: 5px;
67
+ background-color: #007bff;
68
+ color: white;
69
+ font-size: 18px;
70
+ cursor: pointer;
71
+ transition: background-color 0.3s ease;
72
+ }
73
+ button:hover {
74
+ background-color: #0056b3;
75
+ }
76
+ #output {
77
+ margin-top: 20px;
78
+ padding: 15px;
79
+ border: 1px solid #ddd;
80
+ border-radius: 5px;
81
+ background-color: #f9f9f9;
82
+ white-space: pre-wrap;
83
+ word-break: break-word;
84
+ overflow-y: auto;
85
+ max-height: 100vh;
86
+ }
87
+ #output strong {
88
+ font-weight: bold;
89
+ }
90
+ .animated-text {
91
+ position: fixed;
92
+ top: 20px;
93
+ left: 20px;
94
+ font-size: 1.5em;
95
+ color: rgba(0, 0, 0, 0.1);
96
+ pointer-events: none;
97
+ z-index: -1;
98
+ }
99
+ @media (max-width: 768px) {
100
+ .container {
101
+ width: 98%;
102
+ margin-top: 10px;
103
+ margin-bottom: 10px;
104
+ padding: 15px;
105
+ }
106
+ .header h1 {
107
+ font-size: 1.8em;
108
+ }
109
+ .form-group textarea, .form-group input[type="text"] {
110
+ font-size: 14px;
111
+ padding: 8px;
112
+ }
113
+ button {
114
+ font-size: 16px;
115
+ padding: 8px 12px;
116
+ }
117
+ #output {
118
+ font-size: 14px;
119
+ padding: 10px;
120
+ margin-top: 15px;
121
+ }
122
+ }
123
+ </style>
124
+ </head>
125
+ <body>
126
+ <div class="animated-text animate__animated animate__fadeIn animate__infinite infinite">AI POWERED</div>
127
+ <div class="container">
128
+ <div class="header animate__animated animate__fadeInDown">
129
+ </div>
130
+ <div class="form-group animate__animated animate__fadeInLeft">
131
+ <textarea id="text" rows="5" placeholder="Enter text"></textarea>
132
+ </div>
133
+ <button onclick="generateText()" class="animate__animated animate__fadeInUp">Generate Reasoning</button>
134
+ <div id="output" class="animate__animated">
135
+ <strong >Response:</strong><br>
136
+ <span id="generatedText"></span>
137
+ </div>
138
+ </div>
139
+ <script>
140
+ let eventSource = null;
141
+ let accumulatedText = "";
142
+ let lastResponse = "";
143
+ async function generateText() {
144
+ const inputText = document.getElementById("text").value;
145
+ document.getElementById("generatedText").innerText = "";
146
+ accumulatedText = "";
147
+ if (eventSource) {
148
+ eventSource.close();
149
+ }
150
+ const temp = 0.7;
151
+ const top_k_val = 40;
152
+ const top_p_val = 0.0;
153
+ const repetition_penalty_val = 1.2;
154
+ const requestData = {
155
+ text: inputText,
156
+ temp: temp,
157
+ top_k: top_k_val,
158
+ top_p: top_p_val,
159
+ reppenalty: repetition_penalty_val
160
+ };
161
+ eventSource = new EventSource('/generate_stream', {
162
+ headers: {
163
+ 'Content-Type': 'application/json'
164
+ },
165
+ method: 'POST',
166
+ body: JSON.stringify(requestData)
167
+ });
168
+ eventSource.onmessage = function(event) {
169
+ if (event.data === "<END_STREAM>") {
170
+ eventSource.close();
171
+ const currentResponse = accumulatedText.replace("<|endoftext|>", "").replace(re.compile(r'\\s+(?=[.,,。])'), '').trim();
172
+ if (currentResponse === lastResponse.trim()) {
173
+ accumulatedText = "**Response is repetitive. Please try again or rephrase your query.**";
174
+ } else {
175
+ lastResponse = currentResponse;
176
+ }
177
+ document.getElementById("generatedText").innerHTML = marked.parse(accumulatedText);
178
+ return;
179
+ }
180
+ accumulatedText += event.data;
181
+ let partialText = accumulatedText.replace("<|endoftext|>", "").replace(re.compile(r'\\s+(?=[.,,。])'), '').trim();
182
+ document.getElementById("generatedText").innerHTML = marked.parse(partialText);
183
+ };
184
+ eventSource.onerror = function(error) {
185
+ console.error("SSE error", error);
186
+ eventSource.close();
187
+ };
188
+ const outputDiv = document.getElementById("output");
189
+ outputDiv.classList.add("show");
190
+ }
191
+ function base64ToBlob(base64Data, contentType) {
192
+ contentType = contentType || '';
193
+ const sliceSize = 1024;
194
+ const byteCharacters = atob(base64Data);
195
+ const bytesLength = byteCharacters.length;
196
+ const slicesCount = Math.ceil(bytesLength / sliceSize);
197
+ const byteArrays = new Array(slicesCount);
198
+ for (let sliceIndex = 0; sliceIndex < slicesCount; ++sliceIndex) {
199
+ const begin = sliceIndex * sliceSize;
200
+ const end = Math.min(begin + sliceSize, bytesLength);
201
+ const bytes = new Array(end - begin);
202
+ for (let offset = begin, i = 0; offset < end; ++i, ++offset) {
203
+ bytes[i] = byteCharacters[offset].charCodeAt(0);
204
+ }
205
+ byteArrays[sliceIndex] = new Uint8Array(bytes);
206
+ }
207
+ return new Blob(byteArrays, { type: contentType });
208
+ }
209
+ </script>
210
+ </body>
211
+ </html>
212
+ """
213
+
214
+ HTML_CODE = html_code
215
+
216
+ # =============================================================================
217
+ # Constantes definidas por el usuario
218
+ # =============================================================================
219
+
220
+ # GPT-2
221
+ GPT2_FOLDER = "./GPT2"
222
+ MODEL_FILE = "gpt2-pytorch_model.bin"
223
+ ENCODER_FILE = "encoder.json"
224
+ VOCAB_FILE = "vocab.bpe"
225
+ CONFIG_FILE = "config.json"
226
+ GPT2CONFHG = "https://huggingface.co/openai-community/gpt2/resolve/main/config.json"
227
+ MODEL_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"
228
+ ENCODER_URL = "https://raw.githubusercontent.com/graykode/gpt-2-Pytorch/refs/heads/master/GPT2/encoder.json"
229
+ VOCAB_URL = "https://raw.githubusercontent.com/graykode/gpt-2-Pytorch/refs/heads/master/GPT2/vocab.bpe"
230
+
231
+ # Traducción (MBart)
232
+ TRANSLATION_FOLDER = "./TranslationModel"
233
+ TRANSLATION_MODEL_WEIGHTS_FILE = "pytorch_model.bin"
234
+ TRANSLATION_MODEL_CONFIG_FILE = "config.json"
235
+ TRANSLATION_MODEL_VOCAB_FILE = "sentencepiece.bpe.model"
236
+ TRANSLATION_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/pytorch_model.bin"
237
+ TRANSLATION_MODEL_CONFIG_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/config.json"
238
+ TRANSLATION_MODEL_VOCAB_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/sentencepiece.bpe.model"
239
+ TRANSLATION_MODEL_FILES_URLS = [
240
+ (TRANSLATION_MODEL_WEIGHTS_URL, TRANSLATION_MODEL_WEIGHTS_FILE),
241
+ (TRANSLATION_MODEL_CONFIG_URL, TRANSLATION_MODEL_CONFIG_FILE),
242
+ (TRANSLATION_MODEL_VOCAB_URL, TRANSLATION_MODEL_VOCAB_FILE),
243
+ ]
244
+
245
+ # CodeGen
246
+ CODEGEN_FOLDER = "./CodeGenModel"
247
+ CODEGEN_MODEL_NAME = "codegen-350M-multi"
248
+ CODEGEN_MODEL_WEIGHTS = "pytorch_model.bin"
249
+ CODEGEN_CONFIG = "config.json"
250
+ CODEGEN_VOCAB = "vocab.json"
251
+ CODEGEN_MERGES = "merges.txt"
252
+ CODEGEN_MODEL_WEIGHTS_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/pytorch_model.bin"
253
+ CODEGEN_CONFIG_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/config.json"
254
+ CODEGEN_VOCAB_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/vocab.json"
255
+ CODEGEN_MERGES_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/merges.txt"
256
+ CODEGEN_FILES_URLS = [
257
+ (CODEGEN_MODEL_WEIGHTS_URL, CODEGEN_MODEL_WEIGHTS),
258
+ (CODEGEN_CONFIG_URL, CODEGEN_CONFIG),
259
+ (CODEGEN_VOCAB_URL, CODEGEN_VOCAB),
260
+ (CODEGEN_MERGES_URL, CODEGEN_MERGES),
261
+ ]
262
+
263
+ # MusicGen
264
+ MUSICGEN_FOLDER = "./MusicGenModel"
265
+ MUSICGEN_MODEL_NAME = "melody"
266
+ MUSICGEN_MODEL_WEIGHTS = "pytorch_model.bin"
267
+ MUSICGEN_CONFIG = "config.json"
268
+ MUSICGEN_SAMPLE_RATE = 32000
269
+ MUSICGEN_DURATION = 8
270
+ MUSICGEN_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/musicgen-small/resolve/main/pytorch_model.bin"
271
+ MUSICGEN_CONFIG_URL = "https://huggingface.co/facebook/musicgen-small/resolve/main/config.json"
272
+ MUSICGEN_FILES_URLS = [
273
+ (MUSICGEN_MODEL_WEIGHTS_URL, MUSICGEN_MODEL_WEIGHTS),
274
+ (MUSICGEN_CONFIG_URL, MUSICGEN_CONFIG)
275
+ ]
276
+
277
+ # Summarization (Bart)
278
+ SUMMARIZATION_FOLDER = "./SummarizationModel"
279
+ SUMMARIZATION_MODEL_WEIGHTS = "pytorch_model.bin"
280
+ SUMMARIZATION_CONFIG = "config.json"
281
+ SUMMARIZATION_VOCAB = "vocab.json"
282
+ SUMMARIZATION_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/bart-large-cnn/resolve/main/pytorch_model.bin"
283
+ SUMMARIZATION_CONFIG_URL = "https://huggingface.co/facebook/bart-large-cnn/resolve/main/config.json"
284
+ SUMMARIZATION_VOCAB_URL = "https://huggingface.co/facebook/bart-large-cnn/resolve/main/vocab.json"
285
+ SUMMARIZATION_FILES_URLS = [
286
+ (SUMMARIZATION_MODEL_WEIGHTS_URL, SUMMARIZATION_MODEL_WEIGHTS),
287
+ (SUMMARIZATION_CONFIG_URL, SUMMARIZATION_CONFIG),
288
+ (SUMMARIZATION_VOCAB_URL, SUMMARIZATION_VOCAB)
289
+ ]
290
+
291
+ # TTS
292
+ TTS_FOLDER = "./TTSModel"
293
+ TTS_MODEL_NAME = "vits"
294
+ TTS_MODEL_CONFIG = "config.json"
295
+ TTS_MODEL_WEIGHTS = "pytorch_model.bin"
296
+ TTS_VOCAB = "vocab.json"
297
+ TTS_CONFIG_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/config.json"
298
+ TTS_MODEL_WEIGHTS_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/pytorch_model.bin"
299
+ TTS_VOCAB_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/vocab.json"
300
+ TTS_FILES_URLS = [
301
+ (TTS_CONFIG_URL, TTS_MODEL_CONFIG),
302
+ (TTS_MODEL_WEIGHTS_URL, TTS_MODEL_WEIGHTS),
303
+ (TTS_VOCAB_URL, TTS_VOCAB)
304
+ ]
305
+
306
+ # STT
307
+ STT_FOLDER = "./STTModel"
308
+ STT_MODEL_NAME = "wav2vec2"
309
+ STT_MODEL_WEIGHTS = "pytorch_model.bin"
310
+ STT_CONFIG = "config.json"
311
+ STT_VOCAB = "vocab.json"
312
+ STT_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/pytorch_model.bin"
313
+ STT_CONFIG_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json"
314
+ STT_VOCAB_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/vocab.json"
315
+ STT_FILES_URLS = [
316
+ (STT_MODEL_WEIGHTS_URL, STT_MODEL_WEIGHTS),
317
+ (STT_CONFIG_URL, STT_CONFIG),
318
+ (STT_VOCAB_URL, STT_VOCAB)
319
+ ]
320
+
321
+ # Sentiment Analysis
322
+ SENTIMENT_FOLDER = "./SentimentModel"
323
+ SENTIMENT_MODEL_WEIGHTS = "pytorch_model.bin"
324
+ SENTIMENT_VOCAB = "vocab.json"
325
+ SENTIMENT_CONFIG_FILE = "config.json"
326
+ SENTIMENT_MODEL_WEIGHTS_URL = "https://huggingface.co/climatebert/distilroberta-base-climate-sentiment/resolve/main/pytorch_model.bin"
327
+ SENTIMENT_VOCAB_URL = "https://huggingface.co/climatebert/distilroberta-base-climate-sentiment/resolve/main/vocab.json"
328
+ SENTIMENT_CONFIG_URL = "https://huggingface.co/climatebert/distilroberta-base-climate-sentiment/resolve/main/config.json"
329
+ SENTIMENT_FILES_URLS = [
330
+ (SENTIMENT_MODEL_WEIGHTS_URL, SENTIMENT_MODEL_WEIGHTS),
331
+ (SENTIMENT_VOCAB_URL, SENTIMENT_VOCAB),
332
+ (SENTIMENT_CONFIG_URL, SENTIMENT_CONFIG_FILE)
333
+ ]
334
+
335
+ # Image Generation (VAE)
336
+ IMAGEGEN_FOLDER = "./ImageGenModel"
337
+ IMAGEGEN_MODEL_WEIGHTS = "diffusion_pytorch_model.bin"
338
+ IMAGEGEN_CONFIG = "config.json"
339
+ IMAGEGEN_MODEL_WEIGHTS_URL = "https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/diffusion_pytorch_model.bin"
340
+ IMAGEGEN_CONFIG_URL = "https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/config.json"
341
+ IMAGEGEN_FILES_URLS = [
342
+ (IMAGEGEN_MODEL_WEIGHTS_URL, IMAGEGEN_MODEL_WEIGHTS),
343
+ (IMAGEGEN_CONFIG_URL, IMAGEGEN_CONFIG)
344
+ ]
345
+
346
+ # Image to 3D
347
+ IMAGE_TO_3D_FOLDER = "./ImageTo3DModel"
348
+ IMAGE_TO_3D_MODEL_WEIGHTS = "pytorch_model.bin"
349
+ IMAGE_TO_3D_CONFIG = "config.json"
350
+ IMAGE_TO_3D_MODEL_WEIGHTS_URL = "https://huggingface.co/zxhezexin/openlrm-obj-base-1.1/resolve/main/pytorch_model.bin"
351
+ IMAGE_TO_3D_CONFIG_URL = "https://huggingface.co/zxhezexin/openlrm-obj-base-1.1/resolve/main/config.json"
352
+ IMAGE_TO_3D_FILES_URLS = [
353
+ (IMAGE_TO_3D_MODEL_WEIGHTS_URL, IMAGE_TO_3D_MODEL_WEIGHTS),
354
+ (IMAGE_TO_3D_CONFIG_URL, IMAGE_TO_3D_CONFIG)
355
+ ]
356
+
357
+ # Text to Video
358
+ TEXT_TO_VIDEO_FOLDER = "./TextToVideoModel"
359
+ TEXT_TO_VIDEO_MODEL_WEIGHTS = "diffusion_pytorch_model.bin" # Usado para ambos (Unet y VAE)
360
+ TEXT_TO_VIDEOX_MODEL_WEIGHTS = "diffusion_pytorch_model.fp16.bin" # Usado para ambos (Unet y VAE)
361
+ TEXT_TO_VIDEO_CONFIG = "config.json" # Usado para ambos (Unet y VAE)
362
+ TEXT_TO_VIDEO_VOCAB = "vocab.json"
363
+ TEXT_TO_VIDEO_MODEL_WEIGHTS_URL_UNET = "https://huggingface.co/ali-vilab/text-to-video-ms-1.7b/resolve/main/unet/diffusion_pytorch_model.fp16.bin"
364
+ TEXT_TO_VIDEO_CONFIG_URL_UNET = "https://huggingface.co/ali-vilab/text-to-video-ms-1.7b/resolve/main/unet/config.json"
365
+ TEXT_TO_VIDEO_MODEL_WEIGHTS_URL_VAE = "https://huggingface.co/ali-vilab/text-to-video-ms-1.7b/resolve/main/vae/diffusion_pytorch_model.fp16.bin"
366
+ TEXT_TO_VIDEO_CONFIG_URL_VAE = "https://huggingface.co/ali-vilab/text-to-video-ms-1.7b/resolve/main/vae/config.json"
367
+ TEXT_TO_VIDEO_VOCAB_URL = "https://huggingface.co/ali-vilab/text-to-video-ms-1.7b/resolve/main/tokenizer/vocab.json"
368
+ TEXT_TO_VIDEO_FILES_URLS = [
369
+ (TEXT_TO_VIDEO_MODEL_WEIGHTS_URL_UNET, TEXT_TO_VIDEO_MODEL_WEIGHTS),
370
+ (TEXT_TO_VIDEO_MODEL_WEIGHTS_URL_UNET, TEXT_TO_VIDEOX_MODEL_WEIGHTS),
371
+ (TEXT_TO_VIDEO_CONFIG_URL_UNET, TEXT_TO_VIDEO_CONFIG),
372
+ (TEXT_TO_VIDEO_MODEL_WEIGHTS_URL_VAE, TEXT_TO_VIDEO_MODEL_WEIGHTS),
373
+ (TEXT_TO_VIDEO_MODEL_WEIGHTS_URL_VAE, TEXT_TO_VIDEOX_MODEL_WEIGHTS),
374
+ (TEXT_TO_VIDEO_CONFIG_URL_VAE, TEXT_TO_VIDEO_CONFIG),
375
+ (TEXT_TO_VIDEO_VOCAB_URL, TEXT_TO_VIDEO_VOCAB),
376
+ ]
377
+
378
+ # SadTalker
379
+ # ============================================================================
380
+ # Modelos de Restauración para SadTalker (Face Restoration / Super-Resolution)
381
+ # ============================================================================
382
+ # GFPGAN
383
+ GFPGAN_FOLDER = "./GFPGAN"
384
+ GFPGAN_MODEL_FILE = "GFPGANv1.4.pth"
385
+ GFPGAN_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
386
+
387
+ # RestoreFormer
388
+ RESTOREFORMER_FOLDER = "./RestoreFormer"
389
+ RESTOREFORMER_MODEL_FILE = "RestoreFormer.pth"
390
+ RESTOREFORMER_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth"
391
+
392
+ # CodeFormer
393
+ CODEFORMER_FOLDER = "./CodeFormer"
394
+ CODEFORMER_MODEL_FILE = "codeformer.pth"
395
+ CODEFORMER_URL = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
396
+
397
+ # RealESRGAN
398
+ REALESRGAN_FOLDER = "./RealESRGAN"
399
+ REALESRGAN_MODEL_FILE = "RealESRGAN_x2plus.pth"
400
+ REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
401
+
402
+
403
+
404
+ kp = "https://huggingface.co/usyd-community/vitpose-base-simple/resolve/main/model.safetensors"
405
+ kp_file = "kp_detector.safetensors"
406
+ aud = "https://huggingface.co/vinthony/SadTalker/resolve/main/auido2pose_00140-model.pth"
407
+ aud_file = "auido2pose_00140-model.pth"
408
+ wav = "https://huggingface.co/facebook/wav2vec2-base/resolve/main/pytorch_model.bin"
409
+ wav_file = "wav2vec2.bin"
410
+ gen = "https://huggingface.co/vinthony/SadTalker/resolve/main/wav2lip.pth"
411
+ gen_file = "generator.bin"
412
+ mapx = "https://huggingface.co/vinthony/SadTalker/resolve/main/mapping_00229-model.pth.tar"
413
+ mapx_file = "mapping.pth"
414
+ den = "https://huggingface.co/KwaiVGI/LivePortrait/resolve/main/liveportrait/base_models/motion_extractor.pth"
415
+ den_file = "dense_motion.pth"
416
+
417
+ # --- Define constants for new SadTalker models ---
418
+ SADTALKER_KP_FOLDER = "checkpoints"
419
+ SADTALKER_KP_MODEL_FILE = kp_file
420
+ SADTALKER_KP_URL = kp
421
+
422
+ SADTALKER_AUD_FOLDER = "checkpoints" # Assuming these go in the main checkpoints folder for SadTalker
423
+ SADTALKER_AUD_MODEL_FILE = aud_file
424
+ SADTALKER_AUD_URL = aud
425
+
426
+ SADTALKER_WAV_FOLDER = "checkpoints" # Assuming these go in the main checkpoints folder for SadTalker
427
+ SADTALKER_WAV_MODEL_FILE = wav_file
428
+ SADTALKER_WAV_URL = wav
429
+
430
+ SADTALKER_GEN_FOLDER = "checkpoints" # Assuming these go in the main checkpoints folder for SadTalker
431
+ SADTALKER_GEN_MODEL_FILE = gen_file
432
+ SADTALKER_GEN_URL = gen
433
+
434
+ SADTALKER_MAPX_FOLDER = "checkpoints" # Assuming these go in the main checkpoints folder for SadTalker
435
+ SADTALKER_MAPX_MODEL_FILE = mapx_file
436
+ SADTALKER_MAPX_URL = mapx
437
+
438
+ SADTALKER_DEN_FOLDER = "checkpoints" # Assuming these go in the main checkpoints folder for SadTalker
439
+ SADTALKER_DEN_MODEL_FILE = den_file
440
+ SADTALKER_DEN_URL = den
441
+
442
+
443
+
444
+
445
+ # =============================================================================
446
+ # SadTalker
447
+ # =============================================================================
448
+ SADTALKER_CHECKPOINTS_FOLDER = "./checkpoints"
449
+ SADTALKER_CONFIG_FOLDER = "./src/config"
extensions.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import yaml
5
+ from PIL import Image
6
+ from skimage import img_as_ubyte, transform
7
+ import safetensors
8
+ import librosa
9
+ from pydub import AudioSegment
10
+ import imageio
11
+ from scipy.io import loadmat, savemat, wavfile
12
+ import glob
13
+ import tempfile
14
+ from tqdm import tqdm
15
+ import numpy as np
16
+ import math
17
+ import torchvision
18
+ import os
19
+ import re
20
+ import shutil
21
+ from yacs.config import CfgNode as CN
22
+ import requests
23
+ import subprocess
24
+ import cv2
25
+ from collections import OrderedDict
26
+
27
+ def img2tensor(imgs, bgr2rgb=True, float32=True):
28
+ if isinstance(imgs, np.ndarray):
29
+ if imgs.ndim == 3:
30
+ imgs = imgs[..., np.newaxis]
31
+ imgs = torch.from_numpy(imgs.transpose((2, 0, 1)))
32
+ elif isinstance(imgs, Image.Image):
33
+ imgs = torch.from_numpy(np.array(imgs)).permute(2, 0, 1)
34
+ else:
35
+ raise TypeError(f'Type `{type(imgs)}` is not suitable for img2tensor')
36
+ if bgr2rgb:
37
+ if imgs.shape[0] == 3:
38
+ imgs = imgs[[2, 1, 0], :, :]
39
+ if float32:
40
+ imgs = imgs.float() / 255.
41
+ return imgs
42
+
43
+ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
44
+ if not isinstance(tensor, torch.Tensor):
45
+ raise TypeError(f'Input tensor should be torch.Tensor, but got {type(tensor)}')
46
+ tensor = tensor.float().cpu()
47
+ tensor = tensor.clamp_(*min_max)
48
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
49
+ output_img = tensor.mul(255).round()
50
+ output_img = np.transpose(output_img.numpy(), (1, 2, 0))
51
+ output_img = np.clip(output_img, 0, 255).astype(np.uint8)
52
+ if rgb2bgr:
53
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_RGB2BGR)
54
+ return output_img if out_type == np.uint8 else output_img.astype(out_type) / 255.
55
+
56
+ class RealESRGANer():
57
+ def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=0, half=False, device=None, gpu_id=None):
58
+ self.scale = scale
59
+ self.tile = tile
60
+ self.tile_pad = tile_pad
61
+ self.pre_pad = pre_pad
62
+ self.mod_scale = None
63
+ self.half = half
64
+ if device is None:
65
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
66
+ else:
67
+ self.device = device
68
+ if model is None:
69
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
70
+ if half:
71
+ model.half()
72
+ loadnet = torch.load(model_path, map_location=lambda storage, loc: storage)
73
+ if 'params' in loadnet:
74
+ model.load_state_dict(loadnet['params'], strict=True)
75
+ elif 'params_ema' in loadnet:
76
+ model.load_state_dict(loadnet['params_ema'], strict=True)
77
+ else:
78
+ model.load_state_dict(loadnet, strict=True)
79
+ model.eval()
80
+ self.model = model.to(self.device)
81
+
82
+ def enhance(self, img, outscale=None, tile=None, tile_pad=None, pre_pad=None, half=None):
83
+ h_input, w_input = img.shape[0:2]
84
+ if outscale is None:
85
+ outscale = self.scale
86
+ if tile is None:
87
+ tile = self.tile
88
+ if tile_pad is None:
89
+ tile_pad = self.tile_pad
90
+ if pre_pad is None:
91
+ pre_pad = self.pre_pad
92
+ if half is None:
93
+ half = self.half
94
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
95
+ img_tensor = img2tensor(img)
96
+ img_tensor = img_tensor.unsqueeze(0).to(self.device)
97
+ if half:
98
+ img_tensor = img_tensor.half()
99
+ mod_scale = self.mod_scale
100
+ h_pad, w_pad = 0, 0
101
+ if mod_scale is not None:
102
+ h_pad, w_pad = int(np.ceil(h_input / mod_scale) * mod_scale - h_input), int(np.ceil(w_input / mod_scale) * mod_scale - w_input)
103
+ img_tensor = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'reflect')
104
+ window_size = 256
105
+ scale = self.scale
106
+ overlap_ratio = 0.5
107
+ if w_input * h_input < window_size**2:
108
+ tile = None
109
+ if tile is not None and tile > 0:
110
+ tile_overlap = tile * overlap_ratio
111
+ sf = scale
112
+ stride_w = math.ceil(tile - tile_overlap)
113
+ stride_h = math.ceil(tile - tile_overlap)
114
+ numW = math.ceil((w_input + tile_overlap) / stride_w)
115
+ numH = math.ceil((h_input + tile_overlap) / stride_h)
116
+ paddingW = (numW - 1) * stride_w + tile - w_input
117
+ paddingH = (numH - 1) * stride_h + tile - h_input
118
+ padding_bottom = int(max(paddingH, 0))
119
+ padding_right = int(max(paddingW, 0))
120
+ padding_left, padding_top = 0, 0
121
+ img_tensor = F.pad(img_tensor, (padding_left, padding_right, padding_top, padding_bottom), mode='reflect')
122
+ output_h, output_w = padding_top + h_input * scale + padding_bottom, padding_left + w_input * scale + padding_right
123
+ output_tensor = torch.zeros([1, 3, output_h, output_w], dtype=img_tensor.dtype, device=self.device)
124
+ windows = []
125
+ for row in range(numH):
126
+ for col in range(numW):
127
+ start_x = col * stride_w
128
+ start_y = row * stride_h
129
+ end_x = min(start_x + tile, img_tensor.shape[3])
130
+ end_y = min(start_y + tile, img_tensor.shape[2])
131
+ windows.append(img_tensor[:, :, start_y:end_y, start_x:end_x])
132
+ results = []
133
+ batch_size = 8
134
+ for i in range(0, len(windows), batch_size):
135
+ batch_windows = torch.stack(windows[i:min(i + batch_size, len(windows))], dim=0)
136
+ with torch.no_grad():
137
+ results.append(self.model(batch_windows))
138
+ results = torch.cat(results, dim=0)
139
+ count = 0
140
+ for row in range(numH):
141
+ for col in range(numW):
142
+ start_x = col * stride_w
143
+ start_y = row * stride_h
144
+ end_x = min(start_x + tile, img_tensor.shape[3])
145
+ end_y = min(start_y + tile, img_tensor.shape[2])
146
+ out_start_x, out_start_y = start_x * sf, start_y * sf
147
+ out_end_x, out_end_y = end_x * sf, end_y * sf
148
+ output_tensor[:, :, out_start_y:out_end_y, out_start_x:out_end_x] += results[count][:, :, :end_y * sf - out_start_y, :end_x * sf - out_start_x]
149
+ count += 1
150
+ forward_img = output_tensor[:, :, :h_input * sf, :w_input * sf]
151
+ else:
152
+ with torch.no_grad():
153
+ forward_img = self.model(img_tensor)
154
+ if half:
155
+ forward_img = forward_img.float()
156
+ output_img = tensor2img(forward_img.squeeze(0).clamp_(0, 1))
157
+ if mod_scale is not None:
158
+ output_img = output_img[:h_input * self.scale, :w_input * self.scale, ...]
159
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB)
160
+ return [output_img, None]
161
+
162
+ def save_video_with_watermark(video_frames, audio_path, output_path, watermark_path='./assets/sadtalker_logo.png'):
163
+ try:
164
+ watermark = imageio.imread(watermark_path)
165
+ except FileNotFoundError:
166
+ watermark = None
167
+ writer = imageio.get_writer(output_path, fps=25)
168
+ try:
169
+ for frame in tqdm(video_frames, 'Generating video'):
170
+ if watermark is not None:
171
+ frame_h, frame_w = frame.shape[:2]
172
+ watermark_h, watermark_w = watermark.shape[:2]
173
+ if watermark_h > frame_h or watermark_w > frame_w:
174
+ watermark = transform.resize(watermark, (frame_h // 4, frame_w // 4))
175
+ watermark_h, watermark_w = watermark.shape[:2]
176
+ start_h = frame_h - watermark_h - 10
177
+ start_w = frame_w - watermark_w - 10
178
+ frame[start_h:start_h+watermark_h, start_w:start_w+watermark_w, :] = watermark
179
+ writer.append_data(img_as_ubyte(frame))
180
+ except Exception as e:
181
+ print(f"Error in video writing: {e}")
182
+ finally:
183
+ writer.close()
184
+ if audio_path is not None:
185
+ try:
186
+ command = "ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}".format(audio_path, output_path, output_path.replace('.mp4', '_with_audio.mp4'))
187
+ subprocess.call(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
188
+ os.remove(output_path)
189
+ os.rename(output_path.replace('.mp4', '_with_audio.mp4'), output_path)
190
+ except Exception as e:
191
+ print(f"Error adding audio to video: {e}")
192
+
193
+ def paste_pic(video_path, pic_path, crop_info, audio_path, output_path):
194
+ try:
195
+ y_start, y_end, x_start, x_end, old_size, cropped_size = crop_info[0][0], crop_info[0][1], crop_info[1][0], crop_info[1][1], crop_info[2], crop_info[3]
196
+ source_image_h, source_image_w = old_size
197
+ cropped_h, cropped_w = cropped_size
198
+ delta_h, delta_w = source_image_h - cropped_h, source_image_w - cropped_w
199
+ box = [x_start, y_start, source_image_w - x_end, source_image_h - y_end]
200
+ command = "ffmpeg -y -i {} -i {} -filter_complex \"[1]crop=w={}:h={}:x={}:y={},[s];[0][s]overlay=x={}:y={}\" -codec:a copy {}".format(video_path, pic_path, cropped_w, cropped_h, box[0], box[1], box[0], box[1], output_path)
201
+ subprocess.call(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
202
+ except Exception as e:
203
+ print(f"Error pasting picture to video: {e}")
204
+
205
+ def color_transfer_batch(source, target, mode='numpy'):
206
+ source_np = tensor2img(source)
207
+ target_np = tensor2img(target)
208
+ source_lab = cv2.cvtColor(source_np, cv2.COLOR_RGB2LAB).astype(np.float32)
209
+ target_lab = cv2.cvtColor(target_np, cv2.COLOR_RGB2LAB).astype(np.float32)
210
+ source_mu = np.mean(source_lab, axis=(0, 1), keepdims=True)
211
+ source_std = np.std(source_lab, axis=(0, 1), keepdims=True)
212
+ target_mu = np.mean(target_lab, axis=(0, 1), keepdims=True)
213
+ target_std = np.std(target_lab, axis=(0, 1), keepdims=True)
214
+ transfer_lab = (target_lab - target_mu) * (source_std / target_std) + source_mu
215
+ transfer_rgb = cv2.cvtColor(np.clip(transfer_lab, 0, 255).astype(np.uint8), cv2.COLOR_LAB2RGB)
216
+ transfer_rgb_tensor = img2tensor(transfer_rgb)
217
+ return transfer_rgb_tensor.unsqueeze(0).to(source.device)
218
+
219
+ def load_video_to_cv2(path, resize=None):
220
+ video = []
221
+ try:
222
+ cap = cv2.VideoCapture(path)
223
+ if not cap.isOpened():
224
+ raise Exception("Error opening video stream or file")
225
+ while(cap.isOpened()):
226
+ ret, frame = cap.read()
227
+ if ret:
228
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
229
+ if resize is not None:
230
+ frame_rgb = cv2.resize(frame_rgb, resize)
231
+ video.append(frame_rgb)
232
+ else:
233
+ break
234
+ cap.release()
235
+ except Exception as e:
236
+ print(f"Error loading video: {e}")
237
+ return video
238
+
239
+ def get_prior_from_bfm(bfm_path):
240
+ mat_path = os.path.join(bfm_path, 'BFM_prior.mat')
241
+ C = loadmat(mat_path)
242
+ pc_tex = torch.tensor(C['pc_tex'].astype(np.float32)).unsqueeze(0)
243
+ pc_exp = torch.tensor(C['pc_exp'].astype(np.float32)).unsqueeze(0)
244
+ u_tex = torch.tensor(C['u_tex'].astype(np.float32)).unsqueeze(0)
245
+ u_exp = torch.tensor(C['u_exp'].astype(np.float32)).unsqueeze(0)
246
+ prior_coeff = {
247
+ 'pc_tex': pc_tex,
248
+ 'pc_exp': pc_exp,
249
+ 'u_tex': u_tex,
250
+ 'u_exp': u_exp
251
+ }
252
+ return prior_coeff
image_to_3d_api.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ from flask import jsonify, send_file, request
4
+ from main import *
5
+ #from main import import image_to_3d_model, device
6
+ from PIL import Image
7
+ import torch
8
+ import numpy as np
9
+
10
+ def image_to_3d_func(image_path, output_path="output_3d.obj"):
11
+ if image_to_3d_model is None:
12
+ return "Image-to-3D model not initialized."
13
+ pil_image = Image.open(image_path).convert("RGB")
14
+ image = torch.tensor(np.array(pil_image)).float().permute(2,0,1).unsqueeze(0) / 255.0
15
+ image = image.to(device)
16
+ with torch.no_grad():
17
+ mesh_obj = image_to_3d_model(image)
18
+ with open(output_path, 'w') as f:
19
+ f.write(mesh_obj)
20
+ return output_path
21
+
22
+ def image_to_3d_api():
23
+ if 'image' not in request.files:
24
+ return jsonify({"error": "Image file is required"}), 400
25
+ image_file = request.files['image']
26
+ temp_image_path = f"temp_image_{uuid.uuid4()}.png"
27
+ image_file.save(temp_image_path)
28
+ output_file = image_to_3d_func(temp_image_path)
29
+ os.remove(temp_image_path)
30
+ if output_file == "Image-to-3D model not initialized.":
31
+ return jsonify({"error": "Image to 3D failed"}), 500
32
+ return send_file(output_file, mimetype="model/obj", as_attachment=True, download_name="output_3d.obj")
imagegen_api.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from flask import jsonify, send_file, request
3
+ from io import BytesIO
4
+ from PIL import Image
5
+ from main import *
6
+ #from main import import imagegen_model, device
7
+ import torch
8
+
9
+ def generate_image(prompt, output_path="output_image.png"):
10
+ if imagegen_model is None:
11
+ return "Image generation model not initialized."
12
+
13
+ generator = torch.Generator(device=device).manual_seed(0)
14
+ image = imagegen_model(
15
+ prompt,
16
+ generator=generator,
17
+ ).images[0]
18
+ image.save(output_path)
19
+ return output_path
20
+
21
+ def imagegen_api():
22
+ data = request.get_json()
23
+ prompt = data.get('prompt')
24
+ if not prompt:
25
+ return jsonify({"error": "Prompt is required"}), 400
26
+ output_file = generate_image(prompt)
27
+ if output_file == "Image generation model not initialized.":
28
+ return jsonify({"error": "Image generation failed"}), 500
29
+ image_io = BytesIO()
30
+ pil_image = Image.open(output_file)
31
+ pil_image.save(image_io, 'PNG')
32
+ image_io.seek(0)
33
+ return send_file(image_io, mimetype='image/png', as_attachment=True, download_name="output.png")
main.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import queue
3
+ import time
4
+ import os
5
+ import nltk
6
+ import re
7
+ import json
8
+ from flask import Flask
9
+ from flask_cors import CORS
10
+ from api import *
11
+ from extensions import *
12
+ from constants import *
13
+ from configs import *
14
+ from tokenxxx import *
15
+ from models import *
16
+ from model_loader import *
17
+ from utils import *
18
+ from background_tasks import generate_and_queue_text, background_training, background_reasoning_queue
19
+ from text_generation import *
20
+ from sadtalker_utils import *
21
+ import torch
22
+
23
+ state_dict = None
24
+ enc = None
25
+ config = None
26
+ model_gpt2 = None
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ news_clf = None
29
+ tfidf_vectorizer = None
30
+ text_queue = queue.Queue()
31
+ categories = None
32
+ background_threads = []
33
+ feedback_queue = queue.Queue()
34
+ reasoning_queue = queue.Queue()
35
+ seen_responses = set()
36
+ dialogue_history = []
37
+ vocabulary = set()
38
+ word_to_index = {}
39
+ index_to_word = []
40
+ translation_model = None
41
+ sp = None
42
+ codegen_model = None
43
+ codegen_tokenizer = None
44
+ codegen_vocabulary = None
45
+ codegen_index_to_word = None
46
+ codegen_word_to_index = None
47
+ summarization_model = None
48
+ summarization_vocabulary = set()
49
+ summarization_word_to_index = {}
50
+ summarization_index_to_word = []
51
+ sadtalker_instance = None
52
+ imagegen_model = None
53
+ image_to_3d_model = None
54
+ text_to_video_model = None
55
+ stream_type = "text"
56
+ sentiment_model = None
57
+ stt_model = None
58
+ tts_model = None
59
+ musicgen_model = None
60
+
61
+ def load_models():
62
+ global model_gpt2, enc, translation_model, codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index, summarization_model, imagegen_model, image_to_3d_model, text_to_video_model, sadtalker_instance, sentiment_model, stt_model, tts_model, musicgen_model, checkpoint_path, gfpgan_model_file, restoreformer_model_file, codeformer_model_file, realesrgan_model_file, kp_file, aud_file, wav_file, gen_file, mapx_file, den_file
63
+ model_gpt2, enc = initialize_gpt2_model(GPT2_FOLDER, {MODEL_FILE: MODEL_URL, ENCODER_FILE: ENCODER_URL, VOCAB_FILE: VOCAB_URL, CONFIG_FILE: GPT2CONFHG})
64
+ translation_model = initialize_translation_model(TRANSLATION_FOLDER, TRANSLATION_MODEL_FILES_URLS)
65
+ codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index = initialize_codegen_model(CODEGEN_FOLDER, CODEGEN_FILES_URLS)
66
+ summarization_model, _, _, _ = initialize_summarization_model(SUMMARIZATION_FOLDER, SUMMARIZATION_FILES_URLS)
67
+ imagegen_model = initialize_imagegen_model(IMAGEGEN_FOLDER, IMAGEGEN_FILES_URLS)
68
+ image_to_3d_model = initialize_image_to_3d_model(IMAGE_TO_3D_FOLDER, IMAGE_TO_3D_FILES_URLS)
69
+ text_to_video_model = initialize_text_to_video_model(TEXT_TO_VIDEO_FOLDER, TEXT_TO_VIDEO_FILES_URLS)
70
+ sentiment_model = initialize_sentiment_model(SENTIMENT_FOLDER, SENTIMENT_FILES_URLS)
71
+ stt_model = initialize_stt_model(STT_FOLDER, STT_FILES_URLS)
72
+ tts_model = initialize_tts_model(TTS_FOLDER, TTS_FILES_URLS)
73
+ musicgen_model = initialize_musicgen_model(MUSICGEN_FOLDER, MUSICGEN_FILES_URLS)
74
+
75
+ class SimpleClassifier(torch.nn.Module):
76
+ def __init__(self, vocab_size, num_classes):
77
+ super(SimpleClassifier, self).__init__()
78
+ self.embedding = torch.nn.Embedding(vocab_size, 128)
79
+ self.linear = torch.nn.Linear(128, num_classes)
80
+ def forward(self, x):
81
+ embedded = self.embedding(x)
82
+ pooled = torch.mean(embedded, dim=1)
83
+ return self.linear(pooled)
84
+
85
+ def tokenize_text(text):
86
+ global vocabulary, word_to_index, index_to_word
87
+ tokens = text.lower().split()
88
+ for token in tokens:
89
+ if token not in vocabulary:
90
+ vocabulary.add(token)
91
+ word_to_index[token] = len(index_to_word)
92
+ index_to_word.append(token)
93
+ return tokens
94
+
95
+ def text_to_vector(text):
96
+ global vocabulary, word_to_index
97
+ tokens = tokenize_text(text)
98
+ vector = torch.zeros(len(vocabulary))
99
+ for token in tokens:
100
+ if token in word_to_index:
101
+ vector[word_to_index[token]] += 1
102
+ return vector
103
+
104
+ if __name__ == "__main__":
105
+ nltk.download('punkt')
106
+ load_models()
107
+ categories = ['Category1', 'Category2', 'Category3', 'Category4', 'Category5']
108
+ import background_tasks
109
+ background_tasks.categories = categories
110
+ background_tasks.text_queue = text_queue
111
+ background_tasks.reasoning_queue = reasoning_queue
112
+ background_threads.append(threading.Thread(target=generate_and_queue_text, args=('en',), daemon=True))
113
+ background_threads.append(threading.Thread(target=generate_and_queue_text, args=('es',), daemon=True))
114
+ background_threads.append(threading.Thread(target=background_training, daemon=True))
115
+ background_threads.append(threading.Thread(target=background_reasoning_queue, daemon=True))
116
+ for thread in background_threads:
117
+ thread.start()
118
+ app.run(host='0.0.0.0', port=7860)
model_loader.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import urllib.request
4
+ import urllib.parse
5
+ import torch
6
+ import hashlib
7
+ from tqdm import tqdm
8
+ from skimage import img_as_ubyte
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+ import inspect
12
+
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ def filter_kwargs(cls, kwargs):
16
+ sig = inspect.signature(cls.__init__)
17
+ accepted = set(sig.parameters.keys()) - {"self"}
18
+ return {k: v for k, v in kwargs.items() if k in accepted}
19
+
20
+ def sanitize_filename(name, url=None):
21
+ for c in '<>:"/\\|?*':
22
+ name = name.replace(c, '')
23
+ if not name and url is not None:
24
+ name = hashlib.md5(url.encode()).hexdigest()
25
+ return name
26
+
27
+ def download_file(url, filepath):
28
+ d = os.path.dirname(filepath)
29
+ if d and not os.path.exists(d):
30
+ os.makedirs(d, exist_ok=True)
31
+ if not os.path.exists(filepath):
32
+ def prog(t):
33
+ last = [0]
34
+ def inner(n, bs, ts):
35
+ if ts > 0:
36
+ t.total = ts
37
+ t.update(n * bs - last[0])
38
+ last[0] = n * bs
39
+ return inner
40
+ with tqdm(unit='B', unit_scale=True, unit_divisor=1024, desc=os.path.basename(filepath)) as t:
41
+ urllib.request.urlretrieve(url, filepath, reporthook=prog(t))
42
+
43
+ def download_files(folder, files_spec):
44
+ if isinstance(files_spec, dict):
45
+ for fn, url in files_spec.items():
46
+ fn = sanitize_filename(fn, url)
47
+ fp = os.path.join(folder, fn)
48
+ download_file(url, fp)
49
+ elif isinstance(files_spec, list):
50
+ for item in files_spec:
51
+ if isinstance(item, str):
52
+ url = item
53
+ parsed = urllib.parse.urlparse(url)
54
+ fn = os.path.basename(parsed.path)
55
+ if not fn:
56
+ fn = hashlib.md5(url.encode()).hexdigest()
57
+ fn = sanitize_filename(fn, url)
58
+ elif isinstance(item, (list, tuple)) and len(item) == 2:
59
+ url, fn = item
60
+ fn = sanitize_filename(fn, url)
61
+ elif isinstance(item, dict) and "filename" in item and "url" in item:
62
+ fn = sanitize_filename(item["filename"], item["url"])
63
+ url = item["url"]
64
+ else:
65
+ raise ValueError("Invalid file specification")
66
+ fp = os.path.join(folder, fn)
67
+ download_file(url, fp)
68
+ else:
69
+ raise ValueError("files_spec must be dict or list")
70
+
71
+ def read_json(fp):
72
+ with open(fp, 'r', encoding='utf-8') as f:
73
+ return json.load(f)
74
+
75
+ def get_codegen_tokenizer(vocab_path, merges_path):
76
+ with open(vocab_path, 'r', encoding='utf-8') as f:
77
+ vocab = json.load(f)
78
+ with open(merges_path, 'r', encoding='utf-8') as f:
79
+ merges = f.read().splitlines()
80
+ def tokenizer(text):
81
+ toks = text.split()
82
+ return [vocab.get(t, 0) for t in toks]
83
+ return tokenizer
84
+
85
+ def simple_tokenizer(text, vocab, max_length=77):
86
+ toks = text.split()
87
+ ids = [vocab.get(t, 1) for t in toks]
88
+ if len(ids) < max_length:
89
+ ids = ids + [0]*(max_length - len(ids))
90
+ else:
91
+ ids = ids[:max_length]
92
+ return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(device)
93
+
94
+ def load_state_dict_safe(model, loaded_state_dict):
95
+ model_state = model.state_dict()
96
+ new_state = {}
97
+ for key, value in model_state.items():
98
+ if key in loaded_state_dict and loaded_state_dict[key].shape == value.shape:
99
+ new_state[key] = loaded_state_dict[key]
100
+ else:
101
+ new_state[key] = value
102
+ model.load_state_dict(new_state, strict=False)
103
+
104
+ class GPT2Config:
105
+ def __init__(self, vocab_size=50257, **kwargs):
106
+ self.vocab_size = vocab_size
107
+ self.__dict__.update(kwargs)
108
+ @classmethod
109
+ def from_dict(cls, d):
110
+ return cls(**d)
111
+
112
+ class MBartConfig:
113
+ def __init__(self, vocab_size=50265, **kwargs):
114
+ self.vocab_size = vocab_size
115
+ self.__dict__.update(kwargs)
116
+ @classmethod
117
+ def from_dict(cls, d):
118
+ return cls(**d)
119
+
120
+ class CodeGenConfig:
121
+ def __init__(self, vocab_size=50257, **kwargs):
122
+ self.vocab_size = vocab_size
123
+ self.__dict__.update(kwargs)
124
+ @classmethod
125
+ def from_dict(cls, d):
126
+ return cls(**d)
127
+
128
+ class BartConfig:
129
+ def __init__(self, vocab_size=50265, **kwargs):
130
+ self.vocab_size = vocab_size
131
+ self.__dict__.update(kwargs)
132
+ @classmethod
133
+ def from_dict(cls, d):
134
+ return cls(**d)
135
+
136
+ class AutoencoderKLConfig:
137
+ def __init__(self, **kwargs):
138
+ self.__dict__.update(kwargs)
139
+ @classmethod
140
+ def from_dict(cls, d):
141
+ return cls(**d)
142
+
143
+ class OpenLRMConfig:
144
+ def __init__(self, **kwargs):
145
+ self.__dict__.update(kwargs)
146
+ @classmethod
147
+ def from_dict(cls, d):
148
+ return cls(**d)
149
+
150
+ class UNet2DConditionModelConfig:
151
+ def __init__(self, **kwargs):
152
+ self.__dict__.update(kwargs)
153
+ @classmethod
154
+ def from_dict(cls, d):
155
+ return cls(**d)
156
+
157
+ class MusicGenConfig:
158
+ def __init__(self, **kwargs):
159
+ self.__dict__.update(kwargs)
160
+ @classmethod
161
+ def from_dict(cls, d):
162
+ return cls(**d)
163
+
164
+ class GPT2LMHeadModel(nn.Module):
165
+ def __init__(self, config):
166
+ super().__init__()
167
+ layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
168
+ self.transformer = nn.TransformerEncoder(layer, num_layers=12)
169
+ self.lm_head = nn.Linear(768, config.vocab_size)
170
+ def forward(self, x):
171
+ return self.lm_head(self.transformer(x))
172
+
173
+ class MBartForConditionalGeneration(nn.Module):
174
+ def __init__(self, config):
175
+ super().__init__()
176
+ self.config = config
177
+ layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
178
+ self.encoder = nn.TransformerEncoder(layer, num_layers=6)
179
+ dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12)
180
+ self.decoder = nn.TransformerDecoder(dlayer, num_layers=6)
181
+ self.output_layer = nn.Linear(768, config.vocab_size)
182
+ def forward(self, src, tgt):
183
+ return self.output_layer(self.decoder(tgt, self.encoder(src)))
184
+
185
+ class CodeGenForCausalLM(nn.Module):
186
+ def __init__(self, config):
187
+ super().__init__()
188
+ d_model = getattr(config, "d_model", 1024)
189
+ n_head = getattr(config, "n_head", 16)
190
+ num_layers = getattr(config, "num_layers", 12)
191
+ dlayer = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_head)
192
+ self.transformer_decoder = nn.TransformerDecoder(dlayer, num_layers=num_layers)
193
+ self.lm_head = nn.Linear(d_model, config.vocab_size)
194
+ def forward(self, tgt, memory=None):
195
+ if memory is None:
196
+ memory = torch.zeros_like(tgt)
197
+ return self.lm_head(self.transformer_decoder(tgt, memory))
198
+
199
+ class BartForConditionalGeneration(nn.Module):
200
+ def __init__(self, config):
201
+ super().__init__()
202
+ layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
203
+ self.encoder = nn.TransformerEncoder(layer, num_layers=6)
204
+ dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12)
205
+ self.decoder = nn.TransformerDecoder(dlayer, num_layers=6)
206
+ self.output_layer = nn.Linear(768, config.vocab_size)
207
+ def forward(self, src, tgt):
208
+ return self.output_layer(self.decoder(tgt, self.encoder(src)))
209
+
210
+ class ResnetBlock(nn.Module):
211
+ def __init__(self, in_ch, out_ch):
212
+ super().__init__()
213
+ self.norm1 = nn.GroupNorm(32, in_ch)
214
+ self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
215
+ self.norm2 = nn.GroupNorm(32, out_ch)
216
+ self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
217
+ self.conv_shortcut = nn.Conv2d(in_ch, out_ch, 1)
218
+ def forward(self, x):
219
+ sc = self.conv_shortcut(x)
220
+ h = F.silu(self.norm1(x))
221
+ h = self.conv1(h)
222
+ h = F.silu(self.norm2(h))
223
+ h = self.conv2(h)
224
+ return h + sc
225
+
226
+ class Downsample(nn.Module):
227
+ def __init__(self, in_ch, out_ch):
228
+ super().__init__()
229
+ self.conv = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1)
230
+ def forward(self, x):
231
+ return self.conv(x)
232
+
233
+ class DownBlock(nn.Module):
234
+ def __init__(self, in_ch, out_ch, num_res):
235
+ super().__init__()
236
+ self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)])
237
+ self.downsamplers = nn.ModuleList([Downsample(out_ch, out_ch)])
238
+ def forward(self, x):
239
+ for r in self.resnets:
240
+ x = r(x)
241
+ for ds in self.downsamplers:
242
+ x = ds(x)
243
+ return x
244
+
245
+ class Upsample(nn.Module):
246
+ def __init__(self, in_ch, out_ch):
247
+ super().__init__()
248
+ self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
249
+ def forward(self, x):
250
+ return self.conv(x)
251
+
252
+ class UpBlock(nn.Module):
253
+ def __init__(self, in_ch, out_ch, num_res):
254
+ super().__init__()
255
+ self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)])
256
+ self.upsampler = Upsample(out_ch, out_ch)
257
+ def forward(self, x):
258
+ for r in self.resnets:
259
+ x = r(x)
260
+ return self.upsampler(x)
261
+
262
+ class AttentionBlock(nn.Module):
263
+ def __init__(self, ch):
264
+ super().__init__()
265
+ self.norm = nn.GroupNorm(32, ch)
266
+ self.query = nn.Conv2d(ch, ch, 1)
267
+ self.key = nn.Conv2d(ch, ch, 1)
268
+ self.value = nn.Conv2d(ch, ch, 1)
269
+ self.proj_attn = nn.Conv2d(ch, ch, 1)
270
+ def forward(self, x):
271
+ b, c, h, w = x.shape
272
+ xn = self.norm(x)
273
+ q = self.query(xn).view(b, c, -1).permute(0, 2, 1)
274
+ k = self.key(xn).view(b, c, -1)
275
+ v = self.value(xn).view(b, c, -1).permute(0, 2, 1)
276
+ attn = torch.softmax(torch.bmm(q, k) / (c ** 0.5), dim=-1)
277
+ out = torch.bmm(attn, v).permute(0, 2, 1).view(b, c, h, w)
278
+ return x + self.proj_attn(out)
279
+
280
+ class Encoder(nn.Module):
281
+ def __init__(self, in_ch=3, base_ch=128, latent_ch=4):
282
+ super().__init__()
283
+ self.conv_in = nn.Conv2d(in_ch, base_ch, 3, padding=1)
284
+ self.down_blocks = nn.ModuleList([
285
+ DownBlock(base_ch, base_ch, 2),
286
+ DownBlock(base_ch, base_ch * 2, 2),
287
+ DownBlock(base_ch * 2, base_ch * 4, 2),
288
+ DownBlock(base_ch * 4, base_ch * 4, 2)
289
+ ])
290
+ self.mid_block = nn.ModuleList([
291
+ ResnetBlock(base_ch * 4, base_ch * 4),
292
+ AttentionBlock(base_ch * 4),
293
+ ResnetBlock(base_ch * 4, base_ch * 4)
294
+ ])
295
+ self.conv_norm_out = nn.GroupNorm(32, base_ch * 4)
296
+ self.conv_out = nn.Conv2d(base_ch * 4, latent_ch * 2, 3, padding=1)
297
+ self.quant_conv = nn.Conv2d(latent_ch * 2, latent_ch, 1)
298
+ def forward(self, x):
299
+ x = self.conv_in(x)
300
+ for blk in self.down_blocks:
301
+ x = blk(x)
302
+ for m in self.mid_block:
303
+ x = m(x)
304
+ x = self.conv_norm_out(x)
305
+ x = self.conv_out(x)
306
+ return self.quant_conv(x)
307
+
308
+ class Decoder(nn.Module):
309
+ def __init__(self, out_ch=3, base_ch=128, latent_ch=4):
310
+ super().__init__()
311
+ self.post_quant_conv = nn.Conv2d(latent_ch, latent_ch * 2, 1)
312
+ self.conv_in = nn.Conv2d(latent_ch, base_ch * 4, 3, padding=1)
313
+ self.mid_block = nn.ModuleList([
314
+ ResnetBlock(base_ch * 4, base_ch * 4),
315
+ AttentionBlock(base_ch * 4),
316
+ ResnetBlock(base_ch * 4, base_ch * 4)
317
+ ])
318
+ self.up_blocks = nn.ModuleList([
319
+ UpBlock(base_ch * 4, base_ch * 4, 3),
320
+ UpBlock(base_ch * 4, base_ch * 2, 3),
321
+ UpBlock(base_ch * 2, base_ch, 3),
322
+ UpBlock(base_ch, base_ch, 3)
323
+ ])
324
+ self.conv_norm_out = nn.GroupNorm(32, base_ch)
325
+ self.conv_out = nn.Conv2d(base_ch, out_ch, 3, padding=1)
326
+ def forward(self, x):
327
+ x = self.post_quant_conv(x)
328
+ x = self.conv_in(x)
329
+ for m in self.mid_block:
330
+ x = m(x)
331
+ for up in self.up_blocks:
332
+ x = up(x)
333
+ x = self.conv_norm_out(x)
334
+ return self.conv_out(x)
335
+
336
+ class AutoencoderKL(nn.Module):
337
+ def __init__(self, config):
338
+ super().__init__()
339
+ in_ch = config.get("in_channels", 3) if isinstance(config, dict) else config.__dict__.get("in_channels", 3)
340
+ out_ch = config.get("out_channels", 3) if isinstance(config, dict) else config.__dict__.get("out_channels", 3)
341
+ base_ch = config.get("base_channels", 128) if isinstance(config, dict) else config.__dict__.get("base_channels", 128)
342
+ latent_ch = config.get("latent_channels", 4) if isinstance(config, dict) else config.__dict__.get("latent_channels", 4)
343
+ self.encoder = Encoder(in_ch, base_ch, latent_ch)
344
+ self.decoder = Decoder(out_ch, base_ch, latent_ch)
345
+ def forward(self, x):
346
+ return self.decoder(self.encoder(x))
347
+ def decode(self, x):
348
+ return self.decoder(x)
349
+
350
+ class TransformerBlock(nn.Module):
351
+ def __init__(self, embed_dim, num_heads):
352
+ super().__init__()
353
+ self.norm1 = nn.LayerNorm(embed_dim)
354
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads)
355
+ self.norm2 = nn.LayerNorm(embed_dim)
356
+ hidden_dim = embed_dim * 4
357
+ self.mlp = nn.Sequential(
358
+ nn.Linear(embed_dim, hidden_dim),
359
+ nn.GELU(),
360
+ nn.Linear(hidden_dim, embed_dim)
361
+ )
362
+ def forward(self, x):
363
+ res = x
364
+ x = self.norm1(x)
365
+ x = x.transpose(0, 1)
366
+ attn, _ = self.attn(x, x, x)
367
+ x = attn.transpose(0, 1)
368
+ x = res + x
369
+ return x + self.mlp(self.norm2(x))
370
+
371
+ class VisionTransformer(nn.Module):
372
+ def __init__(self, config):
373
+ super().__init__()
374
+ if isinstance(config, dict):
375
+ self.img_size = config.get("img_size", 592)
376
+ self.patch_size = config.get("patch_size", 16)
377
+ self.embed_dim = config.get("hidden_size", 768)
378
+ depth = config.get("depth", 12)
379
+ num_heads = config.get("num_heads", 12)
380
+ else:
381
+ self.img_size = config.__dict__.get("img_size", 592)
382
+ self.patch_size = config.__dict__.get("patch_size", 16)
383
+ self.embed_dim = config.__dict__.get("hidden_size", 768)
384
+ depth = config.__dict__.get("depth", 12)
385
+ num_heads = config.__dict__.get("num_heads", 12)
386
+ num_patches = (self.img_size // self.patch_size) ** 2
387
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
388
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
389
+ self.patch_embed = nn.Conv2d(3, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
390
+ self.blocks = nn.ModuleList([TransformerBlock(self.embed_dim, num_heads) for _ in range(depth)])
391
+ self.norm = nn.LayerNorm(self.embed_dim)
392
+ self.register_tokens = nn.Parameter(torch.zeros(1, 4, self.embed_dim))
393
+ self._init_weights()
394
+ def _init_weights(self):
395
+ nn.init.normal_(self.cls_token, std=0.02)
396
+ nn.init.normal_(self.pos_embed, std=0.02)
397
+ def forward(self, x):
398
+ x = self.patch_embed(x)
399
+ x = x.flatten(2).transpose(1, 2)
400
+ cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
401
+ x = torch.cat((cls_tokens, x), dim=1)
402
+ x = x + self.pos_embed
403
+ for blk in self.blocks:
404
+ x = blk(x)
405
+ return self.norm(x)[:, 0]
406
+
407
+ class OpenLRM(nn.Module):
408
+ def __init__(self, config):
409
+ super().__init__()
410
+ self.encoder = nn.ModuleDict({"model": VisionTransformer(config)})
411
+ hidden = config.get("hidden_size", 768) if isinstance(config, dict) else config.__dict__.get("hidden_size", 768)
412
+ self.linear = nn.Linear(hidden, hidden)
413
+ def forward(self, x):
414
+ return self.linear(self.encoder["model"](x))
415
+
416
+ class VideoUNet(nn.Module):
417
+ def __init__(self, in_ch=4, out_ch=4, features=None):
418
+ super().__init__()
419
+ if features is None:
420
+ features = [64, 128, 256]
421
+ self.encoder = nn.ModuleList()
422
+ self.pool = nn.MaxPool3d(2, 2)
423
+ self.decoder = nn.ModuleList()
424
+ for f in features:
425
+ self.encoder.append(nn.Sequential(
426
+ nn.Conv3d(in_ch, f, 3, padding=1),
427
+ nn.ReLU(inplace=True),
428
+ nn.Conv3d(f, f, 3, padding=1),
429
+ nn.ReLU(inplace=True)
430
+ ))
431
+ in_ch = f
432
+ for f in reversed(features):
433
+ self.decoder.append(nn.Sequential(
434
+ nn.Conv3d(f * 2, f, 3, padding=1),
435
+ nn.ReLU(inplace=True),
436
+ nn.Conv3d(f, f, 3, padding=1),
437
+ nn.ReLU(inplace=True)
438
+ ))
439
+ self.final_conv = nn.Conv3d(features[0], out_ch, 1)
440
+ def forward(self, x, t, encoder_hidden_states):
441
+ skips = []
442
+ for enc in self.encoder:
443
+ x = enc(x)
444
+ skips.append(x)
445
+ x = self.pool(x)
446
+ for dec in self.decoder:
447
+ skip = skips.pop()
448
+ x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=False)
449
+ x = torch.cat([x, skip], dim=1)
450
+ x = dec(x)
451
+ return self.final_conv(x)
452
+
453
+ class SentimentClassifierModel(nn.Module):
454
+ def __init__(self, config):
455
+ super().__init__()
456
+ self.classifier = nn.Sequential(
457
+ nn.Linear(768, 256),
458
+ nn.ReLU(),
459
+ nn.Linear(256, 2)
460
+ )
461
+ def forward(self, x):
462
+ return self.classifier(x)
463
+
464
+ class STTModel(nn.Module):
465
+ def __init__(self, config):
466
+ super().__init__()
467
+ self.net = nn.Sequential(
468
+ nn.Linear(768, 512),
469
+ nn.ReLU(),
470
+ nn.Linear(512, 768)
471
+ )
472
+ def forward(self, x):
473
+ return self.net(x)
474
+
475
+ class TTSModel(nn.Module):
476
+ def __init__(self, config):
477
+ super().__init__()
478
+ self.net = nn.Sequential(
479
+ nn.Linear(768, 512),
480
+ nn.ReLU(),
481
+ nn.Linear(512, 768)
482
+ )
483
+ def forward(self, x):
484
+ return self.net(x)
485
+
486
+ class MusicGenModel(nn.Module):
487
+ def __init__(self, config):
488
+ super().__init__()
489
+ layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
490
+ self.transformer = nn.TransformerEncoder(layer, num_layers=12)
491
+ self.linear = nn.Linear(768, 768)
492
+ def forward(self, x):
493
+ return self.linear(self.transformer(x))
494
+
495
+ class SimpleTextEncoder(nn.Module):
496
+ def __init__(self, vocab_size=10000, embed_dim=768, max_length=77):
497
+ super().__init__()
498
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
499
+ self.max_length = max_length
500
+ def forward(self, text_tokens):
501
+ return self.embedding(text_tokens)
502
+
503
+ class DiffusionScheduler:
504
+ def __init__(self, steps):
505
+ self.steps = steps
506
+ self.betas = torch.linspace(0.1, 0.001, steps=steps).to(device)
507
+ def step(self, noise, t, sample):
508
+ beta = self.betas[t]
509
+ return sample - beta * noise
510
+
511
+ class VideoOutput:
512
+ def __init__(self, frames):
513
+ self.frames = [img_as_ubyte(frame) for frame in frames[0]]
514
+
515
+ class VideoPipeline(nn.Module):
516
+ def __init__(self, unet, vae, text_encoder, vocab):
517
+ super().__init__()
518
+ self.unet = unet
519
+ self.vae = vae
520
+ self.text_encoder = text_encoder
521
+ self.vocab = vocab
522
+ def forward(self, prompt: str, steps: int = 25, num_frames: int = 24):
523
+ token_ids = simple_tokenizer(prompt, self.vocab)
524
+ text_emb = self.text_encoder(token_ids)
525
+ latent = torch.randn((1, 4, num_frames, 64, 64), device=device).half()
526
+ sched = DiffusionScheduler(steps)
527
+ for t in range(steps):
528
+ noise = self.unet(latent, t, text_emb)
529
+ latent = sched.step(noise, t, latent)
530
+ frames = self.vae.decode(latent / 0.18215)
531
+ frames = frames.clamp(0, 1).float().cpu().permute(0, 2, 3, 4, 1).numpy()
532
+ return VideoOutput(frames)
533
+
534
+ def initialize_gpt2_model(folder, files):
535
+ download_files(folder, files)
536
+ config = GPT2Config()
537
+ model = GPT2LMHeadModel(config).to(device)
538
+ sd = torch.load(os.path.join(folder, sanitize_filename("gpt2-pytorch_model.bin")), map_location=device)
539
+ load_state_dict_safe(model, sd)
540
+ model.eval()
541
+ enc = read_json(os.path.join(folder, sanitize_filename("encoder.json")))
542
+ return model, enc
543
+
544
+ def initialize_translation_model(folder, files):
545
+ download_files(folder, files)
546
+ config = MBartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
547
+ model = MBartForConditionalGeneration(config).to(device)
548
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
549
+ load_state_dict_safe(model, sd)
550
+ model.eval()
551
+ vp = os.path.join(folder, "vocab.json")
552
+ if os.path.exists(vp):
553
+ vocab = read_json(vp)
554
+ model.tokenizer = lambda txt: [vocab.get(t, 0) for t in txt.split()]
555
+ else:
556
+ model.tokenizer = lambda txt: txt
557
+ model.config.lang_code_to_id = {'en_XX': 0, 'es_XX': 1}
558
+ return model
559
+
560
+ def initialize_codegen_model(folder, files):
561
+ download_files(folder, files)
562
+ config = CodeGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
563
+ model = CodeGenForCausalLM(config).to(device)
564
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
565
+ load_state_dict_safe(model, sd)
566
+ model.eval()
567
+ tok = get_codegen_tokenizer(os.path.join(folder, "vocab.json"), os.path.join(folder, "merges.txt"))
568
+ vocab = read_json(os.path.join(folder, "vocab.json"))
569
+ idx2w = {v: k for k, v in vocab.items()}
570
+ model.tokenizer = tok
571
+ return model, tok, vocab, idx2w, vocab
572
+
573
+ def initialize_summarization_model(folder, files):
574
+ download_files(folder, files)
575
+ config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
576
+ model = BartForConditionalGeneration(config).to(device)
577
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
578
+ load_state_dict_safe(model, sd)
579
+ model.eval()
580
+ vp = os.path.join(folder, "vocab.json")
581
+ if os.path.exists(vp):
582
+ vocab_json = read_json(vp)
583
+ vocab = set(vocab_json.keys())
584
+ return model, vocab, vocab_json, {v: k for k, v in vocab_json.items()}
585
+ return model, None, None, None
586
+
587
+ def initialize_imagegen_model(folder, files):
588
+ download_files(folder, files)
589
+ config = AutoencoderKLConfig.from_dict(read_json(os.path.join(folder, "config.json")))
590
+ vae = AutoencoderKL(config).to(device)
591
+ sd = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device)
592
+ load_state_dict_safe(vae, sd)
593
+ vae.eval()
594
+ return vae
595
+
596
+ def initialize_image_to_3d_model(folder, files):
597
+ download_files(folder, files)
598
+ config = OpenLRMConfig.from_dict(read_json(os.path.join(folder, "config.json")))
599
+ model3d = OpenLRM(config).to(device)
600
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
601
+ load_state_dict_safe(model3d, sd)
602
+ model3d.eval()
603
+ return model3d
604
+
605
+ def initialize_text_to_video_model(folder, files):
606
+ download_files(folder, files)
607
+ unet_cfg = read_json(os.path.join(folder, "config.json"))
608
+ unet_cfg = filter_kwargs(VideoUNet, unet_cfg)
609
+ unet = VideoUNet(**unet_cfg).half().to(device)
610
+ sd_unet = torch.load(os.path.join(folder, "diffusion_pytorch_model.fp16.bin"), map_location=device)
611
+ load_state_dict_safe(unet, sd_unet)
612
+ unet.eval()
613
+ vae_cfg = read_json(os.path.join(folder, "config.json"))
614
+ vae_cfg = filter_kwargs(AutoencoderKL, vae_cfg)
615
+ vae = AutoencoderKL(vae_cfg).half().to(device)
616
+ sd_vae = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device)
617
+ load_state_dict_safe(vae, sd_vae)
618
+ vae.eval()
619
+ vp = os.path.join(folder, "vocab.json")
620
+ text_vocab = read_json(vp) if os.path.exists(vp) else {}
621
+ te_path = os.path.join(folder, "text_encoder.bin")
622
+ if os.path.exists(te_path):
623
+ text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
624
+ sd_te = torch.load(te_path, map_location=device)
625
+ load_state_dict_safe(text_encoder, sd_te)
626
+ else:
627
+ text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
628
+ text_encoder.eval()
629
+ return VideoPipeline(unet, vae, text_encoder, text_vocab)
630
+
631
+ def initialize_sentiment_model(folder, files):
632
+ download_files(folder, files)
633
+ config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
634
+ model = SentimentClassifierModel(config).to(device)
635
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
636
+ load_state_dict_safe(model, sd)
637
+ model.eval()
638
+ vp = os.path.join(folder, "vocab.json")
639
+ if os.path.exists(vp):
640
+ read_json(vp)
641
+ return model
642
+
643
+ def initialize_stt_model(folder, files):
644
+ download_files(folder, files)
645
+ config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
646
+ model = STTModel(config).to(device)
647
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
648
+ load_state_dict_safe(model, sd)
649
+ model.eval()
650
+ vp = os.path.join(folder, "vocab.json")
651
+ if os.path.exists(vp):
652
+ read_json(vp)
653
+ return model
654
+
655
+ def initialize_tts_model(folder, files):
656
+ download_files(folder, files)
657
+ config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
658
+ model = TTSModel(config).to(device)
659
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
660
+ load_state_dict_safe(model, sd)
661
+ model.eval()
662
+ vp = os.path.join(folder, "vocab.json")
663
+ if os.path.exists(vp):
664
+ read_json(vp)
665
+ return model
666
+
667
+ def initialize_musicgen_model(folder, files):
668
+ download_files(folder, files)
669
+ config = MusicGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
670
+ model = MusicGenModel(config).to(device)
671
+ sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
672
+ load_state_dict_safe(model, sd)
673
+ model.eval()
674
+ return model
models.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ import copy
6
+ #from configs import GPT2Config, MBartConfig, CodeGenConfig, SummarizationConfig, OpenLRMConfig, UNet2DConditionModelConfig, AutoencoderKLConfig, BartConfig, MusicGenConfig
7
+ from configs import *
8
+ #from extensions import gelu, LayerNorm, Conv1D, Attention, MLP, Block, GPT2Model, GPT2LMHead, MBartEncoderLayer, MBartDecoderLayer, MBartEncoder, MBartDecoder, MBartModel, MBartForConditionalGeneration, CodeGenAttention, CodeGenBlock, CodeGenModel, CodeGenForCausalLM, SummarizationModel, OpenLRM, OpenLRMLayer, OpenLRMAttention, OpenLRMFeedForward, AutoencoderKL, Encoder_, Decoder_, DownBlock, UpBlock, ResnetBlock, MidBlock, Downsample2D, Upsample2D, UNet2DConditionModel, UNetMidBlock2DConditionModel, UNetDownBlock2DConditionModel, UNetUpBlock2DConditionModel, ResnetBlock2D, CrossAttentionBlock2D, CrossAttention, SimpleClassifier
9
+ from extensions import *
10
+
11
+ class SentimentClassifierModel(nn.Module):
12
+ def __init__(self, config):
13
+ super().__init__()
14
+ self.config = config
15
+ self.embedding = nn.Embedding(config.vocab_size, config.d_model)
16
+ self.lstm = nn.LSTM(config.d_model, config.d_model, batch_first=True, bidirectional=True)
17
+ self.fc = nn.Linear(config.d_model * 2, 3)
18
+
19
+ def forward(self, input_ids):
20
+ embedded = self.embedding(input_ids)
21
+ packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths=[input_ids.size(1)]*input_ids.size(0), batch_first=True, enforce_sorted=False)
22
+ packed_output, _ = self.lstm(packed_embedded)
23
+ output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
24
+ pooled = output[:, -1, :]
25
+ logits = self.fc(pooled)
26
+ return logits
27
+
28
+ class STTModel(nn.Module):
29
+ def __init__(self, config):
30
+ super().__init__()
31
+ self.config = config
32
+ self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=2, padding=1)
33
+ self.relu1 = nn.ReLU()
34
+ self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
35
+ self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
36
+ self.relu2 = nn.ReLU()
37
+ self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
38
+ self.lstm = nn.LSTM(32 * (config.max_position_embeddings // 8), 128, batch_first=True, bidirectional=True)
39
+ self.fc = nn.Linear(128 * 2, config.vocab_size)
40
+
41
+ def forward(self, audio_data):
42
+ x = self.pool1(self.relu1(self.conv1(audio_data.unsqueeze(1))))
43
+ x = self.pool2(self.relu2(self.conv2(x)))
44
+ x = x.transpose(1, 2).contiguous()
45
+ x = x.view(x.size(0), -1, x.size(2))
46
+ packed_output = nn.utils.rnn.pack_padded_sequence(x, lengths=[x.size(1)]*x.size(0), batch_first=True, enforce_sorted=False)
47
+ packed_output, _ = self.lstm(packed_output)
48
+ output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
49
+ logits = self.fc(output)
50
+ return logits
51
+
52
+ class TTSModel(nn.Module):
53
+ def __init__(self, config):
54
+ super().__init__()
55
+ self.config = config
56
+ self.embedding = nn.Embedding(config.vocab_size, config.d_model)
57
+ self.lstm = nn.LSTM(config.d_model, config.d_model, batch_first=True, bidirectional=True)
58
+ self.fc = nn.Linear(config.d_model * 2, 1)
59
+ self.sigmoid = nn.Sigmoid()
60
+
61
+ def forward(self, input_ids):
62
+ embedded = self.embedding(input_ids)
63
+ packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths=[input_ids.size(1)]*input_ids.size(0), batch_first=True, enforce_sorted=False)
64
+ packed_output, _ = self.lstm(packed_embedded)
65
+ output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
66
+ logits = self.fc(output)
67
+ audio = self.sigmoid(logits)
68
+ return audio
69
+
70
+ class MusicGenModel(nn.Module):
71
+ def __init__(self, config: MusicGenConfig):
72
+ super().__init__()
73
+ self.config = config
74
+ self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
75
+ self.transformer_layers = nn.ModuleList([CodeGenBlock(config) for _ in range(config.num_hidden_layers)])
76
+ self.fc_out = nn.Linear(config.hidden_size, config.vocab_size)
77
+
78
+ def forward(self, input_ids):
79
+ embedded_tokens = self.embedding(input_ids)
80
+ hidden_states = embedded_tokens
81
+ for layer in self.transformer_layers:
82
+ hidden_states = layer(hidden_states)
83
+ logits = self.fc_out(hidden_states)
84
+ return logits
85
+
86
+ def sample(self, attributes, sample_rate, duration):
87
+ input_tokens = torch.randint(0, self.config.vocab_size, (1, 1), dtype=torch.long).to(device)
88
+ audio_output = []
89
+ num_steps = int(duration * sample_rate / 1024)
90
+ for _ in tqdm(range(num_steps), desc="Generating music"):
91
+ logits = self.forward(input_tokens)
92
+ predicted_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
93
+ audio_output.append(predicted_token.cpu())
94
+ input_tokens = torch.cat((input_tokens, predicted_token), dim=1)
95
+ audio_output = torch.cat(audio_output, dim=1).float()
96
+ return audio_output
musicgen_api.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import jsonify, send_file, request
2
+ from main import *
3
+ #from main import import musicgen_model, device
4
+ import torch
5
+ import soundfile as sf
6
+ import numpy as np
7
+ import io
8
+
9
+ def generate_music(prompt, output_path="output_music.wav"):
10
+ if musicgen_model is None:
11
+ return "Music generation model not initialized."
12
+
13
+ attributes = [prompt]
14
+ sample_rate = 32000
15
+ duration = 8
16
+ audio_values = musicgen_model.sample(
17
+ attributes=attributes,
18
+ sample_rate=sample_rate,
19
+ duration=duration,
20
+ )
21
+ output_audio = audio_values.cpu().numpy().squeeze()
22
+ sf.write(output_path, output_audio, sample_rate)
23
+ return output_path
24
+
25
+ def musicgen_api():
26
+ data = request.get_json()
27
+ prompt = data.get('prompt')
28
+ if not prompt:
29
+ return jsonify({"error": "Prompt is required"}), 400
30
+ output_file = generate_music(prompt)
31
+ if output_file == "Music generation model not initialized.":
32
+ return jsonify({"error": "Music generation failed"}), 500
33
+ with open(output_file, 'rb') as f:
34
+ audio_content = f.read()
35
+ return send_file(io.BytesIO(audio_content), mimetype="audio/wav", as_attachment=True, download_name="output.wav")
requirements.txt ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ retry
3
+ asyncio
4
+ basicsr
5
+ beautifulsoup4
6
+ bs4
7
+ opencv-python
8
+ deep-translator
9
+ duckduckgo-search
10
+ fastapi
11
+ faker
12
+ flask
13
+ flask-cors
14
+ facexlib
15
+ ffmpeg-python
16
+ gfpgan
17
+ imageio
18
+ imageio-ffmpeg
19
+ langdetect
20
+ librosa
21
+ nltk
22
+ numpy
23
+ Pillow
24
+ pydub
25
+ pytorch-lightning
26
+ PyYAML
27
+ retry
28
+ safetensors
29
+ scikit-learn
30
+ scipy
31
+ scikit-image
32
+ soundfile
33
+ torch
34
+ torchaudio
35
+ torchvision
36
+ tqdm
37
+ wget
38
+ yacs
39
+ numba
40
+ librosa
sadtalker_api.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import uuid
4
+ import asyncio
5
+ import shutil
6
+ import requests
7
+ from urllib.parse import urlparse
8
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Form, WebSocket
9
+ from fastapi.responses import JSONResponse
10
+ #from fastapi.middleware.cors import CORSMiddleware
11
+ from fastapi import APIRouter
12
+ from extensions import *
13
+ from main import *
14
+ #from main import import sadtalker_instance
15
+ from tts_api import *
16
+ from sadtalker_utils import *
17
+ import base64
18
+ from stt_api import *
19
+ from text_generation import *
20
+
21
+ router = APIRouter()
22
+
23
+ @router.post("/sadtalker")
24
+ async def create_video(
25
+ source_image: str = Form(None),
26
+ source_image_file: UploadFile = File(None),
27
+ driven_audio: str = Form(None),
28
+ driven_audio_file: UploadFile = File(None),
29
+ preprocess: str = Form('crop'),
30
+ still_mode: bool = Form(False),
31
+ use_enhancer: bool = Form(False),
32
+ batch_size: int = Form(1),
33
+ size: int = Form(256),
34
+ pose_style: int = Form(0),
35
+ exp_scale: float = Form(1.0),
36
+ use_ref_video: bool = Form(False),
37
+ ref_video: str = Form(None),
38
+ ref_video_file: UploadFile = File(None),
39
+ ref_info: str = Form(None),
40
+ use_idle_mode: bool = Form(False),
41
+ length_of_audio: int = Form(0),
42
+ use_blink: bool = Form(True),
43
+ checkpoint_dir: str = Form('checkpoints'),
44
+ config_dir: str = Form('src/config'),
45
+ old_version: bool = Form(False),
46
+ tts_text: str = Form(None),
47
+ tts_lang: str = Form('en'),
48
+ ):
49
+ if source_image_file and source_image:
50
+ raise HTTPException(status_code=400, detail="source_image and source_image_file cannot be both not None")
51
+ if driven_audio and driven_audio_file:
52
+ raise HTTPException(status_code=400, detail="driven_audio and driven_audio_file cannot be both not None")
53
+ if ref_video and ref_video_file:
54
+ raise HTTPException(status_code=400, detail="ref_video and ref_video_file cannot be both not None")
55
+ tmp_source_image = None
56
+ if source_image_file:
57
+ tmp_source_image = tempfile.NamedTemporaryFile(suffix=os.path.splitext(source_image_file.filename)[1], delete=False)
58
+ content = await source_image_file.read()
59
+ tmp_source_image.write(content)
60
+ source_image_path = tmp_source_image.name
61
+ elif source_image:
62
+ if urlparse(source_image).scheme in ["http", "https"]:
63
+ response = requests.get(source_image, stream=True)
64
+ response.raise_for_status()
65
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_source_image:
66
+ for chunk in response.iter_content(chunk_size=8192):
67
+ tmp_source_image.write(chunk)
68
+ source_image_path = tmp_source_image.name
69
+ else:
70
+ source_image_path = source_image
71
+ else:
72
+ raise HTTPException(status_code=400, detail="source_image not provided")
73
+ tmp_driven_audio = None
74
+ if driven_audio_file:
75
+ tmp_driven_audio = tempfile.NamedTemporaryFile(suffix=os.path.splitext(driven_audio_file.filename)[1], delete=False)
76
+ content = await driven_audio_file.read()
77
+ tmp_driven_audio.write(content)
78
+ driven_audio_path = tmp_driven_audio.name
79
+ elif driven_audio:
80
+ if urlparse(driven_audio).scheme in ["http", "https"]:
81
+ response = requests.get(driven_audio, stream=True)
82
+ response.raise_for_status()
83
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_driven_audio:
84
+ for chunk in response.iter_content(chunk_size=8192):
85
+ tmp_driven_audio.write(chunk)
86
+ driven_audio_path = tmp_driven_audio.name
87
+ else:
88
+ driven_audio_path = driven_audio
89
+ else:
90
+ driven_audio_path = None
91
+ tmp_ref_video = None
92
+ if ref_video_file:
93
+ tmp_ref_video = tempfile.NamedTemporaryFile(suffix=os.path.splitext(ref_video_file.filename)[1], delete=False)
94
+ content = await ref_video_file.read()
95
+ tmp_ref_video.write(content)
96
+ ref_video_path = tmp_ref_video.name
97
+ elif ref_video:
98
+ if urlparse(ref_video).scheme in ["http", "https"]:
99
+ response = requests.get(ref_video, stream=True)
100
+ response.raise_for_status()
101
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_ref_video:
102
+ for chunk in response.iter_content(chunk_size=8192):
103
+ tmp_ref_video.write(chunk)
104
+ ref_video_path = tmp_ref_video.name
105
+ else:
106
+ ref_video_path = ref_video
107
+ else:
108
+ ref_video_path=None
109
+ try:
110
+ loop = asyncio.get_running_loop()
111
+ output_path = await loop.run_in_executor(None, sadtalker_instance.test,
112
+ source_image_path,
113
+ driven_audio_path,
114
+ preprocess,
115
+ still_mode,
116
+ use_enhancer,
117
+ batch_size,
118
+ size,
119
+ pose_style,
120
+ exp_scale,
121
+ use_ref_video,
122
+ ref_video_path,
123
+ ref_info,
124
+ use_idle_mode,
125
+ length_of_audio,
126
+ use_blink,
127
+ './results/',
128
+ tts_text=tts_text,
129
+ tts_lang=tts_lang,
130
+ )
131
+ return {"video_url": output_path}
132
+ except Exception as e:
133
+ raise HTTPException(status_code=500, detail=str(e))
134
+ finally:
135
+ if tmp_source_image:
136
+ os.remove(tmp_source_image.name)
137
+ if tmp_driven_audio:
138
+ os.remove(tmp_driven_audio.name)
139
+ if tmp_ref_video:
140
+ os.remove(tmp_ref_video.name)
141
+
142
+ @router.websocket("/ws")
143
+ async def websocket_endpoint(websocket: WebSocket):
144
+ await websocket.accept()
145
+ tts_model = TTSTalker()
146
+ try:
147
+ while True:
148
+ data = await websocket.receive_json()
149
+ text = data.get("text")
150
+ audio_base64 = data.get("audio")
151
+ if text:
152
+ audio_path = await asyncio.get_running_loop().run_in_executor(None, tts_model.test, text)
153
+ elif audio_base64:
154
+ try:
155
+ audio_bytes = base64.b64decode(audio_base64)
156
+ tmp_audio_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
157
+ tmp_audio_file.write(audio_bytes)
158
+ audio_path = tmp_audio_file.name
159
+ transcription_text_file = speech_to_text_func(tmp_audio_file.name)
160
+ with open(transcription_text_file, 'r') as f:
161
+ transcription_text = f.read()
162
+ response_stream = perform_reasoning_stream(f"respond to this sentence in 10 words or less {transcription_text}", 0.7, 40, 0.0, 1.2)
163
+ response_text = ""
164
+ for chunk in response_stream:
165
+ if chunk == "<END_STREAM>":
166
+ break
167
+ response_text += chunk
168
+ audio_path = await asyncio.get_running_loop().run_in_executor(None, tts_model.test, response_text)
169
+
170
+ except Exception as e:
171
+ await websocket.send_json({"error":str(e)})
172
+ continue
173
+ finally:
174
+ if 'tmp_audio_file' in locals() and tmp_audio_file:
175
+ os.remove(tmp_audio_file.name)
176
+ else:
177
+ continue
178
+ source_image_path = './examples/source_image/cyarh.png'
179
+ ref_video_path='./examples/driven_video/vid_xdd.mp4'
180
+ loop = asyncio.get_running_loop()
181
+ output = await loop.run_in_executor(None, sadtalker_instance.test,
182
+ source_image_path,
183
+ audio_path,
184
+ 'full',
185
+ True,
186
+ True,
187
+ 1,
188
+ 256,
189
+ 0,
190
+ 1,
191
+ True,
192
+ ref_video_path,
193
+ "pose+blink",
194
+ False,
195
+ 0,
196
+ True,
197
+ './results/'
198
+ )
199
+ await websocket.send_json({"video_url": output})
200
+ except Exception as e:
201
+ print(e)
202
+ await websocket.send_json({"error":str(e)})
sadtalker_utils.py ADDED
@@ -0,0 +1,866 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import uuid
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import yaml
10
+ from PIL import Image
11
+ from skimage import img_as_ubyte, transform
12
+ import safetensors
13
+ import librosa
14
+ from pydub import AudioSegment
15
+ import imageio
16
+ from scipy import signal
17
+ from scipy.io import loadmat, savemat, wavfile
18
+ import glob
19
+ import tempfile
20
+ from tqdm import tqdm
21
+ import math
22
+ import torchaudio
23
+ import urllib.request
24
+
25
+ REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
26
+ CODEFORMER_URL = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
27
+ RESTOREFORMER_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth"
28
+ GFPGAN_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
29
+ kp_url = "https://huggingface.co/usyd-community/vitpose-base-simple/resolve/main/model.safetensors"
30
+ kp_file = "kp_detector.safetensors"
31
+ aud_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/auido2pose_00140-model.pth"
32
+ aud_file = "auido2pose_00140-model.pth"
33
+ wav_url = "https://huggingface.co/facebook/wav2vec2-base/resolve/main/pytorch_model.bin"
34
+ wav_file = "wav2vec2.pth"
35
+ gen_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/wav2lip.pth"
36
+ gen_file = "generator.pth"
37
+ mapx_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/mapping_00229-model.pth.tar"
38
+ mapx_file = "mapping.pth"
39
+ den_url = "https://huggingface.co/KwaiVGI/LivePortrait/resolve/main/liveportrait/base_models/motion_extractor.pth"
40
+ den_file = "dense_motion.pth"
41
+
42
+
43
+ def download_model(url, filename, checkpoint_dir):
44
+ if not os.path.exists(os.path.join(checkpoint_dir, filename)):
45
+ print(f"Downloading {filename}...")
46
+ os.makedirs(checkpoint_dir, exist_ok=True)
47
+ urllib.request.urlretrieve(url, os.path.join(checkpoint_dir, filename))
48
+ print(f"{filename} downloaded.")
49
+ else:
50
+ print(f"{filename} already exists.")
51
+
52
+
53
+ def mp3_to_wav_util(mp3_filename, wav_filename, frame_rate):
54
+ AudioSegment.from_file(mp3_filename).set_frame_rate(frame_rate).export(wav_filename, format="wav")
55
+
56
+
57
+ def load_wav_util(path, sr):
58
+ return librosa.core.load(path, sr=sr)[0]
59
+
60
+
61
+ def save_wav_util(wav, path, sr):
62
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
63
+ wavfile.write(path, sr, wav.astype(np.int16))
64
+
65
+
66
+ class OcclusionAwareKPDetector(nn.Module):
67
+
68
+ def __init__(self, kp_channels, num_kp, num_dilation_blocks, dropout_rate):
69
+ super(OcclusionAwareKPDetector, self).__init__()
70
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
71
+ self.bn1 = nn.BatchNorm2d(64)
72
+ self.relu = nn.ReLU()
73
+ self.conv2 = nn.Conv2d(64, num_kp, kernel_size=3, padding=1)
74
+
75
+ def forward(self, x):
76
+ x = self.relu(self.bn1(self.conv1(x)))
77
+ x = self.conv2(x)
78
+ kp = {'value': x.view(x.size(0), -1)}
79
+ return kp
80
+
81
+
82
+ class Wav2Vec2Model(nn.Module):
83
+
84
+ def __init__(self):
85
+ super(Wav2Vec2Model, self).__init__()
86
+ self.conv = nn.Conv1d(1, 64, kernel_size=10, stride=5, padding=5)
87
+ self.bn = nn.BatchNorm1d(64)
88
+ self.relu = nn.ReLU()
89
+ self.fc = nn.Linear(64, 2048)
90
+
91
+ def forward(self, audio):
92
+ x = audio.unsqueeze(1)
93
+ x = self.relu(self.bn(self.conv(x)))
94
+ x = torch.mean(x, dim=-1)
95
+ x = self.fc(x)
96
+ return x
97
+
98
+
99
+ class AudioCoeffsPredictor(nn.Module):
100
+
101
+ def __init__(self, input_dim, output_dim):
102
+ super(AudioCoeffsPredictor, self).__init__()
103
+ self.linear = nn.Linear(input_dim, output_dim)
104
+
105
+ def forward(self, audio_embedding):
106
+ return self.linear(audio_embedding)
107
+
108
+
109
+ class MappingNet(nn.Module):
110
+
111
+ def __init__(self, num_coeffs, num_layers, hidden_dim):
112
+ super(MappingNet, self).__init__()
113
+ layers = []
114
+ input_dim = num_coeffs * 2
115
+ for _ in range(num_layers):
116
+ layers.append(nn.Linear(input_dim, hidden_dim))
117
+ layers.append(nn.ReLU())
118
+ input_dim = hidden_dim
119
+ layers.append(nn.Linear(hidden_dim, num_coeffs))
120
+ self.net = nn.Sequential(*layers)
121
+
122
+ def forward(self, x):
123
+ return self.net(x)
124
+
125
+
126
+ class DenseMotionNetwork(nn.Module):
127
+
128
+ def __init__(self, num_kp, num_channels, block_expansion, num_blocks, max_features):
129
+ super(DenseMotionNetwork, self).__init__()
130
+ self.conv1 = nn.Conv2d(num_channels, max_features, kernel_size=3, padding=1)
131
+ self.relu = nn.ReLU()
132
+ self.conv2 = nn.Conv2d(max_features, num_channels, kernel_size=3, padding=1)
133
+
134
+ def forward(self, kp_source, kp_driving, jacobian):
135
+ x = self.relu(self.conv1(kp_source))
136
+ x = self.conv2(x)
137
+ sparse_motion = {'dense_motion': x}
138
+ return sparse_motion
139
+
140
+
141
+ class Hourglass(nn.Module):
142
+
143
+ def __init__(self, block_expansion, num_blocks, max_features, num_channels, kp_size, num_deform_blocks):
144
+ super(Hourglass, self).__init__()
145
+ self.encoder = nn.Sequential(nn.Conv2d(num_channels, max_features, kernel_size=7, stride=2, padding=3),
146
+ nn.BatchNorm2d(max_features), nn.ReLU())
147
+ self.decoder = nn.Sequential(
148
+ nn.ConvTranspose2d(max_features, num_channels, kernel_size=4, stride=2, padding=1), nn.Tanh())
149
+
150
+ def forward(self, source_image, kp_driving, **kwargs):
151
+ x = self.encoder(source_image)
152
+ x = self.decoder(x)
153
+ B, C, H, W = x.size()
154
+ video = []
155
+ for _ in range(10):
156
+ frame = (x[0].cpu().detach().numpy().transpose(1, 2, 0) * 127.5 + 127.5).clip(0, 255).astype(
157
+ np.uint8)
158
+ video.append(frame)
159
+ return video
160
+
161
+
162
+ class Face3DHelper:
163
+
164
+ def __init__(self, local_pca_path, device):
165
+ self.local_pca_path = local_pca_path
166
+ self.device = device
167
+
168
+ def run(self, source_image):
169
+ h, w, _ = source_image.shape
170
+ x_min = w // 4
171
+ y_min = h // 4
172
+ x_max = x_min + w // 2
173
+ y_max = y_min + h // 2
174
+ return [x_min, y_min, x_max, y_max]
175
+
176
+
177
+ class Face3DHelperOld(Face3DHelper):
178
+
179
+ def __init__(self, local_pca_path, device):
180
+ super(Face3DHelperOld, self).__init__(local_pca_path, device)
181
+
182
+
183
+ class MouthDetector:
184
+
185
+ def __init__(self):
186
+ pass
187
+
188
+ def detect(self, image):
189
+ h, w = image.shape[:2]
190
+ return (w // 2, h // 2)
191
+
192
+
193
+ class KeypointNorm(nn.Module):
194
+
195
+ def __init__(self, device):
196
+ super(KeypointNorm, self).__init__()
197
+ self.device = device
198
+
199
+ def forward(self, kp_driving):
200
+ return kp_driving
201
+
202
+
203
+ def save_video_with_watermark(video_frames, audio_path, output_path):
204
+ H, W, _ = video_frames[0].shape
205
+ out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H))
206
+ for frame in video_frames:
207
+ out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
208
+ out.release()
209
+
210
+
211
+ def paste_pic(video_path, source_image_crop, crop_info, audio_path, output_path):
212
+ shutil.copy(video_path, output_path)
213
+
214
+
215
+ class TTSTalker:
216
+
217
+ def __init__(self):
218
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
219
+ self.tts_model = None
220
+
221
+ def load_model(self):
222
+ self.tts_model = self
223
+
224
+ def tokenizer(self, text):
225
+ return [ord(c) for c in text]
226
+
227
+ def __call__(self, input_tokens):
228
+ return torch.zeros(1, 16000, device=self.device)
229
+
230
+ def test(self, text, lang='en'):
231
+ if self.tts_model is None:
232
+ self.load_model()
233
+ output_path = os.path.join('./results', str(uuid.uuid4()) + '.wav')
234
+ os.makedirs('./results', exist_ok=True)
235
+ tokens = self.tokenizer(text)
236
+ input_tokens = torch.tensor([tokens], dtype=torch.long).to(self.device)
237
+ with torch.no_grad():
238
+ audio_output = self(input_tokens)
239
+ torchaudio.save(output_path, audio_output.cpu(), 16000)
240
+ return output_path
241
+
242
+
243
+ class SadTalker:
244
+
245
+ def __init__(self, checkpoint_path='checkpoints', config_path='src/config', size=256, preprocess='crop',
246
+ old_version=False):
247
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
248
+ self.cfg = self.get_cfg_defaults()
249
+ self.merge_from_file(os.path.join(config_path, 'sadtalker_config.yaml'))
250
+ self.cfg['MODEL']['CHECKPOINTS_DIR'] = checkpoint_path
251
+ self.cfg['MODEL']['CONFIG_DIR'] = config_path
252
+ self.cfg['MODEL']['DEVICE'] = self.device
253
+ self.cfg['INPUT_IMAGE'] = {}
254
+ self.cfg['INPUT_IMAGE']['SOURCE_IMAGE'] = 'None'
255
+ self.cfg['INPUT_IMAGE']['DRIVEN_AUDIO'] = 'None'
256
+ self.cfg['INPUT_IMAGE']['PREPROCESS'] = preprocess
257
+ self.cfg['INPUT_IMAGE']['SIZE'] = size
258
+ self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version
259
+
260
+ download_model(kp_url, kp_file, checkpoint_path)
261
+ download_model(aud_url, aud_file, checkpoint_path)
262
+ download_model(wav_url, wav_file, checkpoint_path)
263
+ download_model(gen_url, gen_file, checkpoint_path)
264
+ download_model(mapx_url, mapx_file, checkpoint_path)
265
+ download_model(den_url, den_file, checkpoint_path)
266
+ download_model(GFPGAN_URL, 'GFPGANv1.4.pth', checkpoint_path)
267
+ download_model(REALESRGAN_URL, 'RealESRGAN_x2plus.pth', checkpoint_path)
268
+
269
+ self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])
270
+
271
+ def get_cfg_defaults(self):
272
+ return {
273
+ 'MODEL': {
274
+ 'CHECKPOINTS_DIR': '',
275
+ 'CONFIG_DIR': '',
276
+ 'DEVICE': self.device,
277
+ 'SCALE': 64,
278
+ 'NUM_VOXEL_FRAMES': 8,
279
+ 'NUM_MOTION_FRAMES': 10,
280
+ 'MAX_FEATURES': 256,
281
+ 'DRIVEN_AUDIO_SAMPLE_RATE': 16000,
282
+ 'VIDEO_FPS': 25,
283
+ 'OUTPUT_VIDEO_FPS': None,
284
+ 'OUTPUT_AUDIO_SAMPLE_RATE': None,
285
+ 'USE_ENHANCER': False,
286
+ 'ENHANCER_NAME': '',
287
+ 'BG_UPSAMPLER': None,
288
+ 'IS_HALF': False
289
+ },
290
+ 'INPUT_IMAGE': {}
291
+ }
292
+
293
+ def merge_from_file(self, filepath):
294
+ if os.path.exists(filepath):
295
+ with open(filepath, 'r') as f:
296
+ cfg_from_file = yaml.safe_load(f)
297
+ self.cfg.update(cfg_from_file)
298
+
299
+ def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
300
+ batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
301
+ ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/',
302
+ tts_text=None, tts_lang='en'):
303
+ self.sadtalker_model.test(source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size,
304
+ pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode,
305
+ length_of_audio, use_blink, result_dir, tts_text, tts_lang)
306
+ return self.sadtalker_model.save_result()
307
+
308
+
309
+ class SadTalkerModel:
310
+
311
+ def __init__(self, sadtalker_cfg, device_id=[0]):
312
+ self.cfg = sadtalker_cfg
313
+ self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
314
+ self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
315
+ self.preprocesser = self.sadtalker.preprocesser
316
+ self.kp_extractor = self.sadtalker.kp_extractor
317
+ self.generator = self.sadtalker.generator
318
+ self.mapping = self.sadtalker.mapping
319
+ self.he_estimator = self.sadtalker.he_estimator
320
+ self.audio_to_coeff = self.sadtalker.audio_to_coeff
321
+ self.animate_from_coeff = self.sadtalker.animate_from_coeff
322
+ self.face_enhancer = self.sadtalker.face_enhancer
323
+
324
+ def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
325
+ batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
326
+ ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/',
327
+ tts_text=None, tts_lang='en', jitter_amount=10, jitter_source_image=False):
328
+ self.inner_test = SadTalkerInner(self, source_image, driven_audio, preprocess, still_mode, use_enhancer,
329
+ batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info,
330
+ use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang,
331
+ jitter_amount, jitter_source_image)
332
+ return self.inner_test.test()
333
+
334
+ def save_result(self):
335
+ return self.inner_test.save_result()
336
+
337
+
338
+ class SadTalkerInner:
339
+
340
+ def __init__(self, sadtalker_model, source_image, driven_audio, preprocess, still_mode, use_enhancer,
341
+ batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode,
342
+ length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image):
343
+ self.sadtalker_model = sadtalker_model
344
+ self.source_image = source_image
345
+ self.driven_audio = driven_audio
346
+ self.preprocess = preprocess
347
+ self.still_mode = still_mode
348
+ self.use_enhancer = use_enhancer
349
+ self.batch_size = batch_size
350
+ self.size = size
351
+ self.pose_style = pose_style
352
+ self.exp_scale = exp_scale
353
+ self.use_ref_video = use_ref_video
354
+ self.ref_video = ref_video
355
+ self.ref_info = ref_info
356
+ self.use_idle_mode = use_idle_mode
357
+ self.length_of_audio = length_of_audio
358
+ self.use_blink = use_blink
359
+ self.result_dir = result_dir
360
+ self.tts_text = tts_text
361
+ self.tts_lang = tts_lang
362
+ self.jitter_amount = jitter_amount
363
+ self.jitter_source_image = jitter_source_image
364
+ self.device = self.sadtalker_model.device
365
+ self.output_path = None
366
+
367
+ def get_test_data(self):
368
+ proc = self.sadtalker_model.preprocesser
369
+ if self.tts_text is not None:
370
+ temp_dir = tempfile.mkdtemp()
371
+ audio_path = os.path.join(temp_dir, 'audio.wav')
372
+ tts = TTSTalker()
373
+ tts.test(self.tts_text, self.tts_lang)
374
+ self.driven_audio = audio_path
375
+ source_image_pil = Image.open(self.source_image).convert('RGB')
376
+ if self.jitter_source_image:
377
+ jitter_dx = np.random.randint(-self.jitter_amount, self.jitter_amount + 1)
378
+ jitter_dy = np.random.randint(-self.jitter_amount, self.jitter_amount + 1)
379
+ source_image_pil = Image.fromarray(
380
+ np.roll(np.roll(np.array(source_image_pil), jitter_dx, axis=1), jitter_dy, axis=0))
381
+ source_image_tensor, crop_info, cropped_image = proc.crop(source_image_pil, self.preprocess, self.size)
382
+ if self.still_mode or self.use_idle_mode:
383
+ ref_pose_coeff = proc.generate_still_pose(self.pose_style)
384
+ ref_expression_coeff = proc.generate_still_expression(self.exp_scale)
385
+ elif self.use_idle_mode:
386
+ ref_pose_coeff = proc.generate_idles_pose(self.length_of_audio, self.pose_style)
387
+ ref_expression_coeff = proc.generate_idles_expression(self.length_of_audio)
388
+ else:
389
+ ref_pose_coeff = None
390
+ ref_expression_coeff = None
391
+ audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio,
392
+ self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'])
393
+ batch = {
394
+ 'source_image': source_image_tensor.unsqueeze(0).to(self.device),
395
+ 'audio': audio_tensor.unsqueeze(0).to(self.device),
396
+ 'ref_pose_coeff': ref_pose_coeff,
397
+ 'ref_expression_coeff': ref_expression_coeff,
398
+ 'source_image_crop': cropped_image,
399
+ 'crop_info': crop_info,
400
+ 'use_blink': self.use_blink,
401
+ 'pose_style': self.pose_style,
402
+ 'exp_scale': self.exp_scale,
403
+ 'ref_video': self.ref_video,
404
+ 'use_ref_video': self.use_ref_video,
405
+ 'ref_info': self.ref_info,
406
+ }
407
+ return batch, audio_sample_rate
408
+
409
+ def run_inference(self, batch):
410
+ kp_extractor = self.sadtalker_model.kp_extractor
411
+ generator = self.sadtalker_model.generator
412
+ mapping = self.sadtalker_model.mapping
413
+ he_estimator = self.sadtalker_model.he_estimator
414
+ audio_to_coeff = self.sadtalker_model.audio_to_coeff
415
+ animate_from_coeff = self.sadtalker_model.animate_from_coeff
416
+ proc = self.sadtalker_model.preprocesser
417
+ with torch.no_grad():
418
+ kp_source = kp_extractor(batch['source_image'])
419
+ if self.still_mode or self.use_idle_mode:
420
+ ref_pose_coeff = batch['ref_pose_coeff']
421
+ ref_expression_coeff = batch['ref_expression_coeff']
422
+ pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff)
423
+ expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff)
424
+ elif self.use_idle_mode:
425
+ ref_pose_coeff = batch['ref_pose_coeff']
426
+ ref_expression_coeff = batch['ref_expression_coeff']
427
+ pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff)
428
+ expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff)
429
+ else:
430
+ if self.use_ref_video:
431
+ kp_ref = kp_extractor(batch['source_image'])
432
+ pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], kp_ref=kp_ref,
433
+ use_ref_info=batch['ref_info'])
434
+ else:
435
+ pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'])
436
+ expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'])
437
+ coeff = {'pose_coeff': pose_coeff, 'expression_coeff': expression_coeff}
438
+ if self.use_blink:
439
+ coeff['blink_coeff'] = audio_to_coeff.get_blink_coeff(batch['audio'])
440
+ else:
441
+ coeff['blink_coeff'] = None
442
+ kp_driving = audio_to_coeff(batch['audio'])[0]
443
+ kp_norm = animate_from_coeff.normalize_kp(kp_driving)
444
+ coeff['kp_driving'] = kp_norm
445
+ coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4
446
+ face_enhancer = self.sadtalker_model.face_enhancer if self.use_enhancer else None
447
+ output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping,
448
+ he_estimator, batch['audio'], batch['source_image_crop'],
449
+ face_enhancer=face_enhancer)
450
+ return output_video
451
+
452
+ def post_processing(self, output_video, audio_sample_rate, batch):
453
+ proc = self.sadtalker_model.preprocesser
454
+ base_name = os.path.splitext(os.path.basename(batch['source_image_crop']))[0]
455
+ audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0]
456
+ output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4')
457
+ self.output_path = output_video_path
458
+ video_fps = self.sadtalker_model.cfg['MODEL']['VIDEO_FPS'] if self.sadtalker_model.cfg['MODEL'][
459
+ 'OUTPUT_VIDEO_FPS'] is None else \
460
+ self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS']
461
+ audio_output_sample_rate = self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'] if \
462
+ self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE'] is None else \
463
+ self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE']
464
+ if self.use_enhancer:
465
+ enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4')
466
+ save_video_with_watermark(output_video, self.driven_audio, enhanced_path)
467
+ paste_pic(enhanced_path, batch['source_image_crop'], batch['crop_info'], self.driven_audio,
468
+ output_video_path)
469
+ os.remove(enhanced_path)
470
+ else:
471
+ save_video_with_watermark(output_video, self.driven_audio, output_video_path)
472
+ if self.tts_text is not None:
473
+ shutil.rmtree(os.path.dirname(self.driven_audio))
474
+
475
+ def save_result(self):
476
+ return self.output_path
477
+
478
+ def __call__(self):
479
+ return self.output_path
480
+
481
+ def test(self):
482
+ batch, audio_sample_rate = self.get_test_data()
483
+ output_video = self.run_inference(batch)
484
+ self.post_processing(output_video, audio_sample_rate, batch)
485
+ return self.save_result()
486
+
487
+
488
+ class SadTalkerInnerModel:
489
+
490
+ def __init__(self, sadtalker_cfg, device_id=[0]):
491
+ self.cfg = sadtalker_cfg
492
+ self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
493
+ self.preprocesser = Preprocesser(sadtalker_cfg, self.device)
494
+ self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device)
495
+ self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device)
496
+ self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device)
497
+ self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg['MODEL'][
498
+ 'USE_ENHANCER'] else None
499
+ self.generator = Generator(sadtalker_cfg, self.device)
500
+ self.mapping = Mapping(sadtalker_cfg, self.device)
501
+ self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device)
502
+
503
+
504
+ class Preprocesser:
505
+
506
+ def __init__(self, sadtalker_cfg, device):
507
+ self.cfg = sadtalker_cfg
508
+ self.device = device
509
+ if self.cfg['INPUT_IMAGE'].get('OLD_VERSION', False):
510
+ self.face3d_helper = Face3DHelperOld(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
511
+ else:
512
+ self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
513
+ self.mouth_detector = MouthDetector()
514
+
515
+ def crop(self, source_image_pil, preprocess_type, size=256):
516
+ source_image = np.array(source_image_pil)
517
+ face_info = self.face3d_helper.run(source_image)
518
+ if face_info is None:
519
+ raise Exception("No face detected")
520
+ x_min, y_min, x_max, y_max = face_info[:4]
521
+ old_size = (x_max - x_min, y_max - y_min)
522
+ x_center = (x_max + x_min) / 2
523
+ y_center = (y_max + y_min) / 2
524
+ if preprocess_type == 'crop':
525
+ face_size = max(x_max - x_min, y_max - y_min)
526
+ x_min = int(x_center - face_size / 2)
527
+ y_min = int(y_center - face_size / 2)
528
+ x_max = int(x_center + face_size / 2)
529
+ y_max = int(y_center + face_size / 2)
530
+ else:
531
+ x_min -= int((x_max - x_min) * 0.1)
532
+ y_min -= int((y_max - y_min) * 0.1)
533
+ x_max += int((x_max - x_min) * 0.1)
534
+ y_max += int((y_max - y_min) * 0.1)
535
+ h, w = source_image.shape[:2]
536
+ x_min = max(0, x_min)
537
+ y_min = max(0, y_min)
538
+ x_max = min(w, x_max)
539
+ y_max = min(h, y_max)
540
+ cropped_image = source_image[y_min:y_max, x_min:x_max]
541
+ cropped_image_pil = Image.fromarray(cropped_image)
542
+ if size is not None and size != 0:
543
+ cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS)
544
+ source_image_tensor = self.img2tensor(cropped_image_pil)
545
+ return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename(
546
+ self.cfg['INPUT_IMAGE'].get('SOURCE_IMAGE', ''))
547
+
548
+ def img2tensor(self, img):
549
+ img = np.array(img).astype(np.float32) / 255.0
550
+ img = np.transpose(img, (2, 0, 1))
551
+ return torch.FloatTensor(img)
552
+
553
+ def video_to_tensor(self, video, device):
554
+ video_tensor_list = []
555
+ import torchvision.transforms as transforms
556
+ transform_func = transforms.ToTensor()
557
+ for frame in video:
558
+ frame_pil = Image.fromarray(frame)
559
+ frame_tensor = transform_func(frame_pil).unsqueeze(0).to(device)
560
+ video_tensor_list.append(frame_tensor)
561
+ video_tensor = torch.cat(video_tensor_list, dim=0)
562
+ return video_tensor
563
+
564
+ def process_audio(self, audio_path, sample_rate):
565
+ wav = load_wav_util(audio_path, sample_rate)
566
+ wav_tensor = torch.FloatTensor(wav).unsqueeze(0)
567
+ return wav_tensor, sample_rate
568
+
569
+ def generate_still_pose(self, pose_style):
570
+ ref_pose_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device)
571
+ ref_pose_coeff[:, :3] = torch.tensor([0, 0, pose_style * 0.3], dtype=torch.float32)
572
+ return ref_pose_coeff
573
+
574
+ def generate_still_expression(self, exp_scale):
575
+ ref_expression_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device)
576
+ ref_expression_coeff[:, :3] = torch.tensor([0, 0, exp_scale * 0.3], dtype=torch.float32)
577
+ return ref_expression_coeff
578
+
579
+ def generate_idles_pose(self, length_of_audio, pose_style):
580
+ num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS'])
581
+ ref_pose_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
582
+ start_pose = self.generate_still_pose(pose_style)
583
+ end_pose = self.generate_still_pose(pose_style)
584
+ for frame_idx in range(num_frames):
585
+ alpha = frame_idx / num_frames
586
+ ref_pose_coeff[frame_idx] = (1 - alpha) * start_pose + alpha * end_pose
587
+ return ref_pose_coeff
588
+
589
+ def generate_idles_expression(self, length_of_audio):
590
+ num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS'])
591
+ ref_expression_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
592
+ start_exp = self.generate_still_expression(1.0)
593
+ end_exp = self.generate_still_expression(1.0)
594
+ for frame_idx in range(num_frames):
595
+ alpha = frame_idx / num_frames
596
+ ref_expression_coeff[frame_idx] = (1 - alpha) * start_exp + alpha * end_exp
597
+ return ref_expression_coeff
598
+
599
+
600
+ class KeyPointExtractor(nn.Module):
601
+
602
+ def __init__(self, sadtalker_cfg, device):
603
+ super(KeyPointExtractor, self).__init__()
604
+ self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'],
605
+ num_kp=10,
606
+ num_dilation_blocks=2,
607
+ dropout_rate=0.1).to(device)
608
+ checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors')
609
+ self.load_kp_detector(checkpoint_path, device)
610
+
611
+ def load_kp_detector(self, checkpoint_path, device):
612
+ if os.path.exists(checkpoint_path):
613
+ if checkpoint_path.endswith('safetensors'):
614
+ checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
615
+ else:
616
+ checkpoint = torch.load(checkpoint_path, map_location=device)
617
+ self.kp_extractor.load_state_dict(checkpoint.get('kp_detector', {}))
618
+ else:
619
+ raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
620
+
621
+ def forward(self, x):
622
+ kp = self.kp_extractor(x)
623
+ return kp
624
+
625
+
626
+ class Audio2Coeff(nn.Module):
627
+
628
+ def __init__(self, sadtalker_cfg, device):
629
+ super(Audio2Coeff, self).__init__()
630
+ self.audio_model = Wav2Vec2Model().to(device)
631
+ checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'wav2vec2.pth')
632
+ self.load_audio_model(checkpoint_path, device)
633
+ self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device)
634
+ self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device)
635
+ self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device)
636
+ mapping_checkpoint = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'audio2pose_00140-model.pth')
637
+ self.load_mapping_model(mapping_checkpoint, device)
638
+
639
+ def load_audio_model(self, checkpoint_path, device):
640
+ if os.path.exists(checkpoint_path):
641
+ if checkpoint_path.endswith('safetensors'):
642
+ checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
643
+ else:
644
+ checkpoint = torch.load(checkpoint_path, map_location=device)
645
+ self.audio_model.load_state_dict(checkpoint.get("wav2vec2", {}))
646
+ else:
647
+ raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
648
+
649
+ def load_mapping_model(self, checkpoint_path, device):
650
+ if os.path.exists(checkpoint_path):
651
+ if checkpoint_path.endswith('safetensors'):
652
+ checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
653
+ else:
654
+ checkpoint = torch.load(checkpoint_path, map_location=device)
655
+ self.pose_mapper.load_state_dict(checkpoint.get("pose_predictor", {}))
656
+ self.exp_mapper.load_state_dict(checkpoint.get("exp_predictor", {}))
657
+ self.blink_mapper.load_state_dict(checkpoint.get("blink_predictor", {}))
658
+ else:
659
+ raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
660
+
661
+ def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''):
662
+ audio_embedding = self.audio_model(audio_tensor)
663
+ pose_coeff = self.pose_mapper(audio_embedding)
664
+ if ref_pose_coeff is not None:
665
+ pose_coeff = ref_pose_coeff
666
+ if kp_ref is not None and use_ref_info == 'pose':
667
+ ref_pose_6d = kp_ref['value'][:, :6]
668
+ pose_coeff[:, :6] = self.mean_std_normalize(ref_pose_6d).mean(dim=1)
669
+ return pose_coeff
670
+
671
+ def get_exp_coeff(self, audio_tensor, ref_expression_coeff=None):
672
+ audio_embedding = self.audio_model(audio_tensor)
673
+ expression_coeff = self.exp_mapper(audio_embedding)
674
+ if ref_expression_coeff is not None:
675
+ expression_coeff = ref_expression_coeff
676
+ return expression_coeff
677
+
678
+ def get_blink_coeff(self, audio_tensor):
679
+ audio_embedding = self.audio_model(audio_tensor)
680
+ blink_coeff = self.blink_mapper(audio_embedding)
681
+ return blink_coeff
682
+
683
+ def forward(self, audio):
684
+ audio_embedding = self.audio_model(audio)
685
+ pose_coeff, expression_coeff, blink_coeff = self.pose_mapper(audio_embedding), self.exp_mapper(
686
+ audio_embedding), self.blink_mapper(audio_embedding)
687
+ return pose_coeff, expression_coeff, blink_coeff
688
+
689
+ def mean_std_normalize(self, coeff):
690
+ mean = coeff.mean(dim=1, keepdim=True)
691
+ std = coeff.std(dim=1, keepdim=True)
692
+ return (coeff - mean) / std
693
+
694
+
695
+ class AnimateFromCoeff(nn.Module):
696
+
697
+ def __init__(self, sadtalker_cfg, device):
698
+ super(AnimateFromCoeff, self).__init__()
699
+ self.generator = Generator(sadtalker_cfg, device)
700
+ self.mapping = Mapping(sadtalker_cfg, device)
701
+ self.kp_norm = KeypointNorm(device=device)
702
+ self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, device)
703
+
704
+ def normalize_kp(self, kp_driving):
705
+ return self.kp_norm(kp_driving)
706
+
707
+ def generate(self, source_image, kp_source, coeff, generator, mapping, he_estimator, audio, source_image_crop,
708
+ face_enhancer=None):
709
+ kp_driving = coeff['kp_driving']
710
+ jacobian = coeff['jacobian']
711
+ pose_coeff = coeff['pose_coeff']
712
+ expression_coeff = coeff['expression_coeff']
713
+ blink_coeff = coeff['blink_coeff']
714
+ with torch.no_grad():
715
+ if blink_coeff is not None:
716
+ sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
717
+ dense_motion = sparse_motion['dense_motion']
718
+ video_deocclusion = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None})
719
+ face_3d = mapping(expression_coeff, pose_coeff, blink_coeff)
720
+ video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None},
721
+ face_3d_param=face_3d)
722
+ video_output = video_deocclusion['video_no_reocclusion'] + video_3d['video_3d']
723
+ video_output = self.make_animation(video_output)
724
+ else:
725
+ sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
726
+ dense_motion = sparse_motion['dense_motion']
727
+ face_3d = mapping(expression_coeff, pose_coeff)
728
+ video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None},
729
+ face_3d_param=face_3d)
730
+ video_output = video_3d['video_3d']
731
+ video_output = self.make_animation(video_output)
732
+ if face_enhancer is not None:
733
+ video_output_enhanced = []
734
+ for frame in tqdm(video_output, 'Face enhancer running'):
735
+ pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
736
+ enhanced_image = face_enhancer.enhance(np.array(pil_image))[0]
737
+ video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB))
738
+ video_output = video_output_enhanced
739
+ return video_output
740
+
741
+ def make_animation(self, video_array):
742
+ H, W, _ = video_array[0].shape
743
+ out = cv2.VideoWriter('./tmp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H))
744
+ for img in video_array:
745
+ out.write(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
746
+ out.release()
747
+ video = imageio.mimread('./tmp.mp4')
748
+ os.remove('./tmp.mp4')
749
+ return video
750
+
751
+
752
+ class Generator(nn.Module):
753
+
754
+ def __init__(self, sadtalker_cfg, device):
755
+ super(Generator, self).__init__()
756
+ self.generator = Hourglass(block_expansion=sadtalker_cfg['MODEL']['SCALE'],
757
+ num_blocks=sadtalker_cfg['MODEL']['NUM_VOXEL_FRAMES'],
758
+ max_features=sadtalker_cfg['MODEL']['MAX_FEATURES'],
759
+ num_channels=3,
760
+ kp_size=10,
761
+ num_deform_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES']).to(device)
762
+ checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'generator.pth')
763
+ self.load_generator(checkpoint_path, device)
764
+
765
+ def load_generator(self, checkpoint_path, device):
766
+ if os.path.exists(checkpoint_path):
767
+ if checkpoint_path.endswith('safetensors'):
768
+ checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
769
+ else:
770
+ checkpoint = torch.load(checkpoint_path, map_location=device)
771
+ self.generator.load_state_dict(checkpoint.get('generator', {}))
772
+ else:
773
+ raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
774
+
775
+ def forward(self, source_image, dense_motion, bg_param, face_3d_param=None):
776
+ if face_3d_param is not None:
777
+ video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param,
778
+ face_3d_param=face_3d_param)
779
+ else:
780
+ video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param)
781
+ return {'video_3d': video_3d, 'video_no_reocclusion': video_3d}
782
+
783
+
784
+ class Mapping(nn.Module):
785
+
786
+ def __init__(self, sadtalker_cfg, device):
787
+ super(Mapping, self).__init__()
788
+ self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device)
789
+ checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'mapping.pth')
790
+ self.load_mapping_net(checkpoint_path, device)
791
+ self.f_3d_mean = torch.zeros(1, 64, device=device)
792
+
793
+ def load_mapping_net(self, checkpoint_path, device):
794
+ if os.path.exists(checkpoint_path):
795
+ if checkpoint_path.endswith('safetensors'):
796
+ checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
797
+ else:
798
+ checkpoint = torch.load(checkpoint_path, map_location=device)
799
+ self.mapping_net.load_state_dict(checkpoint.get('mapping', {}))
800
+ else:
801
+ raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
802
+
803
+ def forward(self, expression_coeff, pose_coeff, blink_coeff=None):
804
+ coeff = torch.cat([expression_coeff, pose_coeff], dim=1)
805
+ face_3d = self.mapping_net(coeff) + self.f_3d_mean
806
+ if blink_coeff is not None:
807
+ face_3d[:, -1:] = blink_coeff
808
+ return face_3d
809
+
810
+
811
+ class OcclusionAwareDenseMotion(nn.Module):
812
+
813
+ def __init__(self, sadtalker_cfg, device):
814
+ super(OcclusionAwareDenseMotion, self).__init__()
815
+ self.dense_motion_network = DenseMotionNetwork(num_kp=10,
816
+ num_channels=3,
817
+ block_expansion=sadtalker_cfg['MODEL']['SCALE'],
818
+ num_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'] - 1,
819
+ max_features=sadtalker_cfg['MODEL']['MAX_FEATURES']).to(device)
820
+ checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'dense_motion.pth')
821
+ self.load_dense_motion_network(checkpoint_path, device)
822
+
823
+ def load_dense_motion_network(self, checkpoint_path, device):
824
+ if os.path.exists(checkpoint_path):
825
+ if checkpoint_path.endswith('safetensors'):
826
+ checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
827
+ else:
828
+ checkpoint = torch.load(checkpoint_path, map_location=device)
829
+ self.dense_motion_network.load_state_dict(checkpoint.get('dense_motion', {}))
830
+ else:
831
+ raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
832
+
833
+ def forward(self, kp_source, kp_driving, jacobian):
834
+ sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian)
835
+ return sparse_motion
836
+
837
+
838
+ class FaceEnhancer(nn.Module):
839
+
840
+ def __init__(self, sadtalker_cfg, device):
841
+ super(FaceEnhancer, self).__init__()
842
+ enhancer_name = sadtalker_cfg['MODEL']['ENHANCER_NAME']
843
+ bg_upsampler = sadtalker_cfg['MODEL']['BG_UPSAMPLER']
844
+ if enhancer_name == 'gfpgan':
845
+ from gfpgan import GFPGANer
846
+ self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'GFPGANv1.4.pth'),
847
+ upscale=1,
848
+ arch='clean',
849
+ channel_multiplier=2,
850
+ bg_upsampler=bg_upsampler)
851
+ elif enhancer_name == 'realesrgan':
852
+ from realesrgan import RealESRGANer
853
+ half = False if device == 'cpu' else sadtalker_cfg['MODEL']['IS_HALF']
854
+ self.face_enhancer = RealESRGANer(scale=2,
855
+ model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'],
856
+ 'RealESRGAN_x2plus.pth'),
857
+ tile=0,
858
+ tile_pad=10,
859
+ pre_pad=0,
860
+ half=half,
861
+ device=device)
862
+ else:
863
+ self.face_enhancer = None
864
+
865
+ def forward(self, x):
866
+ return self.face_enhancer.enhance(x, outscale=1)[0]
sentiment_api.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import jsonify
2
+ from main import *
3
+ #from main import import sentiment_model, device
4
+ import torch
5
+
6
+ def analyze_sentiment(text, output_path="output_sentiment.json"):
7
+ if sentiment_model is None:
8
+ return "Sentiment model not initialized."
9
+
10
+ input_tokens = sentiment_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
11
+ with torch.no_grad():
12
+ sentiment_logits = sentiment_model(input_tokens['input_ids'])
13
+ predicted_class_id = torch.argmax(sentiment_logits, dim=-1).item()
14
+ sentiment_label = sentiment_model.config.id2label[predicted_class_id]
15
+ probability = torch.softmax(sentiment_logits, dim=-1)[0][predicted_class_id].item()
16
+
17
+ return {"sentiment": sentiment_label, "probability": probability}
18
+
19
+ def sentiment_api():
20
+ data = request.get_json()
21
+ text = data.get('text')
22
+ if not text:
23
+ return jsonify({"error": "Text is required"}), 400
24
+ output_file = analyze_sentiment(text)
25
+ if output_file == "Sentiment model not initialized.":
26
+ return jsonify({"error": "Sentiment analysis failed"}), 500
27
+ return jsonify(output_file)
stt_api.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ from flask import jsonify, send_file, request
4
+ from main import *
5
+ #from main import import stt_model, device
6
+ import torch
7
+ import torchaudio
8
+
9
+ def speech_to_text_func(audio_path, output_path="output_stt.txt"):
10
+ if stt_model is None:
11
+ return "STT model not initialized."
12
+
13
+ waveform, sample_rate = torchaudio.load(audio_path)
14
+ if waveform.ndim > 1:
15
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
16
+ waveform = waveform.to(device)
17
+ with torch.no_grad():
18
+ logits = stt_model(waveform)
19
+ predicted_ids = torch.argmax(logits, dim=-1)
20
+ transcription = stt_model.tokenizer.decode(predicted_ids[0].cpu().tolist())
21
+
22
+ with open(output_path, "w") as file:
23
+ file.write(transcription)
24
+ return output_path
25
+
26
+ def stt_api():
27
+ if 'audio' not in request.files:
28
+ return jsonify({"error": "Audio file is required"}), 400
29
+ audio_file = request.files['audio']
30
+ temp_audio_path = f"temp_audio_{uuid.uuid4()}.wav"
31
+ audio_file.save(temp_audio_path)
32
+ output_file = speech_to_text_func(temp_audio_path)
33
+ os.remove(temp_audio_path)
34
+ if output_file == "STT model not initialized.":
35
+ return jsonify({"error": "STT failed"}), 500
36
+ return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output.txt")
summarization_api.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import jsonify, send_file, request
2
+ from main import *
3
+ #from main import import summarization_model, summarization_word_to_index, device
4
+ import torch
5
+
6
+ def summarize_text(text, output_path="output_summary.txt"):
7
+ if summarization_model is None:
8
+ return "Summarization model not initialized."
9
+
10
+ input_tokens = [summarization_word_to_index.get(word.lower(), 1) for word in text.split()]
11
+ input_tensor = torch.tensor([input_tokens], dtype=torch.long).to(device)
12
+
13
+ with torch.no_grad():
14
+ summary_ids = summarization_model.generate(input_tensor, num_beams=4, max_length=100, early_stopping=True)
15
+ summary_text = summarization_model.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
16
+
17
+ with open(output_path, "w") as file:
18
+ file.write(summary_text)
19
+ return output_path
20
+
21
+ def summarization_api():
22
+ data = request.get_json()
23
+ text = data.get('text')
24
+ if not text:
25
+ return jsonify({"error": "Text is required"}), 400
26
+ output_file = summarize_text(text)
27
+ if output_file == "Summarization model not initialized.":
28
+ return jsonify({"error": "Summarization failed"}), 500
29
+ return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output_summary.txt")
text_generation.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from tqdm import trange
4
+ import time
5
+ from tokenxxx import *
6
+ from main import *
7
+ #from main import import model_gpt2, enc, codegen_model, codegen_tokenizer, summarization_model, device, system_prompt, MAX_LENGTH, summarize_text as summarize_func
8
+ from duckduckgo_search import DDGS
9
+
10
+ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
11
+ top_k = min(top_k, logits.size(-1))
12
+ if top_k > 0:
13
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., [-1]]
14
+ logits[indices_to_remove] = filter_value
15
+ if top_p > 0.0:
16
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
17
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
18
+ sorted_indices_to_remove = cumulative_probs > top_p
19
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
20
+ sorted_indices_to_remove[..., 0] = 0
21
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
22
+ logits[indices_to_remove] = filter_value
23
+ return logits
24
+
25
+ def sample_sequence(prompt, model, enc, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
26
+ start_time = time.time()
27
+ context_tokens = enc.encode(prompt)
28
+ context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
29
+ generated = context_tokens
30
+ past = None
31
+ text_generated_count = 0
32
+ past_key_values = past if past is not None else None
33
+
34
+ with torch.no_grad():
35
+ outputs = model(context_tokens_tensor, past_key_values=past_key_values)
36
+ next_token_logits = outputs[0][:, -1, :] / temperature
37
+ past = outputs[1]
38
+ for token_index in set(generated):
39
+ next_token_logits[0, token_index] /= repetition_penalty
40
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
41
+ if temperature == 0:
42
+ next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
43
+ else:
44
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
45
+ generated += next_token.tolist()[0]
46
+ text_generated_count += 1
47
+ token = next_token.tolist()[0][0]
48
+ yield enc.decode([token])
49
+ if token == enc.encoder[END_OF_TEXT_TOKEN]:
50
+ yield "<END_STREAM>"
51
+ if text_generated_count > length:
52
+ yield "<END_STREAM>"
53
+ if (time.time() - start_time) * 1000 > 5000:
54
+ yield "<END_STREAM>"
55
+
56
+ def sample_sequence_codegen(prompt, model, tokenizer, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
57
+ start_time = time.time()
58
+ context_tokens = tokenizer.encode(prompt)
59
+ context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device).unsqueeze(0)
60
+ generated = context_tokens
61
+ past = None
62
+ text_generated_count = 0
63
+ with torch.no_grad():
64
+ outputs = model(input_ids=context_tokens_tensor, past_key_values=past, labels=None)
65
+ next_token_logits = outputs[0][:, -1, :] / temperature
66
+ past = outputs[1]
67
+ for token_index in set(generated):
68
+ next_token_logits[0, token_index] /= repetition_penalty
69
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
70
+ if temperature == 0:
71
+ next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
72
+ else:
73
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
74
+ generated.append(next_token.tolist()[0][0])
75
+ text_generated_count += 1
76
+ token = next_token.tolist()[0][0]
77
+ yield tokenizer.decode([token])
78
+ if token == 50256:
79
+ yield "<END_STREAM>"
80
+ if text_generated_count > length:
81
+ yield "<END_STREAM>"
82
+ if (time.time() - start_time) * 1000 > 5000:
83
+ yield "<END_STREAM>"
84
+
85
+ def perform_reasoning_stream(text_input, temperature, top_k, top_p, repetition_penalty):
86
+ try:
87
+ prompt_text = system_prompt + "\n\n"
88
+ prompt_text += "User: " + text_input + "\nCyrah: "
89
+ reasoning_prompt = prompt_text
90
+
91
+ ddgs = DDGS()
92
+ search_results = [r for r in ddgs.text(text_input, max_results=MAX_XDD)]
93
+ if search_results:
94
+ prompt_text += "\nWeb Search Results:\n"
95
+ for result in search_results:
96
+ prompt_text += f"- {result['body']}\n"
97
+ prompt_text += "\n"
98
+
99
+ generated_text_stream = []
100
+ stream_type = "text"
101
+
102
+ if "code" in text_input.lower() or "program" in text_input.lower():
103
+ if codegen_model and codegen_tokenizer:
104
+ generated_text_stream = sample_sequence_codegen(
105
+ prompt=reasoning_prompt,
106
+ model=codegen_model,
107
+ tokenizer=codegen_tokenizer,
108
+ length=MAX_LENGTH,
109
+ temperature=temperature,
110
+ top_k=top_k,
111
+ top_p=top_p,
112
+ repetition_penalty=repetition_penalty,
113
+ device=device
114
+ )
115
+ stream_type = "text"
116
+ elif "summarize" in text_input.lower() or "summary" in text_input.lower():
117
+ if summarization_model:
118
+ summary = summarize_func(text_input)
119
+ yield f"SUMMARY_TEXT:{summary}"
120
+ yield "<END_STREAM>"
121
+ stream_type = "summary"
122
+ else:
123
+ if model_gpt2 and enc:
124
+ generated_text_stream = sample_sequence(
125
+ prompt=reasoning_prompt,
126
+ model=model_gpt2,
127
+ enc=enc,
128
+ length=MAX_LENGTH,
129
+ temperature=temperature,
130
+ top_k=top_k,
131
+ top_p=top_p,
132
+ repetition_penalty=repetition_penalty,
133
+ device=device
134
+ )
135
+ stream_type = "text"
136
+
137
+ accumulated_text = ""
138
+ if stream_type == "text":
139
+ for token in generated_text_stream:
140
+ if token == "<END_STREAM>":
141
+ yield accumulated_text
142
+ yield "<END_STREAM>"
143
+ return
144
+ if token == END_OF_TEXT_TOKEN:
145
+ accumulated_text += END_OF_TEXT_TOKEN
146
+ continue
147
+ if token:
148
+ accumulated_text += token
149
+ except Exception as e:
150
+ print(f"Reasoning Error: {e}")
151
+ yield "Error during reasoning. Please try again."
152
+ yield "<END_STREAM>"
text_to_video_api.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ from flask import jsonify, send_file, request
4
+ from main import *
5
+ #from main import import text_to_video_model
6
+ import torch
7
+ import io
8
+ from skimage import img_as_ubyte
9
+ import imageio
10
+
11
+ def text_to_video_func(prompt, output_path="output_video.mp4"):
12
+ if text_to_video_model is None:
13
+ return "Text-to-Video model not initialized."
14
+ video_frames_list = text_to_video_model(prompt)
15
+ if video_frames_list and hasattr(video_frames_list, 'frames'):
16
+ video_frames = video_frames_list.frames
17
+ export_to_video_pure(video_frames, output_video=output_path)
18
+ return output_path
19
+ return "Video generation failed."
20
+
21
+ def export_to_video_pure(video_frames, output_video="output_video.mp4", fps=25):
22
+ writer = imageio.get_writer(output_video, fps=fps)
23
+ for frame in video_frames:
24
+ writer.append_data(img_as_ubyte(frame))
25
+ writer.close()
26
+
27
+ def text_to_video_api():
28
+ data = request.get_json()
29
+ prompt = data.get('prompt')
30
+ if not prompt:
31
+ return jsonify({"error": "Prompt is required"}), 400
32
+ output_file = text_to_video_func(prompt)
33
+ if output_file == "Text-to-Video model not initialized." or output_file == "Video generation failed.":
34
+ return jsonify({"error": "Text to video failed"}), 500
35
+ with open(output_file, 'rb') as f:
36
+ video_content = f.read()
37
+ return send_file(io.BytesIO(video_content), mimetype='video/mp4', as_attachment=True, download_name="output_video.mp4")
tokenxxx.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import unicodedata
4
+ from functools import lru_cache
5
+ import wget
6
+ import os
7
+ from constants import *
8
+ import nltk
9
+
10
+ @lru_cache()
11
+ def bytes_to_unicode():
12
+ bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
13
+ cs = bs[:]
14
+ n = 0
15
+ for b in range(2**8):
16
+ if b not in bs:
17
+ bs.append(b)
18
+ cs.append(2**8 + n)
19
+ n += 1
20
+ cs = [chr(n) for n in cs]
21
+ return dict(zip(bs, cs))
22
+
23
+ def get_pairs(word):
24
+ pairs = set()
25
+ prev_char = word[0]
26
+ for char in word[1:]:
27
+ pairs.add((prev_char, char))
28
+ prev_char = char
29
+ return pairs
30
+
31
+ class Encoder:
32
+ def __init__(self, encoder, bpe_merges, errors='replace', tokenize=None):
33
+ self.encoder = encoder
34
+ self.decoder = {v:k for k,v in self.encoder.items()}
35
+ self.errors = errors
36
+ self.byte_encoder = bytes_to_unicode()
37
+ self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
38
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
39
+ self.cache = {}
40
+ if tokenize is None:
41
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?[^\s\w]+|\s+(?!\S)|\s+""", re.UNICODE)
42
+ self.tokenize = lambda text: re.findall(self.pat, text)
43
+ else:
44
+ self.tokenize = tokenize
45
+
46
+ def bpe(self, token):
47
+ if token in self.cache:
48
+ return self.cache[token]
49
+ word = tuple(token)
50
+ pairs = get_pairs(word)
51
+ if not pairs:
52
+ return token
53
+ while True:
54
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
55
+ if bigram not in self.bpe_ranks:
56
+ break
57
+ first, second = bigram
58
+ new_word = []
59
+ i = 0
60
+ while i < len(word):
61
+ try:
62
+ j = word.index(first, i)
63
+ new_word.extend(word[i:j])
64
+ i = j
65
+ except ValueError:
66
+ new_word.extend(word[i:])
67
+ break
68
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
69
+ new_word.append(first+second)
70
+ i += 2
71
+ else:
72
+ new_word.append(word[i])
73
+ i += 1
74
+ new_word = tuple(new_word)
75
+ word = new_word
76
+ if len(word) == 1:
77
+ break
78
+ else:
79
+ pairs = get_pairs(word)
80
+ word = ' '.join(word)
81
+ self.cache[token] = word
82
+ return word
83
+
84
+ def encode(self, text):
85
+ bpe_tokens = []
86
+ normalized_text = unicodedata.normalize('NFKC', text)
87
+ normalized_text = ''.join(c for c in normalized_text if c.isascii() and c != '\t')
88
+ normalized_text = ''.join(c for c in normalized_text if not unicodedata.category(c).startswith('C'))
89
+ for token in self.tokenize(normalized_text):
90
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8', errors='ignore'))
91
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
92
+ return bpe_tokens
93
+
94
+ def decode(self, tokens):
95
+ text = ''.join([self.decoder[token] for token in tokens])
96
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
97
+ decoded_text = text.replace(" .", ".").replace(" ,", ",").replace(" '", "'").replace(" ?", "?").replace(" !", "!").replace(" :", ":").replace('\n', '<br>')
98
+ sentences = nltk.sent_tokenize(decoded_text)
99
+ return ' '.join(sentences).replace("<br>", "<br>\n")
100
+
101
+ def get_encoder_gpt2():
102
+ encoder_path = os.path.join(GPT2_FOLDER, ENCODER_FILE)
103
+ vocab_path = os.path.join(GPT2_FOLDER, VOCAB_FILE)
104
+ if not os.path.exists(GPT2_FOLDER):
105
+ os.makedirs(GPT2_FOLDER)
106
+ if not os.path.exists(encoder_path):
107
+ wget.download(ENCODER_URL, out=encoder_path)
108
+ if not os.path.exists(vocab_path):
109
+ wget.download(VOCAB_URL, out=vocab_path)
110
+
111
+ with open(encoder_path, 'r') as f:
112
+ encoder = json.load(f)
113
+ with open(vocab_path, 'r', encoding="utf-8") as f:
114
+ bpe_data = f.read()
115
+ bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
116
+ encoder_obj = Encoder(encoder=encoder, bpe_merges=bpe_merges)
117
+ encoder_obj.encoder[END_OF_TEXT_TOKEN] = len(encoder_obj.encoder)
118
+ encoder_obj.decoder[len(encoder_obj.decoder)] = END_OF_TEXT_TOKEN
119
+ return encoder_obj
120
+
121
+ def get_codegen_tokenizer_pure(vocab_file, merges_file):
122
+ vocab = json.load(open(vocab_file))
123
+ merges = open(merges_file, 'r', encoding="utf-8").read().split('\n')[1:-1]
124
+ bpe_merges = [tuple(m.split()) for m in merges]
125
+ byte_encoder = bytes_to_unicode()
126
+ byte_decoder = {v: k for k, v in byte_encoder.items()}
127
+ tokenizer_regex = re.compile(r'''<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+''')
128
+ tokenize = lambda text: re.findall(tokenizer_regex, text)
129
+ encoder_obj = Encoder(
130
+ encoder=vocab,
131
+ bpe_merges=bpe_merges,
132
+ byte_encoder=byte_encoder,
133
+ byte_decoder=byte_decoder,
134
+ tokenize=tokenize
135
+ )
136
+ return encoder_obj
137
+
138
+ def codegen_tokenize(text, tokenizer):
139
+ return tokenizer.encode(text)
140
+
141
+ def codegen_decode(tokens, tokenizer):
142
+ return tokenizer.decode(tokens)
143
+
144
+ def tokenize_text(text):
145
+ global vocabulary, word_to_index, index_to_word
146
+ tokens = text.lower().split()
147
+ for token in tokens:
148
+ if token not in vocabulary:
149
+ vocabulary.add(token)
150
+ word_to_index[token] = len(index_to_word)
151
+ index_to_word.append(token)
152
+ return tokens
153
+
154
+ def text_to_vector(text):
155
+ global vocabulary, word_to_index
156
+ tokens = tokenize_text(text)
157
+ vector = torch.zeros(len(vocabulary))
158
+ for token in tokens:
159
+ if token in word_to_index:
160
+ vector[word_to_index[token]] += 1
161
+ return vector
translation_api.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import jsonify, send_file, request
2
+ from main import *
3
+ #from main import import translation_model, device
4
+
5
+ def perform_translation(text, target_language_code='es_XX', source_language_code='en_XX', output_path="output_translation.txt"):
6
+ if translation_model is None:
7
+ return "Translation model not initialized."
8
+
9
+ encoded_text = translation_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
10
+ generated_tokens = translation_model.generate(input_ids=encoded_text['input_ids'], attention_mask=encoded_text['attention_mask'], forced_bos_token_id=translation_model.config.lang_code_to_id[target_language_code])
11
+ translation = translation_model.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
12
+
13
+ with open(output_path, "w") as file:
14
+ file.write(translation)
15
+ return output_path
16
+
17
+ def translation_api():
18
+ data = request.get_json()
19
+ text = data.get('text')
20
+ target_lang = data.get('target_lang', 'es')
21
+ source_lang = data.get('source_lang', 'en')
22
+ if not text:
23
+ return jsonify({"error": "Text is required"}), 400
24
+ output_file = perform_translation(text, target_language_code=f'{target_lang}_XX', source_language_code=f'{source_lang}_XX')
25
+ if output_file == "Translation model not initialized.":
26
+ return jsonify({"error": "Translation failed"}), 500
27
+ return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output_translation.txt")
tts_api.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from flask import jsonify, send_file, request
3
+ from main import *
4
+ #from main import import tts_model, device
5
+
6
+ def text_to_speech_func(text, output_path="output_tts.wav"):
7
+ if tts_model is None:
8
+ return "TTS model not initialized."
9
+ input_tokens = tts_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
10
+ with torch.no_grad():
11
+ audio_output = tts_model(input_tokens['input_ids'])
12
+ torchaudio.save(output_path, audio_output.cpu(), 16000)
13
+ return output_path
14
+
15
+ def tts_api():
16
+ data = request.get_json()
17
+ text = data.get('text')
18
+ if not text:
19
+ return jsonify({"error": "Text is required"}), 400
20
+ output_file = text_to_speech_func(text)
21
+ if output_file == "TTS model not initialized.":
22
+ return jsonify({"error": "TTS generation failed"}), 500
23
+ return send_file(output_file, mimetype="audio/wav", as_attachment=True, download_name="output.wav")
utils.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from bs4 import BeautifulSoup
3
+ from faker import Faker
4
+ from urllib.request import urlretrieve
5
+ import urllib.request
6
+ from urllib3.util.retry import Retry
7
+ import time
8
+ import os
9
+ import wget
10
+ import json
11
+ import unicodedata
12
+ import nltk
13
+ from sklearn.datasets import fetch_20newsgroups
14
+ from sklearn.feature_extraction.text import TfidfVectorizer
15
+ from sklearn.linear_model import LogisticRegression
16
+ from sklearn.multiclass import OneVsRestClassifier
17
+ import warnings
18
+ from requests.adapters import HTTPAdapter
19
+ from constants import *
20
+
21
+ MAX_XDD = 5
22
+ use_google_search = True
23
+ use_20newsgroup = True
24
+ fake = Faker()
25
+
26
+ def create_retry_session():
27
+ retry_strategy = Retry(
28
+ total=5,
29
+ status_forcelist=[429, 500, 502, 503, 504],
30
+ method_whitelist=["GET"],
31
+ backoff_factor=1,
32
+ )
33
+ adapter = HTTPAdapter(max_retries=retry_strategy)
34
+ http = requests.Session()
35
+ http.mount("https://", adapter)
36
+ http.mount("http://", adapter)
37
+ return http
38
+
39
+ def get_google_search_results(query, retry_session):
40
+ if not use_google_search:
41
+ return []
42
+ headers = {"User-Agent": fake.user_agent()}
43
+ search_url = f"https://www.google.com/search?q={query}"
44
+ try:
45
+ response = retry_session.get(search_url, headers=headers, timeout=10)
46
+ response.raise_for_status()
47
+ except requests.exceptions.RequestException as e:
48
+ return []
49
+ soup = BeautifulSoup(response.text, "html.parser")
50
+ search_results = []
51
+ for a_tag in soup.find_all('a', href=True):
52
+ if 'url?q=' in a_tag['href'] and not a_tag['href'].startswith("https://accounts.google.com"):
53
+ search_results.append(a_tag['href'].split('url?q=')[1].split('&')[0])
54
+ return search_results
55
+
56
+ def fetch_20newsgroup_data():
57
+ if not use_20newsgroup:
58
+ return []
59
+ try:
60
+ newsgroups_train = fetch_20newsgroups(subset='train', categories=['talk.trivia', 'rec.sport.baseball', 'sci.med', 'comp.sys.ibm.pc.hardware', 'soc.religion.christian'])
61
+ data = newsgroups_train.data
62
+ return data
63
+ except Exception as e:
64
+ return []
65
+
66
+ def download_file(url, filename, folder, retries=3):
67
+ filepath = os.path.join(folder, filename)
68
+ if os.path.exists(filepath):
69
+ return True
70
+ os.makedirs(folder, exist_ok=True)
71
+ for attempt in range(retries):
72
+ try:
73
+ wget.download(url, out=filepath)
74
+ return True
75
+ except Exception as e:
76
+ if attempt < retries - 1:
77
+ time.sleep(2)
78
+ else:
79
+ return False
80
+ return False
81
+
82
+ def download_gpt2_files(folder, model_url, model_file, encoder_url, encoder_file, vocab_url, vocab_file):
83
+ if not os.path.exists(folder):
84
+ os.makedirs(folder)
85
+ if not os.path.exists(os.path.join(folder, model_file)):
86
+ download_file(model_url, model_file, folder)
87
+ if not os.path.exists(os.path.join(folder, encoder_file)):
88
+ download_file(encoder_url, encoder_file, folder)
89
+ if not os.path.exists(os.path.join(folder, vocab_file)):
90
+ download_file(vocab_url, vocab_file, folder)
91
+
92
+ def download_translation_files(folder, model_files_urls):
93
+ if not os.path.exists(folder):
94
+ os.makedirs(folder)
95
+ for url, filename in model_files_urls:
96
+ if not os.path.exists(os.path.join(folder, filename)):
97
+ download_file(url, filename, folder)
98
+
99
+ def download_codegen_files(folder, model_files_urls):
100
+ if not os.path.exists(folder):
101
+ os.makedirs(folder)
102
+ for url, filename in model_files_urls:
103
+ if not os.path.exists(os.path.join(folder, filename)):
104
+ download_file(url, filename, folder)
105
+
106
+ def download_summarization_files(folder, model_files_urls):
107
+ if not os.path.exists(folder):
108
+ os.makedirs(folder)
109
+ for url, filename in model_files_urls:
110
+ if not os.path.exists(os.path.join(folder, filename)):
111
+ download_file(url, filename, folder)
112
+
113
+ def download_imagegen_files(folder, model_files_urls):
114
+ if not os.path.exists(folder):
115
+ os.makedirs(folder)
116
+ for url, filename in model_files_urls:
117
+ if not os.path.exists(os.path.join(folder, filename)):
118
+ download_file(url, filename, folder)
119
+
120
+ def download_image_to_3d_files(folder, model_files_urls):
121
+ if not os.path.exists(folder):
122
+ os.makedirs(folder)
123
+ for url, filename in model_files_urls:
124
+ if not os.path.exists(os.path.join(folder, filename)):
125
+ download_file(url, filename, folder)
126
+
127
+ def download_text_to_video_files(folder, model_files_urls):
128
+ if not os.path.exists(folder):
129
+ os.makedirs(folder)
130
+ for url, filename in model_files_urls:
131
+ if not os.path.exists(os.path.join(folder, filename)):
132
+ download_file(url, filename, folder)
133
+
134
+ def download_sentiment_files(folder, model_files_urls):
135
+ if not os.path.exists(folder):
136
+ os.makedirs(folder)
137
+ for url, filename in model_files_urls:
138
+ if not os.path.exists(os.path.join(folder, filename)):
139
+ download_file(url, filename, folder)
140
+
141
+ def download_stt_files(folder, model_files_urls):
142
+ if not os.path.exists(folder):
143
+ os.makedirs(folder)
144
+ for url, filename in model_files_urls:
145
+ if not os.path.exists(os.path.join(folder, filename)):
146
+ download_file(url, filename, folder)
147
+
148
+ def download_tts_files(folder, model_files_urls):
149
+ if not os.path.exists(folder):
150
+ os.makedirs(folder)
151
+ for url, filename in model_files_urls:
152
+ if not os.path.exists(os.path.join(folder, filename)):
153
+ download_file(url, filename, folder)
154
+
155
+ def download_musicgen_files(folder, model_files_urls):
156
+ if not os.path.exists(folder):
157
+ os.makedirs(folder)
158
+ for url, filename in model_files_urls:
159
+ if not os.path.exists(os.path.join(folder, filename)):
160
+ download_file(url, filename, folder)
161
+
162
+ def bytes_to_unicode_gpt2():
163
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
164
+ cs = bs[:]
165
+ n = 0
166
+ for b in range(2**8):
167
+ if b not in bs:
168
+ bs.append(b)
169
+ cs.append(2**8+n)
170
+ n = n+1
171
+ cs = [chr(n) for n in cs]
172
+ return dict(zip(bs, cs))
173
+
174
+ def get_codegen_tokenizer_pure(vocab_file, merges_file):
175
+ vocab = json.load(open(vocab_file))
176
+ merges = open(merges_file, 'r', encoding="utf-8").read().split('\n')[1:-1]
177
+ bpe_ranks = dict(zip(merges, range(len(merges))))
178
+ byte_encoder = bytes_to_unicode()
179
+ byte_decoder = {v: k for k, v in byte_encoder.items()}
180
+ tokenizer_regex = re.compile(r'''<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+''')
181
+ tokenize = lambda text: re.findall(tokenizer_regex, text)
182
+ encoder_obj = Encoder(
183
+ encoder=vocab,
184
+ decoder={v: u for u, v in vocab.items()},
185
+ bpe_ranks=bpe_ranks,
186
+ byte_encoder=byte_encoder,
187
+ byte_decoder=byte_decoder,
188
+ tokenize=tokenize
189
+ )
190
+ return encoder_obj
xxx.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import unicodedata
4
+ from functools import lru_cache
5
+ import wget
6
+ import os
7
+ from constants import GPT2_FOLDER, ENCODER_FILE, VOCAB_FILE, END_OF_TEXT_TOKEN
8
+ import nltk
9
+
10
+ @lru_cache()
11
+ def bytes_to_unicode():
12
+ bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
13
+ cs = bs[:]
14
+ n = 0
15
+ for b in range(2**8):
16
+ if b not in bs:
17
+ bs.append(b)
18
+ cs.append(2**8 + n)
19
+ n += 1
20
+ cs = [chr(n) for n in cs]
21
+ return dict(zip(bs, cs))
22
+
23
+ def get_pairs(word):
24
+ pairs = set()
25
+ prev_char = word[0]
26
+ for char in word[1:]:
27
+ pairs.add((prev_char, char))
28
+ prev_char = char
29
+ return pairs
30
+
31
+ class Encoder:
32
+ def __init__(self, encoder, bpe_merges, errors='replace', tokenize=None):
33
+ self.encoder = encoder
34
+ self.decoder = {v:k for k,v in self.encoder.items()}
35
+ self.errors = errors
36
+ self.byte_encoder = bytes_to_unicode()
37
+ self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
38
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
39
+ self.cache = {}
40
+ if tokenize is None:
41
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?[^\s\w]+|\s+(?!\S)|\s+""", re.UNICODE)
42
+ self.tokenize = lambda text: re.findall(self.pat, text)
43
+ else:
44
+ self.tokenize = tokenize
45
+
46
+ def bpe(self, token):
47
+ if token in self.cache:
48
+ return self.cache[token]
49
+ word = tuple(token)
50
+ pairs = get_pairs(word)
51
+ if not pairs:
52
+ return token
53
+ while True:
54
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
55
+ if bigram not in self.bpe_ranks:
56
+ break
57
+ first, second = bigram
58
+ new_word = []
59
+ i = 0
60
+ while i < len(word):
61
+ try:
62
+ j = word.index(first, i)
63
+ new_word.extend(word[i:j])
64
+ i = j
65
+ except ValueError:
66
+ new_word.extend(word[i:])
67
+ break
68
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
69
+ new_word.append(first+second)
70
+ i += 2
71
+ else:
72
+ new_word.append(word[i])
73
+ i += 1
74
+ new_word = tuple(new_word)
75
+ word = new_word
76
+ if len(word) == 1:
77
+ break
78
+ else:
79
+ pairs = get_pairs(word)
80
+ word = ' '.join(word)
81
+ self.cache[token] = word
82
+ return word
83
+
84
+ def encode(self, text):
85
+ bpe_tokens = []
86
+ normalized_text = unicodedata.normalize('NFKC', text)
87
+ normalized_text = ''.join(c for c in normalized_text if c.isascii() and c != '\t')
88
+ normalized_text = ''.join(c for c in normalized_text if not unicodedata.category(c).startswith('C'))
89
+ for token in self.tokenize(normalized_text):
90
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8', errors='ignore'))
91
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
92
+ return bpe_tokens
93
+
94
+ def decode(self, tokens):
95
+ text = ''.join([self.decoder[token] for token in tokens])
96
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
97
+ decoded_text = text.replace(" .", ".").replace(" ,", ",").replace(" '", "'").replace(" ?", "?").replace(" !", "!").replace(" :", ":").replace('\n', '<br>')
98
+ sentences = nltk.sent_tokenize(decoded_text)
99
+ return ' '.join(sentences).replace("<br>", "<br>\n")
100
+
101
+ def get_encoder_gpt2():
102
+ encoder_path = os.path.join(GPT2_FOLDER, ENCODER_FILE)
103
+ vocab_path = os.path.join(GPT2_FOLDER, VOCAB_FILE)
104
+ if not os.path.exists(GPT2_FOLDER):
105
+ os.makedirs(GPT2_FOLDER)
106
+ if not os.path.exists(encoder_path):
107
+ wget.download(ENCODER_URL, out=encoder_path)
108
+ if not os.path.exists(vocab_path):
109
+ wget.download(VOCAB_URL, out=vocab_path)
110
+
111
+ with open(encoder_path, 'r') as f:
112
+ encoder = json.load(f)
113
+ with open(vocab_path, 'r', encoding="utf-8") as f:
114
+ bpe_data = f.read()
115
+ bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
116
+ encoder_obj = Encoder(encoder=encoder, bpe_merges=bpe_merges)
117
+ encoder_obj.encoder[END_OF_TEXT_TOKEN] = len(encoder_obj.encoder)
118
+ encoder_obj.decoder[len(encoder_obj.decoder)] = END_OF_TEXT_TOKEN
119
+ return encoder_obj
120
+
121
+ def get_codegen_tokenizer_pure(vocab_file, merges_file):
122
+ vocab = json.load(open(vocab_file))
123
+ merges = open(merges_file, 'r', encoding="utf-8").read().split('\n')[1:-1]
124
+ bpe_merges = [tuple(m.split()) for m in merges]
125
+ byte_encoder = bytes_to_unicode()
126
+ byte_decoder = {v: k for k, v in byte_encoder.items()}
127
+ tokenizer_regex = re.compile(r'''<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+''')
128
+ tokenize = lambda text: re.findall(tokenizer_regex, text)
129
+ encoder_obj = Encoder(
130
+ encoder=vocab,
131
+ bpe_merges=bpe_merges,
132
+ byte_encoder=byte_encoder,
133
+ byte_decoder=byte_decoder,
134
+ tokenize=tokenize
135
+ )
136
+ return encoder_obj
137
+
138
+ def codegen_tokenize(text, tokenizer):
139
+ return tokenizer.encode(text)
140
+
141
+ def codegen_decode(tokens, tokenizer):
142
+ return tokenizer.decode(tokens)