Compare commits

...

4 Commits

11 changed files with 400 additions and 72 deletions
+3
View File
@@ -35,6 +35,9 @@ RUN python3.11 -m pip install --no-cache-dir \
COPY requirements.txt . COPY requirements.txt .
RUN python3.11 -m pip install --no-cache-dir -r requirements.txt RUN python3.11 -m pip install --no-cache-dir -r requirements.txt
# Pre-download the spacy model that kokoro needs at runtime
RUN python3.11 -m spacy download en_core_web_sm
COPY . . COPY . .
EXPOSE 8000 EXPOSE 8000
+14
View File
@@ -0,0 +1,14 @@
# LLM backend: "local" or "lmstudio"
llm:
backend: local # change to "lmstudio" to use LM Studio instead
max_cache_tokens: 4096 # max KV-cache size per session (tokens); 0 to disable caching
system_prompt: >-
You are a helpful voice assistant.
Keep your responses extremely concise but natural for spoken conversation.
Do not use markdown, bullet points, code blocks, emojis, or any formatting that doesn't work in speech.
# Settings used only when backend = "lmstudio"
lmstudio:
url: http://host.docker.internal:1234 # host.docker.internal resolves to your PC from inside Docker
model: "" # leave empty to use whatever model LM Studio has loaded
+5
View File
@@ -6,6 +6,11 @@ services:
volumes: volumes:
# Cache models on the host so they survive container rebuilds # Cache models on the host so they survive container rebuilds
- huggingface-cache:/cache/huggingface - huggingface-cache:/cache/huggingface
# Mount source so you can edit code/config without rebuilding the image
- ./config.yml:/app/config.yml:ro
- ./server:/app/server:ro
- ./static:/app/static:ro
- ./run.py:/app/run.py:ro
deploy: deploy:
resources: resources:
reservations: reservations:
+1
View File
@@ -13,3 +13,4 @@ numpy
soundfile soundfile
scipy scipy
python-multipart python-multipart
pyyaml
+12
View File
@@ -0,0 +1,12 @@
import pathlib
import yaml
_CONFIG_PATH = pathlib.Path(__file__).parent.parent / "config.yml"
def load_config() -> dict:
with open(_CONFIG_PATH) as f:
return yaml.safe_load(f)
config = load_config()
+153 -29
View File
@@ -1,32 +1,38 @@
import copy
import dataclasses
import logging import logging
import threading import threading
from typing import AsyncIterator from typing import AsyncIterator
import torch import torch
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
from server.audio_utils import split_sentences from server.audio_utils import split_sentences
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@dataclasses.dataclass
class KVCacheState:
"""Per-session KV-cache persisted across generate() calls."""
past_key_values: DynamicCache | None
cached_token_count: int
cached_messages: list[dict] # snapshot of messages when cache was built
class LLMEngine: class LLMEngine:
"""Wraps Qwen3 for conversation generation.""" """Wraps Qwen3 for conversation generation with persistent KV-cache."""
SYSTEM_PROMPT = ( def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, system_prompt: str,
"You are a helpful voice assistant. Keep your responses concise and natural " max_cache_tokens: int = 4096):
"for spoken conversation. Respond in 1-3 short sentences. "
"Do not use markdown, bullet points, code blocks, emojis, or any "
"formatting that doesn't work in speech."
)
def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.system_prompt = system_prompt
self.max_cache_tokens = max_cache_tokens
self._generate_lock = threading.Lock()
def _build_inputs(self, messages: list[dict]): def _build_inputs(self, messages: list[dict]):
"""Build input token ids using the model's chat template.""" """Build input token ids using the model's chat template."""
chat_messages = [{"role": "system", "content": self.SYSTEM_PROMPT}] chat_messages = [{"role": "system", "content": self.system_prompt}]
for msg in messages: for msg in messages:
chat_messages.append({"role": msg["role"], "content": msg["content"]}) chat_messages.append({"role": msg["role"], "content": msg["content"]})
@@ -38,26 +44,145 @@ class LLMEngine:
) )
return self.tokenizer(text, return_tensors="pt").to(self.model.device) return self.tokenizer(text, return_tensors="pt").to(self.model.device)
def generate(self, messages: list[dict], max_new_tokens: int = 256) -> str: def _validate_cache(self, messages: list[dict], cache_state: KVCacheState | None) -> DynamicCache | None:
"""Generate a complete response (blocking).""" """Return past_key_values if the cache is valid for the given messages, else None."""
inputs = self._build_inputs(messages) if cache_state is None or cache_state.past_key_values is None:
input_len = inputs["input_ids"].shape[1] return None
if self.max_cache_tokens and cache_state.cached_token_count > self.max_cache_tokens:
log.info("KV-cache exceeds max size, discarding.")
return None
cached = cache_state.cached_messages
# The current messages must start with the cached messages as a prefix
if len(cached) > len(messages):
return None
for cached_msg, current_msg in zip(cached, messages):
if cached_msg["role"] != current_msg["role"] or cached_msg["content"] != current_msg["content"]:
return None
return cache_state.past_key_values
with torch.no_grad(): def generate(
output_ids = self.model.generate( self,
**inputs, messages: list[dict],
max_new_tokens=max_new_tokens, max_new_tokens: int = 256,
temperature=0.7, cache_state: KVCacheState | None = None,
top_p=0.9, ) -> tuple[str, KVCacheState]:
do_sample=True, """Generate a complete response (blocking). Returns (response, updated_cache_state)."""
repetition_penalty=1.2, with self._generate_lock:
inputs = self._build_inputs(messages)
input_ids = inputs["input_ids"]
input_len = input_ids.shape[1]
past_kv = self._validate_cache(messages, cache_state)
cached_len = cache_state.cached_token_count if past_kv is not None else 0
log.info(
f"KV-cache: {cached_len}/{input_len} tokens cached, "
f"processing {input_len - cached_len} new tokens"
) )
# Decode only the generated tokens (skip prompt) with torch.no_grad():
new_ids = output_ids[0][input_len:] outputs = self.model.generate(
response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip() input_ids=input_ids,
log.info(f"LLM response: {response}") attention_mask=inputs.get("attention_mask"),
return response past_key_values=past_kv,
max_new_tokens=max_new_tokens,
temperature=0.7,
top_p=0.9,
do_sample=True,
repetition_penalty=1.2,
return_dict_in_generate=True,
use_cache=True,
)
# Decode only the generated tokens (skip prompt)
new_ids = outputs.sequences[0][input_len:]
response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip()
log.info(f"LLM response: {response}")
# Build updated cache state: messages now include the assistant response
new_messages = copy.deepcopy(messages) + [{"role": "assistant", "content": response}]
new_cache = KVCacheState(
past_key_values=outputs.past_key_values,
cached_token_count=outputs.sequences.shape[1],
cached_messages=new_messages,
)
return response, new_cache
def trim_cache(self, cache_state: KVCacheState, messages: list[dict]) -> KVCacheState | None:
"""Trim cache to match the actual conversation history (e.g. after barge-in)."""
if cache_state is None or cache_state.past_key_values is None:
return None
inputs = self._build_inputs(messages)
target_len = inputs["input_ids"].shape[1]
if target_len >= cache_state.cached_token_count:
return cache_state
cache_state.past_key_values.crop(target_len)
cache_state.cached_token_count = target_len
cache_state.cached_messages = copy.deepcopy(messages)
return cache_state
async def generate_sentences(
self,
messages: list[dict],
cancel_event: threading.Event | None = None,
cache_state: KVCacheState | None = None,
) -> AsyncIterator[str]:
"""Generate response and yield it sentence by sentence for TTS pipelining."""
import asyncio
response = await asyncio.to_thread(self.generate, messages, 256, cache_state)
if cancel_event and cancel_event.is_set():
return
# Split into sentences and yield each
sentences, remainder = split_sentences(response)
for sentence in sentences:
if cancel_event and cancel_event.is_set():
return
yield sentence
if remainder:
yield remainder
class LMStudioEngine:
"""LLM engine that delegates to an LM Studio server via its OpenAI-compatible API."""
def __init__(self, base_url: str, model: str, system_prompt: str):
self.base_url = base_url.rstrip("/")
self.model = model
self.system_prompt = system_prompt
def generate(
self,
messages: list[dict],
max_new_tokens: int = 256,
cache_state: KVCacheState | None = None,
) -> tuple[str, None]:
import requests
payload_messages = [{"role": "system", "content": self.system_prompt}]
payload_messages.extend(messages)
body: dict = {
"messages": payload_messages,
"max_tokens": max_new_tokens,
"temperature": 0.7,
"stream": False,
}
if self.model:
body["model"] = self.model
resp = requests.post(
f"{self.base_url}/v1/chat/completions",
json=body,
timeout=30,
)
resp.raise_for_status()
response = resp.json()["choices"][0]["message"]["content"].strip()
log.info(f"LM Studio response: {response}")
return response, None
async def generate_sentences( async def generate_sentences(
self, self,
@@ -72,7 +197,6 @@ class LLMEngine:
if cancel_event and cancel_event.is_set(): if cancel_event and cancel_event.is_set():
return return
# Split into sentences and yield each
sentences, remainder = split_sentences(response) sentences, remainder = split_sentences(response)
for sentence in sentences: for sentence in sentences:
if cancel_event and cancel_event.is_set(): if cancel_event and cancel_event.is_set():
+3 -1
View File
@@ -77,7 +77,9 @@ async def websocket_chat(ws: WebSocket):
continue continue
if msg.get("type") == "interrupt": if msg.get("type") == "interrupt":
await session.interrupt() await session.interrupt(
last_chunk_id=msg.get("last_chunk_id")
)
except WebSocketDisconnect: except WebSocketDisconnect:
log.info("WebSocket client disconnected.") log.info("WebSocket client disconnected.")
+28 -12
View File
@@ -46,7 +46,7 @@ class ModelManager:
from server.vad import SileroVADOnnx from server.vad import SileroVADOnnx
model_path = hf_hub_download( model_path = hf_hub_download(
repo_id="onnx-community/silero-vad", filename="silero_vad.onnx" repo_id="onnx-community/silero-vad", filename="onnx/model.onnx"
) )
self.vad_model = SileroVADOnnx(model_path) self.vad_model = SileroVADOnnx(model_path)
log.info("Silero VAD loaded (ONNX, CPU).") log.info("Silero VAD loaded (ONNX, CPU).")
@@ -66,19 +66,35 @@ class ModelManager:
log.info("Qwen3-ASR-0.6B loaded.") log.info("Qwen3-ASR-0.6B loaded.")
def _load_llm(self): def _load_llm(self):
log.info("Loading Qwen3-4B (GPTQ 4-bit)...") from server.config import config
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "Qwen/Qwen3.5-0.8B" llm_config = config.get("llm", {})
backend = llm_config.get("backend", "local")
system_prompt = llm_config.get("system_prompt", "You are a helpful assistant.")
tokenizer = AutoTokenizer.from_pretrained(model_name) if backend == "lmstudio":
device = get_device() from server.llm import LMStudioEngine
model = AutoModelForCausalLM.from_pretrained(
model_name, lms = llm_config.get("lmstudio", {})
device_map=device, url = lms.get("url", "http://host.docker.internal:1234")
) model = lms.get("model", "") or ""
self.llm_engine = LLMEngine(model, tokenizer) log.info(f"Using LM Studio backend at {url} (model={model or 'server default'})")
log.info("Qwen3-4B-GPTQ-Int4 loaded (~2.5GB VRAM).") self.llm_engine = LMStudioEngine(url, model, system_prompt)
else:
log.info("Loading Qwen3-4B (GPTQ 4-bit)...")
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "Qwen/Qwen3.5-0.8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = get_device()
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=device,
)
max_cache_tokens = llm_config.get("max_cache_tokens", 4096)
self.llm_engine = LLMEngine(model, tokenizer, system_prompt, max_cache_tokens)
log.info("Qwen3-4B-GPTQ-Int4 loaded (~2.5GB VRAM).")
def _load_tts(self): def _load_tts(self):
log.info("Loading Kokoro TTS...") log.info("Loading Kokoro TTS...")
+110 -15
View File
@@ -1,11 +1,13 @@
import asyncio import asyncio
import logging import logging
import queue import queue
import re
import threading import threading
import numpy as np import numpy as np
from server.audio_utils import float32_to_pcm_bytes from server.audio_utils import float32_to_pcm_bytes
from server.llm import KVCacheState
from server.models import ModelManager from server.models import ModelManager
from server.vad import StreamingVAD from server.vad import StreamingVAD
@@ -13,6 +15,56 @@ log = logging.getLogger(__name__)
_SENTINEL = None _SENTINEL = None
# Regex: split after sentence-ending punctuation followed by whitespace
_SENTENCE_RE = re.compile(r'(?<=[.!?])\s+')
# Regex: split after clause-level punctuation followed by whitespace
_CLAUSE_RE = re.compile(r'(?<=[,;:\u2014])\s+')
MAX_SEGMENT_WORDS = 20
MIN_SEGMENT_WORDS = 4
def _split_into_segments(text: str) -> list[str]:
"""Split text into small TTS-friendly segments for fine-grained streaming.
Splits on sentence boundaries first, then breaks long sentences at clause
boundaries (commas, semicolons, colons, em-dashes). Avoids tiny fragments
by merging short pieces with their neighbours.
"""
sentences = _SENTENCE_RE.split(text.strip())
segments: list[str] = []
for sent in sentences:
if len(sent.split()) <= MAX_SEGMENT_WORDS:
segments.append(sent)
else:
# Split long sentences at clause boundaries
clauses = _CLAUSE_RE.split(sent)
current = ""
for clause in clauses:
combined = (current + " " + clause) if current else clause
if current and len(combined.split()) > MAX_SEGMENT_WORDS:
segments.append(current)
current = clause
else:
current = combined
if current:
segments.append(current)
# Merge any tiny fragments into their neighbour
merged: list[str] = []
for seg in segments:
if not seg.strip():
continue
if merged and len(merged[-1].split()) < MIN_SEGMENT_WORDS:
merged[-1] = merged[-1] + " " + seg
else:
merged.append(seg)
# Also merge a trailing runt
if len(merged) > 1 and len(merged[-1].split()) < MIN_SEGMENT_WORDS:
merged[-2] = merged[-2] + " " + merged[-1]
merged.pop()
return merged
class ConversationSession: class ConversationSession:
"""Manages a single client's voice conversation pipeline. """Manages a single client's voice conversation pipeline.
@@ -27,6 +79,7 @@ class ConversationSession:
self.vad: StreamingVAD = models.create_vad() self.vad: StreamingVAD = models.create_vad()
self.conversation_history: list[dict] = [] self.conversation_history: list[dict] = []
self.kv_cache_state: KVCacheState | None = None
self.cancel_event = threading.Event() self.cancel_event = threading.Event()
self.is_responding = False self.is_responding = False
self._response_task: asyncio.Task | None = None self._response_task: asyncio.Task | None = None
@@ -38,6 +91,7 @@ class ConversationSession:
self.cancel_event.set() self.cancel_event.set()
if self._response_task and not self._response_task.done(): if self._response_task and not self._response_task.done():
self._response_task.cancel() self._response_task.cancel()
self.kv_cache_state = None
async def handle_audio_chunk(self, chunk_16k: np.ndarray): async def handle_audio_chunk(self, chunk_16k: np.ndarray):
utterance = self.vad.process_chunk(chunk_16k) utterance = self.vad.process_chunk(chunk_16k)
@@ -50,15 +104,17 @@ class ConversationSession:
elif self.vad.is_speaking and self.is_responding: elif self.vad.is_speaking and self.is_responding:
await self._interrupt() await self._interrupt()
async def interrupt(self): async def interrupt(self, last_chunk_id: int | None = None):
"""Public interrupt method for WebSocket text messages.""" """Public interrupt method for WebSocket text messages."""
if self.is_responding: if self.is_responding:
await self._interrupt() await self._interrupt(last_chunk_id=last_chunk_id)
async def _interrupt(self): async def _interrupt(self, last_chunk_id: int | None = None):
log.info("Barge-in: cancelling response.") log.info("Barge-in: cancelling response.")
self.cancel_event.set() self.cancel_event.set()
self.is_responding = False self.is_responding = False
if last_chunk_id is not None:
self._last_played_chunk_id = last_chunk_id
# Tell client to stop audio immediately # Tell client to stop audio immediately
try: try:
await self.send_json({"type": "interrupt"}) await self.send_json({"type": "interrupt"})
@@ -91,8 +147,8 @@ class ConversationSession:
# LLM # LLM
log.info(f"Conversation history ({len(self.conversation_history)} messages): " log.info(f"Conversation history ({len(self.conversation_history)} messages): "
+ str([m['content'][:50] for m in self.conversation_history])) + str([m['content'][:50] for m in self.conversation_history]))
response = await asyncio.to_thread( response, self.kv_cache_state = await asyncio.to_thread(
self.models.llm_engine.generate, self.conversation_history self.models.llm_engine.generate, self.conversation_history, 256, self.kv_cache_state
) )
if self.cancel_event.is_set(): if self.cancel_event.is_set():
@@ -102,16 +158,23 @@ class ConversationSession:
# TTS - stream chunks with per-sentence text # TTS - stream chunks with per-sentence text
await self.send_json({"type": "status", "state": "speaking"}) await self.send_json({"type": "status", "state": "speaking"})
chunk_queue = queue.Queue() chunk_queue = queue.Queue()
self._last_played_chunk_id = None
segments = _split_into_segments(response)
log.info(f"TTS: split response into {len(segments)} segments")
def _tts_worker(): def _tts_worker():
try: try:
for graphemes, _ps, audio in self.models.tts_engine.pipeline( for segment in segments:
response, voice=self.models.tts_engine.voice
):
if self.cancel_event.is_set(): if self.cancel_event.is_set():
break break
if audio is not None and len(audio) > 0: for graphemes, _ps, audio in self.models.tts_engine.pipeline(
chunk_queue.put((graphemes, audio)) segment, voice=self.models.tts_engine.voice
):
if self.cancel_event.is_set():
break
if audio is not None and len(audio) > 0:
chunk_queue.put((graphemes, audio))
except Exception: except Exception:
log.exception("TTS generation error") log.exception("TTS generation error")
finally: finally:
@@ -121,6 +184,9 @@ class ConversationSession:
tts_thread.start() tts_thread.start()
spoken_text = "" spoken_text = ""
chunk_id = 0
# Maps chunk_id -> cumulative text up to and including that chunk
chunk_text_map: dict[int, str] = {}
while True: while True:
try: try:
item = await asyncio.to_thread(chunk_queue.get, timeout=10.0) item = await asyncio.to_thread(chunk_queue.get, timeout=10.0)
@@ -134,8 +200,14 @@ class ConversationSession:
sentence_text, audio = item sentence_text, audio = item
spoken_text += sentence_text spoken_text += sentence_text
chunk_text_map[chunk_id] = spoken_text
await self.send_json({"type": "response_text", "text": sentence_text, "final": False}) await self.send_json({
"type": "response_text",
"text": sentence_text,
"chunk_id": chunk_id,
"final": False,
})
pcm_bytes = float32_to_pcm_bytes(audio) pcm_bytes = float32_to_pcm_bytes(audio)
try: try:
await self.send_bytes(pcm_bytes) await self.send_bytes(pcm_bytes)
@@ -143,19 +215,42 @@ class ConversationSession:
log.warning("Failed to send audio, client disconnected.") log.warning("Failed to send audio, client disconnected.")
self.cancel_event.set() self.cancel_event.set()
break break
chunk_id += 1
tts_thread.join(timeout=2.0) tts_thread.join(timeout=2.0)
# Save only what was actually spoken # Determine what was actually heard by the client
if spoken_text.strip(): was_interrupted = spoken_text.strip() != response.strip()
if was_interrupted and self._last_played_chunk_id is not None:
# Client told us the last chunk whose audio actually played
heard_text = chunk_text_map.get(self._last_played_chunk_id, "")
log.info(f"Interrupted: client heard up to chunk {self._last_played_chunk_id}")
else:
heard_text = spoken_text
# Save only what was actually spoken/heard
if heard_text.strip():
# Use original LLM response when fully spoken (keeps KV-cache valid);
# use heard_text only when interrupted.
final_content = heard_text.strip() if was_interrupted else response
self.conversation_history.append( self.conversation_history.append(
{"role": "assistant", "content": spoken_text.strip()} {"role": "assistant", "content": final_content}
) )
if was_interrupted and self.kv_cache_state is not None:
self.kv_cache_state = self.models.llm_engine.trim_cache(
self.kv_cache_state, self.conversation_history
)
elif self.conversation_history and self.conversation_history[-1]["role"] == "user": elif self.conversation_history and self.conversation_history[-1]["role"] == "user":
self.conversation_history.pop() self.conversation_history.pop()
self.kv_cache_state = None
if not self.cancel_event.is_set(): if not self.cancel_event.is_set():
await self.send_json({"type": "response_text", "text": "", "final": True}) await self.send_json({
"type": "response_text",
"text": "",
"final": True,
"total_chunks": chunk_id,
})
self.is_responding = False self.is_responding = False
try: try:
+64 -14
View File
@@ -13,6 +13,11 @@ const BARGE_IN_THRESHOLD = 0.03; // RMS energy threshold for barge-in
const BARGE_IN_FRAMES = 2; // Consecutive frames above threshold to trigger const BARGE_IN_FRAMES = 2; // Consecutive frames above threshold to trigger
let bargeInCount = 0; let bargeInCount = 0;
// --- Text-audio sync state ---
let pendingTextChunks = []; // [{chunkId, text}] - text waiting for its audio to arrive
let scheduledTextTimers = []; // timer IDs for text display scheduled to match audio playback
let lastDisplayedChunkId = -1; // last chunk whose text was actually shown to the user
const chatArea = document.getElementById("chat-area"); const chatArea = document.getElementById("chat-area");
const statusBadge = document.getElementById("status-badge"); const statusBadge = document.getElementById("status-badge");
const micBtn = document.getElementById("mic-btn"); const micBtn = document.getElementById("mic-btn");
@@ -54,8 +59,8 @@ function handleJSON(msg) {
case "interrupt": case "interrupt":
stopPlayback(); stopPlayback();
// Trim the assistant message to what was spoken, then finalize // Finalize with interrupted marker — text already reflects only what was heard
finalizeAssistantMessage(); finalizeAssistantMessage(true);
break; break;
case "transcript": case "transcript":
@@ -64,9 +69,15 @@ function handleJSON(msg) {
case "response_text": case "response_text":
if (msg.final) { if (msg.final) {
finalizeAssistantMessage(); // All chunks sent; finalize will happen when last audio chunk plays
// (or immediately if nothing was queued)
if (pendingTextChunks.length === 0 && scheduledTextTimers.length === 0) {
finalizeAssistantMessage(false);
}
// Otherwise, playAudioChunk will finalize after the last scheduled text
} else { } else {
appendAssistantText(msg.text); // Queue text — it will be displayed when corresponding audio starts playing
pendingTextChunks.push({ chunkId: msg.chunk_id, text: msg.text });
} }
break; break;
} }
@@ -113,9 +124,20 @@ function appendAssistantText(text) {
chatArea.scrollTop = chatArea.scrollHeight; chatArea.scrollTop = chatArea.scrollHeight;
} }
function finalizeAssistantMessage() { function finalizeAssistantMessage(interrupted = false) {
if (interrupted && currentAssistantEl && currentAssistantText) {
const marker = document.createElement("span");
marker.className = "interrupted-marker";
marker.textContent = " [interrupted]";
currentAssistantEl.appendChild(marker);
}
currentAssistantEl = null; currentAssistantEl = null;
currentAssistantText = ""; currentAssistantText = "";
// Reset sync state
pendingTextChunks = [];
for (const tid of scheduledTextTimers) clearTimeout(tid);
scheduledTextTimers = [];
lastDisplayedChunkId = -1;
} }
// --- Audio Playback --- // --- Audio Playback ---
@@ -146,18 +168,38 @@ function playAudioChunk(arrayBuffer) {
activeSources.push(source); activeSources.push(source);
isPlaying = true; isPlaying = true;
source.onended = () => {
activeSources = activeSources.filter((s) => s !== source); // Pair this audio chunk with the next queued text chunk
if (activeSources.length === 0) { const textEntry = pendingTextChunks.shift();
isPlaying = false;
bargeInCount = 0;
}
};
const now = ctx.currentTime; const now = ctx.currentTime;
if (nextPlayTime < now) { if (nextPlayTime < now) {
nextPlayTime = now + 0.01; nextPlayTime = now + 0.01;
} }
// Schedule text display to coincide with audio playback start
if (textEntry) {
const delayMs = Math.max(0, (nextPlayTime - now) * 1000);
const tid = setTimeout(() => {
appendAssistantText(textEntry.text);
lastDisplayedChunkId = textEntry.chunkId;
scheduledTextTimers = scheduledTextTimers.filter((t) => t !== tid);
}, delayMs);
scheduledTextTimers.push(tid);
}
source.onended = () => {
activeSources = activeSources.filter((s) => s !== source);
if (activeSources.length === 0) {
isPlaying = false;
bargeInCount = 0;
// If all audio has finished and no more text pending, finalize
if (pendingTextChunks.length === 0 && scheduledTextTimers.length === 0) {
finalizeAssistantMessage(false);
}
}
};
source.start(nextPlayTime); source.start(nextPlayTime);
nextPlayTime += buffer.duration; nextPlayTime += buffer.duration;
} }
@@ -172,6 +214,10 @@ function stopPlayback() {
nextPlayTime = 0; nextPlayTime = 0;
isPlaying = false; isPlaying = false;
bargeInCount = 0; bargeInCount = 0;
// Cancel any pending text displays
for (const tid of scheduledTextTimers) clearTimeout(tid);
scheduledTextTimers = [];
pendingTextChunks = [];
} }
// --- Microphone --- // --- Microphone ---
@@ -229,8 +275,12 @@ async function startMic() {
if (bargeInCount >= BARGE_IN_FRAMES) { if (bargeInCount >= BARGE_IN_FRAMES) {
// User is speaking over the assistant - interrupt // User is speaking over the assistant - interrupt
stopPlayback(); stopPlayback();
finalizeAssistantMessage(); const msg = { type: "interrupt" };
ws.send(JSON.stringify({ type: "interrupt" })); if (lastDisplayedChunkId >= 0) {
msg.last_chunk_id = lastDisplayedChunkId;
}
ws.send(JSON.stringify(msg));
finalizeAssistantMessage(true);
isPlaying = false; isPlaying = false;
bargeInCount = 0; bargeInCount = 0;
} }
+6
View File
@@ -84,6 +84,12 @@ header h1 {
border-bottom-left-radius: 4px; border-bottom-left-radius: 4px;
} }
.interrupted-marker {
color: #888;
font-style: italic;
font-size: 13px;
}
#controls { #controls {
padding: 16px 24px; padding: 16px 24px;
border-top: 1px solid #222; border-top: 1px solid #222;