319 lines
12 KiB
Python
319 lines
12 KiB
Python
import asyncio
|
|
import logging
|
|
import queue
|
|
import re
|
|
import threading
|
|
|
|
import numpy as np
|
|
|
|
from server.audio_utils import float32_to_pcm_bytes
|
|
from server.llm import KVCacheState
|
|
from server.models import ModelManager
|
|
from server.vad import StreamingVAD
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
_SENTINEL = None
|
|
|
|
# Regex: split after sentence-ending punctuation followed by whitespace
|
|
_SENTENCE_RE = re.compile(r'(?<=[.!?])\s+')
|
|
# Regex: split after clause-level punctuation followed by whitespace
|
|
_CLAUSE_RE = re.compile(r'(?<=[,;:\u2014])\s+')
|
|
|
|
MAX_SEGMENT_WORDS = 20
|
|
MIN_SEGMENT_WORDS = 4
|
|
|
|
|
|
def _split_into_segments(text: str) -> list[str]:
|
|
"""Split text into small TTS-friendly segments for fine-grained streaming.
|
|
|
|
Splits on sentence boundaries first, then breaks long sentences at clause
|
|
boundaries (commas, semicolons, colons, em-dashes). Avoids tiny fragments
|
|
by merging short pieces with their neighbours.
|
|
"""
|
|
sentences = _SENTENCE_RE.split(text.strip())
|
|
segments: list[str] = []
|
|
for sent in sentences:
|
|
if len(sent.split()) <= MAX_SEGMENT_WORDS:
|
|
segments.append(sent)
|
|
else:
|
|
# Split long sentences at clause boundaries
|
|
clauses = _CLAUSE_RE.split(sent)
|
|
current = ""
|
|
for clause in clauses:
|
|
combined = (current + " " + clause) if current else clause
|
|
if current and len(combined.split()) > MAX_SEGMENT_WORDS:
|
|
segments.append(current)
|
|
current = clause
|
|
else:
|
|
current = combined
|
|
if current:
|
|
segments.append(current)
|
|
|
|
# Merge any tiny fragments into their neighbour
|
|
merged: list[str] = []
|
|
for seg in segments:
|
|
if not seg.strip():
|
|
continue
|
|
if merged and len(merged[-1].split()) < MIN_SEGMENT_WORDS:
|
|
merged[-1] = merged[-1] + " " + seg
|
|
else:
|
|
merged.append(seg)
|
|
# Also merge a trailing runt
|
|
if len(merged) > 1 and len(merged[-1].split()) < MIN_SEGMENT_WORDS:
|
|
merged[-2] = merged[-2] + " " + merged[-1]
|
|
merged.pop()
|
|
return merged
|
|
|
|
|
|
class ConversationSession:
|
|
"""Manages a single client's voice conversation pipeline.
|
|
|
|
Orchestrates: VAD -> ASR -> LLM -> TTS streaming with barge-in support.
|
|
"""
|
|
|
|
def __init__(self, models: ModelManager, send_json, send_bytes):
|
|
self.models = models
|
|
self.send_json = send_json
|
|
self.send_bytes = send_bytes
|
|
|
|
self.vad: StreamingVAD = models.create_vad()
|
|
self.conversation_history: list[dict] = []
|
|
self.kv_cache_state: KVCacheState | None = None
|
|
self.cancel_event = threading.Event()
|
|
self.is_responding = False
|
|
self._response_task: asyncio.Task | None = None
|
|
|
|
async def start(self):
|
|
await self.send_json({"type": "status", "state": "listening"})
|
|
|
|
async def stop(self):
|
|
self.cancel_event.set()
|
|
if self._response_task and not self._response_task.done():
|
|
self._response_task.cancel()
|
|
self.kv_cache_state = None
|
|
|
|
async def handle_audio_chunk(self, chunk_16k: np.ndarray):
|
|
utterance = self.vad.process_chunk(chunk_16k)
|
|
|
|
if utterance is not None:
|
|
if self.is_responding:
|
|
await self._interrupt()
|
|
# Launch response pipeline as a background task so we don't block receives
|
|
self._response_task = asyncio.create_task(self._process_utterance(utterance))
|
|
elif self.vad.is_speaking and self.is_responding:
|
|
await self._interrupt()
|
|
|
|
async def interrupt(self, last_chunk_id: int | None = None):
|
|
"""Public interrupt method for WebSocket text messages."""
|
|
if self.is_responding:
|
|
await self._interrupt(last_chunk_id=last_chunk_id)
|
|
|
|
async def _interrupt(self, last_chunk_id: int | None = None):
|
|
log.info("Barge-in: cancelling response.")
|
|
self.cancel_event.set()
|
|
self.is_responding = False
|
|
if last_chunk_id is not None:
|
|
self._last_played_chunk_id = last_chunk_id
|
|
# Tell client to stop audio immediately
|
|
try:
|
|
await self.send_json({"type": "interrupt"})
|
|
except Exception:
|
|
pass
|
|
|
|
async def _process_utterance(self, audio_16k: np.ndarray):
|
|
"""Full pipeline: ASR -> LLM -> TTS streaming."""
|
|
self.is_responding = True
|
|
self.cancel_event.clear()
|
|
|
|
# ASR
|
|
await self.send_json({"type": "status", "state": "thinking"})
|
|
text = await asyncio.to_thread(self.models.asr_engine.transcribe, audio_16k)
|
|
|
|
if not text:
|
|
log.info("ASR returned empty text, resuming listening.")
|
|
self.is_responding = False
|
|
await self.send_json({"type": "status", "state": "listening"})
|
|
return
|
|
|
|
await self.send_json({"type": "transcript", "text": text, "final": True})
|
|
log.info(f"User: {text}")
|
|
self.conversation_history.append({"role": "user", "content": text})
|
|
|
|
if self.cancel_event.is_set():
|
|
self.is_responding = False
|
|
return
|
|
|
|
# LLM
|
|
log.info(f"Conversation history ({len(self.conversation_history)} messages): "
|
|
+ str([m['content'][:50] for m in self.conversation_history]))
|
|
response, self.kv_cache_state = await asyncio.to_thread(
|
|
self.models.llm_engine.generate, self.conversation_history, 256, self.kv_cache_state
|
|
)
|
|
|
|
if self.cancel_event.is_set():
|
|
self.is_responding = False
|
|
return
|
|
|
|
# TTS - stream chunks with per-sentence text
|
|
await self.send_json({"type": "status", "state": "speaking"})
|
|
|
|
# Video-mode branch: if a video engine is loaded AND an avatar is
|
|
# set, buffer the full TTS output into a single blob, run MuseTalk
|
|
# lip-sync (library or reflective source), mux to MP4, and send the
|
|
# full clip + text in one shot. The client plays the MP4 (which
|
|
# carries audio) instead of the per-chunk PCM path.
|
|
video_engine = getattr(self.models, "video_engine", None)
|
|
use_video = video_engine is not None and video_engine.is_ready()
|
|
|
|
chunk_queue = queue.Queue()
|
|
self._last_played_chunk_id = None
|
|
|
|
segments = _split_into_segments(response)
|
|
log.info(f"TTS: split response into {len(segments)} segments (video={use_video})")
|
|
|
|
def _tts_worker():
|
|
try:
|
|
for segment in segments:
|
|
if self.cancel_event.is_set():
|
|
break
|
|
for graphemes, _ps, audio in self.models.tts_engine.pipeline(
|
|
segment, voice=self.models.tts_engine.voice
|
|
):
|
|
if self.cancel_event.is_set():
|
|
break
|
|
if audio is not None and len(audio) > 0:
|
|
chunk_queue.put((graphemes, audio))
|
|
except Exception:
|
|
log.exception("TTS generation error")
|
|
finally:
|
|
chunk_queue.put(_SENTINEL)
|
|
|
|
tts_thread = threading.Thread(target=_tts_worker, daemon=True)
|
|
tts_thread.start()
|
|
|
|
spoken_text = ""
|
|
chunk_id = 0
|
|
# Maps chunk_id -> cumulative text up to and including that chunk
|
|
chunk_text_map: dict[int, str] = {}
|
|
# Video mode accumulator: we buffer all TTS audio into one float32
|
|
# array so MuseTalk can align against the full utterance.
|
|
audio_buffer: list[np.ndarray] = []
|
|
|
|
while True:
|
|
try:
|
|
item = await asyncio.to_thread(chunk_queue.get, timeout=10.0)
|
|
except Exception:
|
|
break
|
|
|
|
if item is _SENTINEL:
|
|
break
|
|
if self.cancel_event.is_set():
|
|
break
|
|
|
|
sentence_text, audio = item
|
|
spoken_text += sentence_text
|
|
chunk_text_map[chunk_id] = spoken_text
|
|
|
|
if use_video:
|
|
audio_buffer.append(audio)
|
|
# Don't stream text or PCM during video mode — we'll send
|
|
# everything after the clip renders so the client doesn't
|
|
# start displaying text before the video is ready.
|
|
else:
|
|
await self.send_json({
|
|
"type": "response_text",
|
|
"text": sentence_text,
|
|
"chunk_id": chunk_id,
|
|
"final": False,
|
|
})
|
|
pcm_bytes = float32_to_pcm_bytes(audio)
|
|
try:
|
|
await self.send_bytes(pcm_bytes)
|
|
except Exception:
|
|
log.warning("Failed to send audio, client disconnected.")
|
|
self.cancel_event.set()
|
|
break
|
|
chunk_id += 1
|
|
|
|
tts_thread.join(timeout=2.0)
|
|
|
|
# Video mode: render the speaking clip now that TTS is done.
|
|
if use_video and audio_buffer and not self.cancel_event.is_set():
|
|
try:
|
|
full_audio = np.concatenate(audio_buffer).astype(np.float32)
|
|
sample_rate = getattr(self.models.tts_engine, "sample_rate", 24000)
|
|
log.info(
|
|
"Video: rendering speaking clip (audio=%ds, mode=%s)",
|
|
int(len(full_audio) / sample_rate), video_engine.cfg.mode,
|
|
)
|
|
mp4_bytes = await asyncio.to_thread(
|
|
video_engine.generate_speaking_clip,
|
|
full_audio,
|
|
sample_rate,
|
|
response,
|
|
)
|
|
if self.cancel_event.is_set():
|
|
log.info("Video clip discarded (cancelled during render).")
|
|
else:
|
|
duration_ms = int(len(full_audio) / sample_rate * 1000)
|
|
await self.send_json({
|
|
"type": "speaking_clip",
|
|
"chunk_id": 0,
|
|
"duration_ms": duration_ms,
|
|
"text": response,
|
|
"size_bytes": len(mp4_bytes),
|
|
})
|
|
await self.send_bytes(mp4_bytes)
|
|
except Exception:
|
|
log.exception("Video speaking-clip render failed; falling back silently.")
|
|
# Best-effort: tell the client nothing was spoken visually.
|
|
try:
|
|
await self.send_json({
|
|
"type": "response_text",
|
|
"text": response,
|
|
"chunk_id": 0,
|
|
"final": True,
|
|
})
|
|
except Exception:
|
|
pass
|
|
|
|
# Determine what was actually heard by the client
|
|
was_interrupted = spoken_text.strip() != response.strip()
|
|
if was_interrupted and self._last_played_chunk_id is not None:
|
|
# Client told us the last chunk whose audio actually played
|
|
heard_text = chunk_text_map.get(self._last_played_chunk_id, "")
|
|
log.info(f"Interrupted: client heard up to chunk {self._last_played_chunk_id}")
|
|
else:
|
|
heard_text = spoken_text
|
|
|
|
# Save only what was actually spoken/heard
|
|
if heard_text.strip():
|
|
# Use original LLM response when fully spoken (keeps KV-cache valid);
|
|
# use heard_text only when interrupted.
|
|
final_content = heard_text.strip() if was_interrupted else response
|
|
self.conversation_history.append(
|
|
{"role": "assistant", "content": final_content}
|
|
)
|
|
if was_interrupted and self.kv_cache_state is not None:
|
|
self.kv_cache_state = self.models.llm_engine.trim_cache(
|
|
self.kv_cache_state, self.conversation_history
|
|
)
|
|
elif self.conversation_history and self.conversation_history[-1]["role"] == "user":
|
|
self.conversation_history.pop()
|
|
self.kv_cache_state = None
|
|
|
|
if not self.cancel_event.is_set():
|
|
await self.send_json({
|
|
"type": "response_text",
|
|
"text": "",
|
|
"final": True,
|
|
"total_chunks": chunk_id,
|
|
})
|
|
|
|
self.is_responding = False
|
|
try:
|
|
await self.send_json({"type": "status", "state": "listening"})
|
|
except Exception:
|
|
pass
|