import asyncio import logging import queue import re import threading import numpy as np from server.audio_utils import float32_to_pcm_bytes from server.llm import KVCacheState from server.models import ModelManager from server.vad import StreamingVAD log = logging.getLogger(__name__) _SENTINEL = None # Regex: split after sentence-ending punctuation followed by whitespace _SENTENCE_RE = re.compile(r'(?<=[.!?])\s+') # Regex: split after clause-level punctuation followed by whitespace _CLAUSE_RE = re.compile(r'(?<=[,;:\u2014])\s+') MAX_SEGMENT_WORDS = 20 MIN_SEGMENT_WORDS = 4 def _split_into_segments(text: str) -> list[str]: """Split text into small TTS-friendly segments for fine-grained streaming. Splits on sentence boundaries first, then breaks long sentences at clause boundaries (commas, semicolons, colons, em-dashes). Avoids tiny fragments by merging short pieces with their neighbours. """ sentences = _SENTENCE_RE.split(text.strip()) segments: list[str] = [] for sent in sentences: if len(sent.split()) <= MAX_SEGMENT_WORDS: segments.append(sent) else: # Split long sentences at clause boundaries clauses = _CLAUSE_RE.split(sent) current = "" for clause in clauses: combined = (current + " " + clause) if current else clause if current and len(combined.split()) > MAX_SEGMENT_WORDS: segments.append(current) current = clause else: current = combined if current: segments.append(current) # Merge any tiny fragments into their neighbour merged: list[str] = [] for seg in segments: if not seg.strip(): continue if merged and len(merged[-1].split()) < MIN_SEGMENT_WORDS: merged[-1] = merged[-1] + " " + seg else: merged.append(seg) # Also merge a trailing runt if len(merged) > 1 and len(merged[-1].split()) < MIN_SEGMENT_WORDS: merged[-2] = merged[-2] + " " + merged[-1] merged.pop() return merged class ConversationSession: """Manages a single client's voice conversation pipeline. Orchestrates: VAD -> ASR -> LLM -> TTS streaming with barge-in support. """ def __init__(self, models: ModelManager, send_json, send_bytes): self.models = models self.send_json = send_json self.send_bytes = send_bytes self.vad: StreamingVAD = models.create_vad() self.conversation_history: list[dict] = [] self.kv_cache_state: KVCacheState | None = None self.cancel_event = threading.Event() self.is_responding = False self._response_task: asyncio.Task | None = None async def start(self): await self.send_json({"type": "status", "state": "listening"}) async def stop(self): self.cancel_event.set() if self._response_task and not self._response_task.done(): self._response_task.cancel() self.kv_cache_state = None async def handle_audio_chunk(self, chunk_16k: np.ndarray): utterance = self.vad.process_chunk(chunk_16k) if utterance is not None: if self.is_responding: await self._interrupt() # Launch response pipeline as a background task so we don't block receives self._response_task = asyncio.create_task(self._process_utterance(utterance)) elif self.vad.is_speaking and self.is_responding: await self._interrupt() async def interrupt(self, last_chunk_id: int | None = None): """Public interrupt method for WebSocket text messages.""" if self.is_responding: await self._interrupt(last_chunk_id=last_chunk_id) async def _interrupt(self, last_chunk_id: int | None = None): log.info("Barge-in: cancelling response.") self.cancel_event.set() self.is_responding = False if last_chunk_id is not None: self._last_played_chunk_id = last_chunk_id # Tell client to stop audio immediately try: await self.send_json({"type": "interrupt"}) except Exception: pass async def _process_utterance(self, audio_16k: np.ndarray): """Full pipeline: ASR -> LLM -> TTS streaming.""" self.is_responding = True self.cancel_event.clear() # ASR await self.send_json({"type": "status", "state": "thinking"}) text = await asyncio.to_thread(self.models.asr_engine.transcribe, audio_16k) if not text: log.info("ASR returned empty text, resuming listening.") self.is_responding = False await self.send_json({"type": "status", "state": "listening"}) return await self.send_json({"type": "transcript", "text": text, "final": True}) log.info(f"User: {text}") self.conversation_history.append({"role": "user", "content": text}) if self.cancel_event.is_set(): self.is_responding = False return # LLM log.info(f"Conversation history ({len(self.conversation_history)} messages): " + str([m['content'][:50] for m in self.conversation_history])) response, self.kv_cache_state = await asyncio.to_thread( self.models.llm_engine.generate, self.conversation_history, 256, self.kv_cache_state ) if self.cancel_event.is_set(): self.is_responding = False return # TTS - stream chunks with per-sentence text await self.send_json({"type": "status", "state": "speaking"}) chunk_queue = queue.Queue() self._last_played_chunk_id = None segments = _split_into_segments(response) log.info(f"TTS: split response into {len(segments)} segments") def _tts_worker(): try: for segment in segments: if self.cancel_event.is_set(): break for graphemes, _ps, audio in self.models.tts_engine.pipeline( segment, voice=self.models.tts_engine.voice ): if self.cancel_event.is_set(): break if audio is not None and len(audio) > 0: chunk_queue.put((graphemes, audio)) except Exception: log.exception("TTS generation error") finally: chunk_queue.put(_SENTINEL) tts_thread = threading.Thread(target=_tts_worker, daemon=True) tts_thread.start() spoken_text = "" chunk_id = 0 # Maps chunk_id -> cumulative text up to and including that chunk chunk_text_map: dict[int, str] = {} while True: try: item = await asyncio.to_thread(chunk_queue.get, timeout=10.0) except Exception: break if item is _SENTINEL: break if self.cancel_event.is_set(): break sentence_text, audio = item spoken_text += sentence_text chunk_text_map[chunk_id] = spoken_text await self.send_json({ "type": "response_text", "text": sentence_text, "chunk_id": chunk_id, "final": False, }) pcm_bytes = float32_to_pcm_bytes(audio) try: await self.send_bytes(pcm_bytes) except Exception: log.warning("Failed to send audio, client disconnected.") self.cancel_event.set() break chunk_id += 1 tts_thread.join(timeout=2.0) # Determine what was actually heard by the client was_interrupted = spoken_text.strip() != response.strip() if was_interrupted and self._last_played_chunk_id is not None: # Client told us the last chunk whose audio actually played heard_text = chunk_text_map.get(self._last_played_chunk_id, "") log.info(f"Interrupted: client heard up to chunk {self._last_played_chunk_id}") else: heard_text = spoken_text # Save only what was actually spoken/heard if heard_text.strip(): # Use original LLM response when fully spoken (keeps KV-cache valid); # use heard_text only when interrupted. final_content = heard_text.strip() if was_interrupted else response self.conversation_history.append( {"role": "assistant", "content": final_content} ) if was_interrupted and self.kv_cache_state is not None: self.kv_cache_state = self.models.llm_engine.trim_cache( self.kv_cache_state, self.conversation_history ) elif self.conversation_history and self.conversation_history[-1]["role"] == "user": self.conversation_history.pop() self.kv_cache_state = None if not self.cancel_event.is_set(): await self.send_json({ "type": "response_text", "text": "", "final": True, "total_chunks": chunk_id, }) self.is_responding = False try: await self.send_json({"type": "status", "state": "listening"}) except Exception: pass