import logging import gradio as gr import torch from transformers import ( MODEL_FOR_MASKED_LM_MAPPING, ) from sdlm.arguments import get_args from sdlm.models.utils import load_model from sdlm.pipelines.simplex_ddpm import SimplexDDPMPipeline from sdlm.schedulers import TokenWiseSimplexDDPMScheduler logger = logging.getLogger(__name__) MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) def main(): model_args, data_args, training_args, diffusion_args = get_args() tokenizer, model = load_model(model_args, data_args, training_args, diffusion_args, logger) model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pipeline = SimplexDDPMPipeline( model=model.to(device), scheduler=TokenWiseSimplexDDPMScheduler( num_train_timesteps=diffusion_args.num_train_timesteps if hasattr(diffusion_args, "num_train_timesteps") else 100, beta_schedule=getattr(diffusion_args, "beta_schedule", "squaredcos_improved_ddpm"), simplex_value=getattr(diffusion_args, "simplex_value", 5.0), clip_sample=getattr(diffusion_args, "clip_sample", False), device=device, ), simplex_value=getattr(diffusion_args, "simplex_value", 5.0), top_p=getattr(diffusion_args, "top_p", 0.99), sampling_type="top_p", is_conditional_generation=True, tokenizer=tokenizer, classifier_free_uncond_input="empty_token", temperature=getattr(diffusion_args, "temperature", 1.0), guidance_softmax_combination=True, ) def generate( inputs, simplex_value=5.0, top_p=0.99, temperature=1.0, diffusion_steps=100, beta_schedule="squaredcos_improved_ddpm", clip_sample=False, guidance_scale=1.0, generated_sequence_length=256, progress=gr.Progress(), ): """ Gradio-friendly generation function. Adjusts the pipeline's parameters (simplex_value, top_p, etc.) as requested, then runs generation. """ with torch.inference_mode(): # Update pipeline scheduler with user-provided parameters: pipeline.scheduler.num_train_timesteps = diffusion_steps pipeline.scheduler.beta_schedule = beta_schedule pipeline.scheduler.simplex_value = simplex_value pipeline.scheduler.clip_sample = clip_sample pipeline.simplex_value = simplex_value pipeline.top_p = top_p pipeline.temperature = temperature # tulu chat template inputs = "<|user|>\n" + inputs + "<|assistant|>\n" # Tokenize and prepare input for diffusion tokenized_input = tokenizer([inputs], add_special_tokens=False, return_tensors="pt").input_ids tokenized_input_len = tokenized_input.shape[1] # Concatenate BOS + input + blank space for generation tokenized_input = torch.cat( [ torch.ones((1, 1), dtype=torch.long) * tokenizer.bos_token_id, tokenized_input, torch.ones((1, generated_sequence_length), dtype=torch.long) * tokenizer.pad_token_id, ], dim=-1, ) # Create a mask over the generation region span_mask = torch.cat( [ torch.zeros((1, tokenized_input_len + 1), dtype=torch.bool), torch.ones((1, generated_sequence_length), dtype=torch.bool), ], dim=-1, ) batch = { "input_ids": tokenized_input.to(device), "span_mask": span_mask.to(device), } # Run sampling pipe = pipeline(batch=batch, seq_length=generated_sequence_length, guidance_scale=guidance_scale) for out in pipe: output_ids = out.logits.argmax(dim=-1) generated_tokens = output_ids[:, tokenized_input_len + 1 :] yield tokenizer.decode(generated_tokens[0], skip_special_tokens=True) # Quick test call (uncomment if you want a quick, non-Gradio test) print("Test generation: ", generate("The best things in life are")) demo = gr.Interface( fn=generate, inputs=[ gr.Textbox(lines=5, label="Input Prompt"), gr.Number(value=5.0, label="Simplex value"), gr.Slider(0, 1, value=0.99, step=0.01, label="Top-p"), gr.Slider(0, 5, value=1.0, step=0.1, label="Temperature"), gr.Number(value=100, precision=0, label="Diffusion steps"), gr.Dropdown( choices=["linear", "scaled_linear", "squaredcos_cap_v2", "squaredcos_improved_ddpm"], value="squaredcos_improved_ddpm", label="Beta schedule", ), gr.Checkbox(value=False, label="Clip sample?"), gr.Number(value=1.0, label="Guidance scale"), gr.Number(value=256, label="Generation length (tokens)"), ], outputs="text", title="Simplex Diffusion LM", description="Generate text using a simplex-based diffusion model.", ) demo.queue().launch(server_name="0.0.0.0", server_port=8888, share=True) if __name__ == "__main__": main()