44a10667c2
- Added environment variables to prevent CPU thread pools from busy-waiting. - Deferred loading of video models until first use to reduce VRAM footprint. - Implemented streaming of speaking clips for improved responsiveness. - Introduced a queue for managing speaking clips to handle multiple requests smoothly. - Updated video playback logic to ensure proper handling of clip generation.
234 lines
8.7 KiB
Python
234 lines
8.7 KiB
Python
import copy
|
|
import dataclasses
|
|
import logging
|
|
import threading
|
|
from typing import AsyncIterator
|
|
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
|
|
|
|
from server.audio_utils import split_sentences
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
@dataclasses.dataclass
|
|
class KVCacheState:
|
|
"""Per-session KV-cache persisted across generate() calls."""
|
|
past_key_values: DynamicCache | None
|
|
cached_token_count: int
|
|
cached_messages: list[dict] # snapshot of messages when cache was built
|
|
|
|
|
|
class LLMEngine:
|
|
"""Wraps Qwen3 for conversation generation with persistent KV-cache."""
|
|
|
|
def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, system_prompt: str,
|
|
max_cache_tokens: int = 4096):
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
self.system_prompt = system_prompt
|
|
self.max_cache_tokens = max_cache_tokens
|
|
self._generate_lock = threading.Lock()
|
|
|
|
def _build_inputs(self, messages: list[dict]):
|
|
"""Build input token ids using the model's chat template."""
|
|
chat_messages = [{"role": "system", "content": self.system_prompt}]
|
|
for msg in messages:
|
|
chat_messages.append({"role": msg["role"], "content": msg["content"]})
|
|
|
|
text = self.tokenizer.apply_chat_template(
|
|
chat_messages,
|
|
tokenize=False,
|
|
add_generation_prompt=True,
|
|
enable_thinking=False,
|
|
)
|
|
return self.tokenizer(text, return_tensors="pt").to(self.model.device)
|
|
|
|
def _validate_cache(self, messages: list[dict], cache_state: KVCacheState | None) -> DynamicCache | None:
|
|
"""Return past_key_values if the cache is valid for the given messages, else None."""
|
|
if cache_state is None or cache_state.past_key_values is None:
|
|
return None
|
|
if self.max_cache_tokens and cache_state.cached_token_count > self.max_cache_tokens:
|
|
log.info("KV-cache exceeds max size, discarding.")
|
|
return None
|
|
cached = cache_state.cached_messages
|
|
# The current messages must start with the cached messages as a prefix
|
|
if len(cached) > len(messages):
|
|
return None
|
|
for cached_msg, current_msg in zip(cached, messages):
|
|
if cached_msg["role"] != current_msg["role"] or cached_msg["content"] != current_msg["content"]:
|
|
return None
|
|
return cache_state.past_key_values
|
|
|
|
def generate(
|
|
self,
|
|
messages: list[dict],
|
|
max_new_tokens: int = 256,
|
|
cache_state: KVCacheState | None = None,
|
|
) -> tuple[str, KVCacheState]:
|
|
"""Generate a complete response (blocking). Returns (response, updated_cache_state)."""
|
|
with self._generate_lock:
|
|
inputs = self._build_inputs(messages)
|
|
input_ids = inputs["input_ids"]
|
|
input_len = input_ids.shape[1]
|
|
|
|
past_kv = self._validate_cache(messages, cache_state)
|
|
cached_len = cache_state.cached_token_count if past_kv is not None else 0
|
|
|
|
log.info(
|
|
f"KV-cache: {cached_len}/{input_len} tokens cached, "
|
|
f"processing {input_len - cached_len} new tokens"
|
|
)
|
|
|
|
# Guard: if the cache claims to have seen >= input tokens, it's
|
|
# stale (can happen after barge-in races or tokenizer mismatches).
|
|
# An invalid cache causes an empty cache_position in transformers,
|
|
# which raises IndexError inside model.generate().
|
|
if past_kv is not None:
|
|
cache_seq_len = (
|
|
past_kv.get_seq_length()
|
|
if hasattr(past_kv, "get_seq_length")
|
|
else cached_len
|
|
)
|
|
if cache_seq_len >= input_len:
|
|
log.warning(
|
|
f"KV-cache stale (cache_seq={cache_seq_len} >= input={input_len}), discarding."
|
|
)
|
|
past_kv = None
|
|
cached_len = 0
|
|
|
|
def _do_generate(pkv):
|
|
return self.model.generate(
|
|
input_ids=input_ids,
|
|
attention_mask=inputs.get("attention_mask"),
|
|
past_key_values=pkv,
|
|
max_new_tokens=max_new_tokens,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
do_sample=True,
|
|
repetition_penalty=1.2,
|
|
return_dict_in_generate=True,
|
|
use_cache=True,
|
|
)
|
|
|
|
with torch.no_grad():
|
|
try:
|
|
outputs = _do_generate(past_kv)
|
|
except IndexError:
|
|
log.warning("KV-cache caused IndexError during generate; retrying without cache.")
|
|
past_kv = None
|
|
cached_len = 0
|
|
outputs = _do_generate(None)
|
|
|
|
# Decode only the generated tokens (skip prompt)
|
|
new_ids = outputs.sequences[0][input_len:]
|
|
response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip()
|
|
log.info(f"LLM response: {response}")
|
|
|
|
# Build updated cache state: messages now include the assistant response
|
|
new_messages = copy.deepcopy(messages) + [{"role": "assistant", "content": response}]
|
|
new_cache = KVCacheState(
|
|
past_key_values=outputs.past_key_values,
|
|
cached_token_count=outputs.sequences.shape[1],
|
|
cached_messages=new_messages,
|
|
)
|
|
return response, new_cache
|
|
|
|
def trim_cache(self, cache_state: KVCacheState, messages: list[dict]) -> KVCacheState | None:
|
|
"""Trim cache to match the actual conversation history (e.g. after barge-in)."""
|
|
if cache_state is None or cache_state.past_key_values is None:
|
|
return None
|
|
inputs = self._build_inputs(messages)
|
|
target_len = inputs["input_ids"].shape[1]
|
|
if target_len >= cache_state.cached_token_count:
|
|
return cache_state
|
|
cache_state.past_key_values.crop(target_len)
|
|
cache_state.cached_token_count = target_len
|
|
cache_state.cached_messages = copy.deepcopy(messages)
|
|
return cache_state
|
|
|
|
async def generate_sentences(
|
|
self,
|
|
messages: list[dict],
|
|
cancel_event: threading.Event | None = None,
|
|
cache_state: KVCacheState | None = None,
|
|
) -> AsyncIterator[str]:
|
|
"""Generate response and yield it sentence by sentence for TTS pipelining."""
|
|
import asyncio
|
|
|
|
response = await asyncio.to_thread(self.generate, messages, 256, cache_state)
|
|
|
|
if cancel_event and cancel_event.is_set():
|
|
return
|
|
|
|
# Split into sentences and yield each
|
|
sentences, remainder = split_sentences(response)
|
|
for sentence in sentences:
|
|
if cancel_event and cancel_event.is_set():
|
|
return
|
|
yield sentence
|
|
|
|
if remainder:
|
|
yield remainder
|
|
|
|
|
|
class LMStudioEngine:
|
|
"""LLM engine that delegates to an LM Studio server via its OpenAI-compatible API."""
|
|
|
|
def __init__(self, base_url: str, model: str, system_prompt: str):
|
|
self.base_url = base_url.rstrip("/")
|
|
self.model = model
|
|
self.system_prompt = system_prompt
|
|
|
|
def generate(
|
|
self,
|
|
messages: list[dict],
|
|
max_new_tokens: int = 256,
|
|
cache_state: KVCacheState | None = None,
|
|
) -> tuple[str, None]:
|
|
import requests
|
|
|
|
payload_messages = [{"role": "system", "content": self.system_prompt}]
|
|
payload_messages.extend(messages)
|
|
|
|
body: dict = {
|
|
"messages": payload_messages,
|
|
"max_tokens": max_new_tokens,
|
|
"temperature": 0.7,
|
|
"stream": False,
|
|
}
|
|
if self.model:
|
|
body["model"] = self.model
|
|
|
|
resp = requests.post(
|
|
f"{self.base_url}/v1/chat/completions",
|
|
json=body,
|
|
timeout=30,
|
|
)
|
|
resp.raise_for_status()
|
|
response = resp.json()["choices"][0]["message"]["content"].strip()
|
|
log.info(f"LM Studio response: {response}")
|
|
return response, None
|
|
|
|
async def generate_sentences(
|
|
self,
|
|
messages: list[dict],
|
|
cancel_event: threading.Event | None = None,
|
|
) -> AsyncIterator[str]:
|
|
"""Generate response and yield it sentence by sentence for TTS pipelining."""
|
|
import asyncio
|
|
|
|
response = await asyncio.to_thread(self.generate, messages)
|
|
|
|
if cancel_event and cancel_event.is_set():
|
|
return
|
|
|
|
sentences, remainder = split_sentences(response)
|
|
for sentence in sentences:
|
|
if cancel_event and cancel_event.is_set():
|
|
return
|
|
yield sentence
|
|
|
|
if remainder:
|
|
yield remainder
|