add KV-cache and move system promt to the config

This commit is contained in:
2026-04-08 10:25:03 -04:00
parent c7c4019ecc
commit 175ed943df
4 changed files with 127 additions and 47 deletions
+102 -41
View File
@@ -1,32 +1,38 @@
import copy
import dataclasses
import logging
import threading
from typing import AsyncIterator
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
from server.audio_utils import split_sentences
log = logging.getLogger(__name__)
@dataclasses.dataclass
class KVCacheState:
"""Per-session KV-cache persisted across generate() calls."""
past_key_values: DynamicCache | None
cached_token_count: int
cached_messages: list[dict] # snapshot of messages when cache was built
class LLMEngine:
"""Wraps Qwen3 for conversation generation."""
"""Wraps Qwen3 for conversation generation with persistent KV-cache."""
SYSTEM_PROMPT = (
"You are a helpful voice assistant. Keep your responses concise and natural "
"for spoken conversation. Respond in 1-3 short sentences. "
"Do not use markdown, bullet points, code blocks, emojis, or any "
"formatting that doesn't work in speech."
)
def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, system_prompt: str,
max_cache_tokens: int = 4096):
self.model = model
self.tokenizer = tokenizer
self.system_prompt = system_prompt
self.max_cache_tokens = max_cache_tokens
self._generate_lock = threading.Lock()
def _build_inputs(self, messages: list[dict]):
"""Build input token ids using the model's chat template."""
chat_messages = [{"role": "system", "content": self.SYSTEM_PROMPT}]
chat_messages = [{"role": "system", "content": self.system_prompt}]
for msg in messages:
chat_messages.append({"role": msg["role"], "content": msg["content"]})
@@ -38,36 +44,93 @@ class LLMEngine:
)
return self.tokenizer(text, return_tensors="pt").to(self.model.device)
def generate(self, messages: list[dict], max_new_tokens: int = 256) -> str:
"""Generate a complete response (blocking)."""
inputs = self._build_inputs(messages)
input_len = inputs["input_ids"].shape[1]
def _validate_cache(self, messages: list[dict], cache_state: KVCacheState | None) -> DynamicCache | None:
"""Return past_key_values if the cache is valid for the given messages, else None."""
if cache_state is None or cache_state.past_key_values is None:
return None
if self.max_cache_tokens and cache_state.cached_token_count > self.max_cache_tokens:
log.info("KV-cache exceeds max size, discarding.")
return None
cached = cache_state.cached_messages
# The current messages must start with the cached messages as a prefix
if len(cached) > len(messages):
return None
for cached_msg, current_msg in zip(cached, messages):
if cached_msg["role"] != current_msg["role"] or cached_msg["content"] != current_msg["content"]:
return None
return cache_state.past_key_values
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=0.7,
top_p=0.9,
do_sample=True,
repetition_penalty=1.2,
def generate(
self,
messages: list[dict],
max_new_tokens: int = 256,
cache_state: KVCacheState | None = None,
) -> tuple[str, KVCacheState]:
"""Generate a complete response (blocking). Returns (response, updated_cache_state)."""
with self._generate_lock:
inputs = self._build_inputs(messages)
input_ids = inputs["input_ids"]
input_len = input_ids.shape[1]
past_kv = self._validate_cache(messages, cache_state)
cached_len = cache_state.cached_token_count if past_kv is not None else 0
log.info(
f"KV-cache: {cached_len}/{input_len} tokens cached, "
f"processing {input_len - cached_len} new tokens"
)
# Decode only the generated tokens (skip prompt)
new_ids = output_ids[0][input_len:]
response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip()
log.info(f"LLM response: {response}")
return response
with torch.no_grad():
outputs = self.model.generate(
input_ids=input_ids,
attention_mask=inputs.get("attention_mask"),
past_key_values=past_kv,
max_new_tokens=max_new_tokens,
temperature=0.7,
top_p=0.9,
do_sample=True,
repetition_penalty=1.2,
return_dict_in_generate=True,
use_cache=True,
)
# Decode only the generated tokens (skip prompt)
new_ids = outputs.sequences[0][input_len:]
response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip()
log.info(f"LLM response: {response}")
# Build updated cache state: messages now include the assistant response
new_messages = copy.deepcopy(messages) + [{"role": "assistant", "content": response}]
new_cache = KVCacheState(
past_key_values=outputs.past_key_values,
cached_token_count=outputs.sequences.shape[1],
cached_messages=new_messages,
)
return response, new_cache
def trim_cache(self, cache_state: KVCacheState, messages: list[dict]) -> KVCacheState | None:
"""Trim cache to match the actual conversation history (e.g. after barge-in)."""
if cache_state is None or cache_state.past_key_values is None:
return None
inputs = self._build_inputs(messages)
target_len = inputs["input_ids"].shape[1]
if target_len >= cache_state.cached_token_count:
return cache_state
cache_state.past_key_values.crop(target_len)
cache_state.cached_token_count = target_len
cache_state.cached_messages = copy.deepcopy(messages)
return cache_state
async def generate_sentences(
self,
messages: list[dict],
cancel_event: threading.Event | None = None,
cache_state: KVCacheState | None = None,
) -> AsyncIterator[str]:
"""Generate response and yield it sentence by sentence for TTS pipelining."""
import asyncio
response = await asyncio.to_thread(self.generate, messages)
response = await asyncio.to_thread(self.generate, messages, 256, cache_state)
if cancel_event and cancel_event.is_set():
return
@@ -83,25 +146,23 @@ class LLMEngine:
yield remainder
SYSTEM_PROMPT = (
"You are a helpful voice assistant. Keep your responses concise and natural "
"for spoken conversation. Respond in 1-3 short sentences. "
"Do not use markdown, bullet points, code blocks, emojis, or any "
"formatting that doesn't work in speech."
)
class LMStudioEngine:
"""LLM engine that delegates to an LM Studio server via its OpenAI-compatible API."""
def __init__(self, base_url: str, model: str):
def __init__(self, base_url: str, model: str, system_prompt: str):
self.base_url = base_url.rstrip("/")
self.model = model
self.system_prompt = system_prompt
def generate(self, messages: list[dict], max_new_tokens: int = 256) -> str:
def generate(
self,
messages: list[dict],
max_new_tokens: int = 256,
cache_state: KVCacheState | None = None,
) -> tuple[str, None]:
import requests
payload_messages = [{"role": "system", "content": SYSTEM_PROMPT}]
payload_messages = [{"role": "system", "content": self.system_prompt}]
payload_messages.extend(messages)
body: dict = {
@@ -121,7 +182,7 @@ class LMStudioEngine:
resp.raise_for_status()
response = resp.json()["choices"][0]["message"]["content"].strip()
log.info(f"LM Studio response: {response}")
return response
return response, None
async def generate_sentences(
self,