barge-in changes
This commit is contained in:
+3
-4
@@ -4,10 +4,9 @@ llm:
|
||||
max_cache_tokens: 4096 # max KV-cache size per session (tokens); 0 to disable caching
|
||||
|
||||
system_prompt: >-
|
||||
You are a helpful voice assistant. Keep your responses concise and natural
|
||||
for spoken conversation. Respond in 1-3 short sentences.
|
||||
Do not use markdown, bullet points, code blocks, emojis, or any
|
||||
formatting that doesn't work in speech.
|
||||
You are a helpful voice assistant.
|
||||
Keep your responses extremely concise but natural for spoken conversation.
|
||||
Do not use markdown, bullet points, code blocks, emojis, or any formatting that doesn't work in speech.
|
||||
|
||||
# Settings used only when backend = "lmstudio"
|
||||
lmstudio:
|
||||
|
||||
+4
-1
@@ -6,8 +6,11 @@ services:
|
||||
volumes:
|
||||
# Cache models on the host so they survive container rebuilds
|
||||
- huggingface-cache:/cache/huggingface
|
||||
# Mount config so you can edit backend settings without rebuilding the image
|
||||
# Mount source so you can edit code/config without rebuilding the image
|
||||
- ./config.yml:/app/config.yml:ro
|
||||
- ./server:/app/server:ro
|
||||
- ./static:/app/static:ro
|
||||
- ./run.py:/app/run.py:ro
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
|
||||
+3
-1
@@ -77,7 +77,9 @@ async def websocket_chat(ws: WebSocket):
|
||||
continue
|
||||
|
||||
if msg.get("type") == "interrupt":
|
||||
await session.interrupt()
|
||||
await session.interrupt(
|
||||
last_chunk_id=msg.get("last_chunk_id")
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
log.info("WebSocket client disconnected.")
|
||||
|
||||
+99
-13
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import queue
|
||||
import re
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
@@ -14,6 +15,56 @@ log = logging.getLogger(__name__)
|
||||
|
||||
_SENTINEL = None
|
||||
|
||||
# Regex: split after sentence-ending punctuation followed by whitespace
|
||||
_SENTENCE_RE = re.compile(r'(?<=[.!?])\s+')
|
||||
# Regex: split after clause-level punctuation followed by whitespace
|
||||
_CLAUSE_RE = re.compile(r'(?<=[,;:\u2014])\s+')
|
||||
|
||||
MAX_SEGMENT_WORDS = 20
|
||||
MIN_SEGMENT_WORDS = 4
|
||||
|
||||
|
||||
def _split_into_segments(text: str) -> list[str]:
|
||||
"""Split text into small TTS-friendly segments for fine-grained streaming.
|
||||
|
||||
Splits on sentence boundaries first, then breaks long sentences at clause
|
||||
boundaries (commas, semicolons, colons, em-dashes). Avoids tiny fragments
|
||||
by merging short pieces with their neighbours.
|
||||
"""
|
||||
sentences = _SENTENCE_RE.split(text.strip())
|
||||
segments: list[str] = []
|
||||
for sent in sentences:
|
||||
if len(sent.split()) <= MAX_SEGMENT_WORDS:
|
||||
segments.append(sent)
|
||||
else:
|
||||
# Split long sentences at clause boundaries
|
||||
clauses = _CLAUSE_RE.split(sent)
|
||||
current = ""
|
||||
for clause in clauses:
|
||||
combined = (current + " " + clause) if current else clause
|
||||
if current and len(combined.split()) > MAX_SEGMENT_WORDS:
|
||||
segments.append(current)
|
||||
current = clause
|
||||
else:
|
||||
current = combined
|
||||
if current:
|
||||
segments.append(current)
|
||||
|
||||
# Merge any tiny fragments into their neighbour
|
||||
merged: list[str] = []
|
||||
for seg in segments:
|
||||
if not seg.strip():
|
||||
continue
|
||||
if merged and len(merged[-1].split()) < MIN_SEGMENT_WORDS:
|
||||
merged[-1] = merged[-1] + " " + seg
|
||||
else:
|
||||
merged.append(seg)
|
||||
# Also merge a trailing runt
|
||||
if len(merged) > 1 and len(merged[-1].split()) < MIN_SEGMENT_WORDS:
|
||||
merged[-2] = merged[-2] + " " + merged[-1]
|
||||
merged.pop()
|
||||
return merged
|
||||
|
||||
|
||||
class ConversationSession:
|
||||
"""Manages a single client's voice conversation pipeline.
|
||||
@@ -53,15 +104,17 @@ class ConversationSession:
|
||||
elif self.vad.is_speaking and self.is_responding:
|
||||
await self._interrupt()
|
||||
|
||||
async def interrupt(self):
|
||||
async def interrupt(self, last_chunk_id: int | None = None):
|
||||
"""Public interrupt method for WebSocket text messages."""
|
||||
if self.is_responding:
|
||||
await self._interrupt()
|
||||
await self._interrupt(last_chunk_id=last_chunk_id)
|
||||
|
||||
async def _interrupt(self):
|
||||
async def _interrupt(self, last_chunk_id: int | None = None):
|
||||
log.info("Barge-in: cancelling response.")
|
||||
self.cancel_event.set()
|
||||
self.is_responding = False
|
||||
if last_chunk_id is not None:
|
||||
self._last_played_chunk_id = last_chunk_id
|
||||
# Tell client to stop audio immediately
|
||||
try:
|
||||
await self.send_json({"type": "interrupt"})
|
||||
@@ -105,16 +158,23 @@ class ConversationSession:
|
||||
# TTS - stream chunks with per-sentence text
|
||||
await self.send_json({"type": "status", "state": "speaking"})
|
||||
chunk_queue = queue.Queue()
|
||||
self._last_played_chunk_id = None
|
||||
|
||||
segments = _split_into_segments(response)
|
||||
log.info(f"TTS: split response into {len(segments)} segments")
|
||||
|
||||
def _tts_worker():
|
||||
try:
|
||||
for graphemes, _ps, audio in self.models.tts_engine.pipeline(
|
||||
response, voice=self.models.tts_engine.voice
|
||||
):
|
||||
for segment in segments:
|
||||
if self.cancel_event.is_set():
|
||||
break
|
||||
if audio is not None and len(audio) > 0:
|
||||
chunk_queue.put((graphemes, audio))
|
||||
for graphemes, _ps, audio in self.models.tts_engine.pipeline(
|
||||
segment, voice=self.models.tts_engine.voice
|
||||
):
|
||||
if self.cancel_event.is_set():
|
||||
break
|
||||
if audio is not None and len(audio) > 0:
|
||||
chunk_queue.put((graphemes, audio))
|
||||
except Exception:
|
||||
log.exception("TTS generation error")
|
||||
finally:
|
||||
@@ -124,6 +184,9 @@ class ConversationSession:
|
||||
tts_thread.start()
|
||||
|
||||
spoken_text = ""
|
||||
chunk_id = 0
|
||||
# Maps chunk_id -> cumulative text up to and including that chunk
|
||||
chunk_text_map: dict[int, str] = {}
|
||||
while True:
|
||||
try:
|
||||
item = await asyncio.to_thread(chunk_queue.get, timeout=10.0)
|
||||
@@ -137,8 +200,14 @@ class ConversationSession:
|
||||
|
||||
sentence_text, audio = item
|
||||
spoken_text += sentence_text
|
||||
chunk_text_map[chunk_id] = spoken_text
|
||||
|
||||
await self.send_json({"type": "response_text", "text": sentence_text, "final": False})
|
||||
await self.send_json({
|
||||
"type": "response_text",
|
||||
"text": sentence_text,
|
||||
"chunk_id": chunk_id,
|
||||
"final": False,
|
||||
})
|
||||
pcm_bytes = float32_to_pcm_bytes(audio)
|
||||
try:
|
||||
await self.send_bytes(pcm_bytes)
|
||||
@@ -146,14 +215,26 @@ class ConversationSession:
|
||||
log.warning("Failed to send audio, client disconnected.")
|
||||
self.cancel_event.set()
|
||||
break
|
||||
chunk_id += 1
|
||||
|
||||
tts_thread.join(timeout=2.0)
|
||||
|
||||
# Save only what was actually spoken
|
||||
# Determine what was actually heard by the client
|
||||
was_interrupted = spoken_text.strip() != response.strip()
|
||||
if spoken_text.strip():
|
||||
if was_interrupted and self._last_played_chunk_id is not None:
|
||||
# Client told us the last chunk whose audio actually played
|
||||
heard_text = chunk_text_map.get(self._last_played_chunk_id, "")
|
||||
log.info(f"Interrupted: client heard up to chunk {self._last_played_chunk_id}")
|
||||
else:
|
||||
heard_text = spoken_text
|
||||
|
||||
# Save only what was actually spoken/heard
|
||||
if heard_text.strip():
|
||||
# Use original LLM response when fully spoken (keeps KV-cache valid);
|
||||
# use heard_text only when interrupted.
|
||||
final_content = heard_text.strip() if was_interrupted else response
|
||||
self.conversation_history.append(
|
||||
{"role": "assistant", "content": spoken_text.strip()}
|
||||
{"role": "assistant", "content": final_content}
|
||||
)
|
||||
if was_interrupted and self.kv_cache_state is not None:
|
||||
self.kv_cache_state = self.models.llm_engine.trim_cache(
|
||||
@@ -164,7 +245,12 @@ class ConversationSession:
|
||||
self.kv_cache_state = None
|
||||
|
||||
if not self.cancel_event.is_set():
|
||||
await self.send_json({"type": "response_text", "text": "", "final": True})
|
||||
await self.send_json({
|
||||
"type": "response_text",
|
||||
"text": "",
|
||||
"final": True,
|
||||
"total_chunks": chunk_id,
|
||||
})
|
||||
|
||||
self.is_responding = False
|
||||
try:
|
||||
|
||||
+64
-14
@@ -13,6 +13,11 @@ const BARGE_IN_THRESHOLD = 0.03; // RMS energy threshold for barge-in
|
||||
const BARGE_IN_FRAMES = 2; // Consecutive frames above threshold to trigger
|
||||
let bargeInCount = 0;
|
||||
|
||||
// --- Text-audio sync state ---
|
||||
let pendingTextChunks = []; // [{chunkId, text}] - text waiting for its audio to arrive
|
||||
let scheduledTextTimers = []; // timer IDs for text display scheduled to match audio playback
|
||||
let lastDisplayedChunkId = -1; // last chunk whose text was actually shown to the user
|
||||
|
||||
const chatArea = document.getElementById("chat-area");
|
||||
const statusBadge = document.getElementById("status-badge");
|
||||
const micBtn = document.getElementById("mic-btn");
|
||||
@@ -54,8 +59,8 @@ function handleJSON(msg) {
|
||||
|
||||
case "interrupt":
|
||||
stopPlayback();
|
||||
// Trim the assistant message to what was spoken, then finalize
|
||||
finalizeAssistantMessage();
|
||||
// Finalize with interrupted marker — text already reflects only what was heard
|
||||
finalizeAssistantMessage(true);
|
||||
break;
|
||||
|
||||
case "transcript":
|
||||
@@ -64,9 +69,15 @@ function handleJSON(msg) {
|
||||
|
||||
case "response_text":
|
||||
if (msg.final) {
|
||||
finalizeAssistantMessage();
|
||||
// All chunks sent; finalize will happen when last audio chunk plays
|
||||
// (or immediately if nothing was queued)
|
||||
if (pendingTextChunks.length === 0 && scheduledTextTimers.length === 0) {
|
||||
finalizeAssistantMessage(false);
|
||||
}
|
||||
// Otherwise, playAudioChunk will finalize after the last scheduled text
|
||||
} else {
|
||||
appendAssistantText(msg.text);
|
||||
// Queue text — it will be displayed when corresponding audio starts playing
|
||||
pendingTextChunks.push({ chunkId: msg.chunk_id, text: msg.text });
|
||||
}
|
||||
break;
|
||||
}
|
||||
@@ -113,9 +124,20 @@ function appendAssistantText(text) {
|
||||
chatArea.scrollTop = chatArea.scrollHeight;
|
||||
}
|
||||
|
||||
function finalizeAssistantMessage() {
|
||||
function finalizeAssistantMessage(interrupted = false) {
|
||||
if (interrupted && currentAssistantEl && currentAssistantText) {
|
||||
const marker = document.createElement("span");
|
||||
marker.className = "interrupted-marker";
|
||||
marker.textContent = " [interrupted]";
|
||||
currentAssistantEl.appendChild(marker);
|
||||
}
|
||||
currentAssistantEl = null;
|
||||
currentAssistantText = "";
|
||||
// Reset sync state
|
||||
pendingTextChunks = [];
|
||||
for (const tid of scheduledTextTimers) clearTimeout(tid);
|
||||
scheduledTextTimers = [];
|
||||
lastDisplayedChunkId = -1;
|
||||
}
|
||||
|
||||
// --- Audio Playback ---
|
||||
@@ -146,18 +168,38 @@ function playAudioChunk(arrayBuffer) {
|
||||
|
||||
activeSources.push(source);
|
||||
isPlaying = true;
|
||||
source.onended = () => {
|
||||
activeSources = activeSources.filter((s) => s !== source);
|
||||
if (activeSources.length === 0) {
|
||||
isPlaying = false;
|
||||
bargeInCount = 0;
|
||||
}
|
||||
};
|
||||
|
||||
// Pair this audio chunk with the next queued text chunk
|
||||
const textEntry = pendingTextChunks.shift();
|
||||
|
||||
const now = ctx.currentTime;
|
||||
if (nextPlayTime < now) {
|
||||
nextPlayTime = now + 0.01;
|
||||
}
|
||||
|
||||
// Schedule text display to coincide with audio playback start
|
||||
if (textEntry) {
|
||||
const delayMs = Math.max(0, (nextPlayTime - now) * 1000);
|
||||
const tid = setTimeout(() => {
|
||||
appendAssistantText(textEntry.text);
|
||||
lastDisplayedChunkId = textEntry.chunkId;
|
||||
scheduledTextTimers = scheduledTextTimers.filter((t) => t !== tid);
|
||||
}, delayMs);
|
||||
scheduledTextTimers.push(tid);
|
||||
}
|
||||
|
||||
source.onended = () => {
|
||||
activeSources = activeSources.filter((s) => s !== source);
|
||||
if (activeSources.length === 0) {
|
||||
isPlaying = false;
|
||||
bargeInCount = 0;
|
||||
// If all audio has finished and no more text pending, finalize
|
||||
if (pendingTextChunks.length === 0 && scheduledTextTimers.length === 0) {
|
||||
finalizeAssistantMessage(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
source.start(nextPlayTime);
|
||||
nextPlayTime += buffer.duration;
|
||||
}
|
||||
@@ -172,6 +214,10 @@ function stopPlayback() {
|
||||
nextPlayTime = 0;
|
||||
isPlaying = false;
|
||||
bargeInCount = 0;
|
||||
// Cancel any pending text displays
|
||||
for (const tid of scheduledTextTimers) clearTimeout(tid);
|
||||
scheduledTextTimers = [];
|
||||
pendingTextChunks = [];
|
||||
}
|
||||
|
||||
// --- Microphone ---
|
||||
@@ -229,8 +275,12 @@ async function startMic() {
|
||||
if (bargeInCount >= BARGE_IN_FRAMES) {
|
||||
// User is speaking over the assistant - interrupt
|
||||
stopPlayback();
|
||||
finalizeAssistantMessage();
|
||||
ws.send(JSON.stringify({ type: "interrupt" }));
|
||||
const msg = { type: "interrupt" };
|
||||
if (lastDisplayedChunkId >= 0) {
|
||||
msg.last_chunk_id = lastDisplayedChunkId;
|
||||
}
|
||||
ws.send(JSON.stringify(msg));
|
||||
finalizeAssistantMessage(true);
|
||||
isPlaying = false;
|
||||
bargeInCount = 0;
|
||||
}
|
||||
|
||||
@@ -84,6 +84,12 @@ header h1 {
|
||||
border-bottom-left-radius: 4px;
|
||||
}
|
||||
|
||||
.interrupted-marker {
|
||||
color: #888;
|
||||
font-style: italic;
|
||||
font-size: 13px;
|
||||
}
|
||||
|
||||
#controls {
|
||||
padding: 16px 24px;
|
||||
border-top: 1px solid #222;
|
||||
|
||||
Reference in New Issue
Block a user