import copy import dataclasses import logging import threading from typing import AsyncIterator import torch from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache from server.audio_utils import split_sentences log = logging.getLogger(__name__) @dataclasses.dataclass class KVCacheState: """Per-session KV-cache persisted across generate() calls.""" past_key_values: DynamicCache | None cached_token_count: int cached_messages: list[dict] # snapshot of messages when cache was built class LLMEngine: """Wraps Qwen3 for conversation generation with persistent KV-cache.""" def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, system_prompt: str, max_cache_tokens: int = 4096): self.model = model self.tokenizer = tokenizer self.system_prompt = system_prompt self.max_cache_tokens = max_cache_tokens self._generate_lock = threading.Lock() def _build_inputs(self, messages: list[dict]): """Build input token ids using the model's chat template.""" chat_messages = [{"role": "system", "content": self.system_prompt}] for msg in messages: chat_messages.append({"role": msg["role"], "content": msg["content"]}) text = self.tokenizer.apply_chat_template( chat_messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, ) return self.tokenizer(text, return_tensors="pt").to(self.model.device) def _validate_cache(self, messages: list[dict], cache_state: KVCacheState | None) -> DynamicCache | None: """Return past_key_values if the cache is valid for the given messages, else None.""" if cache_state is None or cache_state.past_key_values is None: return None if self.max_cache_tokens and cache_state.cached_token_count > self.max_cache_tokens: log.info("KV-cache exceeds max size, discarding.") return None cached = cache_state.cached_messages # The current messages must start with the cached messages as a prefix if len(cached) > len(messages): return None for cached_msg, current_msg in zip(cached, messages): if cached_msg["role"] != current_msg["role"] or cached_msg["content"] != current_msg["content"]: return None return cache_state.past_key_values def generate( self, messages: list[dict], max_new_tokens: int = 256, cache_state: KVCacheState | None = None, ) -> tuple[str, KVCacheState]: """Generate a complete response (blocking). Returns (response, updated_cache_state).""" with self._generate_lock: inputs = self._build_inputs(messages) input_ids = inputs["input_ids"] input_len = input_ids.shape[1] past_kv = self._validate_cache(messages, cache_state) cached_len = cache_state.cached_token_count if past_kv is not None else 0 log.info( f"KV-cache: {cached_len}/{input_len} tokens cached, " f"processing {input_len - cached_len} new tokens" ) with torch.no_grad(): outputs = self.model.generate( input_ids=input_ids, attention_mask=inputs.get("attention_mask"), past_key_values=past_kv, max_new_tokens=max_new_tokens, temperature=0.7, top_p=0.9, do_sample=True, repetition_penalty=1.2, return_dict_in_generate=True, use_cache=True, ) # Decode only the generated tokens (skip prompt) new_ids = outputs.sequences[0][input_len:] response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip() log.info(f"LLM response: {response}") # Build updated cache state: messages now include the assistant response new_messages = copy.deepcopy(messages) + [{"role": "assistant", "content": response}] new_cache = KVCacheState( past_key_values=outputs.past_key_values, cached_token_count=outputs.sequences.shape[1], cached_messages=new_messages, ) return response, new_cache def trim_cache(self, cache_state: KVCacheState, messages: list[dict]) -> KVCacheState | None: """Trim cache to match the actual conversation history (e.g. after barge-in).""" if cache_state is None or cache_state.past_key_values is None: return None inputs = self._build_inputs(messages) target_len = inputs["input_ids"].shape[1] if target_len >= cache_state.cached_token_count: return cache_state cache_state.past_key_values.crop(target_len) cache_state.cached_token_count = target_len cache_state.cached_messages = copy.deepcopy(messages) return cache_state async def generate_sentences( self, messages: list[dict], cancel_event: threading.Event | None = None, cache_state: KVCacheState | None = None, ) -> AsyncIterator[str]: """Generate response and yield it sentence by sentence for TTS pipelining.""" import asyncio response = await asyncio.to_thread(self.generate, messages, 256, cache_state) if cancel_event and cancel_event.is_set(): return # Split into sentences and yield each sentences, remainder = split_sentences(response) for sentence in sentences: if cancel_event and cancel_event.is_set(): return yield sentence if remainder: yield remainder class LMStudioEngine: """LLM engine that delegates to an LM Studio server via its OpenAI-compatible API.""" def __init__(self, base_url: str, model: str, system_prompt: str): self.base_url = base_url.rstrip("/") self.model = model self.system_prompt = system_prompt def generate( self, messages: list[dict], max_new_tokens: int = 256, cache_state: KVCacheState | None = None, ) -> tuple[str, None]: import requests payload_messages = [{"role": "system", "content": self.system_prompt}] payload_messages.extend(messages) body: dict = { "messages": payload_messages, "max_tokens": max_new_tokens, "temperature": 0.7, "stream": False, } if self.model: body["model"] = self.model resp = requests.post( f"{self.base_url}/v1/chat/completions", json=body, timeout=30, ) resp.raise_for_status() response = resp.json()["choices"][0]["message"]["content"].strip() log.info(f"LM Studio response: {response}") return response, None async def generate_sentences( self, messages: list[dict], cancel_event: threading.Event | None = None, ) -> AsyncIterator[str]: """Generate response and yield it sentence by sentence for TTS pipelining.""" import asyncio response = await asyncio.to_thread(self.generate, messages) if cancel_event and cancel_event.is_set(): return sentences, remainder = split_sentences(response) for sentence in sentences: if cancel_event and cancel_event.is_set(): return yield sentence if remainder: yield remainder