145 lines
4.3 KiB
Python
145 lines
4.3 KiB
Python
"""Unit test for the video-mode branch in ConversationSession.
|
|
|
|
Stubs every model involved (ASR, LLM, TTS, VideoEngine) so we can verify:
|
|
1. When video_engine is not ready, the existing PCM streaming path runs.
|
|
2. When video_engine IS ready, the per-chunk PCM sends are skipped and a
|
|
single ``speaking_clip`` JSON + MP4 binary is sent instead.
|
|
|
|
Pure asyncio; no CUDA, no real models.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import types
|
|
from unittest.mock import MagicMock
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from server.pipeline import ConversationSession
|
|
|
|
|
|
class _FakeVAD:
|
|
is_speaking = False
|
|
def process_chunk(self, _): return None
|
|
|
|
|
|
class _FakeASR:
|
|
def __init__(self, text="hello"):
|
|
self.text = text
|
|
def transcribe(self, _): return self.text
|
|
|
|
|
|
class _FakeLLM:
|
|
def __init__(self, response="Hi there."):
|
|
self.response = response
|
|
def generate(self, *_a, **_k):
|
|
return self.response, None
|
|
def trim_cache(self, state, _): return state
|
|
|
|
|
|
class _FakeTTSIterable:
|
|
"""Drop-in replacement for Kokoro's pipeline(..) generator."""
|
|
def __init__(self, chunks):
|
|
self._chunks = chunks
|
|
def __call__(self, segment, voice=None):
|
|
for i, audio in enumerate(self._chunks):
|
|
yield f"w{i}", None, audio
|
|
|
|
|
|
class _FakeTTSEngine:
|
|
def __init__(self, chunks):
|
|
self.pipeline = _FakeTTSIterable(chunks)
|
|
self.voice = "v"
|
|
self.sample_rate = 24000
|
|
|
|
|
|
class _FakeVideoEngineReady:
|
|
class _Cfg:
|
|
mode = "reflective"
|
|
cfg = _Cfg()
|
|
def __init__(self):
|
|
self.called_with = None
|
|
def is_ready(self): return True
|
|
def generate_speaking_clip(self, audio, sr, reply_text):
|
|
self.called_with = {"len": len(audio), "sr": sr, "reply": reply_text}
|
|
return b"FAKE_MP4_BYTES"
|
|
|
|
|
|
class _FakeModelsBase:
|
|
def __init__(self, tts_chunks):
|
|
self.asr_engine = _FakeASR()
|
|
self.llm_engine = _FakeLLM()
|
|
self.tts_engine = _FakeTTSEngine(tts_chunks)
|
|
def create_vad(self): return _FakeVAD()
|
|
|
|
|
|
class _FakeModelsStreaming(_FakeModelsBase):
|
|
video_engine = None
|
|
|
|
|
|
class _FakeModelsVideo(_FakeModelsBase):
|
|
def __init__(self, tts_chunks):
|
|
super().__init__(tts_chunks)
|
|
self.video_engine = _FakeVideoEngineReady()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_path_when_video_engine_absent():
|
|
json_sent: list = []
|
|
bytes_sent: list = []
|
|
|
|
async def send_json(d): json_sent.append(d)
|
|
async def send_bytes(b): bytes_sent.append(b)
|
|
|
|
chunks = [
|
|
np.ones(240, dtype=np.float32),
|
|
np.ones(480, dtype=np.float32),
|
|
]
|
|
models = _FakeModelsStreaming(tts_chunks=chunks)
|
|
session = ConversationSession(models, send_json, send_bytes)
|
|
await session._process_utterance(np.zeros(16000, dtype=np.float32))
|
|
|
|
# PCM bytes were sent (one per TTS chunk).
|
|
assert len(bytes_sent) == 2
|
|
# Per-chunk response_text messages were sent (not video's one-shot).
|
|
text_msgs = [m for m in json_sent if m.get("type") == "response_text"]
|
|
assert any(not m.get("final") for m in text_msgs)
|
|
# No speaking_clip envelope
|
|
assert not any(m.get("type") == "speaking_clip" for m in json_sent)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_video_path_when_engine_ready():
|
|
json_sent: list = []
|
|
bytes_sent: list = []
|
|
|
|
async def send_json(d): json_sent.append(d)
|
|
async def send_bytes(b): bytes_sent.append(b)
|
|
|
|
chunks = [
|
|
np.full(480, 0.5, dtype=np.float32),
|
|
np.full(480, 0.25, dtype=np.float32),
|
|
]
|
|
models = _FakeModelsVideo(tts_chunks=chunks)
|
|
session = ConversationSession(models, send_json, send_bytes)
|
|
await session._process_utterance(np.zeros(16000, dtype=np.float32))
|
|
|
|
# MP4 blob was sent once.
|
|
assert bytes_sent == [b"FAKE_MP4_BYTES"]
|
|
# speaking_clip envelope was sent exactly once.
|
|
envelopes = [m for m in json_sent if m.get("type") == "speaking_clip"]
|
|
assert len(envelopes) == 1
|
|
assert envelopes[0]["size_bytes"] == len(b"FAKE_MP4_BYTES")
|
|
assert envelopes[0]["text"] == "Hi there."
|
|
|
|
# The video engine received the concatenated audio.
|
|
ve = models.video_engine
|
|
assert ve.called_with is not None
|
|
assert ve.called_with["len"] == 960 # 480 + 480
|
|
assert ve.called_with["reply"] == "Hi there."
|
|
|
|
# No per-chunk PCM bytes were streamed (video path suppresses them).
|
|
# Only the MP4 blob is in bytes_sent.
|
|
assert len(bytes_sent) == 1
|