import asyncio import logging import queue import re import threading import numpy as np from server.audio_utils import float32_to_pcm_bytes from server.llm import KVCacheState from server.models import ModelManager from server.vad import StreamingVAD log = logging.getLogger(__name__) _SENTINEL = None # Regex: split after sentence-ending punctuation followed by whitespace _SENTENCE_RE = re.compile(r'(?<=[.!?])\s+') # Regex: split after clause-level punctuation followed by whitespace _CLAUSE_RE = re.compile(r'(?<=[,;:\u2014])\s+') MAX_SEGMENT_WORDS = 20 MIN_SEGMENT_WORDS = 4 def _split_into_segments(text: str) -> list[str]: """Split text into small TTS-friendly segments for fine-grained streaming. Splits on sentence boundaries first, then breaks long sentences at clause boundaries (commas, semicolons, colons, em-dashes). Avoids tiny fragments by merging short pieces with their neighbours. """ sentences = _SENTENCE_RE.split(text.strip()) segments: list[str] = [] for sent in sentences: if len(sent.split()) <= MAX_SEGMENT_WORDS: segments.append(sent) else: # Split long sentences at clause boundaries clauses = _CLAUSE_RE.split(sent) current = "" for clause in clauses: combined = (current + " " + clause) if current else clause if current and len(combined.split()) > MAX_SEGMENT_WORDS: segments.append(current) current = clause else: current = combined if current: segments.append(current) # Merge any tiny fragments into their neighbour merged: list[str] = [] for seg in segments: if not seg.strip(): continue if merged and len(merged[-1].split()) < MIN_SEGMENT_WORDS: merged[-1] = merged[-1] + " " + seg else: merged.append(seg) # Also merge a trailing runt if len(merged) > 1 and len(merged[-1].split()) < MIN_SEGMENT_WORDS: merged[-2] = merged[-2] + " " + merged[-1] merged.pop() return merged class ConversationSession: """Manages a single client's voice conversation pipeline. Orchestrates: VAD -> ASR -> LLM -> TTS streaming with barge-in support. """ def __init__(self, models: ModelManager, send_json, send_bytes): self.models = models self.send_json = send_json self.send_bytes = send_bytes self.vad: StreamingVAD = models.create_vad() self.conversation_history: list[dict] = [] self.kv_cache_state: KVCacheState | None = None self.cancel_event = threading.Event() self.is_responding = False self._response_task: asyncio.Task | None = None async def start(self): await self.send_json({"type": "status", "state": "listening"}) async def stop(self): self.cancel_event.set() if self._response_task and not self._response_task.done(): self._response_task.cancel() self.kv_cache_state = None async def handle_audio_chunk(self, chunk_16k: np.ndarray): utterance = self.vad.process_chunk(chunk_16k) if utterance is not None: if self.is_responding: await self._interrupt() # Launch response pipeline as a background task so we don't block receives self._response_task = asyncio.create_task(self._process_utterance(utterance)) elif self.vad.is_speaking and self.is_responding: await self._interrupt() async def interrupt(self, last_chunk_id: int | None = None): """Public interrupt method for WebSocket text messages.""" if self.is_responding: await self._interrupt(last_chunk_id=last_chunk_id) async def _interrupt(self, last_chunk_id: int | None = None): log.info("Barge-in: cancelling response.") self.cancel_event.set() self.is_responding = False if last_chunk_id is not None: self._last_played_chunk_id = last_chunk_id # Tell client to stop audio immediately try: await self.send_json({"type": "interrupt"}) except Exception: pass async def _process_utterance(self, audio_16k: np.ndarray): """Full pipeline: ASR -> LLM -> TTS streaming.""" self.is_responding = True self.cancel_event.clear() # ASR await self.send_json({"type": "status", "state": "thinking"}) text = await asyncio.to_thread(self.models.asr_engine.transcribe, audio_16k) if not text: log.info("ASR returned empty text, resuming listening.") self.is_responding = False await self.send_json({"type": "status", "state": "listening"}) return await self.send_json({"type": "transcript", "text": text, "final": True}) log.info(f"User: {text}") self.conversation_history.append({"role": "user", "content": text}) if self.cancel_event.is_set(): self.is_responding = False return # LLM log.info(f"Conversation history ({len(self.conversation_history)} messages): " + str([m['content'][:50] for m in self.conversation_history])) response, self.kv_cache_state = await asyncio.to_thread( self.models.llm_engine.generate, self.conversation_history, 256, self.kv_cache_state ) if self.cancel_event.is_set(): self.is_responding = False return # TTS - stream chunks with per-sentence text await self.send_json({"type": "status", "state": "speaking"}) # Video-mode branch: if a video engine is loaded AND an avatar is # set, buffer the full TTS output into a single blob, run MuseTalk # lip-sync (library or reflective source), mux to MP4, and send the # full clip + text in one shot. The client plays the MP4 (which # carries audio) instead of the per-chunk PCM path. video_engine = getattr(self.models, "video_engine", None) use_video = video_engine is not None and video_engine.is_ready() chunk_queue = queue.Queue() self._last_played_chunk_id = None segments = _split_into_segments(response) log.info(f"TTS: split response into {len(segments)} segments (video={use_video})") def _tts_worker(): try: for segment in segments: if self.cancel_event.is_set(): break for graphemes, _ps, audio in self.models.tts_engine.pipeline( segment, voice=self.models.tts_engine.voice ): if self.cancel_event.is_set(): break if audio is not None and len(audio) > 0: chunk_queue.put((graphemes, audio)) except Exception: log.exception("TTS generation error") finally: chunk_queue.put(_SENTINEL) tts_thread = threading.Thread(target=_tts_worker, daemon=True) tts_thread.start() spoken_text = "" chunk_id = 0 # Maps chunk_id -> cumulative text up to and including that chunk chunk_text_map: dict[int, str] = {} # Video mode accumulator: we buffer all TTS audio into one float32 # array so MuseTalk can align against the full utterance. audio_buffer: list[np.ndarray] = [] while True: try: item = await asyncio.to_thread(chunk_queue.get, timeout=10.0) except Exception: break if item is _SENTINEL: break if self.cancel_event.is_set(): break sentence_text, audio = item spoken_text += sentence_text chunk_text_map[chunk_id] = spoken_text if use_video: audio_buffer.append(audio) # Don't stream text or PCM during video mode — we'll send # everything after the clip renders so the client doesn't # start displaying text before the video is ready. else: await self.send_json({ "type": "response_text", "text": sentence_text, "chunk_id": chunk_id, "final": False, }) pcm_bytes = float32_to_pcm_bytes(audio) try: await self.send_bytes(pcm_bytes) except Exception: log.warning("Failed to send audio, client disconnected.") self.cancel_event.set() break chunk_id += 1 tts_thread.join(timeout=2.0) # Video mode: stream speaking clips as they're generated (one per audio segment). if use_video and audio_buffer and not self.cancel_event.is_set(): try: full_audio = np.concatenate(audio_buffer).astype(np.float32) sample_rate = getattr(self.models.tts_engine, "sample_rate", 24000) log.info( "Video: rendering speaking clips (audio=%.1fs, mode=%s)", len(full_audio) / sample_rate, video_engine.cfg.mode, ) clip_queue: queue.Queue = queue.Queue() def _video_worker(): try: for clip_data in video_engine.generate_speaking_clips_streaming( full_audio, sample_rate, response ): if self.cancel_event.is_set(): break clip_queue.put(clip_data) except Exception: log.exception("Video clip generation failed") finally: clip_queue.put(_SENTINEL) video_thread = threading.Thread(target=_video_worker, daemon=True) video_thread.start() is_first_clip = True while not self.cancel_event.is_set(): try: item = await asyncio.to_thread(clip_queue.get, timeout=120.0) except Exception: log.warning("Timed out waiting for video clip.") break if item is _SENTINEL: break if self.cancel_event.is_set(): break mp4_bytes, duration_ms = item try: await self.send_json({ "type": "speaking_clip", "chunk_id": 0, "duration_ms": duration_ms, "text": response if is_first_clip else "", "size_bytes": len(mp4_bytes), }) await self.send_bytes(mp4_bytes) is_first_clip = False except Exception: log.warning("Failed to send video clip, client disconnected.") self.cancel_event.set() break except Exception: log.exception("Video speaking-clip render failed; falling back silently.") try: await self.send_json({ "type": "response_text", "text": response, "chunk_id": 0, "final": True, }) except Exception: pass # Determine what was actually heard by the client was_interrupted = spoken_text.strip() != response.strip() if was_interrupted and self._last_played_chunk_id is not None: # Client told us the last chunk whose audio actually played heard_text = chunk_text_map.get(self._last_played_chunk_id, "") log.info(f"Interrupted: client heard up to chunk {self._last_played_chunk_id}") else: heard_text = spoken_text # Save only what was actually spoken/heard if heard_text.strip(): # Use original LLM response when fully spoken (keeps KV-cache valid); # use heard_text only when interrupted. final_content = heard_text.strip() if was_interrupted else response self.conversation_history.append( {"role": "assistant", "content": final_content} ) if was_interrupted and self.kv_cache_state is not None: self.kv_cache_state = self.models.llm_engine.trim_cache( self.kv_cache_state, self.conversation_history ) elif self.conversation_history and self.conversation_history[-1]["role"] == "user": self.conversation_history.pop() self.kv_cache_state = None if not self.cancel_event.is_set(): await self.send_json({ "type": "response_text", "text": "", "final": True, "total_chunks": chunk_id, }) self.is_responding = False try: await self.send_json({"type": "status", "state": "listening"}) except Exception: pass