barge-in changes

This commit is contained in:
2026-04-08 11:40:59 -04:00
parent 175ed943df
commit d509f92a9d
6 changed files with 179 additions and 33 deletions
+3 -4
View File
@@ -4,10 +4,9 @@ llm:
max_cache_tokens: 4096 # max KV-cache size per session (tokens); 0 to disable caching
system_prompt: >-
You are a helpful voice assistant. Keep your responses concise and natural
for spoken conversation. Respond in 1-3 short sentences.
Do not use markdown, bullet points, code blocks, emojis, or any
formatting that doesn't work in speech.
You are a helpful voice assistant.
Keep your responses extremely concise but natural for spoken conversation.
Do not use markdown, bullet points, code blocks, emojis, or any formatting that doesn't work in speech.
# Settings used only when backend = "lmstudio"
lmstudio:
+4 -1
View File
@@ -6,8 +6,11 @@ services:
volumes:
# Cache models on the host so they survive container rebuilds
- huggingface-cache:/cache/huggingface
# Mount config so you can edit backend settings without rebuilding the image
# Mount source so you can edit code/config without rebuilding the image
- ./config.yml:/app/config.yml:ro
- ./server:/app/server:ro
- ./static:/app/static:ro
- ./run.py:/app/run.py:ro
deploy:
resources:
reservations:
+3 -1
View File
@@ -77,7 +77,9 @@ async def websocket_chat(ws: WebSocket):
continue
if msg.get("type") == "interrupt":
await session.interrupt()
await session.interrupt(
last_chunk_id=msg.get("last_chunk_id")
)
except WebSocketDisconnect:
log.info("WebSocket client disconnected.")
+95 -9
View File
@@ -1,6 +1,7 @@
import asyncio
import logging
import queue
import re
import threading
import numpy as np
@@ -14,6 +15,56 @@ log = logging.getLogger(__name__)
_SENTINEL = None
# Regex: split after sentence-ending punctuation followed by whitespace
_SENTENCE_RE = re.compile(r'(?<=[.!?])\s+')
# Regex: split after clause-level punctuation followed by whitespace
_CLAUSE_RE = re.compile(r'(?<=[,;:\u2014])\s+')
MAX_SEGMENT_WORDS = 20
MIN_SEGMENT_WORDS = 4
def _split_into_segments(text: str) -> list[str]:
"""Split text into small TTS-friendly segments for fine-grained streaming.
Splits on sentence boundaries first, then breaks long sentences at clause
boundaries (commas, semicolons, colons, em-dashes). Avoids tiny fragments
by merging short pieces with their neighbours.
"""
sentences = _SENTENCE_RE.split(text.strip())
segments: list[str] = []
for sent in sentences:
if len(sent.split()) <= MAX_SEGMENT_WORDS:
segments.append(sent)
else:
# Split long sentences at clause boundaries
clauses = _CLAUSE_RE.split(sent)
current = ""
for clause in clauses:
combined = (current + " " + clause) if current else clause
if current and len(combined.split()) > MAX_SEGMENT_WORDS:
segments.append(current)
current = clause
else:
current = combined
if current:
segments.append(current)
# Merge any tiny fragments into their neighbour
merged: list[str] = []
for seg in segments:
if not seg.strip():
continue
if merged and len(merged[-1].split()) < MIN_SEGMENT_WORDS:
merged[-1] = merged[-1] + " " + seg
else:
merged.append(seg)
# Also merge a trailing runt
if len(merged) > 1 and len(merged[-1].split()) < MIN_SEGMENT_WORDS:
merged[-2] = merged[-2] + " " + merged[-1]
merged.pop()
return merged
class ConversationSession:
"""Manages a single client's voice conversation pipeline.
@@ -53,15 +104,17 @@ class ConversationSession:
elif self.vad.is_speaking and self.is_responding:
await self._interrupt()
async def interrupt(self):
async def interrupt(self, last_chunk_id: int | None = None):
"""Public interrupt method for WebSocket text messages."""
if self.is_responding:
await self._interrupt()
await self._interrupt(last_chunk_id=last_chunk_id)
async def _interrupt(self):
async def _interrupt(self, last_chunk_id: int | None = None):
log.info("Barge-in: cancelling response.")
self.cancel_event.set()
self.is_responding = False
if last_chunk_id is not None:
self._last_played_chunk_id = last_chunk_id
# Tell client to stop audio immediately
try:
await self.send_json({"type": "interrupt"})
@@ -105,11 +158,18 @@ class ConversationSession:
# TTS - stream chunks with per-sentence text
await self.send_json({"type": "status", "state": "speaking"})
chunk_queue = queue.Queue()
self._last_played_chunk_id = None
segments = _split_into_segments(response)
log.info(f"TTS: split response into {len(segments)} segments")
def _tts_worker():
try:
for segment in segments:
if self.cancel_event.is_set():
break
for graphemes, _ps, audio in self.models.tts_engine.pipeline(
response, voice=self.models.tts_engine.voice
segment, voice=self.models.tts_engine.voice
):
if self.cancel_event.is_set():
break
@@ -124,6 +184,9 @@ class ConversationSession:
tts_thread.start()
spoken_text = ""
chunk_id = 0
# Maps chunk_id -> cumulative text up to and including that chunk
chunk_text_map: dict[int, str] = {}
while True:
try:
item = await asyncio.to_thread(chunk_queue.get, timeout=10.0)
@@ -137,8 +200,14 @@ class ConversationSession:
sentence_text, audio = item
spoken_text += sentence_text
chunk_text_map[chunk_id] = spoken_text
await self.send_json({"type": "response_text", "text": sentence_text, "final": False})
await self.send_json({
"type": "response_text",
"text": sentence_text,
"chunk_id": chunk_id,
"final": False,
})
pcm_bytes = float32_to_pcm_bytes(audio)
try:
await self.send_bytes(pcm_bytes)
@@ -146,14 +215,26 @@ class ConversationSession:
log.warning("Failed to send audio, client disconnected.")
self.cancel_event.set()
break
chunk_id += 1
tts_thread.join(timeout=2.0)
# Save only what was actually spoken
# Determine what was actually heard by the client
was_interrupted = spoken_text.strip() != response.strip()
if spoken_text.strip():
if was_interrupted and self._last_played_chunk_id is not None:
# Client told us the last chunk whose audio actually played
heard_text = chunk_text_map.get(self._last_played_chunk_id, "")
log.info(f"Interrupted: client heard up to chunk {self._last_played_chunk_id}")
else:
heard_text = spoken_text
# Save only what was actually spoken/heard
if heard_text.strip():
# Use original LLM response when fully spoken (keeps KV-cache valid);
# use heard_text only when interrupted.
final_content = heard_text.strip() if was_interrupted else response
self.conversation_history.append(
{"role": "assistant", "content": spoken_text.strip()}
{"role": "assistant", "content": final_content}
)
if was_interrupted and self.kv_cache_state is not None:
self.kv_cache_state = self.models.llm_engine.trim_cache(
@@ -164,7 +245,12 @@ class ConversationSession:
self.kv_cache_state = None
if not self.cancel_event.is_set():
await self.send_json({"type": "response_text", "text": "", "final": True})
await self.send_json({
"type": "response_text",
"text": "",
"final": True,
"total_chunks": chunk_id,
})
self.is_responding = False
try:
+64 -14
View File
@@ -13,6 +13,11 @@ const BARGE_IN_THRESHOLD = 0.03; // RMS energy threshold for barge-in
const BARGE_IN_FRAMES = 2; // Consecutive frames above threshold to trigger
let bargeInCount = 0;
// --- Text-audio sync state ---
let pendingTextChunks = []; // [{chunkId, text}] - text waiting for its audio to arrive
let scheduledTextTimers = []; // timer IDs for text display scheduled to match audio playback
let lastDisplayedChunkId = -1; // last chunk whose text was actually shown to the user
const chatArea = document.getElementById("chat-area");
const statusBadge = document.getElementById("status-badge");
const micBtn = document.getElementById("mic-btn");
@@ -54,8 +59,8 @@ function handleJSON(msg) {
case "interrupt":
stopPlayback();
// Trim the assistant message to what was spoken, then finalize
finalizeAssistantMessage();
// Finalize with interrupted marker — text already reflects only what was heard
finalizeAssistantMessage(true);
break;
case "transcript":
@@ -64,9 +69,15 @@ function handleJSON(msg) {
case "response_text":
if (msg.final) {
finalizeAssistantMessage();
// All chunks sent; finalize will happen when last audio chunk plays
// (or immediately if nothing was queued)
if (pendingTextChunks.length === 0 && scheduledTextTimers.length === 0) {
finalizeAssistantMessage(false);
}
// Otherwise, playAudioChunk will finalize after the last scheduled text
} else {
appendAssistantText(msg.text);
// Queue text — it will be displayed when corresponding audio starts playing
pendingTextChunks.push({ chunkId: msg.chunk_id, text: msg.text });
}
break;
}
@@ -113,9 +124,20 @@ function appendAssistantText(text) {
chatArea.scrollTop = chatArea.scrollHeight;
}
function finalizeAssistantMessage() {
function finalizeAssistantMessage(interrupted = false) {
if (interrupted && currentAssistantEl && currentAssistantText) {
const marker = document.createElement("span");
marker.className = "interrupted-marker";
marker.textContent = " [interrupted]";
currentAssistantEl.appendChild(marker);
}
currentAssistantEl = null;
currentAssistantText = "";
// Reset sync state
pendingTextChunks = [];
for (const tid of scheduledTextTimers) clearTimeout(tid);
scheduledTextTimers = [];
lastDisplayedChunkId = -1;
}
// --- Audio Playback ---
@@ -146,18 +168,38 @@ function playAudioChunk(arrayBuffer) {
activeSources.push(source);
isPlaying = true;
source.onended = () => {
activeSources = activeSources.filter((s) => s !== source);
if (activeSources.length === 0) {
isPlaying = false;
bargeInCount = 0;
}
};
// Pair this audio chunk with the next queued text chunk
const textEntry = pendingTextChunks.shift();
const now = ctx.currentTime;
if (nextPlayTime < now) {
nextPlayTime = now + 0.01;
}
// Schedule text display to coincide with audio playback start
if (textEntry) {
const delayMs = Math.max(0, (nextPlayTime - now) * 1000);
const tid = setTimeout(() => {
appendAssistantText(textEntry.text);
lastDisplayedChunkId = textEntry.chunkId;
scheduledTextTimers = scheduledTextTimers.filter((t) => t !== tid);
}, delayMs);
scheduledTextTimers.push(tid);
}
source.onended = () => {
activeSources = activeSources.filter((s) => s !== source);
if (activeSources.length === 0) {
isPlaying = false;
bargeInCount = 0;
// If all audio has finished and no more text pending, finalize
if (pendingTextChunks.length === 0 && scheduledTextTimers.length === 0) {
finalizeAssistantMessage(false);
}
}
};
source.start(nextPlayTime);
nextPlayTime += buffer.duration;
}
@@ -172,6 +214,10 @@ function stopPlayback() {
nextPlayTime = 0;
isPlaying = false;
bargeInCount = 0;
// Cancel any pending text displays
for (const tid of scheduledTextTimers) clearTimeout(tid);
scheduledTextTimers = [];
pendingTextChunks = [];
}
// --- Microphone ---
@@ -229,8 +275,12 @@ async function startMic() {
if (bargeInCount >= BARGE_IN_FRAMES) {
// User is speaking over the assistant - interrupt
stopPlayback();
finalizeAssistantMessage();
ws.send(JSON.stringify({ type: "interrupt" }));
const msg = { type: "interrupt" };
if (lastDisplayedChunkId >= 0) {
msg.last_chunk_id = lastDisplayedChunkId;
}
ws.send(JSON.stringify(msg));
finalizeAssistantMessage(true);
isPlaying = false;
bargeInCount = 0;
}
+6
View File
@@ -84,6 +84,12 @@ header h1 {
border-bottom-left-radius: 4px;
}
.interrupted-marker {
color: #888;
font-style: italic;
font-size: 13px;
}
#controls {
padding: 16px 24px;
border-top: 1px solid #222;