Compare commits
4 Commits
08c5757b31
...
680c5b04cc
| Author | SHA1 | Date | |
|---|---|---|---|
| 680c5b04cc | |||
| d509f92a9d | |||
| 175ed943df | |||
| c7c4019ecc |
@@ -35,6 +35,9 @@ RUN python3.11 -m pip install --no-cache-dir \
|
|||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
RUN python3.11 -m pip install --no-cache-dir -r requirements.txt
|
RUN python3.11 -m pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# Pre-download the spacy model that kokoro needs at runtime
|
||||||
|
RUN python3.11 -m spacy download en_core_web_sm
|
||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|||||||
+14
@@ -0,0 +1,14 @@
|
|||||||
|
# LLM backend: "local" or "lmstudio"
|
||||||
|
llm:
|
||||||
|
backend: local # change to "lmstudio" to use LM Studio instead
|
||||||
|
max_cache_tokens: 4096 # max KV-cache size per session (tokens); 0 to disable caching
|
||||||
|
|
||||||
|
system_prompt: >-
|
||||||
|
You are a helpful voice assistant.
|
||||||
|
Keep your responses extremely concise but natural for spoken conversation.
|
||||||
|
Do not use markdown, bullet points, code blocks, emojis, or any formatting that doesn't work in speech.
|
||||||
|
|
||||||
|
# Settings used only when backend = "lmstudio"
|
||||||
|
lmstudio:
|
||||||
|
url: http://host.docker.internal:1234 # host.docker.internal resolves to your PC from inside Docker
|
||||||
|
model: "" # leave empty to use whatever model LM Studio has loaded
|
||||||
@@ -6,6 +6,11 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
# Cache models on the host so they survive container rebuilds
|
# Cache models on the host so they survive container rebuilds
|
||||||
- huggingface-cache:/cache/huggingface
|
- huggingface-cache:/cache/huggingface
|
||||||
|
# Mount source so you can edit code/config without rebuilding the image
|
||||||
|
- ./config.yml:/app/config.yml:ro
|
||||||
|
- ./server:/app/server:ro
|
||||||
|
- ./static:/app/static:ro
|
||||||
|
- ./run.py:/app/run.py:ro
|
||||||
deploy:
|
deploy:
|
||||||
resources:
|
resources:
|
||||||
reservations:
|
reservations:
|
||||||
|
|||||||
@@ -13,3 +13,4 @@ numpy
|
|||||||
soundfile
|
soundfile
|
||||||
scipy
|
scipy
|
||||||
python-multipart
|
python-multipart
|
||||||
|
pyyaml
|
||||||
|
|||||||
@@ -0,0 +1,12 @@
|
|||||||
|
import pathlib
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
_CONFIG_PATH = pathlib.Path(__file__).parent.parent / "config.yml"
|
||||||
|
|
||||||
|
|
||||||
|
def load_config() -> dict:
|
||||||
|
with open(_CONFIG_PATH) as f:
|
||||||
|
return yaml.safe_load(f)
|
||||||
|
|
||||||
|
|
||||||
|
config = load_config()
|
||||||
+143
-19
@@ -1,32 +1,38 @@
|
|||||||
|
import copy
|
||||||
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import AsyncIterator
|
from typing import AsyncIterator
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
|
||||||
|
|
||||||
from server.audio_utils import split_sentences
|
from server.audio_utils import split_sentences
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class KVCacheState:
|
||||||
|
"""Per-session KV-cache persisted across generate() calls."""
|
||||||
|
past_key_values: DynamicCache | None
|
||||||
|
cached_token_count: int
|
||||||
|
cached_messages: list[dict] # snapshot of messages when cache was built
|
||||||
|
|
||||||
|
|
||||||
class LLMEngine:
|
class LLMEngine:
|
||||||
"""Wraps Qwen3 for conversation generation."""
|
"""Wraps Qwen3 for conversation generation with persistent KV-cache."""
|
||||||
|
|
||||||
SYSTEM_PROMPT = (
|
def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, system_prompt: str,
|
||||||
"You are a helpful voice assistant. Keep your responses concise and natural "
|
max_cache_tokens: int = 4096):
|
||||||
"for spoken conversation. Respond in 1-3 short sentences. "
|
|
||||||
"Do not use markdown, bullet points, code blocks, emojis, or any "
|
|
||||||
"formatting that doesn't work in speech."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
self.max_cache_tokens = max_cache_tokens
|
||||||
|
self._generate_lock = threading.Lock()
|
||||||
|
|
||||||
def _build_inputs(self, messages: list[dict]):
|
def _build_inputs(self, messages: list[dict]):
|
||||||
"""Build input token ids using the model's chat template."""
|
"""Build input token ids using the model's chat template."""
|
||||||
chat_messages = [{"role": "system", "content": self.SYSTEM_PROMPT}]
|
chat_messages = [{"role": "system", "content": self.system_prompt}]
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
chat_messages.append({"role": msg["role"], "content": msg["content"]})
|
chat_messages.append({"role": msg["role"], "content": msg["content"]})
|
||||||
|
|
||||||
@@ -38,26 +44,145 @@ class LLMEngine:
|
|||||||
)
|
)
|
||||||
return self.tokenizer(text, return_tensors="pt").to(self.model.device)
|
return self.tokenizer(text, return_tensors="pt").to(self.model.device)
|
||||||
|
|
||||||
def generate(self, messages: list[dict], max_new_tokens: int = 256) -> str:
|
def _validate_cache(self, messages: list[dict], cache_state: KVCacheState | None) -> DynamicCache | None:
|
||||||
"""Generate a complete response (blocking)."""
|
"""Return past_key_values if the cache is valid for the given messages, else None."""
|
||||||
|
if cache_state is None or cache_state.past_key_values is None:
|
||||||
|
return None
|
||||||
|
if self.max_cache_tokens and cache_state.cached_token_count > self.max_cache_tokens:
|
||||||
|
log.info("KV-cache exceeds max size, discarding.")
|
||||||
|
return None
|
||||||
|
cached = cache_state.cached_messages
|
||||||
|
# The current messages must start with the cached messages as a prefix
|
||||||
|
if len(cached) > len(messages):
|
||||||
|
return None
|
||||||
|
for cached_msg, current_msg in zip(cached, messages):
|
||||||
|
if cached_msg["role"] != current_msg["role"] or cached_msg["content"] != current_msg["content"]:
|
||||||
|
return None
|
||||||
|
return cache_state.past_key_values
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
max_new_tokens: int = 256,
|
||||||
|
cache_state: KVCacheState | None = None,
|
||||||
|
) -> tuple[str, KVCacheState]:
|
||||||
|
"""Generate a complete response (blocking). Returns (response, updated_cache_state)."""
|
||||||
|
with self._generate_lock:
|
||||||
inputs = self._build_inputs(messages)
|
inputs = self._build_inputs(messages)
|
||||||
input_len = inputs["input_ids"].shape[1]
|
input_ids = inputs["input_ids"]
|
||||||
|
input_len = input_ids.shape[1]
|
||||||
|
|
||||||
|
past_kv = self._validate_cache(messages, cache_state)
|
||||||
|
cached_len = cache_state.cached_token_count if past_kv is not None else 0
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
f"KV-cache: {cached_len}/{input_len} tokens cached, "
|
||||||
|
f"processing {input_len - cached_len} new tokens"
|
||||||
|
)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output_ids = self.model.generate(
|
outputs = self.model.generate(
|
||||||
**inputs,
|
input_ids=input_ids,
|
||||||
|
attention_mask=inputs.get("attention_mask"),
|
||||||
|
past_key_values=past_kv,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
repetition_penalty=1.2,
|
repetition_penalty=1.2,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Decode only the generated tokens (skip prompt)
|
# Decode only the generated tokens (skip prompt)
|
||||||
new_ids = output_ids[0][input_len:]
|
new_ids = outputs.sequences[0][input_len:]
|
||||||
response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip()
|
response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip()
|
||||||
log.info(f"LLM response: {response}")
|
log.info(f"LLM response: {response}")
|
||||||
return response
|
|
||||||
|
# Build updated cache state: messages now include the assistant response
|
||||||
|
new_messages = copy.deepcopy(messages) + [{"role": "assistant", "content": response}]
|
||||||
|
new_cache = KVCacheState(
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
cached_token_count=outputs.sequences.shape[1],
|
||||||
|
cached_messages=new_messages,
|
||||||
|
)
|
||||||
|
return response, new_cache
|
||||||
|
|
||||||
|
def trim_cache(self, cache_state: KVCacheState, messages: list[dict]) -> KVCacheState | None:
|
||||||
|
"""Trim cache to match the actual conversation history (e.g. after barge-in)."""
|
||||||
|
if cache_state is None or cache_state.past_key_values is None:
|
||||||
|
return None
|
||||||
|
inputs = self._build_inputs(messages)
|
||||||
|
target_len = inputs["input_ids"].shape[1]
|
||||||
|
if target_len >= cache_state.cached_token_count:
|
||||||
|
return cache_state
|
||||||
|
cache_state.past_key_values.crop(target_len)
|
||||||
|
cache_state.cached_token_count = target_len
|
||||||
|
cache_state.cached_messages = copy.deepcopy(messages)
|
||||||
|
return cache_state
|
||||||
|
|
||||||
|
async def generate_sentences(
|
||||||
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
cancel_event: threading.Event | None = None,
|
||||||
|
cache_state: KVCacheState | None = None,
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
"""Generate response and yield it sentence by sentence for TTS pipelining."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
response = await asyncio.to_thread(self.generate, messages, 256, cache_state)
|
||||||
|
|
||||||
|
if cancel_event and cancel_event.is_set():
|
||||||
|
return
|
||||||
|
|
||||||
|
# Split into sentences and yield each
|
||||||
|
sentences, remainder = split_sentences(response)
|
||||||
|
for sentence in sentences:
|
||||||
|
if cancel_event and cancel_event.is_set():
|
||||||
|
return
|
||||||
|
yield sentence
|
||||||
|
|
||||||
|
if remainder:
|
||||||
|
yield remainder
|
||||||
|
|
||||||
|
|
||||||
|
class LMStudioEngine:
|
||||||
|
"""LLM engine that delegates to an LM Studio server via its OpenAI-compatible API."""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str, model: str, system_prompt: str):
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.model = model
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
max_new_tokens: int = 256,
|
||||||
|
cache_state: KVCacheState | None = None,
|
||||||
|
) -> tuple[str, None]:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
payload_messages = [{"role": "system", "content": self.system_prompt}]
|
||||||
|
payload_messages.extend(messages)
|
||||||
|
|
||||||
|
body: dict = {
|
||||||
|
"messages": payload_messages,
|
||||||
|
"max_tokens": max_new_tokens,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
if self.model:
|
||||||
|
body["model"] = self.model
|
||||||
|
|
||||||
|
resp = requests.post(
|
||||||
|
f"{self.base_url}/v1/chat/completions",
|
||||||
|
json=body,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
response = resp.json()["choices"][0]["message"]["content"].strip()
|
||||||
|
log.info(f"LM Studio response: {response}")
|
||||||
|
return response, None
|
||||||
|
|
||||||
async def generate_sentences(
|
async def generate_sentences(
|
||||||
self,
|
self,
|
||||||
@@ -72,7 +197,6 @@ class LLMEngine:
|
|||||||
if cancel_event and cancel_event.is_set():
|
if cancel_event and cancel_event.is_set():
|
||||||
return
|
return
|
||||||
|
|
||||||
# Split into sentences and yield each
|
|
||||||
sentences, remainder = split_sentences(response)
|
sentences, remainder = split_sentences(response)
|
||||||
for sentence in sentences:
|
for sentence in sentences:
|
||||||
if cancel_event and cancel_event.is_set():
|
if cancel_event and cancel_event.is_set():
|
||||||
|
|||||||
+3
-1
@@ -77,7 +77,9 @@ async def websocket_chat(ws: WebSocket):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if msg.get("type") == "interrupt":
|
if msg.get("type") == "interrupt":
|
||||||
await session.interrupt()
|
await session.interrupt(
|
||||||
|
last_chunk_id=msg.get("last_chunk_id")
|
||||||
|
)
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
log.info("WebSocket client disconnected.")
|
log.info("WebSocket client disconnected.")
|
||||||
|
|||||||
+18
-2
@@ -46,7 +46,7 @@ class ModelManager:
|
|||||||
from server.vad import SileroVADOnnx
|
from server.vad import SileroVADOnnx
|
||||||
|
|
||||||
model_path = hf_hub_download(
|
model_path = hf_hub_download(
|
||||||
repo_id="onnx-community/silero-vad", filename="silero_vad.onnx"
|
repo_id="onnx-community/silero-vad", filename="onnx/model.onnx"
|
||||||
)
|
)
|
||||||
self.vad_model = SileroVADOnnx(model_path)
|
self.vad_model = SileroVADOnnx(model_path)
|
||||||
log.info("Silero VAD loaded (ONNX, CPU).")
|
log.info("Silero VAD loaded (ONNX, CPU).")
|
||||||
@@ -66,6 +66,21 @@ class ModelManager:
|
|||||||
log.info("Qwen3-ASR-0.6B loaded.")
|
log.info("Qwen3-ASR-0.6B loaded.")
|
||||||
|
|
||||||
def _load_llm(self):
|
def _load_llm(self):
|
||||||
|
from server.config import config
|
||||||
|
|
||||||
|
llm_config = config.get("llm", {})
|
||||||
|
backend = llm_config.get("backend", "local")
|
||||||
|
system_prompt = llm_config.get("system_prompt", "You are a helpful assistant.")
|
||||||
|
|
||||||
|
if backend == "lmstudio":
|
||||||
|
from server.llm import LMStudioEngine
|
||||||
|
|
||||||
|
lms = llm_config.get("lmstudio", {})
|
||||||
|
url = lms.get("url", "http://host.docker.internal:1234")
|
||||||
|
model = lms.get("model", "") or ""
|
||||||
|
log.info(f"Using LM Studio backend at {url} (model={model or 'server default'})")
|
||||||
|
self.llm_engine = LMStudioEngine(url, model, system_prompt)
|
||||||
|
else:
|
||||||
log.info("Loading Qwen3-4B (GPTQ 4-bit)...")
|
log.info("Loading Qwen3-4B (GPTQ 4-bit)...")
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
@@ -77,7 +92,8 @@ class ModelManager:
|
|||||||
model_name,
|
model_name,
|
||||||
device_map=device,
|
device_map=device,
|
||||||
)
|
)
|
||||||
self.llm_engine = LLMEngine(model, tokenizer)
|
max_cache_tokens = llm_config.get("max_cache_tokens", 4096)
|
||||||
|
self.llm_engine = LLMEngine(model, tokenizer, system_prompt, max_cache_tokens)
|
||||||
log.info("Qwen3-4B-GPTQ-Int4 loaded (~2.5GB VRAM).")
|
log.info("Qwen3-4B-GPTQ-Int4 loaded (~2.5GB VRAM).")
|
||||||
|
|
||||||
def _load_tts(self):
|
def _load_tts(self):
|
||||||
|
|||||||
+106
-11
@@ -1,11 +1,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import queue
|
import queue
|
||||||
|
import re
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from server.audio_utils import float32_to_pcm_bytes
|
from server.audio_utils import float32_to_pcm_bytes
|
||||||
|
from server.llm import KVCacheState
|
||||||
from server.models import ModelManager
|
from server.models import ModelManager
|
||||||
from server.vad import StreamingVAD
|
from server.vad import StreamingVAD
|
||||||
|
|
||||||
@@ -13,6 +15,56 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
_SENTINEL = None
|
_SENTINEL = None
|
||||||
|
|
||||||
|
# Regex: split after sentence-ending punctuation followed by whitespace
|
||||||
|
_SENTENCE_RE = re.compile(r'(?<=[.!?])\s+')
|
||||||
|
# Regex: split after clause-level punctuation followed by whitespace
|
||||||
|
_CLAUSE_RE = re.compile(r'(?<=[,;:\u2014])\s+')
|
||||||
|
|
||||||
|
MAX_SEGMENT_WORDS = 20
|
||||||
|
MIN_SEGMENT_WORDS = 4
|
||||||
|
|
||||||
|
|
||||||
|
def _split_into_segments(text: str) -> list[str]:
|
||||||
|
"""Split text into small TTS-friendly segments for fine-grained streaming.
|
||||||
|
|
||||||
|
Splits on sentence boundaries first, then breaks long sentences at clause
|
||||||
|
boundaries (commas, semicolons, colons, em-dashes). Avoids tiny fragments
|
||||||
|
by merging short pieces with their neighbours.
|
||||||
|
"""
|
||||||
|
sentences = _SENTENCE_RE.split(text.strip())
|
||||||
|
segments: list[str] = []
|
||||||
|
for sent in sentences:
|
||||||
|
if len(sent.split()) <= MAX_SEGMENT_WORDS:
|
||||||
|
segments.append(sent)
|
||||||
|
else:
|
||||||
|
# Split long sentences at clause boundaries
|
||||||
|
clauses = _CLAUSE_RE.split(sent)
|
||||||
|
current = ""
|
||||||
|
for clause in clauses:
|
||||||
|
combined = (current + " " + clause) if current else clause
|
||||||
|
if current and len(combined.split()) > MAX_SEGMENT_WORDS:
|
||||||
|
segments.append(current)
|
||||||
|
current = clause
|
||||||
|
else:
|
||||||
|
current = combined
|
||||||
|
if current:
|
||||||
|
segments.append(current)
|
||||||
|
|
||||||
|
# Merge any tiny fragments into their neighbour
|
||||||
|
merged: list[str] = []
|
||||||
|
for seg in segments:
|
||||||
|
if not seg.strip():
|
||||||
|
continue
|
||||||
|
if merged and len(merged[-1].split()) < MIN_SEGMENT_WORDS:
|
||||||
|
merged[-1] = merged[-1] + " " + seg
|
||||||
|
else:
|
||||||
|
merged.append(seg)
|
||||||
|
# Also merge a trailing runt
|
||||||
|
if len(merged) > 1 and len(merged[-1].split()) < MIN_SEGMENT_WORDS:
|
||||||
|
merged[-2] = merged[-2] + " " + merged[-1]
|
||||||
|
merged.pop()
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
class ConversationSession:
|
class ConversationSession:
|
||||||
"""Manages a single client's voice conversation pipeline.
|
"""Manages a single client's voice conversation pipeline.
|
||||||
@@ -27,6 +79,7 @@ class ConversationSession:
|
|||||||
|
|
||||||
self.vad: StreamingVAD = models.create_vad()
|
self.vad: StreamingVAD = models.create_vad()
|
||||||
self.conversation_history: list[dict] = []
|
self.conversation_history: list[dict] = []
|
||||||
|
self.kv_cache_state: KVCacheState | None = None
|
||||||
self.cancel_event = threading.Event()
|
self.cancel_event = threading.Event()
|
||||||
self.is_responding = False
|
self.is_responding = False
|
||||||
self._response_task: asyncio.Task | None = None
|
self._response_task: asyncio.Task | None = None
|
||||||
@@ -38,6 +91,7 @@ class ConversationSession:
|
|||||||
self.cancel_event.set()
|
self.cancel_event.set()
|
||||||
if self._response_task and not self._response_task.done():
|
if self._response_task and not self._response_task.done():
|
||||||
self._response_task.cancel()
|
self._response_task.cancel()
|
||||||
|
self.kv_cache_state = None
|
||||||
|
|
||||||
async def handle_audio_chunk(self, chunk_16k: np.ndarray):
|
async def handle_audio_chunk(self, chunk_16k: np.ndarray):
|
||||||
utterance = self.vad.process_chunk(chunk_16k)
|
utterance = self.vad.process_chunk(chunk_16k)
|
||||||
@@ -50,15 +104,17 @@ class ConversationSession:
|
|||||||
elif self.vad.is_speaking and self.is_responding:
|
elif self.vad.is_speaking and self.is_responding:
|
||||||
await self._interrupt()
|
await self._interrupt()
|
||||||
|
|
||||||
async def interrupt(self):
|
async def interrupt(self, last_chunk_id: int | None = None):
|
||||||
"""Public interrupt method for WebSocket text messages."""
|
"""Public interrupt method for WebSocket text messages."""
|
||||||
if self.is_responding:
|
if self.is_responding:
|
||||||
await self._interrupt()
|
await self._interrupt(last_chunk_id=last_chunk_id)
|
||||||
|
|
||||||
async def _interrupt(self):
|
async def _interrupt(self, last_chunk_id: int | None = None):
|
||||||
log.info("Barge-in: cancelling response.")
|
log.info("Barge-in: cancelling response.")
|
||||||
self.cancel_event.set()
|
self.cancel_event.set()
|
||||||
self.is_responding = False
|
self.is_responding = False
|
||||||
|
if last_chunk_id is not None:
|
||||||
|
self._last_played_chunk_id = last_chunk_id
|
||||||
# Tell client to stop audio immediately
|
# Tell client to stop audio immediately
|
||||||
try:
|
try:
|
||||||
await self.send_json({"type": "interrupt"})
|
await self.send_json({"type": "interrupt"})
|
||||||
@@ -91,8 +147,8 @@ class ConversationSession:
|
|||||||
# LLM
|
# LLM
|
||||||
log.info(f"Conversation history ({len(self.conversation_history)} messages): "
|
log.info(f"Conversation history ({len(self.conversation_history)} messages): "
|
||||||
+ str([m['content'][:50] for m in self.conversation_history]))
|
+ str([m['content'][:50] for m in self.conversation_history]))
|
||||||
response = await asyncio.to_thread(
|
response, self.kv_cache_state = await asyncio.to_thread(
|
||||||
self.models.llm_engine.generate, self.conversation_history
|
self.models.llm_engine.generate, self.conversation_history, 256, self.kv_cache_state
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cancel_event.is_set():
|
if self.cancel_event.is_set():
|
||||||
@@ -102,11 +158,18 @@ class ConversationSession:
|
|||||||
# TTS - stream chunks with per-sentence text
|
# TTS - stream chunks with per-sentence text
|
||||||
await self.send_json({"type": "status", "state": "speaking"})
|
await self.send_json({"type": "status", "state": "speaking"})
|
||||||
chunk_queue = queue.Queue()
|
chunk_queue = queue.Queue()
|
||||||
|
self._last_played_chunk_id = None
|
||||||
|
|
||||||
|
segments = _split_into_segments(response)
|
||||||
|
log.info(f"TTS: split response into {len(segments)} segments")
|
||||||
|
|
||||||
def _tts_worker():
|
def _tts_worker():
|
||||||
try:
|
try:
|
||||||
|
for segment in segments:
|
||||||
|
if self.cancel_event.is_set():
|
||||||
|
break
|
||||||
for graphemes, _ps, audio in self.models.tts_engine.pipeline(
|
for graphemes, _ps, audio in self.models.tts_engine.pipeline(
|
||||||
response, voice=self.models.tts_engine.voice
|
segment, voice=self.models.tts_engine.voice
|
||||||
):
|
):
|
||||||
if self.cancel_event.is_set():
|
if self.cancel_event.is_set():
|
||||||
break
|
break
|
||||||
@@ -121,6 +184,9 @@ class ConversationSession:
|
|||||||
tts_thread.start()
|
tts_thread.start()
|
||||||
|
|
||||||
spoken_text = ""
|
spoken_text = ""
|
||||||
|
chunk_id = 0
|
||||||
|
# Maps chunk_id -> cumulative text up to and including that chunk
|
||||||
|
chunk_text_map: dict[int, str] = {}
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
item = await asyncio.to_thread(chunk_queue.get, timeout=10.0)
|
item = await asyncio.to_thread(chunk_queue.get, timeout=10.0)
|
||||||
@@ -134,8 +200,14 @@ class ConversationSession:
|
|||||||
|
|
||||||
sentence_text, audio = item
|
sentence_text, audio = item
|
||||||
spoken_text += sentence_text
|
spoken_text += sentence_text
|
||||||
|
chunk_text_map[chunk_id] = spoken_text
|
||||||
|
|
||||||
await self.send_json({"type": "response_text", "text": sentence_text, "final": False})
|
await self.send_json({
|
||||||
|
"type": "response_text",
|
||||||
|
"text": sentence_text,
|
||||||
|
"chunk_id": chunk_id,
|
||||||
|
"final": False,
|
||||||
|
})
|
||||||
pcm_bytes = float32_to_pcm_bytes(audio)
|
pcm_bytes = float32_to_pcm_bytes(audio)
|
||||||
try:
|
try:
|
||||||
await self.send_bytes(pcm_bytes)
|
await self.send_bytes(pcm_bytes)
|
||||||
@@ -143,19 +215,42 @@ class ConversationSession:
|
|||||||
log.warning("Failed to send audio, client disconnected.")
|
log.warning("Failed to send audio, client disconnected.")
|
||||||
self.cancel_event.set()
|
self.cancel_event.set()
|
||||||
break
|
break
|
||||||
|
chunk_id += 1
|
||||||
|
|
||||||
tts_thread.join(timeout=2.0)
|
tts_thread.join(timeout=2.0)
|
||||||
|
|
||||||
# Save only what was actually spoken
|
# Determine what was actually heard by the client
|
||||||
if spoken_text.strip():
|
was_interrupted = spoken_text.strip() != response.strip()
|
||||||
|
if was_interrupted and self._last_played_chunk_id is not None:
|
||||||
|
# Client told us the last chunk whose audio actually played
|
||||||
|
heard_text = chunk_text_map.get(self._last_played_chunk_id, "")
|
||||||
|
log.info(f"Interrupted: client heard up to chunk {self._last_played_chunk_id}")
|
||||||
|
else:
|
||||||
|
heard_text = spoken_text
|
||||||
|
|
||||||
|
# Save only what was actually spoken/heard
|
||||||
|
if heard_text.strip():
|
||||||
|
# Use original LLM response when fully spoken (keeps KV-cache valid);
|
||||||
|
# use heard_text only when interrupted.
|
||||||
|
final_content = heard_text.strip() if was_interrupted else response
|
||||||
self.conversation_history.append(
|
self.conversation_history.append(
|
||||||
{"role": "assistant", "content": spoken_text.strip()}
|
{"role": "assistant", "content": final_content}
|
||||||
|
)
|
||||||
|
if was_interrupted and self.kv_cache_state is not None:
|
||||||
|
self.kv_cache_state = self.models.llm_engine.trim_cache(
|
||||||
|
self.kv_cache_state, self.conversation_history
|
||||||
)
|
)
|
||||||
elif self.conversation_history and self.conversation_history[-1]["role"] == "user":
|
elif self.conversation_history and self.conversation_history[-1]["role"] == "user":
|
||||||
self.conversation_history.pop()
|
self.conversation_history.pop()
|
||||||
|
self.kv_cache_state = None
|
||||||
|
|
||||||
if not self.cancel_event.is_set():
|
if not self.cancel_event.is_set():
|
||||||
await self.send_json({"type": "response_text", "text": "", "final": True})
|
await self.send_json({
|
||||||
|
"type": "response_text",
|
||||||
|
"text": "",
|
||||||
|
"final": True,
|
||||||
|
"total_chunks": chunk_id,
|
||||||
|
})
|
||||||
|
|
||||||
self.is_responding = False
|
self.is_responding = False
|
||||||
try:
|
try:
|
||||||
|
|||||||
+64
-14
@@ -13,6 +13,11 @@ const BARGE_IN_THRESHOLD = 0.03; // RMS energy threshold for barge-in
|
|||||||
const BARGE_IN_FRAMES = 2; // Consecutive frames above threshold to trigger
|
const BARGE_IN_FRAMES = 2; // Consecutive frames above threshold to trigger
|
||||||
let bargeInCount = 0;
|
let bargeInCount = 0;
|
||||||
|
|
||||||
|
// --- Text-audio sync state ---
|
||||||
|
let pendingTextChunks = []; // [{chunkId, text}] - text waiting for its audio to arrive
|
||||||
|
let scheduledTextTimers = []; // timer IDs for text display scheduled to match audio playback
|
||||||
|
let lastDisplayedChunkId = -1; // last chunk whose text was actually shown to the user
|
||||||
|
|
||||||
const chatArea = document.getElementById("chat-area");
|
const chatArea = document.getElementById("chat-area");
|
||||||
const statusBadge = document.getElementById("status-badge");
|
const statusBadge = document.getElementById("status-badge");
|
||||||
const micBtn = document.getElementById("mic-btn");
|
const micBtn = document.getElementById("mic-btn");
|
||||||
@@ -54,8 +59,8 @@ function handleJSON(msg) {
|
|||||||
|
|
||||||
case "interrupt":
|
case "interrupt":
|
||||||
stopPlayback();
|
stopPlayback();
|
||||||
// Trim the assistant message to what was spoken, then finalize
|
// Finalize with interrupted marker — text already reflects only what was heard
|
||||||
finalizeAssistantMessage();
|
finalizeAssistantMessage(true);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case "transcript":
|
case "transcript":
|
||||||
@@ -64,9 +69,15 @@ function handleJSON(msg) {
|
|||||||
|
|
||||||
case "response_text":
|
case "response_text":
|
||||||
if (msg.final) {
|
if (msg.final) {
|
||||||
finalizeAssistantMessage();
|
// All chunks sent; finalize will happen when last audio chunk plays
|
||||||
|
// (or immediately if nothing was queued)
|
||||||
|
if (pendingTextChunks.length === 0 && scheduledTextTimers.length === 0) {
|
||||||
|
finalizeAssistantMessage(false);
|
||||||
|
}
|
||||||
|
// Otherwise, playAudioChunk will finalize after the last scheduled text
|
||||||
} else {
|
} else {
|
||||||
appendAssistantText(msg.text);
|
// Queue text — it will be displayed when corresponding audio starts playing
|
||||||
|
pendingTextChunks.push({ chunkId: msg.chunk_id, text: msg.text });
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -113,9 +124,20 @@ function appendAssistantText(text) {
|
|||||||
chatArea.scrollTop = chatArea.scrollHeight;
|
chatArea.scrollTop = chatArea.scrollHeight;
|
||||||
}
|
}
|
||||||
|
|
||||||
function finalizeAssistantMessage() {
|
function finalizeAssistantMessage(interrupted = false) {
|
||||||
|
if (interrupted && currentAssistantEl && currentAssistantText) {
|
||||||
|
const marker = document.createElement("span");
|
||||||
|
marker.className = "interrupted-marker";
|
||||||
|
marker.textContent = " [interrupted]";
|
||||||
|
currentAssistantEl.appendChild(marker);
|
||||||
|
}
|
||||||
currentAssistantEl = null;
|
currentAssistantEl = null;
|
||||||
currentAssistantText = "";
|
currentAssistantText = "";
|
||||||
|
// Reset sync state
|
||||||
|
pendingTextChunks = [];
|
||||||
|
for (const tid of scheduledTextTimers) clearTimeout(tid);
|
||||||
|
scheduledTextTimers = [];
|
||||||
|
lastDisplayedChunkId = -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Audio Playback ---
|
// --- Audio Playback ---
|
||||||
@@ -146,18 +168,38 @@ function playAudioChunk(arrayBuffer) {
|
|||||||
|
|
||||||
activeSources.push(source);
|
activeSources.push(source);
|
||||||
isPlaying = true;
|
isPlaying = true;
|
||||||
source.onended = () => {
|
|
||||||
activeSources = activeSources.filter((s) => s !== source);
|
// Pair this audio chunk with the next queued text chunk
|
||||||
if (activeSources.length === 0) {
|
const textEntry = pendingTextChunks.shift();
|
||||||
isPlaying = false;
|
|
||||||
bargeInCount = 0;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const now = ctx.currentTime;
|
const now = ctx.currentTime;
|
||||||
if (nextPlayTime < now) {
|
if (nextPlayTime < now) {
|
||||||
nextPlayTime = now + 0.01;
|
nextPlayTime = now + 0.01;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Schedule text display to coincide with audio playback start
|
||||||
|
if (textEntry) {
|
||||||
|
const delayMs = Math.max(0, (nextPlayTime - now) * 1000);
|
||||||
|
const tid = setTimeout(() => {
|
||||||
|
appendAssistantText(textEntry.text);
|
||||||
|
lastDisplayedChunkId = textEntry.chunkId;
|
||||||
|
scheduledTextTimers = scheduledTextTimers.filter((t) => t !== tid);
|
||||||
|
}, delayMs);
|
||||||
|
scheduledTextTimers.push(tid);
|
||||||
|
}
|
||||||
|
|
||||||
|
source.onended = () => {
|
||||||
|
activeSources = activeSources.filter((s) => s !== source);
|
||||||
|
if (activeSources.length === 0) {
|
||||||
|
isPlaying = false;
|
||||||
|
bargeInCount = 0;
|
||||||
|
// If all audio has finished and no more text pending, finalize
|
||||||
|
if (pendingTextChunks.length === 0 && scheduledTextTimers.length === 0) {
|
||||||
|
finalizeAssistantMessage(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
source.start(nextPlayTime);
|
source.start(nextPlayTime);
|
||||||
nextPlayTime += buffer.duration;
|
nextPlayTime += buffer.duration;
|
||||||
}
|
}
|
||||||
@@ -172,6 +214,10 @@ function stopPlayback() {
|
|||||||
nextPlayTime = 0;
|
nextPlayTime = 0;
|
||||||
isPlaying = false;
|
isPlaying = false;
|
||||||
bargeInCount = 0;
|
bargeInCount = 0;
|
||||||
|
// Cancel any pending text displays
|
||||||
|
for (const tid of scheduledTextTimers) clearTimeout(tid);
|
||||||
|
scheduledTextTimers = [];
|
||||||
|
pendingTextChunks = [];
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Microphone ---
|
// --- Microphone ---
|
||||||
@@ -229,8 +275,12 @@ async function startMic() {
|
|||||||
if (bargeInCount >= BARGE_IN_FRAMES) {
|
if (bargeInCount >= BARGE_IN_FRAMES) {
|
||||||
// User is speaking over the assistant - interrupt
|
// User is speaking over the assistant - interrupt
|
||||||
stopPlayback();
|
stopPlayback();
|
||||||
finalizeAssistantMessage();
|
const msg = { type: "interrupt" };
|
||||||
ws.send(JSON.stringify({ type: "interrupt" }));
|
if (lastDisplayedChunkId >= 0) {
|
||||||
|
msg.last_chunk_id = lastDisplayedChunkId;
|
||||||
|
}
|
||||||
|
ws.send(JSON.stringify(msg));
|
||||||
|
finalizeAssistantMessage(true);
|
||||||
isPlaying = false;
|
isPlaying = false;
|
||||||
bargeInCount = 0;
|
bargeInCount = 0;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -84,6 +84,12 @@ header h1 {
|
|||||||
border-bottom-left-radius: 4px;
|
border-bottom-left-radius: 4px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.interrupted-marker {
|
||||||
|
color: #888;
|
||||||
|
font-style: italic;
|
||||||
|
font-size: 13px;
|
||||||
|
}
|
||||||
|
|
||||||
#controls {
|
#controls {
|
||||||
padding: 16px 24px;
|
padding: 16px 24px;
|
||||||
border-top: 1px solid #222;
|
border-top: 1px solid #222;
|
||||||
|
|||||||
Reference in New Issue
Block a user