import logging import re from threading import Thread from typing import List, Optional import torch import spaces from transformers import ( AutoModelForCausalLM, AutoProcessor, AutoConfig, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer, ) from .chat_utils import Conversation, get_conv_template logger = logging.getLogger(__name__) def load_model(model_path: str = "moonshotai/Kimi-VL-A3B-Thinking"): # hotfix the model to use flash attention 2 config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) config._attn_implementation = "flash_attention_2" config.vision_config._attn_implementation = "flash_attention_2" config.text_config._attn_implementation = "flash_attention_2" print("Successfully set the attn_implementation to flash_attention_2") model = AutoModelForCausalLM.from_pretrained( model_path, config=config, torch_dtype="auto", device_map="auto", trust_remote_code=True, ) processor = AutoProcessor.from_pretrained(model_path, config=config, trust_remote_code=True) return model, processor class StoppingCriteriaSub(StoppingCriteria): def __init__(self, stops=[], encounters=1): super().__init__() self.stops = [stop.to("cuda") for stop in stops] def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs): for stop in self.stops: if input_ids.shape[-1] < len(stop): continue if torch.all((stop == input_ids[0][-len(stop) :])).item(): return True return False def format_messages( conversations: list[Conversation], system_prompt: Optional[str] = "", sft_format: Optional[str] = "kimi-vl", ): """ Format the conversations to the input format of the model. """ converstion = get_conv_template(sft_format) converstion.set_system_message(system_prompt) for message in conversations: converstion.append_message(message["role"], message["content"]) return converstion def preprocess( messages: list[dict], processor, sft_format: Optional[str] = "kimi-vl", ): """ Build messages from the conversations and images. """ # get images from conversations results = [] images = [] # get texts from conversations converstion = get_conv_template(sft_format) # only use the last 3 round of messages latest_messages = messages[-3:] for mid, message in enumerate(latest_messages): if message["role"] == converstion.roles[0] or message["role"] == "user": record = { "role": message["role"], "content": [], } if "images" in message: per_round_images = message["images"] if len(per_round_images) > 2: per_round_images = per_round_images[-2:] print(f"Only use the last 2 images in the {mid}-th round") images.extend(per_round_images) for image in per_round_images: record["content"].append( { "type": "image", "image": image, } ) if 'content' in message: record["content"].append( { "type": "text", "text": str(message["content"]).strip(), } ) results.append(record) elif message["role"] == converstion.roles[1] or message["role"] == "assistant": formatted_answer = message["content"].strip() # ◁think▷用户说了“你好”,这是一个非常简单的问候,通常用于开启对话。我需要判断用户的意图。可能性一:用户只是礼貌性地打招呼,想要开启一段对话;可能性二:用户可能有更具体的需求,比如询问我的功能、功能或者需要帮助。由于用户没有提供更多信息,我需要保持开放,同时引导用户进一步说明他们的需求。 # 我的回复需要既友好又开放,不能显得过于正式或冷漠。同时,我需要避免假设用户的具体需求,而是提供一个轻松的、鼓励继续对话的回应。◁/think▷你好!很高兴见到你。有什么我可以帮助你的吗 # delete all the texts between ◁think▷ and ◁/think▷ # FIXME: this is a hack to remove the thinking texts # formatted_answer = re.sub(r"◁think▷.*◁/think▷", "", formatted_answer) think_end_token = '◁/think▷' formatted_answer = formatted_answer.split(think_end_token)[-1] results.append( { "role": message["role"], "content": [ { "type": "text", "text": formatted_answer, } ], } ) assert ( formatted_answer.count(processor.image_token) == 0 ), f"there should be no {processor.image_token} in the assistant's reply, but got {messages}" converstion.append_message(converstion.roles[1], formatted_answer) text = processor.apply_chat_template(results, add_generation_prompt=True) print(f"raw text = {text}") if len(images) == 0: images = None inputs = processor( images=images, text=[text], return_tensors="pt", padding=True, truncation=True, ) return inputs @torch.no_grad() @torch.inference_mode() def kimi_vl_generate( model: torch.nn.Module, processor: AutoProcessor, conversations: list[Conversation], stop_words: list, max_length: int = 256, temperature: float = 1.0, top_p: float = 1.0, chunk_size: int = -1, ): # convert conversation to inputs print(f"conversations = {conversations}") inputs = preprocess(conversations, processor=processor) inputs = inputs.to(model.device) return generate( model, processor, inputs, max_gen_len=max_length, temperature=temperature, top_p=top_p, stop_words=stop_words, chunk_size=chunk_size, ) def generate( model, processor, inputs, max_gen_len: int = 256, temperature: float = 0, top_p: float = 0.95, stop_words: List[str] = [], chunk_size: int = -1, ): """Stream the text output from the multimodality model with prompt and image inputs.""" tokenizer = processor.tokenizer stop_words_ids = [torch.tensor(tokenizer.encode(stop_word)) for stop_word in stop_words] stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) kwargs = dict( **inputs, max_new_tokens=max_gen_len, do_sample=True, use_cache=True, streamer=streamer, stopping_criteria=stopping_criteria, ) if temperature > 0: kwargs.update( { "do_sample": True, "top_p": top_p, "temperature": temperature, } ) else: kwargs["do_sample"] = False thread = Thread(target=model.generate, kwargs=kwargs) thread.start() yield from streamer