From d509f92a9d19884097c83d6ed952069b3b24b6e1 Mon Sep 17 00:00:00 2001 From: Brian Date: Wed, 8 Apr 2026 11:40:59 -0400 Subject: [PATCH] barge-in changes --- config.yml | 7 ++- docker-compose.yml | 5 +- server/main.py | 4 +- server/pipeline.py | 112 +++++++++++++++++++++++++++++++++++++++------ static/app.js | 78 +++++++++++++++++++++++++------ static/style.css | 6 +++ 6 files changed, 179 insertions(+), 33 deletions(-) diff --git a/config.yml b/config.yml index ded8996..3d40f44 100644 --- a/config.yml +++ b/config.yml @@ -4,10 +4,9 @@ llm: max_cache_tokens: 4096 # max KV-cache size per session (tokens); 0 to disable caching system_prompt: >- - You are a helpful voice assistant. Keep your responses concise and natural - for spoken conversation. Respond in 1-3 short sentences. - Do not use markdown, bullet points, code blocks, emojis, or any - formatting that doesn't work in speech. + You are a helpful voice assistant. + Keep your responses extremely concise but natural for spoken conversation. + Do not use markdown, bullet points, code blocks, emojis, or any formatting that doesn't work in speech. # Settings used only when backend = "lmstudio" lmstudio: diff --git a/docker-compose.yml b/docker-compose.yml index 688e49e..83c9a94 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,8 +6,11 @@ services: volumes: # Cache models on the host so they survive container rebuilds - huggingface-cache:/cache/huggingface - # Mount config so you can edit backend settings without rebuilding the image + # Mount source so you can edit code/config without rebuilding the image - ./config.yml:/app/config.yml:ro + - ./server:/app/server:ro + - ./static:/app/static:ro + - ./run.py:/app/run.py:ro deploy: resources: reservations: diff --git a/server/main.py b/server/main.py index 6d3fdb3..d5ab84e 100644 --- a/server/main.py +++ b/server/main.py @@ -77,7 +77,9 @@ async def websocket_chat(ws: WebSocket): continue if msg.get("type") == "interrupt": - await session.interrupt() + await session.interrupt( + last_chunk_id=msg.get("last_chunk_id") + ) except WebSocketDisconnect: log.info("WebSocket client disconnected.") diff --git a/server/pipeline.py b/server/pipeline.py index 47fa2b9..38b23d3 100644 --- a/server/pipeline.py +++ b/server/pipeline.py @@ -1,6 +1,7 @@ import asyncio import logging import queue +import re import threading import numpy as np @@ -14,6 +15,56 @@ log = logging.getLogger(__name__) _SENTINEL = None +# Regex: split after sentence-ending punctuation followed by whitespace +_SENTENCE_RE = re.compile(r'(?<=[.!?])\s+') +# Regex: split after clause-level punctuation followed by whitespace +_CLAUSE_RE = re.compile(r'(?<=[,;:\u2014])\s+') + +MAX_SEGMENT_WORDS = 20 +MIN_SEGMENT_WORDS = 4 + + +def _split_into_segments(text: str) -> list[str]: + """Split text into small TTS-friendly segments for fine-grained streaming. + + Splits on sentence boundaries first, then breaks long sentences at clause + boundaries (commas, semicolons, colons, em-dashes). Avoids tiny fragments + by merging short pieces with their neighbours. + """ + sentences = _SENTENCE_RE.split(text.strip()) + segments: list[str] = [] + for sent in sentences: + if len(sent.split()) <= MAX_SEGMENT_WORDS: + segments.append(sent) + else: + # Split long sentences at clause boundaries + clauses = _CLAUSE_RE.split(sent) + current = "" + for clause in clauses: + combined = (current + " " + clause) if current else clause + if current and len(combined.split()) > MAX_SEGMENT_WORDS: + segments.append(current) + current = clause + else: + current = combined + if current: + segments.append(current) + + # Merge any tiny fragments into their neighbour + merged: list[str] = [] + for seg in segments: + if not seg.strip(): + continue + if merged and len(merged[-1].split()) < MIN_SEGMENT_WORDS: + merged[-1] = merged[-1] + " " + seg + else: + merged.append(seg) + # Also merge a trailing runt + if len(merged) > 1 and len(merged[-1].split()) < MIN_SEGMENT_WORDS: + merged[-2] = merged[-2] + " " + merged[-1] + merged.pop() + return merged + class ConversationSession: """Manages a single client's voice conversation pipeline. @@ -53,15 +104,17 @@ class ConversationSession: elif self.vad.is_speaking and self.is_responding: await self._interrupt() - async def interrupt(self): + async def interrupt(self, last_chunk_id: int | None = None): """Public interrupt method for WebSocket text messages.""" if self.is_responding: - await self._interrupt() + await self._interrupt(last_chunk_id=last_chunk_id) - async def _interrupt(self): + async def _interrupt(self, last_chunk_id: int | None = None): log.info("Barge-in: cancelling response.") self.cancel_event.set() self.is_responding = False + if last_chunk_id is not None: + self._last_played_chunk_id = last_chunk_id # Tell client to stop audio immediately try: await self.send_json({"type": "interrupt"}) @@ -105,16 +158,23 @@ class ConversationSession: # TTS - stream chunks with per-sentence text await self.send_json({"type": "status", "state": "speaking"}) chunk_queue = queue.Queue() + self._last_played_chunk_id = None + + segments = _split_into_segments(response) + log.info(f"TTS: split response into {len(segments)} segments") def _tts_worker(): try: - for graphemes, _ps, audio in self.models.tts_engine.pipeline( - response, voice=self.models.tts_engine.voice - ): + for segment in segments: if self.cancel_event.is_set(): break - if audio is not None and len(audio) > 0: - chunk_queue.put((graphemes, audio)) + for graphemes, _ps, audio in self.models.tts_engine.pipeline( + segment, voice=self.models.tts_engine.voice + ): + if self.cancel_event.is_set(): + break + if audio is not None and len(audio) > 0: + chunk_queue.put((graphemes, audio)) except Exception: log.exception("TTS generation error") finally: @@ -124,6 +184,9 @@ class ConversationSession: tts_thread.start() spoken_text = "" + chunk_id = 0 + # Maps chunk_id -> cumulative text up to and including that chunk + chunk_text_map: dict[int, str] = {} while True: try: item = await asyncio.to_thread(chunk_queue.get, timeout=10.0) @@ -137,8 +200,14 @@ class ConversationSession: sentence_text, audio = item spoken_text += sentence_text + chunk_text_map[chunk_id] = spoken_text - await self.send_json({"type": "response_text", "text": sentence_text, "final": False}) + await self.send_json({ + "type": "response_text", + "text": sentence_text, + "chunk_id": chunk_id, + "final": False, + }) pcm_bytes = float32_to_pcm_bytes(audio) try: await self.send_bytes(pcm_bytes) @@ -146,14 +215,26 @@ class ConversationSession: log.warning("Failed to send audio, client disconnected.") self.cancel_event.set() break + chunk_id += 1 tts_thread.join(timeout=2.0) - # Save only what was actually spoken + # Determine what was actually heard by the client was_interrupted = spoken_text.strip() != response.strip() - if spoken_text.strip(): + if was_interrupted and self._last_played_chunk_id is not None: + # Client told us the last chunk whose audio actually played + heard_text = chunk_text_map.get(self._last_played_chunk_id, "") + log.info(f"Interrupted: client heard up to chunk {self._last_played_chunk_id}") + else: + heard_text = spoken_text + + # Save only what was actually spoken/heard + if heard_text.strip(): + # Use original LLM response when fully spoken (keeps KV-cache valid); + # use heard_text only when interrupted. + final_content = heard_text.strip() if was_interrupted else response self.conversation_history.append( - {"role": "assistant", "content": spoken_text.strip()} + {"role": "assistant", "content": final_content} ) if was_interrupted and self.kv_cache_state is not None: self.kv_cache_state = self.models.llm_engine.trim_cache( @@ -164,7 +245,12 @@ class ConversationSession: self.kv_cache_state = None if not self.cancel_event.is_set(): - await self.send_json({"type": "response_text", "text": "", "final": True}) + await self.send_json({ + "type": "response_text", + "text": "", + "final": True, + "total_chunks": chunk_id, + }) self.is_responding = False try: diff --git a/static/app.js b/static/app.js index f234d8c..5d1a2b9 100644 --- a/static/app.js +++ b/static/app.js @@ -13,6 +13,11 @@ const BARGE_IN_THRESHOLD = 0.03; // RMS energy threshold for barge-in const BARGE_IN_FRAMES = 2; // Consecutive frames above threshold to trigger let bargeInCount = 0; +// --- Text-audio sync state --- +let pendingTextChunks = []; // [{chunkId, text}] - text waiting for its audio to arrive +let scheduledTextTimers = []; // timer IDs for text display scheduled to match audio playback +let lastDisplayedChunkId = -1; // last chunk whose text was actually shown to the user + const chatArea = document.getElementById("chat-area"); const statusBadge = document.getElementById("status-badge"); const micBtn = document.getElementById("mic-btn"); @@ -54,8 +59,8 @@ function handleJSON(msg) { case "interrupt": stopPlayback(); - // Trim the assistant message to what was spoken, then finalize - finalizeAssistantMessage(); + // Finalize with interrupted marker — text already reflects only what was heard + finalizeAssistantMessage(true); break; case "transcript": @@ -64,9 +69,15 @@ function handleJSON(msg) { case "response_text": if (msg.final) { - finalizeAssistantMessage(); + // All chunks sent; finalize will happen when last audio chunk plays + // (or immediately if nothing was queued) + if (pendingTextChunks.length === 0 && scheduledTextTimers.length === 0) { + finalizeAssistantMessage(false); + } + // Otherwise, playAudioChunk will finalize after the last scheduled text } else { - appendAssistantText(msg.text); + // Queue text — it will be displayed when corresponding audio starts playing + pendingTextChunks.push({ chunkId: msg.chunk_id, text: msg.text }); } break; } @@ -113,9 +124,20 @@ function appendAssistantText(text) { chatArea.scrollTop = chatArea.scrollHeight; } -function finalizeAssistantMessage() { +function finalizeAssistantMessage(interrupted = false) { + if (interrupted && currentAssistantEl && currentAssistantText) { + const marker = document.createElement("span"); + marker.className = "interrupted-marker"; + marker.textContent = " [interrupted]"; + currentAssistantEl.appendChild(marker); + } currentAssistantEl = null; currentAssistantText = ""; + // Reset sync state + pendingTextChunks = []; + for (const tid of scheduledTextTimers) clearTimeout(tid); + scheduledTextTimers = []; + lastDisplayedChunkId = -1; } // --- Audio Playback --- @@ -146,18 +168,38 @@ function playAudioChunk(arrayBuffer) { activeSources.push(source); isPlaying = true; - source.onended = () => { - activeSources = activeSources.filter((s) => s !== source); - if (activeSources.length === 0) { - isPlaying = false; - bargeInCount = 0; - } - }; + + // Pair this audio chunk with the next queued text chunk + const textEntry = pendingTextChunks.shift(); const now = ctx.currentTime; if (nextPlayTime < now) { nextPlayTime = now + 0.01; } + + // Schedule text display to coincide with audio playback start + if (textEntry) { + const delayMs = Math.max(0, (nextPlayTime - now) * 1000); + const tid = setTimeout(() => { + appendAssistantText(textEntry.text); + lastDisplayedChunkId = textEntry.chunkId; + scheduledTextTimers = scheduledTextTimers.filter((t) => t !== tid); + }, delayMs); + scheduledTextTimers.push(tid); + } + + source.onended = () => { + activeSources = activeSources.filter((s) => s !== source); + if (activeSources.length === 0) { + isPlaying = false; + bargeInCount = 0; + // If all audio has finished and no more text pending, finalize + if (pendingTextChunks.length === 0 && scheduledTextTimers.length === 0) { + finalizeAssistantMessage(false); + } + } + }; + source.start(nextPlayTime); nextPlayTime += buffer.duration; } @@ -172,6 +214,10 @@ function stopPlayback() { nextPlayTime = 0; isPlaying = false; bargeInCount = 0; + // Cancel any pending text displays + for (const tid of scheduledTextTimers) clearTimeout(tid); + scheduledTextTimers = []; + pendingTextChunks = []; } // --- Microphone --- @@ -229,8 +275,12 @@ async function startMic() { if (bargeInCount >= BARGE_IN_FRAMES) { // User is speaking over the assistant - interrupt stopPlayback(); - finalizeAssistantMessage(); - ws.send(JSON.stringify({ type: "interrupt" })); + const msg = { type: "interrupt" }; + if (lastDisplayedChunkId >= 0) { + msg.last_chunk_id = lastDisplayedChunkId; + } + ws.send(JSON.stringify(msg)); + finalizeAssistantMessage(true); isPlaying = false; bargeInCount = 0; } diff --git a/static/style.css b/static/style.css index dde72f1..bbb5301 100644 --- a/static/style.css +++ b/static/style.css @@ -84,6 +84,12 @@ header h1 { border-bottom-left-radius: 4px; } +.interrupted-marker { + color: #888; + font-style: italic; + font-size: 13px; +} + #controls { padding: 16px 24px; border-top: 1px solid #222;