add KV-cache and move system promt to the config
This commit is contained in:
@@ -1,6 +1,13 @@
|
||||
# LLM backend: "local" or "lmstudio"
|
||||
llm:
|
||||
backend: local # change to "lmstudio" to use LM Studio instead
|
||||
max_cache_tokens: 4096 # max KV-cache size per session (tokens); 0 to disable caching
|
||||
|
||||
system_prompt: >-
|
||||
You are a helpful voice assistant. Keep your responses concise and natural
|
||||
for spoken conversation. Respond in 1-3 short sentences.
|
||||
Do not use markdown, bullet points, code blocks, emojis, or any
|
||||
formatting that doesn't work in speech.
|
||||
|
||||
# Settings used only when backend = "lmstudio"
|
||||
lmstudio:
|
||||
|
||||
+102
-41
@@ -1,32 +1,38 @@
|
||||
import copy
|
||||
import dataclasses
|
||||
import logging
|
||||
import threading
|
||||
from typing import AsyncIterator
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
|
||||
|
||||
from server.audio_utils import split_sentences
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class KVCacheState:
|
||||
"""Per-session KV-cache persisted across generate() calls."""
|
||||
past_key_values: DynamicCache | None
|
||||
cached_token_count: int
|
||||
cached_messages: list[dict] # snapshot of messages when cache was built
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""Wraps Qwen3 for conversation generation."""
|
||||
"""Wraps Qwen3 for conversation generation with persistent KV-cache."""
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a helpful voice assistant. Keep your responses concise and natural "
|
||||
"for spoken conversation. Respond in 1-3 short sentences. "
|
||||
"Do not use markdown, bullet points, code blocks, emojis, or any "
|
||||
"formatting that doesn't work in speech."
|
||||
)
|
||||
|
||||
def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
|
||||
def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, system_prompt: str,
|
||||
max_cache_tokens: int = 4096):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.system_prompt = system_prompt
|
||||
self.max_cache_tokens = max_cache_tokens
|
||||
self._generate_lock = threading.Lock()
|
||||
|
||||
def _build_inputs(self, messages: list[dict]):
|
||||
"""Build input token ids using the model's chat template."""
|
||||
chat_messages = [{"role": "system", "content": self.SYSTEM_PROMPT}]
|
||||
chat_messages = [{"role": "system", "content": self.system_prompt}]
|
||||
for msg in messages:
|
||||
chat_messages.append({"role": msg["role"], "content": msg["content"]})
|
||||
|
||||
@@ -38,36 +44,93 @@ class LLMEngine:
|
||||
)
|
||||
return self.tokenizer(text, return_tensors="pt").to(self.model.device)
|
||||
|
||||
def generate(self, messages: list[dict], max_new_tokens: int = 256) -> str:
|
||||
"""Generate a complete response (blocking)."""
|
||||
inputs = self._build_inputs(messages)
|
||||
input_len = inputs["input_ids"].shape[1]
|
||||
def _validate_cache(self, messages: list[dict], cache_state: KVCacheState | None) -> DynamicCache | None:
|
||||
"""Return past_key_values if the cache is valid for the given messages, else None."""
|
||||
if cache_state is None or cache_state.past_key_values is None:
|
||||
return None
|
||||
if self.max_cache_tokens and cache_state.cached_token_count > self.max_cache_tokens:
|
||||
log.info("KV-cache exceeds max size, discarding.")
|
||||
return None
|
||||
cached = cache_state.cached_messages
|
||||
# The current messages must start with the cached messages as a prefix
|
||||
if len(cached) > len(messages):
|
||||
return None
|
||||
for cached_msg, current_msg in zip(cached, messages):
|
||||
if cached_msg["role"] != current_msg["role"] or cached_msg["content"] != current_msg["content"]:
|
||||
return None
|
||||
return cache_state.past_key_values
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
do_sample=True,
|
||||
repetition_penalty=1.2,
|
||||
def generate(
|
||||
self,
|
||||
messages: list[dict],
|
||||
max_new_tokens: int = 256,
|
||||
cache_state: KVCacheState | None = None,
|
||||
) -> tuple[str, KVCacheState]:
|
||||
"""Generate a complete response (blocking). Returns (response, updated_cache_state)."""
|
||||
with self._generate_lock:
|
||||
inputs = self._build_inputs(messages)
|
||||
input_ids = inputs["input_ids"]
|
||||
input_len = input_ids.shape[1]
|
||||
|
||||
past_kv = self._validate_cache(messages, cache_state)
|
||||
cached_len = cache_state.cached_token_count if past_kv is not None else 0
|
||||
|
||||
log.info(
|
||||
f"KV-cache: {cached_len}/{input_len} tokens cached, "
|
||||
f"processing {input_len - cached_len} new tokens"
|
||||
)
|
||||
|
||||
# Decode only the generated tokens (skip prompt)
|
||||
new_ids = output_ids[0][input_len:]
|
||||
response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip()
|
||||
log.info(f"LLM response: {response}")
|
||||
return response
|
||||
with torch.no_grad():
|
||||
outputs = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=inputs.get("attention_mask"),
|
||||
past_key_values=past_kv,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
do_sample=True,
|
||||
repetition_penalty=1.2,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
# Decode only the generated tokens (skip prompt)
|
||||
new_ids = outputs.sequences[0][input_len:]
|
||||
response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip()
|
||||
log.info(f"LLM response: {response}")
|
||||
|
||||
# Build updated cache state: messages now include the assistant response
|
||||
new_messages = copy.deepcopy(messages) + [{"role": "assistant", "content": response}]
|
||||
new_cache = KVCacheState(
|
||||
past_key_values=outputs.past_key_values,
|
||||
cached_token_count=outputs.sequences.shape[1],
|
||||
cached_messages=new_messages,
|
||||
)
|
||||
return response, new_cache
|
||||
|
||||
def trim_cache(self, cache_state: KVCacheState, messages: list[dict]) -> KVCacheState | None:
|
||||
"""Trim cache to match the actual conversation history (e.g. after barge-in)."""
|
||||
if cache_state is None or cache_state.past_key_values is None:
|
||||
return None
|
||||
inputs = self._build_inputs(messages)
|
||||
target_len = inputs["input_ids"].shape[1]
|
||||
if target_len >= cache_state.cached_token_count:
|
||||
return cache_state
|
||||
cache_state.past_key_values.crop(target_len)
|
||||
cache_state.cached_token_count = target_len
|
||||
cache_state.cached_messages = copy.deepcopy(messages)
|
||||
return cache_state
|
||||
|
||||
async def generate_sentences(
|
||||
self,
|
||||
messages: list[dict],
|
||||
cancel_event: threading.Event | None = None,
|
||||
cache_state: KVCacheState | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Generate response and yield it sentence by sentence for TTS pipelining."""
|
||||
import asyncio
|
||||
|
||||
response = await asyncio.to_thread(self.generate, messages)
|
||||
response = await asyncio.to_thread(self.generate, messages, 256, cache_state)
|
||||
|
||||
if cancel_event and cancel_event.is_set():
|
||||
return
|
||||
@@ -83,25 +146,23 @@ class LLMEngine:
|
||||
yield remainder
|
||||
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a helpful voice assistant. Keep your responses concise and natural "
|
||||
"for spoken conversation. Respond in 1-3 short sentences. "
|
||||
"Do not use markdown, bullet points, code blocks, emojis, or any "
|
||||
"formatting that doesn't work in speech."
|
||||
)
|
||||
|
||||
|
||||
class LMStudioEngine:
|
||||
"""LLM engine that delegates to an LM Studio server via its OpenAI-compatible API."""
|
||||
|
||||
def __init__(self, base_url: str, model: str):
|
||||
def __init__(self, base_url: str, model: str, system_prompt: str):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.model = model
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
def generate(self, messages: list[dict], max_new_tokens: int = 256) -> str:
|
||||
def generate(
|
||||
self,
|
||||
messages: list[dict],
|
||||
max_new_tokens: int = 256,
|
||||
cache_state: KVCacheState | None = None,
|
||||
) -> tuple[str, None]:
|
||||
import requests
|
||||
|
||||
payload_messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||
payload_messages = [{"role": "system", "content": self.system_prompt}]
|
||||
payload_messages.extend(messages)
|
||||
|
||||
body: dict = {
|
||||
@@ -121,7 +182,7 @@ class LMStudioEngine:
|
||||
resp.raise_for_status()
|
||||
response = resp.json()["choices"][0]["message"]["content"].strip()
|
||||
log.info(f"LM Studio response: {response}")
|
||||
return response
|
||||
return response, None
|
||||
|
||||
async def generate_sentences(
|
||||
self,
|
||||
|
||||
+7
-4
@@ -68,16 +68,18 @@ class ModelManager:
|
||||
def _load_llm(self):
|
||||
from server.config import config
|
||||
|
||||
backend = config.get("llm", {}).get("backend", "local")
|
||||
llm_config = config.get("llm", {})
|
||||
backend = llm_config.get("backend", "local")
|
||||
system_prompt = llm_config.get("system_prompt", "You are a helpful assistant.")
|
||||
|
||||
if backend == "lmstudio":
|
||||
from server.llm import LMStudioEngine
|
||||
|
||||
lms = config.get("llm", {}).get("lmstudio", {})
|
||||
lms = llm_config.get("lmstudio", {})
|
||||
url = lms.get("url", "http://host.docker.internal:1234")
|
||||
model = lms.get("model", "") or ""
|
||||
log.info(f"Using LM Studio backend at {url} (model={model or 'server default'})")
|
||||
self.llm_engine = LMStudioEngine(url, model)
|
||||
self.llm_engine = LMStudioEngine(url, model, system_prompt)
|
||||
else:
|
||||
log.info("Loading Qwen3-4B (GPTQ 4-bit)...")
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
@@ -90,7 +92,8 @@ class ModelManager:
|
||||
model_name,
|
||||
device_map=device,
|
||||
)
|
||||
self.llm_engine = LLMEngine(model, tokenizer)
|
||||
max_cache_tokens = llm_config.get("max_cache_tokens", 4096)
|
||||
self.llm_engine = LLMEngine(model, tokenizer, system_prompt, max_cache_tokens)
|
||||
log.info("Qwen3-4B-GPTQ-Int4 loaded (~2.5GB VRAM).")
|
||||
|
||||
def _load_tts(self):
|
||||
|
||||
+11
-2
@@ -6,6 +6,7 @@ import threading
|
||||
import numpy as np
|
||||
|
||||
from server.audio_utils import float32_to_pcm_bytes
|
||||
from server.llm import KVCacheState
|
||||
from server.models import ModelManager
|
||||
from server.vad import StreamingVAD
|
||||
|
||||
@@ -27,6 +28,7 @@ class ConversationSession:
|
||||
|
||||
self.vad: StreamingVAD = models.create_vad()
|
||||
self.conversation_history: list[dict] = []
|
||||
self.kv_cache_state: KVCacheState | None = None
|
||||
self.cancel_event = threading.Event()
|
||||
self.is_responding = False
|
||||
self._response_task: asyncio.Task | None = None
|
||||
@@ -38,6 +40,7 @@ class ConversationSession:
|
||||
self.cancel_event.set()
|
||||
if self._response_task and not self._response_task.done():
|
||||
self._response_task.cancel()
|
||||
self.kv_cache_state = None
|
||||
|
||||
async def handle_audio_chunk(self, chunk_16k: np.ndarray):
|
||||
utterance = self.vad.process_chunk(chunk_16k)
|
||||
@@ -91,8 +94,8 @@ class ConversationSession:
|
||||
# LLM
|
||||
log.info(f"Conversation history ({len(self.conversation_history)} messages): "
|
||||
+ str([m['content'][:50] for m in self.conversation_history]))
|
||||
response = await asyncio.to_thread(
|
||||
self.models.llm_engine.generate, self.conversation_history
|
||||
response, self.kv_cache_state = await asyncio.to_thread(
|
||||
self.models.llm_engine.generate, self.conversation_history, 256, self.kv_cache_state
|
||||
)
|
||||
|
||||
if self.cancel_event.is_set():
|
||||
@@ -147,12 +150,18 @@ class ConversationSession:
|
||||
tts_thread.join(timeout=2.0)
|
||||
|
||||
# Save only what was actually spoken
|
||||
was_interrupted = spoken_text.strip() != response.strip()
|
||||
if spoken_text.strip():
|
||||
self.conversation_history.append(
|
||||
{"role": "assistant", "content": spoken_text.strip()}
|
||||
)
|
||||
if was_interrupted and self.kv_cache_state is not None:
|
||||
self.kv_cache_state = self.models.llm_engine.trim_cache(
|
||||
self.kv_cache_state, self.conversation_history
|
||||
)
|
||||
elif self.conversation_history and self.conversation_history[-1]["role"] == "user":
|
||||
self.conversation_history.pop()
|
||||
self.kv_cache_state = None
|
||||
|
||||
if not self.cancel_event.is_set():
|
||||
await self.send_json({"type": "response_text", "text": "", "final": True})
|
||||
|
||||
Reference in New Issue
Block a user