initial commit
This commit is contained in:
@@ -0,0 +1,164 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
|
||||
from server.audio_utils import float32_to_pcm_bytes
|
||||
from server.models import ModelManager
|
||||
from server.vad import StreamingVAD
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_SENTINEL = None
|
||||
|
||||
|
||||
class ConversationSession:
|
||||
"""Manages a single client's voice conversation pipeline.
|
||||
|
||||
Orchestrates: VAD -> ASR -> LLM -> TTS streaming with barge-in support.
|
||||
"""
|
||||
|
||||
def __init__(self, models: ModelManager, send_json, send_bytes):
|
||||
self.models = models
|
||||
self.send_json = send_json
|
||||
self.send_bytes = send_bytes
|
||||
|
||||
self.vad: StreamingVAD = models.create_vad()
|
||||
self.conversation_history: list[dict] = []
|
||||
self.cancel_event = threading.Event()
|
||||
self.is_responding = False
|
||||
self._response_task: asyncio.Task | None = None
|
||||
|
||||
async def start(self):
|
||||
await self.send_json({"type": "status", "state": "listening"})
|
||||
|
||||
async def stop(self):
|
||||
self.cancel_event.set()
|
||||
if self._response_task and not self._response_task.done():
|
||||
self._response_task.cancel()
|
||||
|
||||
async def handle_audio_chunk(self, chunk_16k: np.ndarray):
|
||||
utterance = self.vad.process_chunk(chunk_16k)
|
||||
|
||||
if utterance is not None:
|
||||
if self.is_responding:
|
||||
await self._interrupt()
|
||||
# Launch response pipeline as a background task so we don't block receives
|
||||
self._response_task = asyncio.create_task(self._process_utterance(utterance))
|
||||
elif self.vad.is_speaking and self.is_responding:
|
||||
await self._interrupt()
|
||||
|
||||
async def interrupt(self):
|
||||
"""Public interrupt method for WebSocket text messages."""
|
||||
if self.is_responding:
|
||||
await self._interrupt()
|
||||
|
||||
async def _interrupt(self):
|
||||
log.info("Barge-in: cancelling response.")
|
||||
self.cancel_event.set()
|
||||
self.is_responding = False
|
||||
# Tell client to stop audio immediately
|
||||
try:
|
||||
await self.send_json({"type": "interrupt"})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _process_utterance(self, audio_16k: np.ndarray):
|
||||
"""Full pipeline: ASR -> LLM -> TTS streaming."""
|
||||
self.is_responding = True
|
||||
self.cancel_event.clear()
|
||||
|
||||
# ASR
|
||||
await self.send_json({"type": "status", "state": "thinking"})
|
||||
text = await asyncio.to_thread(self.models.asr_engine.transcribe, audio_16k)
|
||||
|
||||
if not text:
|
||||
log.info("ASR returned empty text, resuming listening.")
|
||||
self.is_responding = False
|
||||
await self.send_json({"type": "status", "state": "listening"})
|
||||
return
|
||||
|
||||
await self.send_json({"type": "transcript", "text": text, "final": True})
|
||||
log.info(f"User: {text}")
|
||||
self.conversation_history.append({"role": "user", "content": text})
|
||||
|
||||
if self.cancel_event.is_set():
|
||||
self.is_responding = False
|
||||
return
|
||||
|
||||
# LLM
|
||||
log.info(f"Conversation history ({len(self.conversation_history)} messages): "
|
||||
+ str([m['content'][:50] for m in self.conversation_history]))
|
||||
response = await asyncio.to_thread(
|
||||
self.models.llm_engine.generate, self.conversation_history
|
||||
)
|
||||
|
||||
if self.cancel_event.is_set():
|
||||
self.is_responding = False
|
||||
return
|
||||
|
||||
# TTS - stream chunks with per-sentence text
|
||||
await self.send_json({"type": "status", "state": "speaking"})
|
||||
chunk_queue = queue.Queue()
|
||||
|
||||
def _tts_worker():
|
||||
try:
|
||||
for graphemes, _ps, audio in self.models.tts_engine.pipeline(
|
||||
response, voice=self.models.tts_engine.voice
|
||||
):
|
||||
if self.cancel_event.is_set():
|
||||
break
|
||||
if audio is not None and len(audio) > 0:
|
||||
chunk_queue.put((graphemes, audio))
|
||||
except Exception:
|
||||
log.exception("TTS generation error")
|
||||
finally:
|
||||
chunk_queue.put(_SENTINEL)
|
||||
|
||||
tts_thread = threading.Thread(target=_tts_worker, daemon=True)
|
||||
tts_thread.start()
|
||||
|
||||
spoken_text = ""
|
||||
while True:
|
||||
try:
|
||||
item = await asyncio.to_thread(chunk_queue.get, timeout=10.0)
|
||||
except Exception:
|
||||
break
|
||||
|
||||
if item is _SENTINEL:
|
||||
break
|
||||
if self.cancel_event.is_set():
|
||||
break
|
||||
|
||||
sentence_text, audio = item
|
||||
spoken_text += sentence_text
|
||||
|
||||
await self.send_json({"type": "response_text", "text": sentence_text, "final": False})
|
||||
pcm_bytes = float32_to_pcm_bytes(audio)
|
||||
try:
|
||||
await self.send_bytes(pcm_bytes)
|
||||
except Exception:
|
||||
log.warning("Failed to send audio, client disconnected.")
|
||||
self.cancel_event.set()
|
||||
break
|
||||
|
||||
tts_thread.join(timeout=2.0)
|
||||
|
||||
# Save only what was actually spoken
|
||||
if spoken_text.strip():
|
||||
self.conversation_history.append(
|
||||
{"role": "assistant", "content": spoken_text.strip()}
|
||||
)
|
||||
elif self.conversation_history and self.conversation_history[-1]["role"] == "user":
|
||||
self.conversation_history.pop()
|
||||
|
||||
if not self.cancel_event.is_set():
|
||||
await self.send_json({"type": "response_text", "text": "", "final": True})
|
||||
|
||||
self.is_responding = False
|
||||
try:
|
||||
await self.send_json({"type": "status", "state": "listening"})
|
||||
except Exception:
|
||||
pass
|
||||
Reference in New Issue
Block a user