Spaces:
Running
Running
from tokenxxx import * | |
from constants import * | |
from utils import * | |
import os | |
import json | |
import urllib.request | |
import urllib.parse | |
import torch | |
import hashlib | |
from tqdm import tqdm | |
from skimage import img_as_ubyte | |
from torch import nn | |
import torch.nn.functional as F | |
import inspect | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def filter_kwargs(cls, kwargs): | |
sig = inspect.signature(cls.__init__) | |
accepted = set(sig.parameters.keys()) - {"self"} | |
return {k: v for k, v in kwargs.items() if k in accepted} | |
def sanitize_filename(name, url=None): | |
for c in '<>:"/\\|?*': | |
name = name.replace(c, '') | |
if not name and url is not None: | |
name = hashlib.md5(url.encode()).hexdigest() | |
return name | |
def download_file(url, filepath): | |
d = os.path.dirname(filepath) | |
if d and not os.path.exists(d): | |
os.makedirs(d, exist_ok=True) | |
while not os.path.exists(filepath): | |
try: | |
def prog(t): | |
last = [0] | |
def inner(n, bs, ts): | |
if ts > 0: | |
t.total = ts | |
t.update(n * bs - last[0]) | |
last[0] = n * bs | |
return inner | |
with tqdm(unit='B', unit_scale=True, unit_divisor=1024, desc=os.path.basename(filepath)) as t: | |
urllib.request.urlretrieve(url, filepath, reporthook=prog(t)) | |
except Exception: | |
continue | |
def download_files(folder, files_spec): | |
if isinstance(files_spec, dict): | |
for fn, url in files_spec.items(): | |
fn = sanitize_filename(fn, url) | |
fp = os.path.join(folder, fn) | |
download_file(url, fp) | |
elif isinstance(files_spec, list): | |
for item in files_spec: | |
if isinstance(item, str): | |
url = item | |
parsed = urllib.parse.urlparse(url) | |
fn = os.path.basename(parsed.path) | |
if not fn: | |
fn = hashlib.md5(url.encode()).hexdigest() | |
fn = sanitize_filename(fn, url) | |
elif isinstance(item, (list, tuple)) and len(item) == 2: | |
url, fn = item | |
fn = sanitize_filename(fn, url) | |
elif isinstance(item, dict) and "filename" in item and "url" in item: | |
fn = sanitize_filename(item["filename"], item["url"]) | |
url = item["url"] | |
else: | |
raise ValueError("Invalid file specification") | |
fp = os.path.join(folder, fn) | |
download_file(url, fp) | |
else: | |
raise ValueError("files_spec must be dict or list") | |
def read_json(fp): | |
with open(fp, 'r', encoding='utf-8') as f: | |
return json.load(f) | |
def get_codegen_tokenizer(vocab_path, merges_path): | |
with open(vocab_path, 'r', encoding='utf-8') as f: | |
vocab = json.load(f) | |
with open(merges_path, 'r', encoding='utf-8') as f: | |
merges = f.read().splitlines() | |
merge_ranks = {} | |
for i, merge in enumerate(merges): | |
parts = merge.strip().split() | |
if len(parts) == 2: | |
merge_ranks[tuple(parts)] = i | |
def bpe(token): | |
word = list(token) | |
pairs = [(word[i], word[i+1]) for i in range(len(word)-1)] | |
while True: | |
candidate = None | |
candidate_rank = None | |
candidate_index = None | |
for i, pair in enumerate(pairs): | |
if pair in merge_ranks: | |
rank = merge_ranks[pair] | |
if candidate is None or rank < candidate_rank: | |
candidate = pair | |
candidate_rank = rank | |
candidate_index = i | |
if candidate is None: | |
break | |
first, second = candidate | |
new_word = [] | |
i = 0 | |
while i < len(word): | |
if i < len(word) - 1 and word[i] == first and word[i+1] == second: | |
new_word.append(first + second) | |
i += 2 | |
else: | |
new_word.append(word[i]) | |
i += 1 | |
word = new_word | |
if len(word) == 1: | |
break | |
pairs = [(word[i], word[i+1]) for i in range(len(word)-1)] | |
return word | |
def tokenizer(text): | |
tokens = [] | |
for token in text.split(): | |
bpe_tokens = bpe(token) | |
for subtoken in bpe_tokens: | |
tokens.append(vocab.get(subtoken, 0)) | |
return tokens | |
return tokenizer | |
def simple_tokenizer(text, vocab, max_length=77): | |
toks = text.split() | |
ids = [vocab.get(t, 1) for t in toks] | |
if len(ids) < max_length: | |
ids = ids + [0] * (max_length - len(ids)) | |
else: | |
ids = ids[:max_length] | |
return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(device) | |
def load_state_dict_safe(model, loaded_state_dict): | |
model_state = model.state_dict() | |
new_state = {} | |
for key, value in model_state.items(): | |
if key in loaded_state_dict and loaded_state_dict[key].shape == value.shape: | |
new_state[key] = loaded_state_dict[key] | |
else: | |
new_state[key] = value | |
model.load_state_dict(new_state, strict=False) | |
class GPT2Config: | |
def __init__(self, vocab_size=50257, **kwargs): | |
self.vocab_size = vocab_size | |
self.__dict__.update(kwargs) | |
def from_dict(cls, d): | |
return cls(**d) | |
class MBartConfig: | |
def __init__(self, vocab_size=50265, **kwargs): | |
self.vocab_size = vocab_size | |
self.__dict__.update(kwargs) | |
def from_dict(cls, d): | |
return cls(**d) | |
class CodeGenConfig: | |
def __init__(self, vocab_size=50257, **kwargs): | |
self.vocab_size = vocab_size | |
self.__dict__.update(kwargs) | |
def from_dict(cls, d): | |
return cls(**d) | |
class BartConfig: | |
def __init__(self, vocab_size=50265, **kwargs): | |
self.vocab_size = vocab_size | |
self.__dict__.update(kwargs) | |
def from_dict(cls, d): | |
return cls(**d) | |
class AutoencoderKLConfig: | |
def __init__(self, **kwargs): | |
self.__dict__.update(kwargs) | |
def from_dict(cls, d): | |
return cls(**d) | |
class OpenLRMConfig: | |
def __init__(self, **kwargs): | |
self.__dict__.update(kwargs) | |
def from_dict(cls, d): | |
return cls(**d) | |
class UNet2DConditionModelConfig: | |
def __init__(self, **kwargs): | |
self.__dict__.update(kwargs) | |
def from_dict(cls, d): | |
return cls(**d) | |
class MusicGenConfig: | |
def __init__(self, **kwargs): | |
self.__dict__.update(kwargs) | |
def from_dict(cls, d): | |
return cls(**d) | |
class GPT2LMHeadModel(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
layer = nn.TransformerEncoderLayer(d_model=768, nhead=12) | |
self.transformer = nn.TransformerEncoder(layer, num_layers=12) | |
self.lm_head = nn.Linear(768, config.vocab_size) | |
def forward(self, x): | |
return self.lm_head(self.transformer(x)) | |
class MBartForConditionalGeneration(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
layer = nn.TransformerEncoderLayer(d_model=768, nhead=12) | |
self.encoder = nn.TransformerEncoder(layer, num_layers=6) | |
dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12) | |
self.decoder = nn.TransformerDecoder(dlayer, num_layers=6) | |
self.output_layer = nn.Linear(768, config.vocab_size) | |
def forward(self, src, tgt): | |
return self.output_layer(self.decoder(tgt, self.encoder(src))) | |
class CodeGenForCausalLM(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
d_model = getattr(config, "d_model", 1024) | |
n_head = getattr(config, "n_head", 16) | |
num_layers = getattr(config, "num_layers", 12) | |
dlayer = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_head) | |
self.transformer_decoder = nn.TransformerDecoder(dlayer, num_layers=num_layers) | |
self.lm_head = nn.Linear(d_model, config.vocab_size) | |
def forward(self, tgt, memory=None): | |
if memory is None: | |
memory = torch.zeros_like(tgt) | |
return self.lm_head(self.transformer_decoder(tgt, memory)) | |
class BartForConditionalGeneration(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
layer = nn.TransformerEncoderLayer(d_model=768, nhead=12) | |
self.encoder = nn.TransformerEncoder(layer, num_layers=6) | |
dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12) | |
self.decoder = nn.TransformerDecoder(dlayer, num_layers=6) | |
self.output_layer = nn.Linear(768, config.vocab_size) | |
def forward(self, src, tgt): | |
return self.output_layer(self.decoder(tgt, self.encoder(src))) | |
class ResnetBlock(nn.Module): | |
def __init__(self, in_ch, out_ch): | |
super().__init__() | |
self.norm1 = nn.GroupNorm(32, in_ch) | |
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) | |
self.norm2 = nn.GroupNorm(32, out_ch) | |
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) | |
self.conv_shortcut = nn.Conv2d(in_ch, out_ch, 1) | |
def forward(self, x): | |
sc = self.conv_shortcut(x) | |
h = F.silu(self.norm1(x)) | |
h = self.conv1(h) | |
h = F.silu(self.norm2(x)) | |
h = self.conv2(h) | |
return h + sc | |
class Downsample(nn.Module): | |
def __init__(self, in_ch, out_ch): | |
super().__init__() | |
self.conv = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1) | |
def forward(self, x): | |
return self.conv(x) | |
class DownBlock(nn.Module): | |
def __init__(self, in_ch, out_ch, num_res): | |
super().__init__() | |
self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)]) | |
self.downsamplers = nn.ModuleList([Downsample(out_ch, out_ch)]) | |
def forward(self, x): | |
for r in self.resnets: | |
x = r(x) | |
for ds in self.downsamplers: | |
x = ds(x) | |
return x | |
class Upsample(nn.Module): | |
def __init__(self, in_ch, out_ch): | |
super().__init__() | |
self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1) | |
def forward(self, x): | |
return self.conv(x) | |
class UpBlock(nn.Module): | |
def __init__(self, in_ch, out_ch, num_res): | |
super().__init__() | |
self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)]) | |
self.upsampler = Upsample(out_ch, out_ch) | |
def forward(self, x): | |
for r in self.resnets: | |
x = r(x) | |
return self.upsampler(x) | |
class AttentionBlock(nn.Module): | |
def __init__(self, ch): | |
super().__init__() | |
self.norm = nn.GroupNorm(32, ch) | |
self.query = nn.Conv2d(ch, ch, 1) | |
self.key = nn.Conv2d(ch, ch, 1) | |
self.value = nn.Conv2d(ch, ch, 1) | |
self.proj_attn = nn.Conv2d(ch, ch, 1) | |
def forward(self, x): | |
b, c, h, w = x.shape | |
xn = self.norm(x) | |
q = self.query(xn).view(b, c, -1).permute(0, 2, 1) | |
k = self.key(xn).view(b, c, -1) | |
v = self.value(xn).view(b, c, -1).permute(0, 2, 1) | |
attn = torch.softmax(torch.bmm(q, k) / (c ** 0.5), dim=-1) | |
out = torch.bmm(attn, v).permute(0, 2, 1).view(b, c, h, w) | |
return x + self.proj_attn(out) | |
class Encoder(nn.Module): | |
def __init__(self, in_ch=3, base_ch=128, latent_ch=4): | |
super().__init__() | |
self.conv_in = nn.Conv2d(in_ch, base_ch, 3, padding=1) | |
self.down_blocks = nn.ModuleList([ | |
DownBlock(base_ch, base_ch, 2), | |
DownBlock(base_ch, base_ch * 2, 2), | |
DownBlock(base_ch * 2, base_ch * 4, 2), | |
DownBlock(base_ch * 4, base_ch * 4, 2) | |
]) | |
self.mid_block = nn.ModuleList([ | |
ResnetBlock(base_ch * 4, base_ch * 4), | |
AttentionBlock(base_ch * 4), | |
ResnetBlock(base_ch * 4, base_ch * 4) | |
]) | |
self.conv_norm_out = nn.GroupNorm(32, base_ch * 4) | |
self.conv_out = nn.Conv2d(base_ch * 4, latent_ch * 2, 3, padding=1) | |
self.quant_conv = nn.Conv2d(latent_ch * 2, latent_ch, 1) | |
def forward(self, x): | |
x = self.conv_in(x) | |
for blk in self.down_blocks: | |
x = blk(x) | |
for m in self.mid_block: | |
x = m(x) | |
x = self.conv_norm_out(x) | |
x = self.conv_out(x) | |
return self.quant_conv(x) | |
class Decoder(nn.Module): | |
def __init__(self, out_ch=3, base_ch=128, latent_ch=4): | |
super().__init__() | |
self.post_quant_conv = nn.Conv2d(latent_ch, latent_ch * 2, 1) | |
self.conv_in = nn.Conv2d(latent_ch, base_ch * 4, 3, padding=1) | |
self.mid_block = nn.ModuleList([ | |
ResnetBlock(base_ch * 4, base_ch * 4), | |
AttentionBlock(base_ch * 4), | |
ResnetBlock(base_ch * 4, base_ch * 4) | |
]) | |
self.up_blocks = nn.ModuleList([ | |
UpBlock(base_ch * 4, base_ch * 4, 3), | |
UpBlock(base_ch * 4, base_ch * 2, 3), | |
UpBlock(base_ch * 2, base_ch, 3), | |
UpBlock(base_ch, base_ch, 3) | |
]) | |
self.conv_norm_out = nn.GroupNorm(32, base_ch) | |
self.conv_out = nn.Conv2d(base_ch, out_ch, 3, padding=1) | |
def forward(self, x): | |
x = self.post_quant_conv(x) | |
x = self.conv_in(x) | |
for m in self.mid_block: | |
x = m(x) | |
for up in self.up_blocks: | |
x = up(x) | |
x = self.conv_norm_out(x) | |
return self.conv_out(x) | |
class AutoencoderKL(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
in_ch = config.get("in_channels", 3) if isinstance(config, dict) else config.__dict__.get("in_channels", 3) | |
out_ch = config.get("out_channels", 3) if isinstance(config, dict) else config.__dict__.get("out_channels", 3) | |
base_ch = config.get("base_channels", 128) if isinstance(config, dict) else config.__dict__.get("base_channels", 128) | |
latent_ch = config.get("latent_channels", 4) if isinstance(config, dict) else config.__dict__.get("latent_channels", 4) | |
self.encoder = Encoder(in_ch, base_ch, latent_ch) | |
self.decoder = Decoder(out_ch, base_ch, latent_ch) | |
def forward(self, x): | |
return self.decoder(self.encoder(x)) | |
def decode(self, x): | |
return self.decoder(x) | |
class TransformerBlock(nn.Module): | |
def __init__(self, embed_dim, num_heads): | |
super().__init__() | |
self.norm1 = nn.LayerNorm(embed_dim) | |
self.attn = nn.MultiheadAttention(embed_dim, num_heads) | |
self.norm2 = nn.LayerNorm(embed_dim) | |
hidden_dim = embed_dim * 4 | |
self.mlp = nn.Sequential( | |
nn.Linear(embed_dim, hidden_dim), | |
nn.GELU(), | |
nn.Linear(hidden_dim, embed_dim) | |
) | |
def forward(self, x): | |
res = x | |
x = self.norm1(x) | |
x = x.transpose(0, 1) | |
attn, _ = self.attn(x, x, x) | |
x = attn.transpose(0, 1) | |
x = res + x | |
return x + self.mlp(self.norm2(x)) | |
class VisionTransformer(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
if isinstance(config, dict): | |
self.img_size = config.get("img_size", 592) | |
self.patch_size = config.get("patch_size", 16) | |
self.embed_dim = config.get("hidden_size", 768) | |
depth = config.get("depth", 12) | |
num_heads = config.get("num_heads", 12) | |
else: | |
self.img_size = config.__dict__.get("img_size", 592) | |
self.patch_size = config.__dict__.get("patch_size", 16) | |
self.embed_dim = config.__dict__.get("hidden_size", 768) | |
depth = config.__dict__.get("depth", 12) | |
num_heads = config.__dict__.get("num_heads", 12) | |
num_patches = (self.img_size // self.patch_size) ** 2 | |
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) | |
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim)) | |
self.patch_embed = nn.Conv2d(3, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size) | |
self.blocks = nn.ModuleList([TransformerBlock(self.embed_dim, num_heads) for _ in range(depth)]) | |
self.norm = nn.LayerNorm(self.embed_dim) | |
self.register_tokens = nn.Parameter(torch.zeros(1, 4, self.embed_dim)) | |
self._init_weights() | |
def _init_weights(self): | |
nn.init.normal_(self.cls_token, std=0.02) | |
nn.init.normal_(self.pos_embed, std=0.02) | |
def forward(self, x): | |
x = self.patch_embed(x) | |
x = x.flatten(2).transpose(1, 2) | |
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) | |
x = torch.cat((cls_tokens, x), dim=1) | |
x = x + self.pos_embed | |
for blk in self.blocks: | |
x = blk(x) | |
return self.norm(x)[:, 0] | |
class OpenLRM(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.encoder = nn.ModuleDict({"model": VisionTransformer(config)}) | |
hidden = config.get("hidden_size", 768) if isinstance(config, dict) else config.__dict__.get("hidden_size", 768) | |
self.linear = nn.Linear(hidden, hidden) | |
def forward(self, x): | |
return self.linear(self.encoder["model"](x)) | |
class VideoUNet(nn.Module): | |
def __init__(self, in_ch=4, out_ch=4, features=None): | |
super().__init__() | |
if features is None: | |
features = [64, 128, 256] | |
self.encoder = nn.ModuleList() | |
self.pool = nn.MaxPool3d(2, 2) | |
self.decoder = nn.ModuleList() | |
for f in features: | |
self.encoder.append(nn.Sequential( | |
nn.Conv3d(in_ch, f, 3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv3d(f, f, 3, padding=1), | |
nn.ReLU(inplace=True) | |
)) | |
in_ch = f | |
for f in reversed(features): | |
self.decoder.append(nn.Sequential( | |
nn.Conv3d(f * 2, f, 3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv3d(f, f, 3, padding=1), | |
nn.ReLU(inplace=True) | |
)) | |
self.final_conv = nn.Conv3d(features[0], out_ch, 1) | |
def forward(self, x, t, encoder_hidden_states): | |
skips = [] | |
for enc in self.encoder: | |
x = enc(x) | |
skips.append(x) | |
x = self.pool(x) | |
for dec in self.decoder: | |
skip = skips.pop() | |
x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=False) | |
x = torch.cat([x, skip], dim=1) | |
x = dec(x) | |
return self.final_conv(x) | |
class SentimentClassifierModel(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.classifier = nn.Sequential( | |
nn.Linear(768, 256), | |
nn.ReLU(), | |
nn.Linear(256, 2) | |
) | |
def forward(self, x): | |
return self.classifier(x) | |
class STTModel(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(768, 512), | |
nn.ReLU(), | |
nn.Linear(512, 768) | |
) | |
def forward(self, x): | |
return self.net(x) | |
class TTSModel(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(768, 512), | |
nn.ReLU(), | |
nn.Linear(512, 768) | |
) | |
def forward(self, x): | |
return self.net(x) | |
class MusicGenModel(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
layer = nn.TransformerEncoderLayer(d_model=768, nhead=12) | |
self.transformer = nn.TransformerEncoder(layer, num_layers=12) | |
self.linear = nn.Linear(768, 768) | |
def forward(self, x): | |
return self.linear(self.transformer(x)) | |
class SimpleTextEncoder(nn.Module): | |
def __init__(self, vocab_size=10000, embed_dim=768, max_length=77): | |
super().__init__() | |
self.embedding = nn.Embedding(vocab_size, embed_dim) | |
self.max_length = max_length | |
def forward(self, text_tokens): | |
return self.embedding(text_tokens) | |
class DiffusionScheduler: | |
def __init__(self, steps): | |
self.steps = steps | |
self.betas = torch.linspace(0.1, 0.001, steps=steps).to(device) | |
self.alphas = 1 - self.betas | |
self.alpha_bars = torch.cumprod(self.alphas, dim=0) | |
def step(self, noise, t, sample): | |
alpha_bar = self.alpha_bars[t] | |
alpha_bar_prev = self.alpha_bars[t-1] if t > 0 else torch.tensor(1.0, device=sample.device) | |
x0 = (sample - torch.sqrt(1 - alpha_bar) * noise) / torch.sqrt(alpha_bar) | |
new_sample = torch.sqrt(alpha_bar_prev) * x0 + torch.sqrt(1 - alpha_bar_prev) * noise | |
return new_sample | |
class VideoOutput: | |
def __init__(self, frames): | |
self.frames = [img_as_ubyte(frame) for frame in frames[0]] | |
class VideoPipeline(nn.Module): | |
def __init__(self, unet, vae, text_encoder, vocab): | |
super().__init__() | |
self.unet = unet | |
self.vae = vae | |
self.text_encoder = text_encoder | |
self.vocab = vocab | |
def forward(self, prompt: str, steps: int = 25, num_frames: int = 24): | |
token_ids = simple_tokenizer(prompt, self.vocab) | |
text_emb = self.text_encoder(token_ids) | |
latent = torch.randn((1, 4, num_frames, 64, 64), device=device).half() | |
sched = DiffusionScheduler(steps) | |
for t in range(steps): | |
noise = self.unet(latent, t, text_emb) | |
latent = sched.step(noise, t, latent) | |
frames = self.vae.decode(latent / 0.18215) | |
frames = frames.clamp(0, 1).float().cpu().permute(0, 2, 3, 4, 1).numpy() | |
return VideoOutput(frames) | |
def initialize_gpt2_model(folder, files): | |
download_files(folder, files) | |
config = GPT2Config() | |
model = GPT2LMHeadModel(config).to(device) | |
sd = torch.load(os.path.join(folder, sanitize_filename("gpt2-pytorch_model.bin")), map_location=device) | |
load_state_dict_safe(model, sd) | |
model.eval() | |
enc = read_json(os.path.join(folder, sanitize_filename("encoder.json"))) | |
return model, enc | |
def initialize_translation_model(folder, files): | |
download_files(folder, files) | |
config = MBartConfig.from_dict(read_json(os.path.join(folder, "config.json"))) | |
model = MBartForConditionalGeneration(config).to(device) | |
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device) | |
load_state_dict_safe(model, sd) | |
model.eval() | |
vp = os.path.join(folder, "vocab.json") | |
if os.path.exists(vp): | |
vocab = read_json(vp) | |
model.tokenizer = lambda txt: [vocab.get(t, 0) for t in txt.split()] | |
else: | |
model.tokenizer = lambda txt: txt | |
model.config.lang_code_to_id = {'en_XX': 0, 'es_XX': 1} | |
return model | |
def initialize_codegen_model(folder, files): | |
download_files(folder, files) | |
config = CodeGenConfig.from_dict(read_json(os.path.join(folder, "config.json"))) | |
model = CodeGenForCausalLM(config).to(device) | |
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device) | |
load_state_dict_safe(model, sd) | |
model.eval() | |
tok = get_codegen_tokenizer(os.path.join(folder, "vocab.json"), os.path.join(folder, "merges.txt")) | |
vocab = read_json(os.path.join(folder, "vocab.json")) | |
idx2w = {v: k for k, v in vocab.items()} | |
model.tokenizer = tok | |
return model, tok, vocab, idx2w, vocab | |
def initialize_summarization_model(folder, files): | |
download_files(folder, files) | |
config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json"))) | |
model = BartForConditionalGeneration(config).to(device) | |
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device) | |
load_state_dict_safe(model, sd) | |
model.eval() | |
vp = os.path.join(folder, "vocab.json") | |
if os.path.exists(vp): | |
vocab_json = read_json(vp) | |
vocab = set(vocab_json.keys()) | |
return model, vocab, vocab_json, {v: k for k, v in vocab_json.items()} | |
return model, None, None, None | |
def initialize_imagegen_model(folder, files): | |
download_files(folder, files) | |
config = AutoencoderKLConfig.from_dict(read_json(os.path.join(folder, "config.json"))) | |
vae = AutoencoderKL(config).to(device) | |
sd = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device) | |
load_state_dict_safe(vae, sd) | |
vae.eval() | |
return vae | |
def initialize_image_to_3d_model(folder, files): | |
download_files(folder, files) | |
config = OpenLRMConfig.from_dict(read_json(os.path.join(folder, "config.json"))) | |
model3d = OpenLRM(config).to(device) | |
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device) | |
load_state_dict_safe(model3d, sd) | |
model3d.eval() | |
return model3d | |
def initialize_text_to_video_model(folder, files): | |
download_files(folder, files) | |
unet_cfg = read_json(os.path.join(folder, "config.json")) | |
unet_cfg = filter_kwargs(VideoUNet, unet_cfg) | |
unet = VideoUNet(**unet_cfg).half().to(device) | |
sd_unet = torch.load(os.path.join(folder, "diffusion_pytorch_model.fp16.bin"), map_location=device) | |
load_state_dict_safe(unet, sd_unet) | |
unet.eval() | |
vae_cfg = read_json(os.path.join(folder, "config.json")) | |
vae_cfg = filter_kwargs(AutoencoderKL, vae_cfg) | |
vae = AutoencoderKL(vae_cfg).half().to(device) | |
sd_vae = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device) | |
load_state_dict_safe(vae, sd_vae) | |
vae.eval() | |
vp = os.path.join(folder, "vocab.json") | |
text_vocab = read_json(vp) if os.path.exists(vp) else {} | |
te_path = os.path.join(folder, "text_encoder.bin") | |
if os.path.exists(te_path): | |
text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device) | |
sd_te = torch.load(te_path, map_location=device) | |
load_state_dict_safe(text_encoder, sd_te) | |
else: | |
text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device) | |
text_encoder.eval() | |
return VideoPipeline(unet, vae, text_encoder, text_vocab) | |
def initialize_sentiment_model(folder, files): | |
download_files(folder, files) | |
config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json"))) | |
model = SentimentClassifierModel(config).to(device) | |
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device) | |
load_state_dict_safe(model, sd) | |
model.eval() | |
vp = os.path.join(folder, "vocab.json") | |
if os.path.exists(vp): | |
read_json(vp) | |
return model | |
def initialize_stt_model(folder, files): | |
download_files(folder, files) | |
config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json"))) | |
model = STTModel(config).to(device) | |
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device) | |
load_state_dict_safe(model, sd) | |
model.eval() | |
vp = os.path.join(folder, "vocab.json") | |
if os.path.exists(vp): | |
read_json(vp) | |
return model | |
def initialize_tts_model(folder, files): | |
download_files(folder, files) | |
config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json"))) | |
model = TTSModel(config).to(device) | |
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device) | |
load_state_dict_safe(model, sd) | |
model.eval() | |
vp = os.path.join(folder, "vocab.json") | |
if os.path.exists(vp): | |
read_json(vp) | |
return model | |
def initialize_musicgen_model(folder, files): | |
download_files(folder, files) | |
config = MusicGenConfig.from_dict(read_json(os.path.join(folder, "config.json"))) | |
model = MusicGenModel(config).to(device) | |
sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device) | |
load_state_dict_safe(model, sd) | |
model.eval() | |
return model | |