add KV-cache and move system promt to the config

This commit is contained in:
2026-04-08 10:25:03 -04:00
parent c7c4019ecc
commit 175ed943df
4 changed files with 127 additions and 47 deletions
+7
View File
@@ -1,6 +1,13 @@
# LLM backend: "local" or "lmstudio" # LLM backend: "local" or "lmstudio"
llm: llm:
backend: local # change to "lmstudio" to use LM Studio instead backend: local # change to "lmstudio" to use LM Studio instead
max_cache_tokens: 4096 # max KV-cache size per session (tokens); 0 to disable caching
system_prompt: >-
You are a helpful voice assistant. Keep your responses concise and natural
for spoken conversation. Respond in 1-3 short sentences.
Do not use markdown, bullet points, code blocks, emojis, or any
formatting that doesn't work in speech.
# Settings used only when backend = "lmstudio" # Settings used only when backend = "lmstudio"
lmstudio: lmstudio:
+92 -31
View File
@@ -1,32 +1,38 @@
import copy
import dataclasses
import logging import logging
import threading import threading
from typing import AsyncIterator from typing import AsyncIterator
import torch import torch
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
from server.audio_utils import split_sentences from server.audio_utils import split_sentences
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@dataclasses.dataclass
class KVCacheState:
"""Per-session KV-cache persisted across generate() calls."""
past_key_values: DynamicCache | None
cached_token_count: int
cached_messages: list[dict] # snapshot of messages when cache was built
class LLMEngine: class LLMEngine:
"""Wraps Qwen3 for conversation generation.""" """Wraps Qwen3 for conversation generation with persistent KV-cache."""
SYSTEM_PROMPT = ( def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, system_prompt: str,
"You are a helpful voice assistant. Keep your responses concise and natural " max_cache_tokens: int = 4096):
"for spoken conversation. Respond in 1-3 short sentences. "
"Do not use markdown, bullet points, code blocks, emojis, or any "
"formatting that doesn't work in speech."
)
def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.system_prompt = system_prompt
self.max_cache_tokens = max_cache_tokens
self._generate_lock = threading.Lock()
def _build_inputs(self, messages: list[dict]): def _build_inputs(self, messages: list[dict]):
"""Build input token ids using the model's chat template.""" """Build input token ids using the model's chat template."""
chat_messages = [{"role": "system", "content": self.SYSTEM_PROMPT}] chat_messages = [{"role": "system", "content": self.system_prompt}]
for msg in messages: for msg in messages:
chat_messages.append({"role": msg["role"], "content": msg["content"]}) chat_messages.append({"role": msg["role"], "content": msg["content"]})
@@ -38,36 +44,93 @@ class LLMEngine:
) )
return self.tokenizer(text, return_tensors="pt").to(self.model.device) return self.tokenizer(text, return_tensors="pt").to(self.model.device)
def generate(self, messages: list[dict], max_new_tokens: int = 256) -> str: def _validate_cache(self, messages: list[dict], cache_state: KVCacheState | None) -> DynamicCache | None:
"""Generate a complete response (blocking).""" """Return past_key_values if the cache is valid for the given messages, else None."""
if cache_state is None or cache_state.past_key_values is None:
return None
if self.max_cache_tokens and cache_state.cached_token_count > self.max_cache_tokens:
log.info("KV-cache exceeds max size, discarding.")
return None
cached = cache_state.cached_messages
# The current messages must start with the cached messages as a prefix
if len(cached) > len(messages):
return None
for cached_msg, current_msg in zip(cached, messages):
if cached_msg["role"] != current_msg["role"] or cached_msg["content"] != current_msg["content"]:
return None
return cache_state.past_key_values
def generate(
self,
messages: list[dict],
max_new_tokens: int = 256,
cache_state: KVCacheState | None = None,
) -> tuple[str, KVCacheState]:
"""Generate a complete response (blocking). Returns (response, updated_cache_state)."""
with self._generate_lock:
inputs = self._build_inputs(messages) inputs = self._build_inputs(messages)
input_len = inputs["input_ids"].shape[1] input_ids = inputs["input_ids"]
input_len = input_ids.shape[1]
past_kv = self._validate_cache(messages, cache_state)
cached_len = cache_state.cached_token_count if past_kv is not None else 0
log.info(
f"KV-cache: {cached_len}/{input_len} tokens cached, "
f"processing {input_len - cached_len} new tokens"
)
with torch.no_grad(): with torch.no_grad():
output_ids = self.model.generate( outputs = self.model.generate(
**inputs, input_ids=input_ids,
attention_mask=inputs.get("attention_mask"),
past_key_values=past_kv,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
temperature=0.7, temperature=0.7,
top_p=0.9, top_p=0.9,
do_sample=True, do_sample=True,
repetition_penalty=1.2, repetition_penalty=1.2,
return_dict_in_generate=True,
use_cache=True,
) )
# Decode only the generated tokens (skip prompt) # Decode only the generated tokens (skip prompt)
new_ids = output_ids[0][input_len:] new_ids = outputs.sequences[0][input_len:]
response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip() response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip()
log.info(f"LLM response: {response}") log.info(f"LLM response: {response}")
return response
# Build updated cache state: messages now include the assistant response
new_messages = copy.deepcopy(messages) + [{"role": "assistant", "content": response}]
new_cache = KVCacheState(
past_key_values=outputs.past_key_values,
cached_token_count=outputs.sequences.shape[1],
cached_messages=new_messages,
)
return response, new_cache
def trim_cache(self, cache_state: KVCacheState, messages: list[dict]) -> KVCacheState | None:
"""Trim cache to match the actual conversation history (e.g. after barge-in)."""
if cache_state is None or cache_state.past_key_values is None:
return None
inputs = self._build_inputs(messages)
target_len = inputs["input_ids"].shape[1]
if target_len >= cache_state.cached_token_count:
return cache_state
cache_state.past_key_values.crop(target_len)
cache_state.cached_token_count = target_len
cache_state.cached_messages = copy.deepcopy(messages)
return cache_state
async def generate_sentences( async def generate_sentences(
self, self,
messages: list[dict], messages: list[dict],
cancel_event: threading.Event | None = None, cancel_event: threading.Event | None = None,
cache_state: KVCacheState | None = None,
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
"""Generate response and yield it sentence by sentence for TTS pipelining.""" """Generate response and yield it sentence by sentence for TTS pipelining."""
import asyncio import asyncio
response = await asyncio.to_thread(self.generate, messages) response = await asyncio.to_thread(self.generate, messages, 256, cache_state)
if cancel_event and cancel_event.is_set(): if cancel_event and cancel_event.is_set():
return return
@@ -83,25 +146,23 @@ class LLMEngine:
yield remainder yield remainder
SYSTEM_PROMPT = (
"You are a helpful voice assistant. Keep your responses concise and natural "
"for spoken conversation. Respond in 1-3 short sentences. "
"Do not use markdown, bullet points, code blocks, emojis, or any "
"formatting that doesn't work in speech."
)
class LMStudioEngine: class LMStudioEngine:
"""LLM engine that delegates to an LM Studio server via its OpenAI-compatible API.""" """LLM engine that delegates to an LM Studio server via its OpenAI-compatible API."""
def __init__(self, base_url: str, model: str): def __init__(self, base_url: str, model: str, system_prompt: str):
self.base_url = base_url.rstrip("/") self.base_url = base_url.rstrip("/")
self.model = model self.model = model
self.system_prompt = system_prompt
def generate(self, messages: list[dict], max_new_tokens: int = 256) -> str: def generate(
self,
messages: list[dict],
max_new_tokens: int = 256,
cache_state: KVCacheState | None = None,
) -> tuple[str, None]:
import requests import requests
payload_messages = [{"role": "system", "content": SYSTEM_PROMPT}] payload_messages = [{"role": "system", "content": self.system_prompt}]
payload_messages.extend(messages) payload_messages.extend(messages)
body: dict = { body: dict = {
@@ -121,7 +182,7 @@ class LMStudioEngine:
resp.raise_for_status() resp.raise_for_status()
response = resp.json()["choices"][0]["message"]["content"].strip() response = resp.json()["choices"][0]["message"]["content"].strip()
log.info(f"LM Studio response: {response}") log.info(f"LM Studio response: {response}")
return response return response, None
async def generate_sentences( async def generate_sentences(
self, self,
+7 -4
View File
@@ -68,16 +68,18 @@ class ModelManager:
def _load_llm(self): def _load_llm(self):
from server.config import config from server.config import config
backend = config.get("llm", {}).get("backend", "local") llm_config = config.get("llm", {})
backend = llm_config.get("backend", "local")
system_prompt = llm_config.get("system_prompt", "You are a helpful assistant.")
if backend == "lmstudio": if backend == "lmstudio":
from server.llm import LMStudioEngine from server.llm import LMStudioEngine
lms = config.get("llm", {}).get("lmstudio", {}) lms = llm_config.get("lmstudio", {})
url = lms.get("url", "http://host.docker.internal:1234") url = lms.get("url", "http://host.docker.internal:1234")
model = lms.get("model", "") or "" model = lms.get("model", "") or ""
log.info(f"Using LM Studio backend at {url} (model={model or 'server default'})") log.info(f"Using LM Studio backend at {url} (model={model or 'server default'})")
self.llm_engine = LMStudioEngine(url, model) self.llm_engine = LMStudioEngine(url, model, system_prompt)
else: else:
log.info("Loading Qwen3-4B (GPTQ 4-bit)...") log.info("Loading Qwen3-4B (GPTQ 4-bit)...")
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -90,7 +92,8 @@ class ModelManager:
model_name, model_name,
device_map=device, device_map=device,
) )
self.llm_engine = LLMEngine(model, tokenizer) max_cache_tokens = llm_config.get("max_cache_tokens", 4096)
self.llm_engine = LLMEngine(model, tokenizer, system_prompt, max_cache_tokens)
log.info("Qwen3-4B-GPTQ-Int4 loaded (~2.5GB VRAM).") log.info("Qwen3-4B-GPTQ-Int4 loaded (~2.5GB VRAM).")
def _load_tts(self): def _load_tts(self):
+11 -2
View File
@@ -6,6 +6,7 @@ import threading
import numpy as np import numpy as np
from server.audio_utils import float32_to_pcm_bytes from server.audio_utils import float32_to_pcm_bytes
from server.llm import KVCacheState
from server.models import ModelManager from server.models import ModelManager
from server.vad import StreamingVAD from server.vad import StreamingVAD
@@ -27,6 +28,7 @@ class ConversationSession:
self.vad: StreamingVAD = models.create_vad() self.vad: StreamingVAD = models.create_vad()
self.conversation_history: list[dict] = [] self.conversation_history: list[dict] = []
self.kv_cache_state: KVCacheState | None = None
self.cancel_event = threading.Event() self.cancel_event = threading.Event()
self.is_responding = False self.is_responding = False
self._response_task: asyncio.Task | None = None self._response_task: asyncio.Task | None = None
@@ -38,6 +40,7 @@ class ConversationSession:
self.cancel_event.set() self.cancel_event.set()
if self._response_task and not self._response_task.done(): if self._response_task and not self._response_task.done():
self._response_task.cancel() self._response_task.cancel()
self.kv_cache_state = None
async def handle_audio_chunk(self, chunk_16k: np.ndarray): async def handle_audio_chunk(self, chunk_16k: np.ndarray):
utterance = self.vad.process_chunk(chunk_16k) utterance = self.vad.process_chunk(chunk_16k)
@@ -91,8 +94,8 @@ class ConversationSession:
# LLM # LLM
log.info(f"Conversation history ({len(self.conversation_history)} messages): " log.info(f"Conversation history ({len(self.conversation_history)} messages): "
+ str([m['content'][:50] for m in self.conversation_history])) + str([m['content'][:50] for m in self.conversation_history]))
response = await asyncio.to_thread( response, self.kv_cache_state = await asyncio.to_thread(
self.models.llm_engine.generate, self.conversation_history self.models.llm_engine.generate, self.conversation_history, 256, self.kv_cache_state
) )
if self.cancel_event.is_set(): if self.cancel_event.is_set():
@@ -147,12 +150,18 @@ class ConversationSession:
tts_thread.join(timeout=2.0) tts_thread.join(timeout=2.0)
# Save only what was actually spoken # Save only what was actually spoken
was_interrupted = spoken_text.strip() != response.strip()
if spoken_text.strip(): if spoken_text.strip():
self.conversation_history.append( self.conversation_history.append(
{"role": "assistant", "content": spoken_text.strip()} {"role": "assistant", "content": spoken_text.strip()}
) )
if was_interrupted and self.kv_cache_state is not None:
self.kv_cache_state = self.models.llm_engine.trim_cache(
self.kv_cache_state, self.conversation_history
)
elif self.conversation_history and self.conversation_history[-1]["role"] == "user": elif self.conversation_history and self.conversation_history[-1]["role"] == "user":
self.conversation_history.pop() self.conversation_history.pop()
self.kv_cache_state = None
if not self.cancel_event.is_set(): if not self.cancel_event.is_set():
await self.send_json({"type": "response_text", "text": "", "final": True}) await self.send_json({"type": "response_text", "text": "", "final": True})