Files
live-voice-chat/server/pipeline.py
T
2026-04-08 11:40:59 -04:00

260 lines
9.4 KiB
Python

import asyncio
import logging
import queue
import re
import threading
import numpy as np
from server.audio_utils import float32_to_pcm_bytes
from server.llm import KVCacheState
from server.models import ModelManager
from server.vad import StreamingVAD
log = logging.getLogger(__name__)
_SENTINEL = None
# Regex: split after sentence-ending punctuation followed by whitespace
_SENTENCE_RE = re.compile(r'(?<=[.!?])\s+')
# Regex: split after clause-level punctuation followed by whitespace
_CLAUSE_RE = re.compile(r'(?<=[,;:\u2014])\s+')
MAX_SEGMENT_WORDS = 20
MIN_SEGMENT_WORDS = 4
def _split_into_segments(text: str) -> list[str]:
"""Split text into small TTS-friendly segments for fine-grained streaming.
Splits on sentence boundaries first, then breaks long sentences at clause
boundaries (commas, semicolons, colons, em-dashes). Avoids tiny fragments
by merging short pieces with their neighbours.
"""
sentences = _SENTENCE_RE.split(text.strip())
segments: list[str] = []
for sent in sentences:
if len(sent.split()) <= MAX_SEGMENT_WORDS:
segments.append(sent)
else:
# Split long sentences at clause boundaries
clauses = _CLAUSE_RE.split(sent)
current = ""
for clause in clauses:
combined = (current + " " + clause) if current else clause
if current and len(combined.split()) > MAX_SEGMENT_WORDS:
segments.append(current)
current = clause
else:
current = combined
if current:
segments.append(current)
# Merge any tiny fragments into their neighbour
merged: list[str] = []
for seg in segments:
if not seg.strip():
continue
if merged and len(merged[-1].split()) < MIN_SEGMENT_WORDS:
merged[-1] = merged[-1] + " " + seg
else:
merged.append(seg)
# Also merge a trailing runt
if len(merged) > 1 and len(merged[-1].split()) < MIN_SEGMENT_WORDS:
merged[-2] = merged[-2] + " " + merged[-1]
merged.pop()
return merged
class ConversationSession:
"""Manages a single client's voice conversation pipeline.
Orchestrates: VAD -> ASR -> LLM -> TTS streaming with barge-in support.
"""
def __init__(self, models: ModelManager, send_json, send_bytes):
self.models = models
self.send_json = send_json
self.send_bytes = send_bytes
self.vad: StreamingVAD = models.create_vad()
self.conversation_history: list[dict] = []
self.kv_cache_state: KVCacheState | None = None
self.cancel_event = threading.Event()
self.is_responding = False
self._response_task: asyncio.Task | None = None
async def start(self):
await self.send_json({"type": "status", "state": "listening"})
async def stop(self):
self.cancel_event.set()
if self._response_task and not self._response_task.done():
self._response_task.cancel()
self.kv_cache_state = None
async def handle_audio_chunk(self, chunk_16k: np.ndarray):
utterance = self.vad.process_chunk(chunk_16k)
if utterance is not None:
if self.is_responding:
await self._interrupt()
# Launch response pipeline as a background task so we don't block receives
self._response_task = asyncio.create_task(self._process_utterance(utterance))
elif self.vad.is_speaking and self.is_responding:
await self._interrupt()
async def interrupt(self, last_chunk_id: int | None = None):
"""Public interrupt method for WebSocket text messages."""
if self.is_responding:
await self._interrupt(last_chunk_id=last_chunk_id)
async def _interrupt(self, last_chunk_id: int | None = None):
log.info("Barge-in: cancelling response.")
self.cancel_event.set()
self.is_responding = False
if last_chunk_id is not None:
self._last_played_chunk_id = last_chunk_id
# Tell client to stop audio immediately
try:
await self.send_json({"type": "interrupt"})
except Exception:
pass
async def _process_utterance(self, audio_16k: np.ndarray):
"""Full pipeline: ASR -> LLM -> TTS streaming."""
self.is_responding = True
self.cancel_event.clear()
# ASR
await self.send_json({"type": "status", "state": "thinking"})
text = await asyncio.to_thread(self.models.asr_engine.transcribe, audio_16k)
if not text:
log.info("ASR returned empty text, resuming listening.")
self.is_responding = False
await self.send_json({"type": "status", "state": "listening"})
return
await self.send_json({"type": "transcript", "text": text, "final": True})
log.info(f"User: {text}")
self.conversation_history.append({"role": "user", "content": text})
if self.cancel_event.is_set():
self.is_responding = False
return
# LLM
log.info(f"Conversation history ({len(self.conversation_history)} messages): "
+ str([m['content'][:50] for m in self.conversation_history]))
response, self.kv_cache_state = await asyncio.to_thread(
self.models.llm_engine.generate, self.conversation_history, 256, self.kv_cache_state
)
if self.cancel_event.is_set():
self.is_responding = False
return
# TTS - stream chunks with per-sentence text
await self.send_json({"type": "status", "state": "speaking"})
chunk_queue = queue.Queue()
self._last_played_chunk_id = None
segments = _split_into_segments(response)
log.info(f"TTS: split response into {len(segments)} segments")
def _tts_worker():
try:
for segment in segments:
if self.cancel_event.is_set():
break
for graphemes, _ps, audio in self.models.tts_engine.pipeline(
segment, voice=self.models.tts_engine.voice
):
if self.cancel_event.is_set():
break
if audio is not None and len(audio) > 0:
chunk_queue.put((graphemes, audio))
except Exception:
log.exception("TTS generation error")
finally:
chunk_queue.put(_SENTINEL)
tts_thread = threading.Thread(target=_tts_worker, daemon=True)
tts_thread.start()
spoken_text = ""
chunk_id = 0
# Maps chunk_id -> cumulative text up to and including that chunk
chunk_text_map: dict[int, str] = {}
while True:
try:
item = await asyncio.to_thread(chunk_queue.get, timeout=10.0)
except Exception:
break
if item is _SENTINEL:
break
if self.cancel_event.is_set():
break
sentence_text, audio = item
spoken_text += sentence_text
chunk_text_map[chunk_id] = spoken_text
await self.send_json({
"type": "response_text",
"text": sentence_text,
"chunk_id": chunk_id,
"final": False,
})
pcm_bytes = float32_to_pcm_bytes(audio)
try:
await self.send_bytes(pcm_bytes)
except Exception:
log.warning("Failed to send audio, client disconnected.")
self.cancel_event.set()
break
chunk_id += 1
tts_thread.join(timeout=2.0)
# Determine what was actually heard by the client
was_interrupted = spoken_text.strip() != response.strip()
if was_interrupted and self._last_played_chunk_id is not None:
# Client told us the last chunk whose audio actually played
heard_text = chunk_text_map.get(self._last_played_chunk_id, "")
log.info(f"Interrupted: client heard up to chunk {self._last_played_chunk_id}")
else:
heard_text = spoken_text
# Save only what was actually spoken/heard
if heard_text.strip():
# Use original LLM response when fully spoken (keeps KV-cache valid);
# use heard_text only when interrupted.
final_content = heard_text.strip() if was_interrupted else response
self.conversation_history.append(
{"role": "assistant", "content": final_content}
)
if was_interrupted and self.kv_cache_state is not None:
self.kv_cache_state = self.models.llm_engine.trim_cache(
self.kv_cache_state, self.conversation_history
)
elif self.conversation_history and self.conversation_history[-1]["role"] == "user":
self.conversation_history.pop()
self.kv_cache_state = None
if not self.cancel_event.is_set():
await self.send_json({
"type": "response_text",
"text": "",
"final": True,
"total_chunks": chunk_id,
})
self.is_responding = False
try:
await self.send_json({"type": "status", "state": "listening"})
except Exception:
pass