import logging import threading from typing import AsyncIterator import torch from transformers import AutoModelForCausalLM, AutoTokenizer from server.audio_utils import split_sentences log = logging.getLogger(__name__) class LLMEngine: """Wraps Qwen3 for conversation generation.""" SYSTEM_PROMPT = ( "You are a helpful voice assistant. Keep your responses concise and natural " "for spoken conversation. Respond in 1-3 short sentences. " "Do not use markdown, bullet points, code blocks, emojis, or any " "formatting that doesn't work in speech." ) def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer): self.model = model self.tokenizer = tokenizer def _build_inputs(self, messages: list[dict]): """Build input token ids using the model's chat template.""" chat_messages = [{"role": "system", "content": self.SYSTEM_PROMPT}] for msg in messages: chat_messages.append({"role": msg["role"], "content": msg["content"]}) text = self.tokenizer.apply_chat_template( chat_messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, ) return self.tokenizer(text, return_tensors="pt").to(self.model.device) def generate(self, messages: list[dict], max_new_tokens: int = 256) -> str: """Generate a complete response (blocking).""" inputs = self._build_inputs(messages) input_len = inputs["input_ids"].shape[1] with torch.no_grad(): output_ids = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=0.7, top_p=0.9, do_sample=True, repetition_penalty=1.2, ) # Decode only the generated tokens (skip prompt) new_ids = output_ids[0][input_len:] response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip() log.info(f"LLM response: {response}") return response async def generate_sentences( self, messages: list[dict], cancel_event: threading.Event | None = None, ) -> AsyncIterator[str]: """Generate response and yield it sentence by sentence for TTS pipelining.""" import asyncio response = await asyncio.to_thread(self.generate, messages) if cancel_event and cancel_event.is_set(): return # Split into sentences and yield each sentences, remainder = split_sentences(response) for sentence in sentences: if cancel_event and cancel_event.is_set(): return yield sentence if remainder: yield remainder