barge-in changes

This commit is contained in:
2026-04-08 11:40:59 -04:00
parent 175ed943df
commit d509f92a9d
6 changed files with 179 additions and 33 deletions
+3 -4
View File
@@ -4,10 +4,9 @@ llm:
max_cache_tokens: 4096 # max KV-cache size per session (tokens); 0 to disable caching max_cache_tokens: 4096 # max KV-cache size per session (tokens); 0 to disable caching
system_prompt: >- system_prompt: >-
You are a helpful voice assistant. Keep your responses concise and natural You are a helpful voice assistant.
for spoken conversation. Respond in 1-3 short sentences. Keep your responses extremely concise but natural for spoken conversation.
Do not use markdown, bullet points, code blocks, emojis, or any Do not use markdown, bullet points, code blocks, emojis, or any formatting that doesn't work in speech.
formatting that doesn't work in speech.
# Settings used only when backend = "lmstudio" # Settings used only when backend = "lmstudio"
lmstudio: lmstudio:
+4 -1
View File
@@ -6,8 +6,11 @@ services:
volumes: volumes:
# Cache models on the host so they survive container rebuilds # Cache models on the host so they survive container rebuilds
- huggingface-cache:/cache/huggingface - huggingface-cache:/cache/huggingface
# Mount config so you can edit backend settings without rebuilding the image # Mount source so you can edit code/config without rebuilding the image
- ./config.yml:/app/config.yml:ro - ./config.yml:/app/config.yml:ro
- ./server:/app/server:ro
- ./static:/app/static:ro
- ./run.py:/app/run.py:ro
deploy: deploy:
resources: resources:
reservations: reservations:
+3 -1
View File
@@ -77,7 +77,9 @@ async def websocket_chat(ws: WebSocket):
continue continue
if msg.get("type") == "interrupt": if msg.get("type") == "interrupt":
await session.interrupt() await session.interrupt(
last_chunk_id=msg.get("last_chunk_id")
)
except WebSocketDisconnect: except WebSocketDisconnect:
log.info("WebSocket client disconnected.") log.info("WebSocket client disconnected.")
+95 -9
View File
@@ -1,6 +1,7 @@
import asyncio import asyncio
import logging import logging
import queue import queue
import re
import threading import threading
import numpy as np import numpy as np
@@ -14,6 +15,56 @@ log = logging.getLogger(__name__)
_SENTINEL = None _SENTINEL = None
# Regex: split after sentence-ending punctuation followed by whitespace
_SENTENCE_RE = re.compile(r'(?<=[.!?])\s+')
# Regex: split after clause-level punctuation followed by whitespace
_CLAUSE_RE = re.compile(r'(?<=[,;:\u2014])\s+')
MAX_SEGMENT_WORDS = 20
MIN_SEGMENT_WORDS = 4
def _split_into_segments(text: str) -> list[str]:
"""Split text into small TTS-friendly segments for fine-grained streaming.
Splits on sentence boundaries first, then breaks long sentences at clause
boundaries (commas, semicolons, colons, em-dashes). Avoids tiny fragments
by merging short pieces with their neighbours.
"""
sentences = _SENTENCE_RE.split(text.strip())
segments: list[str] = []
for sent in sentences:
if len(sent.split()) <= MAX_SEGMENT_WORDS:
segments.append(sent)
else:
# Split long sentences at clause boundaries
clauses = _CLAUSE_RE.split(sent)
current = ""
for clause in clauses:
combined = (current + " " + clause) if current else clause
if current and len(combined.split()) > MAX_SEGMENT_WORDS:
segments.append(current)
current = clause
else:
current = combined
if current:
segments.append(current)
# Merge any tiny fragments into their neighbour
merged: list[str] = []
for seg in segments:
if not seg.strip():
continue
if merged and len(merged[-1].split()) < MIN_SEGMENT_WORDS:
merged[-1] = merged[-1] + " " + seg
else:
merged.append(seg)
# Also merge a trailing runt
if len(merged) > 1 and len(merged[-1].split()) < MIN_SEGMENT_WORDS:
merged[-2] = merged[-2] + " " + merged[-1]
merged.pop()
return merged
class ConversationSession: class ConversationSession:
"""Manages a single client's voice conversation pipeline. """Manages a single client's voice conversation pipeline.
@@ -53,15 +104,17 @@ class ConversationSession:
elif self.vad.is_speaking and self.is_responding: elif self.vad.is_speaking and self.is_responding:
await self._interrupt() await self._interrupt()
async def interrupt(self): async def interrupt(self, last_chunk_id: int | None = None):
"""Public interrupt method for WebSocket text messages.""" """Public interrupt method for WebSocket text messages."""
if self.is_responding: if self.is_responding:
await self._interrupt() await self._interrupt(last_chunk_id=last_chunk_id)
async def _interrupt(self): async def _interrupt(self, last_chunk_id: int | None = None):
log.info("Barge-in: cancelling response.") log.info("Barge-in: cancelling response.")
self.cancel_event.set() self.cancel_event.set()
self.is_responding = False self.is_responding = False
if last_chunk_id is not None:
self._last_played_chunk_id = last_chunk_id
# Tell client to stop audio immediately # Tell client to stop audio immediately
try: try:
await self.send_json({"type": "interrupt"}) await self.send_json({"type": "interrupt"})
@@ -105,11 +158,18 @@ class ConversationSession:
# TTS - stream chunks with per-sentence text # TTS - stream chunks with per-sentence text
await self.send_json({"type": "status", "state": "speaking"}) await self.send_json({"type": "status", "state": "speaking"})
chunk_queue = queue.Queue() chunk_queue = queue.Queue()
self._last_played_chunk_id = None
segments = _split_into_segments(response)
log.info(f"TTS: split response into {len(segments)} segments")
def _tts_worker(): def _tts_worker():
try: try:
for segment in segments:
if self.cancel_event.is_set():
break
for graphemes, _ps, audio in self.models.tts_engine.pipeline( for graphemes, _ps, audio in self.models.tts_engine.pipeline(
response, voice=self.models.tts_engine.voice segment, voice=self.models.tts_engine.voice
): ):
if self.cancel_event.is_set(): if self.cancel_event.is_set():
break break
@@ -124,6 +184,9 @@ class ConversationSession:
tts_thread.start() tts_thread.start()
spoken_text = "" spoken_text = ""
chunk_id = 0
# Maps chunk_id -> cumulative text up to and including that chunk
chunk_text_map: dict[int, str] = {}
while True: while True:
try: try:
item = await asyncio.to_thread(chunk_queue.get, timeout=10.0) item = await asyncio.to_thread(chunk_queue.get, timeout=10.0)
@@ -137,8 +200,14 @@ class ConversationSession:
sentence_text, audio = item sentence_text, audio = item
spoken_text += sentence_text spoken_text += sentence_text
chunk_text_map[chunk_id] = spoken_text
await self.send_json({"type": "response_text", "text": sentence_text, "final": False}) await self.send_json({
"type": "response_text",
"text": sentence_text,
"chunk_id": chunk_id,
"final": False,
})
pcm_bytes = float32_to_pcm_bytes(audio) pcm_bytes = float32_to_pcm_bytes(audio)
try: try:
await self.send_bytes(pcm_bytes) await self.send_bytes(pcm_bytes)
@@ -146,14 +215,26 @@ class ConversationSession:
log.warning("Failed to send audio, client disconnected.") log.warning("Failed to send audio, client disconnected.")
self.cancel_event.set() self.cancel_event.set()
break break
chunk_id += 1
tts_thread.join(timeout=2.0) tts_thread.join(timeout=2.0)
# Save only what was actually spoken # Determine what was actually heard by the client
was_interrupted = spoken_text.strip() != response.strip() was_interrupted = spoken_text.strip() != response.strip()
if spoken_text.strip(): if was_interrupted and self._last_played_chunk_id is not None:
# Client told us the last chunk whose audio actually played
heard_text = chunk_text_map.get(self._last_played_chunk_id, "")
log.info(f"Interrupted: client heard up to chunk {self._last_played_chunk_id}")
else:
heard_text = spoken_text
# Save only what was actually spoken/heard
if heard_text.strip():
# Use original LLM response when fully spoken (keeps KV-cache valid);
# use heard_text only when interrupted.
final_content = heard_text.strip() if was_interrupted else response
self.conversation_history.append( self.conversation_history.append(
{"role": "assistant", "content": spoken_text.strip()} {"role": "assistant", "content": final_content}
) )
if was_interrupted and self.kv_cache_state is not None: if was_interrupted and self.kv_cache_state is not None:
self.kv_cache_state = self.models.llm_engine.trim_cache( self.kv_cache_state = self.models.llm_engine.trim_cache(
@@ -164,7 +245,12 @@ class ConversationSession:
self.kv_cache_state = None self.kv_cache_state = None
if not self.cancel_event.is_set(): if not self.cancel_event.is_set():
await self.send_json({"type": "response_text", "text": "", "final": True}) await self.send_json({
"type": "response_text",
"text": "",
"final": True,
"total_chunks": chunk_id,
})
self.is_responding = False self.is_responding = False
try: try:
+64 -14
View File
@@ -13,6 +13,11 @@ const BARGE_IN_THRESHOLD = 0.03; // RMS energy threshold for barge-in
const BARGE_IN_FRAMES = 2; // Consecutive frames above threshold to trigger const BARGE_IN_FRAMES = 2; // Consecutive frames above threshold to trigger
let bargeInCount = 0; let bargeInCount = 0;
// --- Text-audio sync state ---
let pendingTextChunks = []; // [{chunkId, text}] - text waiting for its audio to arrive
let scheduledTextTimers = []; // timer IDs for text display scheduled to match audio playback
let lastDisplayedChunkId = -1; // last chunk whose text was actually shown to the user
const chatArea = document.getElementById("chat-area"); const chatArea = document.getElementById("chat-area");
const statusBadge = document.getElementById("status-badge"); const statusBadge = document.getElementById("status-badge");
const micBtn = document.getElementById("mic-btn"); const micBtn = document.getElementById("mic-btn");
@@ -54,8 +59,8 @@ function handleJSON(msg) {
case "interrupt": case "interrupt":
stopPlayback(); stopPlayback();
// Trim the assistant message to what was spoken, then finalize // Finalize with interrupted marker — text already reflects only what was heard
finalizeAssistantMessage(); finalizeAssistantMessage(true);
break; break;
case "transcript": case "transcript":
@@ -64,9 +69,15 @@ function handleJSON(msg) {
case "response_text": case "response_text":
if (msg.final) { if (msg.final) {
finalizeAssistantMessage(); // All chunks sent; finalize will happen when last audio chunk plays
// (or immediately if nothing was queued)
if (pendingTextChunks.length === 0 && scheduledTextTimers.length === 0) {
finalizeAssistantMessage(false);
}
// Otherwise, playAudioChunk will finalize after the last scheduled text
} else { } else {
appendAssistantText(msg.text); // Queue text — it will be displayed when corresponding audio starts playing
pendingTextChunks.push({ chunkId: msg.chunk_id, text: msg.text });
} }
break; break;
} }
@@ -113,9 +124,20 @@ function appendAssistantText(text) {
chatArea.scrollTop = chatArea.scrollHeight; chatArea.scrollTop = chatArea.scrollHeight;
} }
function finalizeAssistantMessage() { function finalizeAssistantMessage(interrupted = false) {
if (interrupted && currentAssistantEl && currentAssistantText) {
const marker = document.createElement("span");
marker.className = "interrupted-marker";
marker.textContent = " [interrupted]";
currentAssistantEl.appendChild(marker);
}
currentAssistantEl = null; currentAssistantEl = null;
currentAssistantText = ""; currentAssistantText = "";
// Reset sync state
pendingTextChunks = [];
for (const tid of scheduledTextTimers) clearTimeout(tid);
scheduledTextTimers = [];
lastDisplayedChunkId = -1;
} }
// --- Audio Playback --- // --- Audio Playback ---
@@ -146,18 +168,38 @@ function playAudioChunk(arrayBuffer) {
activeSources.push(source); activeSources.push(source);
isPlaying = true; isPlaying = true;
source.onended = () => {
activeSources = activeSources.filter((s) => s !== source); // Pair this audio chunk with the next queued text chunk
if (activeSources.length === 0) { const textEntry = pendingTextChunks.shift();
isPlaying = false;
bargeInCount = 0;
}
};
const now = ctx.currentTime; const now = ctx.currentTime;
if (nextPlayTime < now) { if (nextPlayTime < now) {
nextPlayTime = now + 0.01; nextPlayTime = now + 0.01;
} }
// Schedule text display to coincide with audio playback start
if (textEntry) {
const delayMs = Math.max(0, (nextPlayTime - now) * 1000);
const tid = setTimeout(() => {
appendAssistantText(textEntry.text);
lastDisplayedChunkId = textEntry.chunkId;
scheduledTextTimers = scheduledTextTimers.filter((t) => t !== tid);
}, delayMs);
scheduledTextTimers.push(tid);
}
source.onended = () => {
activeSources = activeSources.filter((s) => s !== source);
if (activeSources.length === 0) {
isPlaying = false;
bargeInCount = 0;
// If all audio has finished and no more text pending, finalize
if (pendingTextChunks.length === 0 && scheduledTextTimers.length === 0) {
finalizeAssistantMessage(false);
}
}
};
source.start(nextPlayTime); source.start(nextPlayTime);
nextPlayTime += buffer.duration; nextPlayTime += buffer.duration;
} }
@@ -172,6 +214,10 @@ function stopPlayback() {
nextPlayTime = 0; nextPlayTime = 0;
isPlaying = false; isPlaying = false;
bargeInCount = 0; bargeInCount = 0;
// Cancel any pending text displays
for (const tid of scheduledTextTimers) clearTimeout(tid);
scheduledTextTimers = [];
pendingTextChunks = [];
} }
// --- Microphone --- // --- Microphone ---
@@ -229,8 +275,12 @@ async function startMic() {
if (bargeInCount >= BARGE_IN_FRAMES) { if (bargeInCount >= BARGE_IN_FRAMES) {
// User is speaking over the assistant - interrupt // User is speaking over the assistant - interrupt
stopPlayback(); stopPlayback();
finalizeAssistantMessage(); const msg = { type: "interrupt" };
ws.send(JSON.stringify({ type: "interrupt" })); if (lastDisplayedChunkId >= 0) {
msg.last_chunk_id = lastDisplayedChunkId;
}
ws.send(JSON.stringify(msg));
finalizeAssistantMessage(true);
isPlaying = false; isPlaying = false;
bargeInCount = 0; bargeInCount = 0;
} }
+6
View File
@@ -84,6 +84,12 @@ header h1 {
border-bottom-left-radius: 4px; border-bottom-left-radius: 4px;
} }
.interrupted-marker {
color: #888;
font-style: italic;
font-size: 13px;
}
#controls { #controls {
padding: 16px 24px; padding: 16px 24px;
border-top: 1px solid #222; border-top: 1px solid #222;