From 175ed943dfae6c6c3923079ca9f14ee6e01f2435 Mon Sep 17 00:00:00 2001 From: Brian Date: Wed, 8 Apr 2026 10:25:03 -0400 Subject: [PATCH] add KV-cache and move system promt to the config --- config.yml | 7 +++ server/llm.py | 143 ++++++++++++++++++++++++++++++++------------- server/models.py | 11 ++-- server/pipeline.py | 13 ++++- 4 files changed, 127 insertions(+), 47 deletions(-) diff --git a/config.yml b/config.yml index 133620c..ded8996 100644 --- a/config.yml +++ b/config.yml @@ -1,6 +1,13 @@ # LLM backend: "local" or "lmstudio" llm: backend: local # change to "lmstudio" to use LM Studio instead + max_cache_tokens: 4096 # max KV-cache size per session (tokens); 0 to disable caching + + system_prompt: >- + You are a helpful voice assistant. Keep your responses concise and natural + for spoken conversation. Respond in 1-3 short sentences. + Do not use markdown, bullet points, code blocks, emojis, or any + formatting that doesn't work in speech. # Settings used only when backend = "lmstudio" lmstudio: diff --git a/server/llm.py b/server/llm.py index ab4b7b9..0a3ccd7 100644 --- a/server/llm.py +++ b/server/llm.py @@ -1,32 +1,38 @@ +import copy +import dataclasses import logging import threading from typing import AsyncIterator import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache from server.audio_utils import split_sentences log = logging.getLogger(__name__) +@dataclasses.dataclass +class KVCacheState: + """Per-session KV-cache persisted across generate() calls.""" + past_key_values: DynamicCache | None + cached_token_count: int + cached_messages: list[dict] # snapshot of messages when cache was built + class LLMEngine: - """Wraps Qwen3 for conversation generation.""" + """Wraps Qwen3 for conversation generation with persistent KV-cache.""" - SYSTEM_PROMPT = ( - "You are a helpful voice assistant. Keep your responses concise and natural " - "for spoken conversation. Respond in 1-3 short sentences. " - "Do not use markdown, bullet points, code blocks, emojis, or any " - "formatting that doesn't work in speech." - ) - - def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer): + def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, system_prompt: str, + max_cache_tokens: int = 4096): self.model = model self.tokenizer = tokenizer + self.system_prompt = system_prompt + self.max_cache_tokens = max_cache_tokens + self._generate_lock = threading.Lock() def _build_inputs(self, messages: list[dict]): """Build input token ids using the model's chat template.""" - chat_messages = [{"role": "system", "content": self.SYSTEM_PROMPT}] + chat_messages = [{"role": "system", "content": self.system_prompt}] for msg in messages: chat_messages.append({"role": msg["role"], "content": msg["content"]}) @@ -38,36 +44,93 @@ class LLMEngine: ) return self.tokenizer(text, return_tensors="pt").to(self.model.device) - def generate(self, messages: list[dict], max_new_tokens: int = 256) -> str: - """Generate a complete response (blocking).""" - inputs = self._build_inputs(messages) - input_len = inputs["input_ids"].shape[1] + def _validate_cache(self, messages: list[dict], cache_state: KVCacheState | None) -> DynamicCache | None: + """Return past_key_values if the cache is valid for the given messages, else None.""" + if cache_state is None or cache_state.past_key_values is None: + return None + if self.max_cache_tokens and cache_state.cached_token_count > self.max_cache_tokens: + log.info("KV-cache exceeds max size, discarding.") + return None + cached = cache_state.cached_messages + # The current messages must start with the cached messages as a prefix + if len(cached) > len(messages): + return None + for cached_msg, current_msg in zip(cached, messages): + if cached_msg["role"] != current_msg["role"] or cached_msg["content"] != current_msg["content"]: + return None + return cache_state.past_key_values - with torch.no_grad(): - output_ids = self.model.generate( - **inputs, - max_new_tokens=max_new_tokens, - temperature=0.7, - top_p=0.9, - do_sample=True, - repetition_penalty=1.2, + def generate( + self, + messages: list[dict], + max_new_tokens: int = 256, + cache_state: KVCacheState | None = None, + ) -> tuple[str, KVCacheState]: + """Generate a complete response (blocking). Returns (response, updated_cache_state).""" + with self._generate_lock: + inputs = self._build_inputs(messages) + input_ids = inputs["input_ids"] + input_len = input_ids.shape[1] + + past_kv = self._validate_cache(messages, cache_state) + cached_len = cache_state.cached_token_count if past_kv is not None else 0 + + log.info( + f"KV-cache: {cached_len}/{input_len} tokens cached, " + f"processing {input_len - cached_len} new tokens" ) - # Decode only the generated tokens (skip prompt) - new_ids = output_ids[0][input_len:] - response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip() - log.info(f"LLM response: {response}") - return response + with torch.no_grad(): + outputs = self.model.generate( + input_ids=input_ids, + attention_mask=inputs.get("attention_mask"), + past_key_values=past_kv, + max_new_tokens=max_new_tokens, + temperature=0.7, + top_p=0.9, + do_sample=True, + repetition_penalty=1.2, + return_dict_in_generate=True, + use_cache=True, + ) + + # Decode only the generated tokens (skip prompt) + new_ids = outputs.sequences[0][input_len:] + response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip() + log.info(f"LLM response: {response}") + + # Build updated cache state: messages now include the assistant response + new_messages = copy.deepcopy(messages) + [{"role": "assistant", "content": response}] + new_cache = KVCacheState( + past_key_values=outputs.past_key_values, + cached_token_count=outputs.sequences.shape[1], + cached_messages=new_messages, + ) + return response, new_cache + + def trim_cache(self, cache_state: KVCacheState, messages: list[dict]) -> KVCacheState | None: + """Trim cache to match the actual conversation history (e.g. after barge-in).""" + if cache_state is None or cache_state.past_key_values is None: + return None + inputs = self._build_inputs(messages) + target_len = inputs["input_ids"].shape[1] + if target_len >= cache_state.cached_token_count: + return cache_state + cache_state.past_key_values.crop(target_len) + cache_state.cached_token_count = target_len + cache_state.cached_messages = copy.deepcopy(messages) + return cache_state async def generate_sentences( self, messages: list[dict], cancel_event: threading.Event | None = None, + cache_state: KVCacheState | None = None, ) -> AsyncIterator[str]: """Generate response and yield it sentence by sentence for TTS pipelining.""" import asyncio - response = await asyncio.to_thread(self.generate, messages) + response = await asyncio.to_thread(self.generate, messages, 256, cache_state) if cancel_event and cancel_event.is_set(): return @@ -83,25 +146,23 @@ class LLMEngine: yield remainder -SYSTEM_PROMPT = ( - "You are a helpful voice assistant. Keep your responses concise and natural " - "for spoken conversation. Respond in 1-3 short sentences. " - "Do not use markdown, bullet points, code blocks, emojis, or any " - "formatting that doesn't work in speech." -) - - class LMStudioEngine: """LLM engine that delegates to an LM Studio server via its OpenAI-compatible API.""" - def __init__(self, base_url: str, model: str): + def __init__(self, base_url: str, model: str, system_prompt: str): self.base_url = base_url.rstrip("/") self.model = model + self.system_prompt = system_prompt - def generate(self, messages: list[dict], max_new_tokens: int = 256) -> str: + def generate( + self, + messages: list[dict], + max_new_tokens: int = 256, + cache_state: KVCacheState | None = None, + ) -> tuple[str, None]: import requests - payload_messages = [{"role": "system", "content": SYSTEM_PROMPT}] + payload_messages = [{"role": "system", "content": self.system_prompt}] payload_messages.extend(messages) body: dict = { @@ -121,7 +182,7 @@ class LMStudioEngine: resp.raise_for_status() response = resp.json()["choices"][0]["message"]["content"].strip() log.info(f"LM Studio response: {response}") - return response + return response, None async def generate_sentences( self, diff --git a/server/models.py b/server/models.py index efa459c..00042d1 100644 --- a/server/models.py +++ b/server/models.py @@ -68,16 +68,18 @@ class ModelManager: def _load_llm(self): from server.config import config - backend = config.get("llm", {}).get("backend", "local") + llm_config = config.get("llm", {}) + backend = llm_config.get("backend", "local") + system_prompt = llm_config.get("system_prompt", "You are a helpful assistant.") if backend == "lmstudio": from server.llm import LMStudioEngine - lms = config.get("llm", {}).get("lmstudio", {}) + lms = llm_config.get("lmstudio", {}) url = lms.get("url", "http://host.docker.internal:1234") model = lms.get("model", "") or "" log.info(f"Using LM Studio backend at {url} (model={model or 'server default'})") - self.llm_engine = LMStudioEngine(url, model) + self.llm_engine = LMStudioEngine(url, model, system_prompt) else: log.info("Loading Qwen3-4B (GPTQ 4-bit)...") from transformers import AutoModelForCausalLM, AutoTokenizer @@ -90,7 +92,8 @@ class ModelManager: model_name, device_map=device, ) - self.llm_engine = LLMEngine(model, tokenizer) + max_cache_tokens = llm_config.get("max_cache_tokens", 4096) + self.llm_engine = LLMEngine(model, tokenizer, system_prompt, max_cache_tokens) log.info("Qwen3-4B-GPTQ-Int4 loaded (~2.5GB VRAM).") def _load_tts(self): diff --git a/server/pipeline.py b/server/pipeline.py index 064d3ac..47fa2b9 100644 --- a/server/pipeline.py +++ b/server/pipeline.py @@ -6,6 +6,7 @@ import threading import numpy as np from server.audio_utils import float32_to_pcm_bytes +from server.llm import KVCacheState from server.models import ModelManager from server.vad import StreamingVAD @@ -27,6 +28,7 @@ class ConversationSession: self.vad: StreamingVAD = models.create_vad() self.conversation_history: list[dict] = [] + self.kv_cache_state: KVCacheState | None = None self.cancel_event = threading.Event() self.is_responding = False self._response_task: asyncio.Task | None = None @@ -38,6 +40,7 @@ class ConversationSession: self.cancel_event.set() if self._response_task and not self._response_task.done(): self._response_task.cancel() + self.kv_cache_state = None async def handle_audio_chunk(self, chunk_16k: np.ndarray): utterance = self.vad.process_chunk(chunk_16k) @@ -91,8 +94,8 @@ class ConversationSession: # LLM log.info(f"Conversation history ({len(self.conversation_history)} messages): " + str([m['content'][:50] for m in self.conversation_history])) - response = await asyncio.to_thread( - self.models.llm_engine.generate, self.conversation_history + response, self.kv_cache_state = await asyncio.to_thread( + self.models.llm_engine.generate, self.conversation_history, 256, self.kv_cache_state ) if self.cancel_event.is_set(): @@ -147,12 +150,18 @@ class ConversationSession: tts_thread.join(timeout=2.0) # Save only what was actually spoken + was_interrupted = spoken_text.strip() != response.strip() if spoken_text.strip(): self.conversation_history.append( {"role": "assistant", "content": spoken_text.strip()} ) + if was_interrupted and self.kv_cache_state is not None: + self.kv_cache_state = self.models.llm_engine.trim_cache( + self.kv_cache_state, self.conversation_history + ) elif self.conversation_history and self.conversation_history[-1]["role"] == "user": self.conversation_history.pop() + self.kv_cache_state = None if not self.cancel_event.is_set(): await self.send_json({"type": "response_text", "text": "", "final": True})