Hhhh / model_loader.py
Hjgugugjhuhjggg's picture
Upload 28 files
e83e49f verified
raw
history blame
27.9 kB
from tokenxxx import *
from constants import *
from utils import *
import os
import json
import urllib.request
import urllib.parse
import torch
import hashlib
from tqdm import tqdm
from skimage import img_as_ubyte
from torch import nn
import torch.nn.functional as F
import inspect
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def filter_kwargs(cls, kwargs):
sig = inspect.signature(cls.__init__)
accepted = set(sig.parameters.keys()) - {"self"}
return {k: v for k, v in kwargs.items() if k in accepted}
def sanitize_filename(name, url=None):
for c in '<>:"/\\|?*':
name = name.replace(c, '')
if not name and url is not None:
name = hashlib.md5(url.encode()).hexdigest()
return name
def download_file(url, filepath):
d = os.path.dirname(filepath)
if d and not os.path.exists(d):
os.makedirs(d, exist_ok=True)
while not os.path.exists(filepath):
try:
def prog(t):
last = [0]
def inner(n, bs, ts):
if ts > 0:
t.total = ts
t.update(n * bs - last[0])
last[0] = n * bs
return inner
with tqdm(unit='B', unit_scale=True, unit_divisor=1024, desc=os.path.basename(filepath)) as t:
urllib.request.urlretrieve(url, filepath, reporthook=prog(t))
except Exception:
continue
def download_files(folder, files_spec):
if isinstance(files_spec, dict):
for fn, url in files_spec.items():
fn = sanitize_filename(fn, url)
fp = os.path.join(folder, fn)
download_file(url, fp)
elif isinstance(files_spec, list):
for item in files_spec:
if isinstance(item, str):
url = item
parsed = urllib.parse.urlparse(url)
fn = os.path.basename(parsed.path)
if not fn:
fn = hashlib.md5(url.encode()).hexdigest()
fn = sanitize_filename(fn, url)
elif isinstance(item, (list, tuple)) and len(item) == 2:
url, fn = item
fn = sanitize_filename(fn, url)
elif isinstance(item, dict) and "filename" in item and "url" in item:
fn = sanitize_filename(item["filename"], item["url"])
url = item["url"]
else:
raise ValueError("Invalid file specification")
fp = os.path.join(folder, fn)
download_file(url, fp)
else:
raise ValueError("files_spec must be dict or list")
def read_json(fp):
with open(fp, 'r', encoding='utf-8') as f:
return json.load(f)
def get_codegen_tokenizer(vocab_path, merges_path):
with open(vocab_path, 'r', encoding='utf-8') as f:
vocab = json.load(f)
with open(merges_path, 'r', encoding='utf-8') as f:
merges = f.read().splitlines()
merge_ranks = {}
for i, merge in enumerate(merges):
parts = merge.strip().split()
if len(parts) == 2:
merge_ranks[tuple(parts)] = i
def bpe(token):
word = list(token)
pairs = [(word[i], word[i+1]) for i in range(len(word)-1)]
while True:
candidate = None
candidate_rank = None
candidate_index = None
for i, pair in enumerate(pairs):
if pair in merge_ranks:
rank = merge_ranks[pair]
if candidate is None or rank < candidate_rank:
candidate = pair
candidate_rank = rank
candidate_index = i
if candidate is None:
break
first, second = candidate
new_word = []
i = 0
while i < len(word):
if i < len(word) - 1 and word[i] == first and word[i+1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
word = new_word
if len(word) == 1:
break
pairs = [(word[i], word[i+1]) for i in range(len(word)-1)]
return word
def tokenizer(text):
tokens = []
for token in text.split():
bpe_tokens = bpe(token)
for subtoken in bpe_tokens:
tokens.append(vocab.get(subtoken, 0))
return tokens
return tokenizer
def simple_tokenizer(text, vocab, max_length=77):
toks = text.split()
ids = [vocab.get(t, 1) for t in toks]
if len(ids) < max_length:
ids = ids + [0] * (max_length - len(ids))
else:
ids = ids[:max_length]
return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(device)
def load_state_dict_safe(model, loaded_state_dict):
model_state = model.state_dict()
new_state = {}
for key, value in model_state.items():
if key in loaded_state_dict and loaded_state_dict[key].shape == value.shape:
new_state[key] = loaded_state_dict[key]
else:
new_state[key] = value
model.load_state_dict(new_state, strict=False)
class GPT2Config:
def __init__(self, vocab_size=50257, **kwargs):
self.vocab_size = vocab_size
self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d):
return cls(**d)
class MBartConfig:
def __init__(self, vocab_size=50265, **kwargs):
self.vocab_size = vocab_size
self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d):
return cls(**d)
class CodeGenConfig:
def __init__(self, vocab_size=50257, **kwargs):
self.vocab_size = vocab_size
self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d):
return cls(**d)
class BartConfig:
def __init__(self, vocab_size=50265, **kwargs):
self.vocab_size = vocab_size
self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d):
return cls(**d)
class AutoencoderKLConfig:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d):
return cls(**d)
class OpenLRMConfig:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d):
return cls(**d)
class UNet2DConditionModelConfig:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d):
return cls(**d)
class MusicGenConfig:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
@classmethod
def from_dict(cls, d):
return cls(**d)
class GPT2LMHeadModel(nn.Module):
def __init__(self, config):
super().__init__()
layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
self.transformer = nn.TransformerEncoder(layer, num_layers=12)
self.lm_head = nn.Linear(768, config.vocab_size)
def forward(self, x):
return self.lm_head(self.transformer(x))
class MBartForConditionalGeneration(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
self.encoder = nn.TransformerEncoder(layer, num_layers=6)
dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12)
self.decoder = nn.TransformerDecoder(dlayer, num_layers=6)
self.output_layer = nn.Linear(768, config.vocab_size)
def forward(self, src, tgt):
return self.output_layer(self.decoder(tgt, self.encoder(src)))
class CodeGenForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
d_model = getattr(config, "d_model", 1024)
n_head = getattr(config, "n_head", 16)
num_layers = getattr(config, "num_layers", 12)
dlayer = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_head)
self.transformer_decoder = nn.TransformerDecoder(dlayer, num_layers=num_layers)
self.lm_head = nn.Linear(d_model, config.vocab_size)
def forward(self, tgt, memory=None):
if memory is None:
memory = torch.zeros_like(tgt)
return self.lm_head(self.transformer_decoder(tgt, memory))
class BartForConditionalGeneration(nn.Module):
def __init__(self, config):
super().__init__()
layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
self.encoder = nn.TransformerEncoder(layer, num_layers=6)
dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12)
self.decoder = nn.TransformerDecoder(dlayer, num_layers=6)
self.output_layer = nn.Linear(768, config.vocab_size)
def forward(self, src, tgt):
return self.output_layer(self.decoder(tgt, self.encoder(src)))
class ResnetBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.norm1 = nn.GroupNorm(32, in_ch)
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.norm2 = nn.GroupNorm(32, out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.conv_shortcut = nn.Conv2d(in_ch, out_ch, 1)
def forward(self, x):
sc = self.conv_shortcut(x)
h = F.silu(self.norm1(x))
h = self.conv1(h)
h = F.silu(self.norm2(x))
h = self.conv2(h)
return h + sc
class Downsample(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1)
def forward(self, x):
return self.conv(x)
class DownBlock(nn.Module):
def __init__(self, in_ch, out_ch, num_res):
super().__init__()
self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)])
self.downsamplers = nn.ModuleList([Downsample(out_ch, out_ch)])
def forward(self, x):
for r in self.resnets:
x = r(x)
for ds in self.downsamplers:
x = ds(x)
return x
class Upsample(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
def forward(self, x):
return self.conv(x)
class UpBlock(nn.Module):
def __init__(self, in_ch, out_ch, num_res):
super().__init__()
self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)])
self.upsampler = Upsample(out_ch, out_ch)
def forward(self, x):
for r in self.resnets:
x = r(x)
return self.upsampler(x)
class AttentionBlock(nn.Module):
def __init__(self, ch):
super().__init__()
self.norm = nn.GroupNorm(32, ch)
self.query = nn.Conv2d(ch, ch, 1)
self.key = nn.Conv2d(ch, ch, 1)
self.value = nn.Conv2d(ch, ch, 1)
self.proj_attn = nn.Conv2d(ch, ch, 1)
def forward(self, x):
b, c, h, w = x.shape
xn = self.norm(x)
q = self.query(xn).view(b, c, -1).permute(0, 2, 1)
k = self.key(xn).view(b, c, -1)
v = self.value(xn).view(b, c, -1).permute(0, 2, 1)
attn = torch.softmax(torch.bmm(q, k) / (c ** 0.5), dim=-1)
out = torch.bmm(attn, v).permute(0, 2, 1).view(b, c, h, w)
return x + self.proj_attn(out)
class Encoder(nn.Module):
def __init__(self, in_ch=3, base_ch=128, latent_ch=4):
super().__init__()
self.conv_in = nn.Conv2d(in_ch, base_ch, 3, padding=1)
self.down_blocks = nn.ModuleList([
DownBlock(base_ch, base_ch, 2),
DownBlock(base_ch, base_ch * 2, 2),
DownBlock(base_ch * 2, base_ch * 4, 2),
DownBlock(base_ch * 4, base_ch * 4, 2)
])
self.mid_block = nn.ModuleList([
ResnetBlock(base_ch * 4, base_ch * 4),
AttentionBlock(base_ch * 4),
ResnetBlock(base_ch * 4, base_ch * 4)
])
self.conv_norm_out = nn.GroupNorm(32, base_ch * 4)
self.conv_out = nn.Conv2d(base_ch * 4, latent_ch * 2, 3, padding=1)
self.quant_conv = nn.Conv2d(latent_ch * 2, latent_ch, 1)
def forward(self, x):
x = self.conv_in(x)
for blk in self.down_blocks:
x = blk(x)
for m in self.mid_block:
x = m(x)
x = self.conv_norm_out(x)
x = self.conv_out(x)
return self.quant_conv(x)
class Decoder(nn.Module):
def __init__(self, out_ch=3, base_ch=128, latent_ch=4):
super().__init__()
self.post_quant_conv = nn.Conv2d(latent_ch, latent_ch * 2, 1)
self.conv_in = nn.Conv2d(latent_ch, base_ch * 4, 3, padding=1)
self.mid_block = nn.ModuleList([
ResnetBlock(base_ch * 4, base_ch * 4),
AttentionBlock(base_ch * 4),
ResnetBlock(base_ch * 4, base_ch * 4)
])
self.up_blocks = nn.ModuleList([
UpBlock(base_ch * 4, base_ch * 4, 3),
UpBlock(base_ch * 4, base_ch * 2, 3),
UpBlock(base_ch * 2, base_ch, 3),
UpBlock(base_ch, base_ch, 3)
])
self.conv_norm_out = nn.GroupNorm(32, base_ch)
self.conv_out = nn.Conv2d(base_ch, out_ch, 3, padding=1)
def forward(self, x):
x = self.post_quant_conv(x)
x = self.conv_in(x)
for m in self.mid_block:
x = m(x)
for up in self.up_blocks:
x = up(x)
x = self.conv_norm_out(x)
return self.conv_out(x)
class AutoencoderKL(nn.Module):
def __init__(self, config):
super().__init__()
in_ch = config.get("in_channels", 3) if isinstance(config, dict) else config.__dict__.get("in_channels", 3)
out_ch = config.get("out_channels", 3) if isinstance(config, dict) else config.__dict__.get("out_channels", 3)
base_ch = config.get("base_channels", 128) if isinstance(config, dict) else config.__dict__.get("base_channels", 128)
latent_ch = config.get("latent_channels", 4) if isinstance(config, dict) else config.__dict__.get("latent_channels", 4)
self.encoder = Encoder(in_ch, base_ch, latent_ch)
self.decoder = Decoder(out_ch, base_ch, latent_ch)
def forward(self, x):
return self.decoder(self.encoder(x))
def decode(self, x):
return self.decoder(x)
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.norm2 = nn.LayerNorm(embed_dim)
hidden_dim = embed_dim * 4
self.mlp = nn.Sequential(
nn.Linear(embed_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, embed_dim)
)
def forward(self, x):
res = x
x = self.norm1(x)
x = x.transpose(0, 1)
attn, _ = self.attn(x, x, x)
x = attn.transpose(0, 1)
x = res + x
return x + self.mlp(self.norm2(x))
class VisionTransformer(nn.Module):
def __init__(self, config):
super().__init__()
if isinstance(config, dict):
self.img_size = config.get("img_size", 592)
self.patch_size = config.get("patch_size", 16)
self.embed_dim = config.get("hidden_size", 768)
depth = config.get("depth", 12)
num_heads = config.get("num_heads", 12)
else:
self.img_size = config.__dict__.get("img_size", 592)
self.patch_size = config.__dict__.get("patch_size", 16)
self.embed_dim = config.__dict__.get("hidden_size", 768)
depth = config.__dict__.get("depth", 12)
num_heads = config.__dict__.get("num_heads", 12)
num_patches = (self.img_size // self.patch_size) ** 2
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
self.patch_embed = nn.Conv2d(3, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
self.blocks = nn.ModuleList([TransformerBlock(self.embed_dim, num_heads) for _ in range(depth)])
self.norm = nn.LayerNorm(self.embed_dim)
self.register_tokens = nn.Parameter(torch.zeros(1, 4, self.embed_dim))
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.cls_token, std=0.02)
nn.init.normal_(self.pos_embed, std=0.02)
def forward(self, x):
x = self.patch_embed(x)
x = x.flatten(2).transpose(1, 2)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
return self.norm(x)[:, 0]
class OpenLRM(nn.Module):
def __init__(self, config):
super().__init__()
self.encoder = nn.ModuleDict({"model": VisionTransformer(config)})
hidden = config.get("hidden_size", 768) if isinstance(config, dict) else config.__dict__.get("hidden_size", 768)
self.linear = nn.Linear(hidden, hidden)
def forward(self, x):
return self.linear(self.encoder["model"](x))
class VideoUNet(nn.Module):
def __init__(self, in_ch=4, out_ch=4, features=None):
super().__init__()
if features is None:
features = [64, 128, 256]
self.encoder = nn.ModuleList()
self.pool = nn.MaxPool3d(2, 2)
self.decoder = nn.ModuleList()
for f in features:
self.encoder.append(nn.Sequential(
nn.Conv3d(in_ch, f, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv3d(f, f, 3, padding=1),
nn.ReLU(inplace=True)
))
in_ch = f
for f in reversed(features):
self.decoder.append(nn.Sequential(
nn.Conv3d(f * 2, f, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv3d(f, f, 3, padding=1),
nn.ReLU(inplace=True)
))
self.final_conv = nn.Conv3d(features[0], out_ch, 1)
def forward(self, x, t, encoder_hidden_states):
skips = []
for enc in self.encoder:
x = enc(x)
skips.append(x)
x = self.pool(x)
for dec in self.decoder:
skip = skips.pop()
x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=False)
x = torch.cat([x, skip], dim=1)
x = dec(x)
return self.final_conv(x)
class SentimentClassifierModel(nn.Module):
def __init__(self, config):
super().__init__()
self.classifier = nn.Sequential(
nn.Linear(768, 256),
nn.ReLU(),
nn.Linear(256, 2)
)
def forward(self, x):
return self.classifier(x)
class STTModel(nn.Module):
def __init__(self, config):
super().__init__()
self.net = nn.Sequential(
nn.Linear(768, 512),
nn.ReLU(),
nn.Linear(512, 768)
)
def forward(self, x):
return self.net(x)
class TTSModel(nn.Module):
def __init__(self, config):
super().__init__()
self.net = nn.Sequential(
nn.Linear(768, 512),
nn.ReLU(),
nn.Linear(512, 768)
)
def forward(self, x):
return self.net(x)
class MusicGenModel(nn.Module):
def __init__(self, config):
super().__init__()
layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
self.transformer = nn.TransformerEncoder(layer, num_layers=12)
self.linear = nn.Linear(768, 768)
def forward(self, x):
return self.linear(self.transformer(x))
class SimpleTextEncoder(nn.Module):
def __init__(self, vocab_size=10000, embed_dim=768, max_length=77):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.max_length = max_length
def forward(self, text_tokens):
return self.embedding(text_tokens)
class DiffusionScheduler:
def __init__(self, steps):
self.steps = steps
self.betas = torch.linspace(0.1, 0.001, steps=steps).to(device)
self.alphas = 1 - self.betas
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
def step(self, noise, t, sample):
alpha_bar = self.alpha_bars[t]
alpha_bar_prev = self.alpha_bars[t-1] if t > 0 else torch.tensor(1.0, device=sample.device)
x0 = (sample - torch.sqrt(1 - alpha_bar) * noise) / torch.sqrt(alpha_bar)
new_sample = torch.sqrt(alpha_bar_prev) * x0 + torch.sqrt(1 - alpha_bar_prev) * noise
return new_sample
class VideoOutput:
def __init__(self, frames):
self.frames = [img_as_ubyte(frame) for frame in frames[0]]
class VideoPipeline(nn.Module):
def __init__(self, unet, vae, text_encoder, vocab):
super().__init__()
self.unet = unet
self.vae = vae
self.text_encoder = text_encoder
self.vocab = vocab
def forward(self, prompt: str, steps: int = 25, num_frames: int = 24):
token_ids = simple_tokenizer(prompt, self.vocab)
text_emb = self.text_encoder(token_ids)
latent = torch.randn((1, 4, num_frames, 64, 64), device=device).half()
sched = DiffusionScheduler(steps)
for t in range(steps):
noise = self.unet(latent, t, text_emb)
latent = sched.step(noise, t, latent)
frames = self.vae.decode(latent / 0.18215)
frames = frames.clamp(0, 1).float().cpu().permute(0, 2, 3, 4, 1).numpy()
return VideoOutput(frames)
def initialize_gpt2_model(folder, files):
download_files(folder, files)
config = GPT2Config()
model = GPT2LMHeadModel(config).to(device)
sd = torch.load(os.path.join(folder, sanitize_filename("gpt2-pytorch_model.bin")), map_location=device)
load_state_dict_safe(model, sd)
model.eval()
enc = read_json(os.path.join(folder, sanitize_filename("encoder.json")))
return model, enc
def initialize_translation_model(folder, files):
download_files(folder, files)
config = MBartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = MBartForConditionalGeneration(config).to(device)
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
load_state_dict_safe(model, sd)
model.eval()
vp = os.path.join(folder, "vocab.json")
if os.path.exists(vp):
vocab = read_json(vp)
model.tokenizer = lambda txt: [vocab.get(t, 0) for t in txt.split()]
else:
model.tokenizer = lambda txt: txt
model.config.lang_code_to_id = {'en_XX': 0, 'es_XX': 1}
return model
def initialize_codegen_model(folder, files):
download_files(folder, files)
config = CodeGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = CodeGenForCausalLM(config).to(device)
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
load_state_dict_safe(model, sd)
model.eval()
tok = get_codegen_tokenizer(os.path.join(folder, "vocab.json"), os.path.join(folder, "merges.txt"))
vocab = read_json(os.path.join(folder, "vocab.json"))
idx2w = {v: k for k, v in vocab.items()}
model.tokenizer = tok
return model, tok, vocab, idx2w, vocab
def initialize_summarization_model(folder, files):
download_files(folder, files)
config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = BartForConditionalGeneration(config).to(device)
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
load_state_dict_safe(model, sd)
model.eval()
vp = os.path.join(folder, "vocab.json")
if os.path.exists(vp):
vocab_json = read_json(vp)
vocab = set(vocab_json.keys())
return model, vocab, vocab_json, {v: k for k, v in vocab_json.items()}
return model, None, None, None
def initialize_imagegen_model(folder, files):
download_files(folder, files)
config = AutoencoderKLConfig.from_dict(read_json(os.path.join(folder, "config.json")))
vae = AutoencoderKL(config).to(device)
sd = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device)
load_state_dict_safe(vae, sd)
vae.eval()
return vae
def initialize_image_to_3d_model(folder, files):
download_files(folder, files)
config = OpenLRMConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model3d = OpenLRM(config).to(device)
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
load_state_dict_safe(model3d, sd)
model3d.eval()
return model3d
def initialize_text_to_video_model(folder, files):
download_files(folder, files)
unet_cfg = read_json(os.path.join(folder, "config.json"))
unet_cfg = filter_kwargs(VideoUNet, unet_cfg)
unet = VideoUNet(**unet_cfg).half().to(device)
sd_unet = torch.load(os.path.join(folder, "diffusion_pytorch_model.fp16.bin"), map_location=device)
load_state_dict_safe(unet, sd_unet)
unet.eval()
vae_cfg = read_json(os.path.join(folder, "config.json"))
vae_cfg = filter_kwargs(AutoencoderKL, vae_cfg)
vae = AutoencoderKL(vae_cfg).half().to(device)
sd_vae = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device)
load_state_dict_safe(vae, sd_vae)
vae.eval()
vp = os.path.join(folder, "vocab.json")
text_vocab = read_json(vp) if os.path.exists(vp) else {}
te_path = os.path.join(folder, "text_encoder.bin")
if os.path.exists(te_path):
text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
sd_te = torch.load(te_path, map_location=device)
load_state_dict_safe(text_encoder, sd_te)
else:
text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
text_encoder.eval()
return VideoPipeline(unet, vae, text_encoder, text_vocab)
def initialize_sentiment_model(folder, files):
download_files(folder, files)
config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = SentimentClassifierModel(config).to(device)
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
load_state_dict_safe(model, sd)
model.eval()
vp = os.path.join(folder, "vocab.json")
if os.path.exists(vp):
read_json(vp)
return model
def initialize_stt_model(folder, files):
download_files(folder, files)
config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = STTModel(config).to(device)
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
load_state_dict_safe(model, sd)
model.eval()
vp = os.path.join(folder, "vocab.json")
if os.path.exists(vp):
read_json(vp)
return model
def initialize_tts_model(folder, files):
download_files(folder, files)
config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = TTSModel(config).to(device)
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
load_state_dict_safe(model, sd)
model.eval()
vp = os.path.join(folder, "vocab.json")
if os.path.exists(vp):
read_json(vp)
return model
def initialize_musicgen_model(folder, files):
download_files(folder, files)
config = MusicGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
model = MusicGenModel(config).to(device)
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
load_state_dict_safe(model, sd)
model.eval()
return model