first stab at adding video

This commit is contained in:
2026-04-12 04:11:52 -04:00
parent 680c5b04cc
commit 2818b41004
37 changed files with 2982 additions and 24 deletions
+4 -1
View File
@@ -1,3 +1,6 @@
.venv .venv
.claude .claude
__pycache__ __pycache__
tests/component/_out/
avatars/
loras/
+23
View File
@@ -4,6 +4,9 @@ ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
# HuggingFace model cache — mounted as a volume so models persist across runs # HuggingFace model cache — mounted as a volume so models persist across runs
ENV HF_HOME=/cache/huggingface ENV HF_HOME=/cache/huggingface
# LoRA directory — users drop .safetensors files here and reference them
# from config.yml::video.loras. Bind-mounted via docker-compose.
ENV LORA_DIR=/cache/loras
RUN apt-get update && apt-get install -y \ RUN apt-get update && apt-get install -y \
python3.11 \ python3.11 \
@@ -38,6 +41,26 @@ RUN python3.11 -m pip install --no-cache-dir -r requirements.txt
# Pre-download the spacy model that kokoro needs at runtime # Pre-download the spacy model that kokoro needs at runtime
RUN python3.11 -m spacy download en_core_web_sm RUN python3.11 -m spacy download en_core_web_sm
# --- Optional: avatar video stack -------------------------------------------
# These are heavy installs; keep them after the core deps so rebuilds only
# redo this layer when ONLY the video stack changes. If you don't plan to
# use config.video.enabled=true, you can comment this block out to speed
# up builds and shrink the image.
#
# LightX2V (Wan2.2-Lightning inference framework) — installed from source
# since there is no stable PyPI release yet.
RUN python3.11 -m pip install --no-cache-dir \
"git+https://github.com/ModelTC/LightX2V.git" || \
echo "LightX2V install failed — config.video.enabled must stay false until fixed"
#
# MuseTalk (audio-driven lip-sync) — same story.
RUN python3.11 -m pip install --no-cache-dir \
"git+https://github.com/TMElyralab/MuseTalk.git" || \
echo "MuseTalk install failed — config.video.enabled must stay false until fixed"
#
# LoRA directory (user drops .safetensors here; bind-mounted in compose).
RUN mkdir -p /cache/loras
COPY . . COPY . .
EXPOSE 8000 EXPOSE 8000
+46
View File
@@ -12,3 +12,49 @@ llm:
lmstudio: lmstudio:
url: http://host.docker.internal:1234 # host.docker.internal resolves to your PC from inside Docker url: http://host.docker.internal:1234 # host.docker.internal resolves to your PC from inside Docker
model: "" # leave empty to use whatever model LM Studio has loaded model: "" # leave empty to use whatever model LM Studio has loaded
# Avatar video generation (Wan2.2-Lightning fp8 via LightX2V + MuseTalk lip-sync)
video:
enabled: false # master toggle — when false, video models are not loaded
backend: lightx2v # only option for now
mode: reflective # "library" (pre-baked clips) | "reflective" (fresh per turn)
resolution: 480 # 480 or 720
fps: 16 # Wan2.2 native rate; MuseTalk resamples as needed
library:
base_clip_count: 4 # how many speaking base clips to pre-generate per avatar
base_clip_seconds: 6 # duration of each pre-baked clip
reflective:
clip_seconds: 5 # target length of each fresh Wan2.2 clip per turn
clip_prompt_template: >-
webcam view of a person speaking, {reply_hint},
casual gestures, natural lighting, soft focus background
prompt_reply_words: 18 # max words lifted from reply to inject as {reply_hint}
# Model sources for the video stack. The fp8 e4m3 4-step distilled DIT
# weights from lightx2v/Wan2.2-Distill-Models are ~15 GB each (vs ~28 GB
# bf16) — that's the "save VRAM" path. T5/VAE/tokenizer still come from
# the Wan-AI base repo. Both repos download on first run into
# HF_HOME=/cache/huggingface.
models:
wan22_base_repo: Wan-AI/Wan2.2-I2V-A14B
wan22_fp8_repo: lightx2v/Wan2.2-Distill-Models
wan22_model_cls: wan2.2_moe_distill
wan22_config_json: /app/configs/lightx2v/wan22_i2v_fp8_distill.json
musetalk_path: TMElyralab/MuseTalk
# LoRAs applied to the fp8 base at load time via runtime switch_lora.
# Wan2.2 is a MoE with separate high-noise and low-noise sub-models —
# `target` picks which sub-model each LoRA attaches to. The two files
# below are the user-supplied ./loras/wan22-[HL]-e8.safetensors mounted
# into the container at /cache/loras/.
loras:
- path: /cache/loras/wan22-H-e8.safetensors
weight: 1.0
target: high_noise
name: wan22-H-e8
- path: /cache/loras/wan22-L-e8.safetensors
weight: 1.0
target: low_noise
name: wan22-L-e8
@@ -0,0 +1,36 @@
{
"_comment": "Wan2.2 i2v MoE 4-step distill, fp8 e4m3 quantized. Built for 24 GB-class GPUs — cpu_offload keeps DIT layers swapping in block-by-block. Derived from LightX2V's configs/distill/wan22/wan_moe_i2v_distill_4090.json plus the quant scheme + ckpt overrides from wan_moe_i2v_distill_quant.json. high_noise_quantized_ckpt / low_noise_quantized_ckpt are filled in at runtime by server/video_models/wan22.py with absolute paths to the files downloaded into HF_HOME.",
"infer_steps": 4,
"target_video_length": 81,
"text_len": 512,
"resize_mode": "adaptive",
"resolution": "480p",
"target_height": 480,
"target_width": 480,
"fps": 16,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"sample_guide_scale": [3.5, 3.5],
"sample_shift": 5.0,
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "block",
"lazy_load": true,
"t5_cpu_offload": true,
"vae_cpu_offload": false,
"use_image_encoder": false,
"boundary_step_index": 2,
"denoising_step_list": [1000, 750, 500, 250],
"dit_quantized": true,
"dit_quant_scheme": "fp8-sgl",
"t5_quantized": false
}
+11
View File
@@ -0,0 +1,11 @@
"""Pytest configuration.
Ensures the project root is on ``sys.path`` so tests can import ``server.*``
without installing the project as a package.
"""
import os
import sys
_ROOT = os.path.dirname(os.path.abspath(__file__))
if _ROOT not in sys.path:
sys.path.insert(0, _ROOT)
+6
View File
@@ -6,8 +6,14 @@ services:
volumes: volumes:
# Cache models on the host so they survive container rebuilds # Cache models on the host so they survive container rebuilds
- huggingface-cache:/cache/huggingface - huggingface-cache:/cache/huggingface
# LoRA adapters — drop .safetensors files into ./loras on the host,
# reference them from config.yml as /cache/loras/<file>.safetensors
- ./loras:/cache/loras
# Avatar images uploaded via the web UI persist between restarts
- ./avatars:/app/avatars
# Mount source so you can edit code/config without rebuilding the image # Mount source so you can edit code/config without rebuilding the image
- ./config.yml:/app/config.yml:ro - ./config.yml:/app/config.yml:ro
- ./configs:/app/configs:ro
- ./server:/app/server:ro - ./server:/app/server:ro
- ./static:/app/static:ro - ./static:/app/static:ro
- ./run.py:/app/run.py:ro - ./run.py:/app/run.py:ro
+9
View File
@@ -14,3 +14,12 @@ soundfile
scipy scipy
python-multipart python-multipart
pyyaml pyyaml
# --- Avatar video (optional, only used when config.video.enabled=true) ---
# Video frame I/O (used by video_models/wan22.py and the muxer).
imageio[ffmpeg]>=2.34
av>=12.0
pyzmq>=25.0
# LightX2V (Wan2.2-Lightning) and MuseTalk are installed from source in the
# Dockerfile because neither ships a stable PyPI release yet. See lines
# "LightX2V from source" / "MuseTalk from source" in Dockerfile.
+121 -2
View File
@@ -1,23 +1,27 @@
import json import json
import logging import logging
import os import os
import tempfile
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import numpy as np import numpy as np
from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.params import Form from fastapi.params import Form
from fastapi.responses import FileResponse from fastapi.responses import FileResponse, Response
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from server.audio_utils import pcm_bytes_to_float32 from server.audio_utils import pcm_bytes_to_float32
from server.models import ModelManager from server.models import ModelManager
from server.pipeline import ConversationSession from server.pipeline import ConversationSession
from server.video import LoRASpec
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
REFERENCE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "reference_audio") REFERENCE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "reference_audio")
STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static") STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
AVATAR_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "avatars")
os.makedirs(AVATAR_DIR, exist_ok=True)
model_mgr = ModelManager() model_mgr = ModelManager()
@@ -47,6 +51,110 @@ async def set_voice(voice: str = Form(...), lang: str = Form("a")):
return {"status": "ok", "voice": voice} return {"status": "ok", "voice": voice}
# --- Video / avatar endpoints ---------------------------------------------
def _require_video() -> "object":
"""Return the video engine, or raise 404 if video mode isn't enabled."""
ve = model_mgr.video_engine
if ve is None:
raise HTTPException(
status_code=404,
detail="Video engine disabled. Set config.video.enabled=true and restart.",
)
return ve
@app.post("/api/set-avatar")
async def set_avatar(image: UploadFile):
"""Upload an avatar image and (re)generate cached clips."""
ve = _require_video()
suffix = os.path.splitext(image.filename or "avatar.png")[1] or ".png"
dest = os.path.join(AVATAR_DIR, f"avatar{suffix}")
with open(dest, "wb") as f:
f.write(await image.read())
log.info("Avatar saved to %s", dest)
import asyncio
try:
await asyncio.to_thread(ve.set_avatar, dest)
except Exception as e:
log.exception("set_avatar failed")
raise HTTPException(status_code=500, detail=f"Avatar setup failed: {e}")
return {
"status": "ok",
"avatar_path": dest,
"idle_clip_url": "/api/idle-clip",
"mode": ve.cfg.mode,
}
@app.get("/api/idle-clip")
async def idle_clip():
"""Return the cached idle loop MP4."""
ve = _require_video()
data = ve.get_idle_clip()
if data is None:
raise HTTPException(status_code=404, detail="No idle clip. Upload an avatar first.")
return Response(content=data, media_type="video/mp4")
@app.post("/api/set-video-mode")
async def set_video_mode(mode: str = Form(...)):
"""Switch between 'off', 'library', and 'reflective'.
'off' leaves the video engine loaded but makes the pipeline take the
PCM streaming path on subsequent turns (by marking the engine not-ready
from the client's perspective via a simple flag).
"""
ve = _require_video()
if mode not in ("off", "library", "reflective"):
raise HTTPException(
status_code=400,
detail="mode must be one of: off, library, reflective",
)
# Switching between library/reflective changes how set_avatar prebakes
# clips. Require a fresh avatar upload afterwards to re-bake.
ve.cfg.mode = mode
return {"status": "ok", "mode": mode, "note": "Re-upload avatar to re-bake library clips." if mode == "library" else ""}
@app.post("/api/reload-loras")
async def reload_loras(body: dict):
"""Hot-reload LoRA stack. Body: ``{"loras": [{"path","weight","target","name"}]}``.
Regenerates the idle clip if an avatar is already set, since the new
LoRAs change the base style.
"""
ve = _require_video()
raw = body.get("loras") or []
specs: list[LoRASpec] = []
for entry in raw:
if not entry or "path" not in entry:
continue
target = str(entry.get("target", "both")).lower()
if target not in ("high_noise", "low_noise", "both"):
target = "both"
specs.append(
LoRASpec(
path=str(entry["path"]),
weight=float(entry.get("weight", 1.0)),
target=target, # type: ignore[arg-type]
name=entry.get("name"),
)
)
import asyncio
try:
await asyncio.to_thread(ve.load_loras, specs)
if ve.avatar_path:
log.info("Regenerating idle clip after LoRA reload.")
await asyncio.to_thread(ve.set_avatar, ve.avatar_path)
except Exception as e:
log.exception("reload_loras failed")
raise HTTPException(status_code=500, detail=str(e))
return {"status": "ok", "lora_count": len(specs), "idle_clip_url": "/api/idle-clip"}
@app.websocket("/ws/chat") @app.websocket("/ws/chat")
async def websocket_chat(ws: WebSocket): async def websocket_chat(ws: WebSocket):
await ws.accept() await ws.accept()
@@ -61,6 +169,17 @@ async def websocket_chat(ws: WebSocket):
session = ConversationSession(model_mgr, send_json, send_bytes) session = ConversationSession(model_mgr, send_json, send_bytes)
await session.start() await session.start()
# Tell the client whether video mode is active so it knows whether to
# suppress PCM playback and wait for speaking_clip messages instead.
ve = model_mgr.video_engine
await send_json({
"type": "video_mode",
"enabled": ve is not None,
"ready": ve.is_ready() if ve is not None else False,
"mode": ve.cfg.mode if ve is not None else "off",
"idle_clip_url": "/api/idle-clip" if (ve is not None and ve.get_idle_clip()) else None,
})
try: try:
while True: while True:
message = await ws.receive() message = await ws.receive()
+26 -2
View File
@@ -5,6 +5,7 @@ from server.vad import StreamingVAD
from server.asr import ASREngine from server.asr import ASREngine
from server.llm import LLMEngine from server.llm import LLMEngine
from server.tts import TTSEngine from server.tts import TTSEngine
from server.video import VideoConfig, VideoEngine
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -31,6 +32,7 @@ class ModelManager:
self.asr_engine: ASREngine | None = None self.asr_engine: ASREngine | None = None
self.llm_engine: LLMEngine | None = None self.llm_engine: LLMEngine | None = None
self.tts_engine: TTSEngine | None = None self.tts_engine: TTSEngine | None = None
self.video_engine: VideoEngine | None = None
def load_all(self): def load_all(self):
"""Load all models sequentially. Call from the main process.""" """Load all models sequentially. Call from the main process."""
@@ -38,6 +40,7 @@ class ModelManager:
self._load_asr() self._load_asr()
self._load_llm() self._load_llm()
self._load_tts() self._load_tts()
self._load_video()
log.info("All models loaded successfully.") log.info("All models loaded successfully.")
def _load_vad(self): def _load_vad(self):
@@ -84,8 +87,8 @@ class ModelManager:
log.info("Loading Qwen3-4B (GPTQ 4-bit)...") log.info("Loading Qwen3-4B (GPTQ 4-bit)...")
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "Qwen/Qwen3.5-0.8B" # model_name = "Qwen/Qwen3.5-0.8B"
model_name = "dphn/Dolphin-X1-8B-FP8"
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
device = get_device() device = get_device()
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
@@ -101,6 +104,27 @@ class ModelManager:
self.tts_engine = TTSEngine() self.tts_engine = TTSEngine()
log.info("Kokoro TTS loaded.") log.info("Kokoro TTS loaded.")
def _load_video(self):
"""Load the avatar video stack iff config.video.enabled is true.
Leaves ``video_engine`` as None when disabled so existing voice flow
is untouched. Later phases replace this stub with actual Wan2.2 +
MuseTalk loading inside ``VideoEngine``.
"""
from server.config import config
video_cfg_raw = config.get("video", {}) or {}
if not video_cfg_raw.get("enabled", False):
log.info("Video engine disabled (config.video.enabled=false). Skipping load.")
return
log.info("Loading avatar video engine...")
cfg = VideoConfig.from_dict(video_cfg_raw)
self.video_engine = VideoEngine(cfg)
if cfg.loras:
self.video_engine.load_loras(cfg.loras)
log.info("Avatar video engine loaded (mode=%s).", cfg.mode)
def create_vad(self) -> StreamingVAD: def create_vad(self) -> StreamingVAD:
"""Create a new StreamingVAD instance for a client session.""" """Create a new StreamingVAD instance for a client session."""
return StreamingVAD(self.vad_model) return StreamingVAD(self.vad_model)
+73 -14
View File
@@ -157,11 +157,20 @@ class ConversationSession:
# TTS - stream chunks with per-sentence text # TTS - stream chunks with per-sentence text
await self.send_json({"type": "status", "state": "speaking"}) await self.send_json({"type": "status", "state": "speaking"})
# Video-mode branch: if a video engine is loaded AND an avatar is
# set, buffer the full TTS output into a single blob, run MuseTalk
# lip-sync (library or reflective source), mux to MP4, and send the
# full clip + text in one shot. The client plays the MP4 (which
# carries audio) instead of the per-chunk PCM path.
video_engine = getattr(self.models, "video_engine", None)
use_video = video_engine is not None and video_engine.is_ready()
chunk_queue = queue.Queue() chunk_queue = queue.Queue()
self._last_played_chunk_id = None self._last_played_chunk_id = None
segments = _split_into_segments(response) segments = _split_into_segments(response)
log.info(f"TTS: split response into {len(segments)} segments") log.info(f"TTS: split response into {len(segments)} segments (video={use_video})")
def _tts_worker(): def _tts_worker():
try: try:
@@ -187,6 +196,10 @@ class ConversationSession:
chunk_id = 0 chunk_id = 0
# Maps chunk_id -> cumulative text up to and including that chunk # Maps chunk_id -> cumulative text up to and including that chunk
chunk_text_map: dict[int, str] = {} chunk_text_map: dict[int, str] = {}
# Video mode accumulator: we buffer all TTS audio into one float32
# array so MuseTalk can align against the full utterance.
audio_buffer: list[np.ndarray] = []
while True: while True:
try: try:
item = await asyncio.to_thread(chunk_queue.get, timeout=10.0) item = await asyncio.to_thread(chunk_queue.get, timeout=10.0)
@@ -202,23 +215,69 @@ class ConversationSession:
spoken_text += sentence_text spoken_text += sentence_text
chunk_text_map[chunk_id] = spoken_text chunk_text_map[chunk_id] = spoken_text
await self.send_json({ if use_video:
"type": "response_text", audio_buffer.append(audio)
"text": sentence_text, # Don't stream text or PCM during video mode — we'll send
"chunk_id": chunk_id, # everything after the clip renders so the client doesn't
"final": False, # start displaying text before the video is ready.
}) else:
pcm_bytes = float32_to_pcm_bytes(audio) await self.send_json({
try: "type": "response_text",
await self.send_bytes(pcm_bytes) "text": sentence_text,
except Exception: "chunk_id": chunk_id,
log.warning("Failed to send audio, client disconnected.") "final": False,
self.cancel_event.set() })
break pcm_bytes = float32_to_pcm_bytes(audio)
try:
await self.send_bytes(pcm_bytes)
except Exception:
log.warning("Failed to send audio, client disconnected.")
self.cancel_event.set()
break
chunk_id += 1 chunk_id += 1
tts_thread.join(timeout=2.0) tts_thread.join(timeout=2.0)
# Video mode: render the speaking clip now that TTS is done.
if use_video and audio_buffer and not self.cancel_event.is_set():
try:
full_audio = np.concatenate(audio_buffer).astype(np.float32)
sample_rate = getattr(self.models.tts_engine, "sample_rate", 24000)
log.info(
"Video: rendering speaking clip (audio=%ds, mode=%s)",
int(len(full_audio) / sample_rate), video_engine.cfg.mode,
)
mp4_bytes = await asyncio.to_thread(
video_engine.generate_speaking_clip,
full_audio,
sample_rate,
response,
)
if self.cancel_event.is_set():
log.info("Video clip discarded (cancelled during render).")
else:
duration_ms = int(len(full_audio) / sample_rate * 1000)
await self.send_json({
"type": "speaking_clip",
"chunk_id": 0,
"duration_ms": duration_ms,
"text": response,
"size_bytes": len(mp4_bytes),
})
await self.send_bytes(mp4_bytes)
except Exception:
log.exception("Video speaking-clip render failed; falling back silently.")
# Best-effort: tell the client nothing was spoken visually.
try:
await self.send_json({
"type": "response_text",
"text": response,
"chunk_id": 0,
"final": True,
})
except Exception:
pass
# Determine what was actually heard by the client # Determine what was actually heard by the client
was_interrupted = spoken_text.strip() != response.strip() was_interrupted = spoken_text.strip() != response.strip()
if was_interrupted and self._last_played_chunk_id is not None: if was_interrupted and self._last_played_chunk_id is not None:
+391
View File
@@ -0,0 +1,391 @@
"""Avatar video generation: Wan2.2-Lightning base + MuseTalk lip-sync.
Top-level orchestrator. The heavy 3rd-party model code is isolated in
``server/video_models/`` so each wrapper can be updated independently.
This module is only imported by ``server/models.py`` when
``config.video.enabled`` is true. When disabled, the existing voice pipeline
is completely untouched.
"""
from __future__ import annotations
import logging
import threading
from dataclasses import dataclass, field
from typing import Literal
import numpy as np
log = logging.getLogger(__name__)
LoRATarget = Literal["high_noise", "low_noise", "both"]
@dataclass
class LoRASpec:
"""One LoRA adapter entry from ``config.video.loras``.
Wan2.2 I2V is a Mixture-of-Experts model with separate high-noise and
low-noise sub-models. Most LightX2V distill LoRAs come paired (one per
sub-model) and must be applied to the correct target. Allow
``target="both"`` for LoRAs that should be applied to both sub-models
(e.g. style LoRAs).
"""
path: str
weight: float = 1.0
target: LoRATarget = "both"
name: str | None = None
@dataclass
class VideoConfig:
"""Flattened view of the ``video:`` section of config.yml."""
enabled: bool = False
backend: str = "lightx2v"
mode: str = "reflective" # "library" | "reflective"
resolution: int = 480
fps: int = 16
library_base_clip_count: int = 4
library_base_clip_seconds: int = 6
reflective_clip_seconds: int = 5
reflective_prompt_template: str = (
"webcam view of a person speaking, {reply_hint}, casual gestures, "
"natural lighting, soft focus background"
)
reflective_prompt_reply_words: int = 18
loras: list[LoRASpec] = field(default_factory=list)
# Model paths — can be overridden via config.yml.video.models.
# wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer.
# The bf16 DIT shards in this repo are skipped — we
# replace them with the fp8 files from wan22_fp8_repo.
# wan22_fp8_repo : HF repo id (or local dir) providing the two fp8 e4m3
# 4-step distilled DIT checkpoints (~15 GB each).
# wan22_config_json: path to the LightX2V inference config template the
# Wan22Pipeline will fill in with absolute ckpt paths.
wan22_base_repo: str = "Wan-AI/Wan2.2-I2V-A14B"
wan22_fp8_repo: str = "lightx2v/Wan2.2-Distill-Models"
wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
wan22_model_cls: str = "wan2.2_moe_distill"
musetalk_model_path: str = "TMElyralab/MuseTalk"
@classmethod
def from_dict(cls, raw: dict) -> "VideoConfig":
raw = raw or {}
library = raw.get("library", {}) or {}
reflective = raw.get("reflective", {}) or {}
models_raw = raw.get("models", {}) or {}
loras_raw = raw.get("loras") or []
default_template = (
"webcam view of a person speaking, {reply_hint}, casual gestures, "
"natural lighting, soft focus background"
)
loras: list[LoRASpec] = []
for entry in loras_raw:
if not entry or "path" not in entry:
continue
target = str(entry.get("target", "both")).lower()
if target not in ("high_noise", "low_noise", "both"):
log.warning(
"LoRA %s: invalid target %r, defaulting to 'both'",
entry.get("path"), target,
)
target = "both"
loras.append(
LoRASpec(
path=str(entry["path"]),
weight=float(entry.get("weight", 1.0)),
target=target, # type: ignore[arg-type]
name=entry.get("name"),
)
)
return cls(
enabled=bool(raw.get("enabled", False)),
backend=str(raw.get("backend", "lightx2v")),
mode=str(raw.get("mode", "reflective")),
resolution=int(raw.get("resolution", 480)),
fps=int(raw.get("fps", 16)),
library_base_clip_count=int(library.get("base_clip_count", 4)),
library_base_clip_seconds=int(library.get("base_clip_seconds", 6)),
reflective_clip_seconds=int(reflective.get("clip_seconds", 5)),
reflective_prompt_template=str(
reflective.get("clip_prompt_template", default_template)
),
reflective_prompt_reply_words=int(reflective.get("prompt_reply_words", 18)),
loras=loras,
wan22_base_repo=str(
models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-I2V-A14B")
),
wan22_fp8_repo=str(
models_raw.get("wan22_fp8_repo", "lightx2v/Wan2.2-Distill-Models")
),
wan22_config_json=str(
models_raw.get(
"wan22_config_json",
"/app/configs/lightx2v/wan22_i2v_fp8_distill.json",
)
),
wan22_model_cls=str(
models_raw.get("wan22_model_cls", "wan2.2_moe_distill")
),
musetalk_model_path=str(
models_raw.get("musetalk_path", "TMElyralab/MuseTalk")
),
)
# Library-mode base-clip prompts. Varied gestures so the pre-baked set feels
# less repetitive when replayed. Kept module-level so tests can import them.
LIBRARY_BASE_PROMPTS = [
"webcam view of a person speaking, subtle head nods, casual expression, "
"natural lighting, soft focus background",
"webcam view of a person speaking, slight smile, gentle hand gesture, "
"natural lighting, soft focus background",
"webcam view of a person speaking, looking thoughtful, small head tilt, "
"natural lighting, soft focus background",
"webcam view of a person speaking, engaged and attentive, minor shoulder "
"movement, natural lighting, soft focus background",
"webcam view of a person speaking, relaxed posture, blinking naturally, "
"natural lighting, soft focus background",
]
IDLE_PROMPT = (
"webcam view of a person listening quietly, mouth closed, subtle "
"breathing, occasional blinks, calm expression, natural lighting, "
"soft focus background"
)
class VideoEngine:
"""Top-level video generation orchestrator.
Holds the Wan2.2 and MuseTalk model wrappers, plus the current avatar's
pre-rendered clips. Exposed to ``ConversationSession`` via
``ModelManager.video_engine``.
"""
def __init__(self, cfg: VideoConfig):
self.cfg = cfg
self._lock = threading.Lock()
# Avatar state
self.avatar_path: str | None = None
self.idle_clip_mp4: bytes | None = None
# Pre-baked speaking base clips for library mode. Each entry is a
# contiguous ``np.ndarray`` of shape ``[T, H, W, 3]`` uint8.
self.speaking_base_frames: list[np.ndarray] = []
# Round-robin pointer for picking a library clip per turn
self._library_cursor = 0
# Model wrappers — instantiated lazily by ``load_models()`` so unit
# tests can exercise VideoEngine without touching CUDA at all.
self._wan22 = None # server.video_models.wan22.Wan22Pipeline
self._musetalk = None # server.video_models.musetalk.MuseTalkEngine
log.info(
"VideoEngine initialised (mode=%s, resolution=%d, fps=%d, loras=%d).",
cfg.mode, cfg.resolution, cfg.fps, len(cfg.loras),
)
# --- Model loading --------------------------------------------------
def load_models(self) -> None:
"""Instantiate the underlying model wrappers.
Separated from ``__init__`` so tests can mock ``_wan22``/``_musetalk``
without triggering Wan2.2's ~12-16GB VRAM allocation.
"""
from server.video_models.wan22 import Wan22Pipeline
from server.video_models.musetalk import MuseTalkEngine
log.info(
"Loading Wan2.2-Lightning fp8 pipeline (base=%s, fp8=%s)...",
self.cfg.wan22_base_repo, self.cfg.wan22_fp8_repo,
)
self._wan22 = Wan22Pipeline(
base_repo=self.cfg.wan22_base_repo,
fp8_repo=self.cfg.wan22_fp8_repo,
config_json=self.cfg.wan22_config_json,
model_cls=self.cfg.wan22_model_cls,
resolution=self.cfg.resolution,
fps=self.cfg.fps,
)
if self.cfg.loras:
self._wan22.load_loras(self.cfg.loras)
log.info("Wan2.2 pipeline ready.")
log.info("Loading MuseTalk engine (%s)...", self.cfg.musetalk_model_path)
self._musetalk = MuseTalkEngine(model_path=self.cfg.musetalk_model_path)
log.info("MuseTalk engine ready.")
# --- Readiness ------------------------------------------------------
def is_ready(self) -> bool:
"""True when an avatar is set and a speaking clip can be produced."""
return (
self._wan22 is not None
and self._musetalk is not None
and self.avatar_path is not None
and self.idle_clip_mp4 is not None
)
# --- LoRA management ------------------------------------------------
def load_loras(self, specs: list[LoRASpec]) -> None:
"""Apply a list of LoRA adapters to the Wan2.2 base.
Replaces any previously applied LoRAs. Safe to call after init for
hot-reload via ``POST /api/reload-loras``.
"""
if self._wan22 is None:
raise RuntimeError("load_loras called before load_models()")
with self._lock:
self._wan22.unload_loras()
self._wan22.load_loras(specs)
self.cfg.loras = list(specs)
log.info("Applied %d LoRA(s): %s",
len(specs),
", ".join(s.name or s.path for s in specs) or "<none>")
# --- Avatar lifecycle ----------------------------------------------
def set_avatar(self, image_path: str) -> None:
"""Register an avatar image and pre-generate cached clips.
- Always: generate the idle loop.
- Library mode: also pre-generate ``library.base_clip_count``
speaking base clips.
- Reflective mode: idle loop only.
"""
if self._wan22 is None:
raise RuntimeError("set_avatar called before load_models()")
with self._lock:
log.info("Setting avatar: %s", image_path)
self.avatar_path = image_path
# Drop any previously cached clips so the new avatar's library
# doesn't mix with the old.
self.speaking_base_frames = []
self.idle_clip_mp4 = None
# Idle clip: short loop, neutral/listening prompt.
log.info("Generating idle clip...")
idle_frames = self._wan22.generate_i2v(
image_path=image_path,
prompt=IDLE_PROMPT,
seconds=self.cfg.library_base_clip_seconds,
seed=0,
)
from server.video_models.muxer import frames_to_mp4_loop
self.idle_clip_mp4 = frames_to_mp4_loop(idle_frames, fps=self.cfg.fps)
log.info("Idle clip ready (%d bytes).", len(self.idle_clip_mp4))
# Library mode: pre-bake N speaking base clips.
if self.cfg.mode == "library":
n = self.cfg.library_base_clip_count
log.info("Pre-baking %d speaking base clip(s) for library mode.", n)
for i in range(n):
prompt = LIBRARY_BASE_PROMPTS[i % len(LIBRARY_BASE_PROMPTS)]
frames = self._wan22.generate_i2v(
image_path=image_path,
prompt=prompt,
seconds=self.cfg.library_base_clip_seconds,
seed=i + 1,
)
self.speaking_base_frames.append(frames)
log.info(" base clip %d/%d rendered", i + 1, n)
self._library_cursor = 0
def get_idle_clip(self) -> bytes | None:
return self.idle_clip_mp4
# --- Per-turn generation -------------------------------------------
def generate_speaking_clip(
self,
audio_f32: np.ndarray,
sample_rate: int,
reply_text: str,
) -> bytes:
"""Produce a lip-synced MP4 for one assistant turn."""
if not self.is_ready():
raise RuntimeError(
"generate_speaking_clip: engine not ready "
"(avatar set? models loaded?)"
)
assert self._wan22 is not None
assert self._musetalk is not None
# 1. Source base frames.
if self.cfg.mode == "library":
base_frames = self._pick_library_frames(audio_f32, sample_rate)
else: # reflective
prompt = self._derive_prompt(reply_text)
log.info("Reflective prompt: %s", prompt[:120])
base_frames = self._wan22.generate_i2v(
image_path=self.avatar_path or "",
prompt=prompt,
seconds=self.cfg.reflective_clip_seconds,
seed=None, # random each turn
)
# 2. Lip-sync the base frames to the given audio.
synced_frames = self._musetalk.lip_sync(
frames=base_frames,
audio=audio_f32,
sample_rate=sample_rate,
fps=self.cfg.fps,
)
# 3. Mux frames + audio into an MP4.
from server.video_models.muxer import frames_and_audio_to_mp4
return frames_and_audio_to_mp4(
frames=synced_frames,
audio=audio_f32,
sample_rate=sample_rate,
fps=self.cfg.fps,
)
def _pick_library_frames(
self, audio_f32: np.ndarray, sample_rate: int
) -> np.ndarray:
"""Round-robin pick from the pre-baked library, clipped or looped
to roughly the audio's duration so there's no long freeze frame."""
if not self.speaking_base_frames:
raise RuntimeError(
"Library mode has no pre-baked base clips. "
"Was set_avatar called with mode=library?"
)
frames = self.speaking_base_frames[
self._library_cursor % len(self.speaking_base_frames)
]
self._library_cursor += 1
target_frames = int(round(len(audio_f32) / sample_rate * self.cfg.fps))
if target_frames <= 0:
return frames
if target_frames <= len(frames):
return frames[:target_frames]
# Loop (with a mirror tail to soften the seam) to cover longer audio.
loops = target_frames // len(frames) + 1
extended = np.concatenate([frames] * loops, axis=0)
return extended[:target_frames]
def _derive_prompt(self, reply_text: str) -> str:
"""Template-based prompt builder for reflective mode.
Takes up to ``prompt_reply_words`` words from the start of the reply
and interpolates them into the configured template. Cheap,
deterministic, no extra LLM call.
"""
words = (reply_text or "").split()
hint = " ".join(words[: self.cfg.reflective_prompt_reply_words]).strip()
if not hint:
hint = "calm and friendly"
return self.cfg.reflective_prompt_template.format(reply_hint=hint)
+10
View File
@@ -0,0 +1,10 @@
"""Thin wrappers around 3rd-party video generation models.
Each submodule isolates one external dependency so the real API surface
can be updated in a single file without touching the pipeline.
Submodules:
- ``wan22``: Wan2.2-Lightning image-to-video via LightX2V
- ``musetalk``: MuseTalk audio-driven lip-sync
- ``muxer``: ffmpeg-based frame/audio → MP4 encoding
"""
+164
View File
@@ -0,0 +1,164 @@
"""MuseTalk audio-driven lip-sync wrapper.
MuseTalk takes a sequence of face frames + driving audio and returns a new
sequence of frames where the mouth region is animated to match the audio.
This module isolates MuseTalk's real API behind a single ``lip_sync()``
method. MuseTalk's upstream Python surface varies between forks — if the
real import path or call signature differs, update this file only.
"""
from __future__ import annotations
import logging
import os
import numpy as np
log = logging.getLogger(__name__)
class MuseTalkEngine:
"""Thin wrapper over MuseTalk inference."""
def __init__(self, model_path: str = "TMElyralab/MuseTalk"):
self.model_path = model_path
# MuseTalk's canonical entry point is ``musetalk.inference`` or a
# similar ``MuseTalkInfer`` class. Try the most common imports.
self._infer = self._load_impl(model_path)
log.info("MuseTalk engine loaded from %s", model_path)
@staticmethod
def _load_impl(model_path: str):
"""Load the MuseTalk inference implementation.
If none of the known entry points work the error message points at
this file so you know where to fix it.
"""
resolved = model_path
if not os.path.isdir(model_path) and "/" in model_path:
try:
from huggingface_hub import snapshot_download
resolved = snapshot_download(repo_id=model_path)
except Exception as e: # pragma: no cover
log.warning("Could not snapshot_download MuseTalk repo: %s", e)
# Try upstream MuseTalk repo layout.
try:
from musetalk.musetalk_inference import MuseTalkInference # type: ignore[import-not-found]
return MuseTalkInference(model_path=resolved)
except ImportError:
pass
try:
from musetalk.inference import MuseTalkInfer # type: ignore[import-not-found]
return MuseTalkInfer(model_path=resolved)
except ImportError:
pass
try:
from musetalk import Inference # type: ignore[import-not-found]
return Inference(model_path=resolved)
except ImportError:
pass
raise RuntimeError(
"MuseTalk is installed but no known Python entry point was found. "
"Update server/video_models/musetalk.py::MuseTalkEngine._load_impl "
"to match the installed MuseTalk version."
)
# --- Inference ---------------------------------------------------------
def lip_sync(
self,
frames: np.ndarray,
audio: np.ndarray,
sample_rate: int,
fps: int,
) -> np.ndarray:
"""Return new frames with lip-sync applied to match ``audio``.
Args:
frames: uint8 ``[T, H, W, 3]`` RGB base frames.
audio: float32 mono 1D audio.
sample_rate: sample rate of ``audio``.
fps: frame rate of ``frames``.
Returns:
uint8 ``[T', H, W, 3]`` RGB frames. ``T'`` is trimmed or padded
to match audio duration at ``fps``.
"""
if frames.ndim != 4 or frames.shape[-1] != 3:
raise ValueError(
f"frames must be [T, H, W, 3] uint8, got {frames.shape}"
)
# Normalise frame count to audio duration so the caller doesn't have
# to do the arithmetic.
target_t = int(round(len(audio) / sample_rate * fps))
if target_t > 0 and len(frames) != target_t:
frames = _fit_frames_to_length(frames, target_t)
# The real MuseTalk call signature varies. Most common is a method
# like ``run(frames, audio, sr, fps)`` or ``infer(...)``.
for method_name in ("run", "infer", "lip_sync", "__call__"):
method = getattr(self._infer, method_name, None)
if method is None:
continue
try:
result = method(
frames=frames,
audio=audio,
sample_rate=sample_rate,
fps=fps,
)
return _ensure_uint8_rgb(result)
except TypeError:
# Try positional
try:
result = method(frames, audio, sample_rate, fps)
return _ensure_uint8_rgb(result)
except TypeError:
continue
raise RuntimeError(
"MuseTalk wrapper could not find a working inference method. "
"Update server/video_models/musetalk.py::MuseTalkEngine.lip_sync."
)
def _fit_frames_to_length(frames: np.ndarray, target_t: int) -> np.ndarray:
"""Trim or repeat ``frames`` (contiguous T axis) to exactly ``target_t``.
Repeats with a ping-pong / boomerang tail so the seam between loops is
less jarring than a hard cut back to frame 0.
"""
if target_t <= 0:
return frames
t = len(frames)
if t == target_t:
return frames
if t > target_t:
return frames[:target_t]
# Extend via ping-pong looping
extended = [frames]
total = t
flip = True
while total < target_t:
seg = frames[::-1] if flip else frames
extended.append(seg)
total += t
flip = not flip
return np.concatenate(extended, axis=0)[:target_t]
def _ensure_uint8_rgb(arr) -> np.ndarray:
"""Coerce the MuseTalk output to uint8 [T, H, W, 3] RGB."""
result = np.asarray(arr)
if result.dtype != np.uint8:
if result.dtype in (np.float32, np.float64):
result = np.clip(result * 255.0, 0, 255).astype(np.uint8)
else:
result = result.astype(np.uint8)
if result.ndim == 3:
result = result[None, ...]
return result
+146
View File
@@ -0,0 +1,146 @@
"""ffmpeg-based frame + audio → MP4 muxing.
Uses the system ``ffmpeg`` binary already installed in the Dockerfile.
No extra python dependencies beyond ``numpy``.
"""
from __future__ import annotations
import logging
import os
import shutil
import subprocess
import tempfile
import numpy as np
log = logging.getLogger(__name__)
def _ffmpeg_bin() -> str:
bin_path = shutil.which("ffmpeg")
if bin_path is None:
raise RuntimeError(
"ffmpeg binary not found on PATH. It should be installed by "
"the Dockerfile (line 13). Ensure you're running inside the "
"docker image or install ffmpeg locally."
)
return bin_path
def _write_raw_frames(frames: np.ndarray, path: str) -> tuple[int, int]:
"""Write uint8 RGB frames to ``path`` as raw rgb24 bytes. Returns (h, w)."""
if frames.ndim != 4 or frames.shape[-1] != 3:
raise ValueError(
f"frames must be [T, H, W, 3] uint8, got {frames.shape}"
)
if frames.dtype != np.uint8:
frames = frames.astype(np.uint8)
with open(path, "wb") as f:
f.write(frames.tobytes())
_, h, w, _ = frames.shape
return h, w
def _write_wav(audio: np.ndarray, sample_rate: int, path: str) -> None:
"""Write a float32 mono audio array to a 16-bit PCM WAV at ``path``."""
from scipy.io import wavfile # type: ignore[import-not-found]
audio = np.asarray(audio, dtype=np.float32).reshape(-1)
int16 = np.clip(audio * 32767.0, -32768, 32767).astype(np.int16)
wavfile.write(path, sample_rate, int16)
def frames_to_mp4_loop(frames: np.ndarray, fps: int) -> bytes:
"""Encode ``frames`` to a silent MP4 suitable for looping playback.
Used for the idle clip: no audio track, loopable on an HTMLMediaElement
without audible seams.
"""
if frames.size == 0:
raise ValueError("frames_to_mp4_loop: empty frames")
ffmpeg = _ffmpeg_bin()
with tempfile.TemporaryDirectory() as td:
raw_path = os.path.join(td, "frames.raw")
out_path = os.path.join(td, "out.mp4")
h, w = _write_raw_frames(frames, raw_path)
cmd = [
ffmpeg, "-y",
"-f", "rawvideo",
"-pix_fmt", "rgb24",
"-s", f"{w}x{h}",
"-r", str(fps),
"-i", raw_path,
"-an",
"-c:v", "libx264",
"-preset", "veryfast",
"-pix_fmt", "yuv420p",
"-movflags", "+faststart",
out_path,
]
log.debug("muxer idle clip: %s", " ".join(cmd))
_run_ffmpeg(cmd)
with open(out_path, "rb") as f:
return f.read()
def frames_and_audio_to_mp4(
frames: np.ndarray,
audio: np.ndarray,
sample_rate: int,
fps: int,
) -> bytes:
"""Encode ``frames`` + ``audio`` to an MP4 with H.264 video + AAC audio.
Used for per-turn speaking clips.
"""
if frames.size == 0:
raise ValueError("frames_and_audio_to_mp4: empty frames")
if audio.size == 0:
raise ValueError("frames_and_audio_to_mp4: empty audio")
ffmpeg = _ffmpeg_bin()
with tempfile.TemporaryDirectory() as td:
raw_path = os.path.join(td, "frames.raw")
wav_path = os.path.join(td, "audio.wav")
out_path = os.path.join(td, "out.mp4")
h, w = _write_raw_frames(frames, raw_path)
_write_wav(audio, sample_rate, wav_path)
cmd = [
ffmpeg, "-y",
"-f", "rawvideo",
"-pix_fmt", "rgb24",
"-s", f"{w}x{h}",
"-r", str(fps),
"-i", raw_path,
"-i", wav_path,
"-c:v", "libx264",
"-preset", "veryfast",
"-pix_fmt", "yuv420p",
"-c:a", "aac",
"-b:a", "128k",
"-shortest",
"-movflags", "+faststart",
out_path,
]
log.debug("muxer speaking clip: %s", " ".join(cmd))
_run_ffmpeg(cmd)
with open(out_path, "rb") as f:
return f.read()
def _run_ffmpeg(cmd: list[str]) -> None:
try:
proc = subprocess.run(
cmd,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
except subprocess.CalledProcessError as e:
log.error("ffmpeg failed (exit %d): %s", e.returncode, e.stderr.decode(errors="replace"))
raise
if proc.returncode != 0: # pragma: no cover
raise RuntimeError(f"ffmpeg returned {proc.returncode}")
+423
View File
@@ -0,0 +1,423 @@
"""Wan2.2-Lightning fp8 image-to-video wrapper via LightX2V.
This wrapper targets LightX2V's actual Python entry points (verified against
the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main):
from lightx2v.utils.set_config import set_config
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
from lightx2v.infer import init_runner
args = argparse.Namespace(model_cls=..., task="i2v", model_path=..., config_json=..., ...)
config = set_config(args)
input_info = init_empty_input_info(args.task, args.support_tasks)
runner = init_runner(config) # loads all weights — done ONCE
update_input_info_from_dict(input_info, {"seed": ..., "prompt": ..., "image_path": ..., "save_result_path": ...})
runner.run_pipeline(input_info) # per-turn; MP4 written to save_result_path
# LoRA hot-swap:
runner.switch_lora(lora_path, strength) # swap in
runner.switch_lora("", 0.0) # remove
Model weights are loaded once at construction and held resident across turns
so reflective mode doesn't re-pay the load cost each reply.
Two HuggingFace repos are consumed on first run (cached under HF_HOME):
- Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only.
The bf16 DIT shards under high_noise_model/
and low_noise_model/ are SKIPPED via
ignore_patterns — we replace them with fp8.
- lightx2v/Wan2.2-Distill-Models — exactly two safetensors files:
the fp8 e4m3 4-step distilled high/low
noise DIT checkpoints (~15 GB each).
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import random
import tempfile
from typing import TYPE_CHECKING
import numpy as np
if TYPE_CHECKING:
from server.video import LoRASpec
log = logging.getLogger(__name__)
FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
# The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the
# T5/VAE/tokenizer support files (~12 GB). We only need the latter — the fp8
# files from the distill repo replace the DIT weights entirely. We must keep
# the config.json / index.json metadata under high_noise_model/ and
# low_noise_model/ (LightX2V's set_config reads architecture params like
# ``dim`` from them) and the tokenizer files under google/.
BASE_REPO_IGNORE_PATTERNS = [
"high_noise_model/*.safetensors",
"low_noise_model/*.safetensors",
"assets/*",
"examples/*",
"nohup.out",
"*.md",
]
class Wan22Pipeline:
"""Wrapper around LightX2V's Wan2.2 MoE distill runner using fp8 weights.
Constructor downloads (if needed) both HF repos, writes a runtime JSON
config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``.
``generate_i2v`` runs one inference turn against the already-loaded runner.
"""
def __init__(
self,
base_repo: str,
fp8_repo: str,
config_json: str,
model_cls: str = "wan2.2_moe_distill",
resolution: int = 480,
fps: int = 16,
):
self.base_repo = base_repo
self.fp8_repo = fp8_repo
self.config_json_template = config_json
self.model_cls = model_cls
self.resolution = resolution
self.fps = fps
self._applied_loras: list[LoRASpec] = []
# 1. Resolve / download base repo (T5/VAE/config) and fp8 DIT ckpts.
self._model_root = self._ensure_base_repo(base_repo)
self._fp8_high, self._fp8_low = self._ensure_fp8_checkpoints(fp8_repo)
# 2. Materialize a runtime JSON config with absolute ckpt paths.
self._runtime_json_path = self._build_runtime_config()
# 3. Build the argparse-like namespace LightX2V.set_config() expects.
args = self._build_args(
model_cls=model_cls,
model_path=self._model_root,
config_json=self._runtime_json_path,
)
# 4. set_config → init_runner. Runner construction triggers weight load.
# Imports are scoped here so ``import server.video_models.wan22``
# never pulls in lightx2v (tests can import this module on CPU).
from lightx2v.utils.set_config import set_config # type: ignore[import-not-found]
from lightx2v.utils.input_info import init_empty_input_info # type: ignore[import-not-found]
from lightx2v.infer import init_runner # type: ignore[import-not-found]
log.info("LightX2V set_config (model_cls=%s, model_path=%s)",
model_cls, self._model_root)
self._config = set_config(args)
self._input_info_template = init_empty_input_info(
args.task, args.support_tasks
)
log.info("LightX2V init_runner — loading weights (this takes a while)...")
self._runner = init_runner(self._config)
log.info("LightX2V runner loaded; weights resident.")
# --- Weight provisioning -------------------------------------------------
@staticmethod
def _ensure_base_repo(base_repo: str) -> str:
"""Return a local directory containing the Wan2.2 base support files.
If ``base_repo`` is already a local directory, use it as-is. Otherwise
snapshot_download the HF repo into HF_HOME, skipping the bf16 DIT
shards (they're replaced by the fp8 files).
"""
if os.path.isdir(base_repo):
return base_repo
from huggingface_hub import snapshot_download
log.info("Downloading Wan2.2 base support files from %s "
"(skipping bf16 DIT shards)...", base_repo)
return snapshot_download(
repo_id=base_repo,
ignore_patterns=BASE_REPO_IGNORE_PATTERNS,
)
@staticmethod
def _ensure_fp8_checkpoints(fp8_repo: str) -> tuple[str, str]:
"""Return (high_noise_path, low_noise_path) for the fp8 i2v MoE pair.
- If ``fp8_repo`` is a local directory, expect both files inside it.
- Otherwise treat it as a HF repo id and download only the two files
we need (not the ~150 GB of other variants in that repo).
"""
if not fp8_repo:
raise ValueError("fp8_repo must be a HF repo id or local directory.")
if os.path.isdir(fp8_repo):
high = os.path.join(fp8_repo, FP8_HIGH_NOISE_FILE)
low = os.path.join(fp8_repo, FP8_LOW_NOISE_FILE)
if not (os.path.isfile(high) and os.path.isfile(low)):
raise FileNotFoundError(
f"fp8 checkpoints not found in {fp8_repo}: expected "
f"{FP8_HIGH_NOISE_FILE} and {FP8_LOW_NOISE_FILE}"
)
return high, low
from huggingface_hub import hf_hub_download
log.info("Downloading fp8 i2v DIT checkpoints from %s ...", fp8_repo)
high = hf_hub_download(repo_id=fp8_repo, filename=FP8_HIGH_NOISE_FILE)
low = hf_hub_download(repo_id=fp8_repo, filename=FP8_LOW_NOISE_FILE)
return high, low
def _build_runtime_config(self) -> str:
"""Load the template JSON, inject absolute ckpt paths, persist to temp."""
with open(self.config_json_template, "r", encoding="utf-8") as f:
cfg = json.load(f)
# Drop editorial comments before passing to LightX2V.
cfg.pop("_comment", None)
cfg["high_noise_quantized_ckpt"] = self._fp8_high
cfg["low_noise_quantized_ckpt"] = self._fp8_low
cfg.setdefault("fps", self.fps)
tmp = tempfile.NamedTemporaryFile(
prefix="wan22_fp8_", suffix=".json",
mode="w", delete=False, encoding="utf-8",
)
json.dump(cfg, tmp, indent=2)
tmp.close()
log.info("Runtime LightX2V config: %s", tmp.name)
return tmp.name
@staticmethod
def _build_args(
*, model_cls: str, model_path: str, config_json: str
) -> argparse.Namespace:
"""Mirror every field from ``lightx2v.infer.main``'s argparse so
``set_config`` finds the attributes it expects. We only customize the
model/task/path fields; everything else stays at the CLI defaults.
"""
return argparse.Namespace(
seed=42,
model_cls=model_cls,
task="i2v",
support_tasks=[],
model_path=model_path,
sf_model_path=None,
config_json=config_json,
use_prompt_enhancer=False,
prompt="",
negative_prompt="",
image_path="",
last_frame_path="",
audio_path="",
image_strength="1.0",
image_frame_idx="",
src_ref_images=None,
src_video=None,
src_mask=None,
src_pose_path=None,
src_face_path=None,
src_bg_path=None,
src_mask_path=None,
pose=None,
action_path=None,
action_ckpt=None,
save_result_path=None,
return_result_tensor=False,
target_shape=[],
target_video_length=81,
aspect_ratio="",
video_path=None,
sr_ratio=2.0,
)
# --- LoRA --------------------------------------------------------------
def load_loras(self, specs: list["LoRASpec"]) -> None:
"""Apply LoRAs to the Wan2.2 MoE distill pipeline.
Each spec's ``target`` must be ``"high_noise"`` or ``"low_noise"``
to route the LoRA to the correct expert.
With ``lazy_load`` the DIT models are ``None`` at this point, so
runtime ``switch_lora`` is impossible. Instead we inject
``lora_configs`` + ``lora_dynamic_apply`` into the runner config so
the LoRAs are applied when the models materialise on first inference.
Without ``lazy_load`` (models already resident) we call
``switch_lora`` with explicit high/low keyword args.
"""
if not specs:
return
# Resolve every path up-front (may trigger HF download).
resolved: list[tuple["LoRASpec", str]] = []
for spec in specs:
local_path = self._resolve_lora_path(spec.path)
log.info(" LoRA %s → strength=%.2f target=%s (%s)",
spec.name or spec.path, spec.weight, spec.target,
local_path)
resolved.append((spec, local_path))
lazy = self._config.get("lazy_load", False)
if lazy:
# Build the lora_configs list that LightX2V's lazy-load path
# reads inside MultiDistillModelStruct.infer().
lora_cfgs = []
for spec, local_path in resolved:
# LightX2V expects name "high_noise_model" / "low_noise_model"
cfg_name = {
"high_noise": "high_noise_model",
"low_noise": "low_noise_model",
}.get(spec.target)
if cfg_name is None:
raise ValueError(
f"LoRA target must be 'high_noise' or 'low_noise', "
f"got {spec.target!r}")
lora_cfgs.append({
"name": cfg_name,
"path": local_path,
"strength": spec.weight,
})
self._runner.set_config({
"lora_configs": lora_cfgs,
"lora_dynamic_apply": True,
})
else:
# Models are loaded — use runtime hot-swap.
high_path = high_strength = None
low_path = low_strength = None
for spec, local_path in resolved:
if spec.target == "high_noise":
high_path, high_strength = local_path, spec.weight
elif spec.target == "low_noise":
low_path, low_strength = local_path, spec.weight
else:
raise ValueError(
f"LoRA target must be 'high_noise' or 'low_noise', "
f"got {spec.target!r}")
kwargs: dict = {}
if high_path is not None:
kwargs["high_lora_path"] = high_path
kwargs["high_lora_strength"] = high_strength
if low_path is not None:
kwargs["low_lora_path"] = low_path
kwargs["low_lora_strength"] = low_strength
ok = self._runner.switch_lora(**kwargs)
if not ok:
raise RuntimeError(
"runner.switch_lora returned False. Check that your "
"LightX2V build supports runtime LoRA updates for "
f"{self.model_cls}.")
self._applied_loras = list(specs)
def unload_loras(self) -> None:
"""Remove all currently applied LoRAs."""
if not self._applied_loras:
return
lazy = self._config.get("lazy_load", False)
if lazy:
self._runner.set_config({
"lora_configs": None,
"lora_dynamic_apply": False,
})
# If models were materialised, drop them so the next inference
# recreates them without LoRAs.
model_struct = getattr(self._runner, "model", None)
if model_struct is not None and hasattr(model_struct, "model"):
for i in range(len(model_struct.model)):
model_struct.model[i] = None
else:
self._runner.switch_lora("", 0.0)
self._applied_loras = []
@staticmethod
def _resolve_lora_path(path: str) -> str:
"""Resolve a LoRA path. Supports:
- Absolute/relative local paths (returned as-is if the file exists)
- ``repo_id:filename`` HuggingFace references
"""
if os.path.isfile(path):
return path
if ":" in path and not path.startswith(("/", "./")):
repo_id, filename = path.split(":", 1)
from huggingface_hub import hf_hub_download
return hf_hub_download(repo_id=repo_id, filename=filename)
return path
# --- Inference ---------------------------------------------------------
def generate_i2v(
self,
image_path: str,
prompt: str,
seconds: int,
seed: int | None = None,
negative_prompt: str = "",
) -> np.ndarray:
"""Run image-to-video inference and return decoded frames.
Returns ``np.ndarray`` shape ``[T, H, W, 3]`` dtype uint8 in RGB.
"""
if seed is None:
seed = random.randint(0, 2**31 - 1)
# Wan2.2 target_video_length is "frames including the conditioning
# frame", so N seconds → N*fps + 1.
target_frames = seconds * self.fps + 1
from lightx2v.utils.input_info import update_input_info_from_dict # type: ignore[import-not-found]
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf:
out_path = tf.name
try:
log.info("Wan2.2 generate: prompt=%r seconds=%d seed=%d%s",
prompt[:80], seconds, seed, out_path)
update_input_info_from_dict(
self._input_info_template,
{
"seed": seed,
"prompt": prompt,
"negative_prompt": negative_prompt,
"image_path": image_path,
"save_result_path": out_path,
"target_video_length": target_frames,
"return_result_tensor": False,
},
)
self._runner.run_pipeline(self._input_info_template)
return _read_mp4_to_frames(out_path)
finally:
try:
os.remove(out_path)
except OSError:
pass
# --- MP4 decoding helper ------------------------------------------------------
def _read_mp4_to_frames(path: str) -> np.ndarray:
"""Decode an MP4 into an RGB uint8 frame array ``[T, H, W, 3]``."""
try:
import imageio.v3 as iio # type: ignore[import-not-found]
frames = iio.imread(path, plugin="pyav")
arr = np.asarray(frames)
if arr.ndim == 3:
arr = arr[None, ...]
return arr.astype(np.uint8)
except Exception as e: # pragma: no cover - fallback path
log.warning("imageio decode failed (%s); falling back to cv2", e)
import cv2 # type: ignore[import-not-found]
cap = cv2.VideoCapture(path)
frames: list[np.ndarray] = []
while True:
ok, frame = cap.read()
if not ok:
break
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
cap.release()
if not frames:
raise RuntimeError(f"Failed to decode any frames from {path}")
return np.stack(frames, axis=0).astype(np.uint8)
+160 -1
View File
@@ -18,9 +18,18 @@ let pendingTextChunks = []; // [{chunkId, text}] - text waiting for its audio to
let scheduledTextTimers = []; // timer IDs for text display scheduled to match audio playback let scheduledTextTimers = []; // timer IDs for text display scheduled to match audio playback
let lastDisplayedChunkId = -1; // last chunk whose text was actually shown to the user let lastDisplayedChunkId = -1; // last chunk whose text was actually shown to the user
// --- Video mode state ---
let videoModeEnabled = false; // true when server has video engine active AND ready
let videoModeName = "off"; // "off" | "library" | "reflective"
let idleClipUrl = null; // URL string (server-served) or null
let pendingSpeakingClipMeta = null; // {chunk_id, duration_ms, text} waiting for MP4 binary
let currentSpeakingClipBlobUrl = null;
const chatArea = document.getElementById("chat-area"); const chatArea = document.getElementById("chat-area");
const statusBadge = document.getElementById("status-badge"); const statusBadge = document.getElementById("status-badge");
const micBtn = document.getElementById("mic-btn"); const micBtn = document.getElementById("mic-btn");
const avatarVideo = document.getElementById("avatar-video");
const stageEl = document.getElementById("stage");
// --- WebSocket --- // --- WebSocket ---
@@ -44,7 +53,18 @@ function connectWS() {
ws.onmessage = (event) => { ws.onmessage = (event) => {
if (event.data instanceof ArrayBuffer) { if (event.data instanceof ArrayBuffer) {
playAudioChunk(event.data); // In video mode, the next binary frame after a "speaking_clip"
// envelope is an MP4 blob; otherwise it's a PCM audio chunk.
if (pendingSpeakingClipMeta) {
const meta = pendingSpeakingClipMeta;
pendingSpeakingClipMeta = null;
playSpeakingClip(event.data, meta);
} else if (videoModeEnabled) {
// Video mode is active but we didn't get a speaking_clip envelope
// first — ignore raw PCM so we don't double-play audio.
} else {
playAudioChunk(event.data);
}
} else { } else {
handleJSON(JSON.parse(event.data)); handleJSON(JSON.parse(event.data));
} }
@@ -59,6 +79,7 @@ function handleJSON(msg) {
case "interrupt": case "interrupt":
stopPlayback(); stopPlayback();
stopSpeakingClip();
// Finalize with interrupted marker — text already reflects only what was heard // Finalize with interrupted marker — text already reflects only what was heard
finalizeAssistantMessage(true); finalizeAssistantMessage(true);
break; break;
@@ -80,6 +101,141 @@ function handleJSON(msg) {
pendingTextChunks.push({ chunkId: msg.chunk_id, text: msg.text }); pendingTextChunks.push({ chunkId: msg.chunk_id, text: msg.text });
} }
break; break;
case "video_mode":
// Sent once on WS open. Toggles the video element + speaking-clip path.
applyVideoModeState(msg);
break;
case "speaking_clip":
// Envelope preceding an MP4 binary frame with the full turn.
pendingSpeakingClipMeta = {
chunk_id: msg.chunk_id,
duration_ms: msg.duration_ms,
text: msg.text,
};
break;
}
}
// --- Video mode ------------------------------------------------------------
function applyVideoModeState(msg) {
videoModeEnabled = !!msg.enabled && !!msg.ready;
videoModeName = msg.mode || "off";
idleClipUrl = msg.idle_clip_url || null;
refreshStage();
}
function refreshStage() {
if (videoModeEnabled && idleClipUrl) {
stageEl.classList.add("active");
if (avatarVideo.src !== location.origin + idleClipUrl) {
avatarVideo.src = idleClipUrl;
avatarVideo.loop = true;
avatarVideo.muted = true;
avatarVideo.play().catch(() => {});
}
} else {
stageEl.classList.remove("active");
}
}
function playSpeakingClip(arrayBuffer, meta) {
// Replace the idle loop with the speaking clip.
stopSpeakingClip();
const blob = new Blob([arrayBuffer], { type: "video/mp4" });
currentSpeakingClipBlobUrl = URL.createObjectURL(blob);
avatarVideo.loop = false;
avatarVideo.muted = false;
avatarVideo.src = currentSpeakingClipBlobUrl;
// Show the full reply text now — the MP4 plays it in one shot so there's
// no per-chunk sync to do.
if (meta && meta.text) {
appendAssistantText(meta.text);
}
isPlaying = true;
avatarVideo.onended = () => {
isPlaying = false;
finalizeAssistantMessage(false);
// Return to idle loop.
if (idleClipUrl) {
avatarVideo.loop = true;
avatarVideo.muted = true;
avatarVideo.src = idleClipUrl;
avatarVideo.play().catch(() => {});
}
if (currentSpeakingClipBlobUrl) {
URL.revokeObjectURL(currentSpeakingClipBlobUrl);
currentSpeakingClipBlobUrl = null;
}
};
avatarVideo.play().catch((e) => {
console.error("speaking clip play failed:", e);
});
}
function stopSpeakingClip() {
if (!currentSpeakingClipBlobUrl) return;
try {
avatarVideo.pause();
} catch (_) {}
URL.revokeObjectURL(currentSpeakingClipBlobUrl);
currentSpeakingClipBlobUrl = null;
if (idleClipUrl) {
avatarVideo.loop = true;
avatarVideo.muted = true;
avatarVideo.src = idleClipUrl;
avatarVideo.play().catch(() => {});
}
isPlaying = false;
}
async function uploadAvatar() {
const fileInput = document.getElementById("avatar-file");
const status = document.getElementById("avatar-status");
if (!fileInput.files || !fileInput.files[0]) {
status.textContent = "Pick an image first.";
return;
}
status.textContent = "Uploading and rendering idle clip (this takes a while)...";
const fd = new FormData();
fd.append("image", fileInput.files[0]);
try {
const resp = await fetch("/api/set-avatar", { method: "POST", body: fd });
if (!resp.ok) throw new Error(await resp.text());
const data = await resp.json();
idleClipUrl = data.idle_clip_url + "?t=" + Date.now(); // cache-bust
videoModeEnabled = true;
videoModeName = data.mode || videoModeName;
refreshStage();
status.textContent = "Avatar ready (" + data.mode + ")";
} catch (err) {
console.error(err);
status.textContent = "Failed: " + err.message;
}
}
async function applyVideoMode() {
const sel = document.getElementById("video-mode-select");
const status = document.getElementById("avatar-status");
const fd = new FormData();
fd.append("mode", sel.value);
try {
const resp = await fetch("/api/set-video-mode", { method: "POST", body: fd });
if (!resp.ok) throw new Error(await resp.text());
const data = await resp.json();
videoModeName = data.mode;
if (data.mode === "off") {
videoModeEnabled = false;
stageEl.classList.remove("active");
}
status.textContent = "Mode: " + data.mode + (data.note ? " — " + data.note : "");
} catch (err) {
status.textContent = "Failed: " + err.message;
} }
} }
@@ -275,6 +431,7 @@ async function startMic() {
if (bargeInCount >= BARGE_IN_FRAMES) { if (bargeInCount >= BARGE_IN_FRAMES) {
// User is speaking over the assistant - interrupt // User is speaking over the assistant - interrupt
stopPlayback(); stopPlayback();
stopSpeakingClip();
const msg = { type: "interrupt" }; const msg = { type: "interrupt" };
if (lastDisplayedChunkId >= 0) { if (lastDisplayedChunkId >= 0) {
msg.last_chunk_id = lastDisplayedChunkId; msg.last_chunk_id = lastDisplayedChunkId;
@@ -353,3 +510,5 @@ async function applyVoice() {
// Expose to HTML onclick // Expose to HTML onclick
window.toggleMic = toggleMic; window.toggleMic = toggleMic;
window.applyVoice = applyVoice; window.applyVoice = applyVoice;
window.uploadAvatar = uploadAvatar;
window.applyVideoMode = applyVideoMode;
+32
View File
@@ -12,6 +12,17 @@
<span id="status-badge">Disconnected</span> <span id="status-badge">Disconnected</span>
</header> </header>
<div id="stage">
<video
id="avatar-video"
autoplay
muted
loop
playsinline
preload="auto"
></video>
</div>
<div id="chat-area"></div> <div id="chat-area"></div>
<details id="voice-panel"> <details id="voice-panel">
@@ -40,6 +51,27 @@
</div> </div>
</details> </details>
<details id="avatar-panel">
<summary>Avatar / Video</summary>
<div class="panel-content">
<label>
Avatar image
<input type="file" id="avatar-file" accept="image/*" />
</label>
<button id="upload-avatar-btn" onclick="uploadAvatar()">Upload</button>
<label>
Mode
<select id="video-mode-select">
<option value="off">Off</option>
<option value="library">Library (pre-baked)</option>
<option value="reflective" selected>Reflective (per-turn)</option>
</select>
</label>
<button id="apply-mode-btn" onclick="applyVideoMode()">Apply mode</button>
<span id="avatar-status"></span>
</div>
</details>
<div id="controls"> <div id="controls">
<button id="mic-btn" onclick="toggleMic()">&#x1F3A4;</button> <button id="mic-btn" onclick="toggleMic()">&#x1F3A4;</button>
</div> </div>
+39 -4
View File
@@ -52,6 +52,28 @@ header h1 {
color: #a78bfa; color: #a78bfa;
} }
#stage {
display: none; /* toggled on when video mode is enabled */
align-items: center;
justify-content: center;
padding: 16px 24px 0;
background: #0a0a0a;
}
#stage.active {
display: flex;
}
#avatar-video {
width: 100%;
max-width: 480px;
aspect-ratio: 16 / 9;
background: #000;
border-radius: 12px;
object-fit: cover;
box-shadow: 0 8px 24px rgba(0, 0, 0, 0.4);
}
#chat-area { #chat-area {
flex: 1; flex: 1;
overflow-y: auto; overflow-y: auto;
@@ -130,21 +152,34 @@ header h1 {
50% { box-shadow: 0 0 0 12px rgba(239, 68, 68, 0); } 50% { box-shadow: 0 0 0 12px rgba(239, 68, 68, 0); }
} }
/* Voice clone panel */ /* Voice + avatar panels */
#voice-panel { #voice-panel,
#avatar-panel {
padding: 12px 24px; padding: 12px 24px;
border-top: 1px solid #222; border-top: 1px solid #222;
background: #0a0a0a; background: #0a0a0a;
} }
#voice-panel summary { #voice-panel select,
#avatar-panel select {
background: #1a1a1a;
border: 1px solid #333;
border-radius: 6px;
padding: 6px 10px;
color: #e0e0e0;
font-size: 13px;
}
#voice-panel summary,
#avatar-panel summary {
cursor: pointer; cursor: pointer;
font-size: 13px; font-size: 13px;
color: #888; color: #888;
user-select: none; user-select: none;
} }
#voice-panel .panel-content { #voice-panel .panel-content,
#avatar-panel .panel-content {
margin-top: 12px; margin-top: 12px;
display: flex; display: flex;
gap: 12px; gap: 12px;
+47
View File
@@ -0,0 +1,47 @@
# Voice-chat tests
Two tiers.
## Unit tests — fast, GPU-free
```
python -m pytest tests/unit -v
```
These exercise pure logic: config parsing, prompt derivation, LoRA spec
parsing, frame-length fitting, library round-robin selection. They do not
touch CUDA, Wan2.2, MuseTalk, or ffmpeg. Safe to run on Windows, outside
Docker, without any models installed.
## Component tests — slow, GPU-required, run inside Docker
Each script in `tests/component/` exercises one subsystem end-to-end against
the real models. They are ordered to match the implementation phases:
| Script | Phase | Tests |
|---|---|---|
| `test_01_video_skeleton.py` | 1 | VideoEngine loads, config gate respected |
| `test_02_wan22_loras.py` | 2 | Wan2.2 pipeline loads, LoRA stack applies |
| `test_03_idle_clip.py` | 3 | set_avatar → idle MP4, written to disk for eyeballing |
| `test_04_library_prebake.py` | 4 | library mode pre-bakes N base clips |
| `test_05_musetalk_lipsync.py` | 5 | MuseTalk lip-sync on library frames + ffmpeg mux |
| `test_06_reflective.py` | 6 | reflective mode: fresh Wan2.2 per reply |
| `test_07_endpoints.py` | 7 | HTTP endpoints return sane responses |
| `test_08_lora_reload.py` | 8 | /api/reload-loras swaps LoRAs live |
Run one:
```
# Inside the container:
docker compose exec voice-chat python -m tests.component.test_03_idle_clip
```
Run all (slow, ~20+ minutes on 5090):
```
docker compose exec voice-chat python -m tests.component.run_all
```
Each component script writes its artifacts (MP4s, PNG frame dumps, logs)
to `tests/component/_out/` so you can visually inspect results. That
directory is gitignored.
View File
View File
+72
View File
@@ -0,0 +1,72 @@
"""Shared utilities for component tests.
Component tests run inside the Docker image against real GPU models. They
write their output artefacts (MP4s, PNGs, logs) to ``_out/`` so you can
visually inspect results.
"""
from __future__ import annotations
import logging
import os
import sys
import numpy as np
OUT_DIR = os.path.join(os.path.dirname(__file__), "_out")
os.makedirs(OUT_DIR, exist_ok=True)
# A tiny 256x256 portrait PNG lives next to the component tests so tests
# don't need a user-supplied file. If it's missing we synthesise one on
# the fly.
SAMPLE_AVATAR = os.path.join(os.path.dirname(__file__), "sample_avatar.png")
def get_logger(name: str) -> logging.Logger:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(name)s %(levelname)s %(message)s",
stream=sys.stdout,
)
return logging.getLogger(name)
def ensure_sample_avatar() -> str:
"""Guarantee a usable avatar image exists. Returns its path."""
if os.path.isfile(SAMPLE_AVATAR):
return SAMPLE_AVATAR
# Synthesise a simple gradient PNG as a last resort (won't look like a
# person but is valid input for Wan2.2 so the pipeline doesn't fail).
try:
from PIL import Image # type: ignore[import-not-found]
except ImportError:
import imageio.v3 as iio # type: ignore[import-not-found]
arr = np.zeros((256, 256, 3), dtype=np.uint8)
for y in range(256):
arr[y, :, 0] = y
arr[y, :, 1] = 255 - y
arr[y, :, 2] = 128
iio.imwrite(SAMPLE_AVATAR, arr)
return SAMPLE_AVATAR
arr = np.zeros((256, 256, 3), dtype=np.uint8)
for y in range(256):
arr[y, :, 0] = y
arr[y, :, 1] = 255 - y
arr[y, :, 2] = 128
Image.fromarray(arr).save(SAMPLE_AVATAR)
return SAMPLE_AVATAR
def write_bytes(name: str, data: bytes) -> str:
"""Write an artefact to _out/<name> and return the full path."""
path = os.path.join(OUT_DIR, name)
with open(path, "wb") as f:
f.write(data)
return path
def synth_tone(seconds: float, sample_rate: int = 24000, freq: float = 220.0) -> np.ndarray:
"""Return a float32 sine tone usable as stand-in TTS audio."""
t = np.arange(int(seconds * sample_rate), dtype=np.float32) / sample_rate
return (0.2 * np.sin(2 * np.pi * freq * t)).astype(np.float32)
+46
View File
@@ -0,0 +1,46 @@
"""Run every component test in order. Stops at first failure.
docker compose exec voice-chat python -m tests.component.run_all
"""
import importlib
import sys
import traceback
SCRIPTS = [
"tests.component.test_01_video_skeleton",
"tests.component.test_02_wan22_loras",
"tests.component.test_03_idle_clip",
"tests.component.test_04_library_prebake",
"tests.component.test_05_musetalk_lipsync",
"tests.component.test_06_reflective",
"tests.component.test_07_endpoints",
"tests.component.test_08_lora_reload",
]
def main() -> int:
failed: list[str] = []
for name in SCRIPTS:
print(f"\n{'=' * 70}\nRUNNING: {name}\n{'=' * 70}")
try:
mod = importlib.import_module(name)
mod.run()
except SystemExit as e:
if e.code:
print(f"FAILED: {name} (exit {e.code})")
failed.append(name)
break # hard-stop on failure
except Exception:
traceback.print_exc()
failed.append(name)
break
if failed:
print(f"\n{len(failed)} failed: {failed}")
return 1
print("\nALL COMPONENT TESTS PASSED")
return 0
if __name__ == "__main__":
sys.exit(main())
+69
View File
@@ -0,0 +1,69 @@
"""Phase 1 component test: VideoEngine skeleton + config gate.
Verifies:
- ``ModelManager`` can be imported and constructed.
- When ``config.video.enabled=false``, ``_load_video`` skips and leaves
``video_engine=None`` (existing voice path unaffected).
- When ``config.video.enabled=true``, a ``VideoEngine`` instance is created
and ``is_ready()`` returns False (no models loaded yet).
Does NOT load Wan2.2 or MuseTalk — this test is safe to run on any machine
with the python deps installed (no GPU needed).
Run inside Docker:
docker compose exec voice-chat python -m tests.component.test_01_video_skeleton
"""
from __future__ import annotations
import sys
from server.models import ModelManager
from server.video import VideoConfig, VideoEngine
from tests.component._common import get_logger
log = get_logger("test_01")
def run():
# --- disabled path ---
log.info("[case 1] config.video.enabled=False → engine skipped")
mgr = ModelManager()
# Monkey-patch the config module to simulate disabled
import server.config as cfgmod
original = cfgmod.config
cfgmod.config = {"video": {"enabled": False}, **{k: v for k, v in original.items() if k != "video"}}
try:
mgr._load_video()
assert mgr.video_engine is None, "video_engine should be None when disabled"
log.info(" PASS: video_engine is None")
finally:
cfgmod.config = original
# --- enabled path (no models loaded) ---
log.info("[case 2] config.video.enabled=True → engine created, not ready")
mgr2 = ModelManager()
cfgmod.config = {
**original,
"video": {"enabled": True, "mode": "reflective", "loras": []},
}
try:
mgr2._load_video()
assert mgr2.video_engine is not None, "video_engine should be created"
assert isinstance(mgr2.video_engine, VideoEngine)
assert mgr2.video_engine.is_ready() is False
log.info(" PASS: engine=%s, ready=%s",
type(mgr2.video_engine).__name__, mgr2.video_engine.is_ready())
finally:
cfgmod.config = original
log.info("ALL PASSED")
if __name__ == "__main__":
try:
run()
sys.exit(0)
except AssertionError as e:
log.error("FAILED: %s", e)
sys.exit(1)
+106
View File
@@ -0,0 +1,106 @@
"""Phase 2 component test: Wan2.2-Lightning fp8 pipeline + LoRA stacking.
Verifies:
- ``Wan22Pipeline`` loads successfully against the fp8 distill path
(exercises the real LightX2V set_config → init_runner flow).
- ``load_loras`` / ``unload_loras`` survive with the two user LoRAs at
``/cache/loras/wan22-[HL]-e8.safetensors``.
Requires GPU and a first-run download of both HF repos (base support files
~12 GB, fp8 DIT ~30 GB). If LightX2V isn't installed the test is skipped.
Run:
docker compose exec voice-chat python -m tests.component.test_02_wan22_loras
"""
from __future__ import annotations
import os
import sys
from tests.component._common import get_logger
log = get_logger("test_02")
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
LORA_HIGH = "/cache/loras/wan22-H-e8.safetensors"
LORA_LOW = "/cache/loras/wan22-L-e8.safetensors"
def run():
try:
from server.video_models.wan22 import Wan22Pipeline
except ImportError as e:
log.error("Wan22Pipeline import failed: %s", e)
log.warning("SKIP: phase 2 deps not installed")
sys.exit(0)
from server.video import LoRASpec
log.info("[case 1] Instantiate Wan22Pipeline "
"(first run downloads ~42 GB total)...")
try:
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-I2V-A14B",
fp8_repo="lightx2v/Wan2.2-Distill-Models",
config_json=CONFIG_JSON,
model_cls="wan2.2_moe_distill",
resolution=480,
fps=16,
)
except Exception as e:
log.error("FAIL: Wan22Pipeline construction raised: %s", e)
log.error("Check: LightX2V install, HF cache at /cache/huggingface, "
"VRAM headroom, and that %s exists inside the container.",
CONFIG_JSON)
sys.exit(2)
log.info(" PASS: pipeline constructed")
# --- LoRAs ---
log.info("[case 2] load_loras with empty list → no-op")
pipe.load_loras([])
log.info(" PASS")
if not (os.path.isfile(LORA_HIGH) and os.path.isfile(LORA_LOW)):
log.warning("SKIP: expected LoRA files not found at %s / %s",
LORA_HIGH, LORA_LOW)
log.info("ALL PASSED (partial — LoRA cases skipped)")
return
log.info("[case 3] load_loras with the two MoE distill LoRAs")
specs = [
LoRASpec(
path=LORA_HIGH,
weight=1.0,
target="high_noise",
name="wan22-H-e8",
),
LoRASpec(
path=LORA_LOW,
weight=1.0,
target="low_noise",
name="wan22-L-e8",
),
]
try:
pipe.load_loras(specs)
except Exception as e:
log.error("FAIL: load_loras raised: %s", e)
log.error("Check: switch_lora support for wan2.2_moe_distill in the "
"installed LightX2V build. If it errors there, pre-declare "
"LoRAs in the config_json 'lora_configs' field instead.")
sys.exit(3)
log.info(" PASS: LoRAs applied")
log.info("[case 4] unload_loras")
try:
pipe.unload_loras()
except Exception as e:
log.error("FAIL: unload_loras raised: %s", e)
sys.exit(4)
log.info(" PASS")
log.info("ALL PASSED")
if __name__ == "__main__":
run()
+66
View File
@@ -0,0 +1,66 @@
"""Phase 3 component test: avatar upload → idle clip generation.
Verifies:
- ``VideoEngine.load_models()`` + ``set_avatar(image)`` produces a non-empty
idle MP4 blob.
- The blob decodes as a valid MP4 (ftyp header).
Writes the idle clip to ``tests/component/_out/phase3_idle.mp4`` so you can
inspect it visually.
Run:
docker compose exec voice-chat python -m tests.component.test_03_idle_clip
"""
from __future__ import annotations
import sys
from server.video import VideoConfig, VideoEngine
from tests.component._common import ensure_sample_avatar, get_logger, write_bytes
log = get_logger("test_03")
def run():
avatar_path = ensure_sample_avatar()
log.info("Using avatar: %s", avatar_path)
cfg = VideoConfig.from_dict(
{
"enabled": True,
"mode": "reflective", # reflective skips the library prebake
"resolution": 480,
"fps": 16,
"library": {"base_clip_count": 0, "base_clip_seconds": 3},
}
)
engine = VideoEngine(cfg)
log.info("Loading models (Wan2.2 + MuseTalk)...")
try:
engine.load_models()
except Exception as e:
log.error("FAIL: load_models raised: %s", e)
sys.exit(2)
log.info("Models loaded.")
log.info("Generating idle clip for avatar...")
try:
engine.set_avatar(avatar_path)
except Exception as e:
log.error("FAIL: set_avatar raised: %s", e)
sys.exit(3)
idle = engine.get_idle_clip()
assert idle is not None and len(idle) > 0, "idle clip is empty"
assert idle[4:8] == b"ftyp", "idle clip is not a valid MP4"
out_path = write_bytes("phase3_idle.mp4", idle)
log.info("PASS: idle clip written to %s (%d bytes)", out_path, len(idle))
assert engine.is_ready() is True
log.info(" engine.is_ready() = True (avatar + models present)")
if __name__ == "__main__":
run()
@@ -0,0 +1,55 @@
"""Phase 4 component test: library mode pre-bake of speaking-base clips.
Verifies:
- ``set_avatar`` under ``mode=library`` populates ``speaking_base_frames``
with ``library_base_clip_count`` entries.
- Each cached entry has shape ``[T, H, W, 3]`` uint8.
Run:
docker compose exec voice-chat python -m tests.component.test_04_library_prebake
"""
from __future__ import annotations
import sys
import numpy as np
from server.video import VideoConfig, VideoEngine
from tests.component._common import ensure_sample_avatar, get_logger
log = get_logger("test_04")
def run():
avatar_path = ensure_sample_avatar()
cfg = VideoConfig.from_dict(
{
"enabled": True,
"mode": "library",
"resolution": 480,
"fps": 16,
"library": {"base_clip_count": 2, "base_clip_seconds": 3},
}
)
engine = VideoEngine(cfg)
log.info("Loading models...")
engine.load_models()
log.info("Pre-baking 2 library clips...")
engine.set_avatar(avatar_path)
assert len(engine.speaking_base_frames) == 2, \
f"expected 2 base clips, got {len(engine.speaking_base_frames)}"
for i, frames in enumerate(engine.speaking_base_frames):
assert isinstance(frames, np.ndarray)
assert frames.ndim == 4 and frames.shape[-1] == 3
assert frames.dtype == np.uint8
log.info(" clip %d: shape=%s", i, frames.shape)
assert engine.get_idle_clip() is not None
log.info("PASS: library pre-bake complete")
if __name__ == "__main__":
run()
@@ -0,0 +1,57 @@
"""Phase 5 component test: MuseTalk lip-sync + ffmpeg mux.
Verifies the full library-mode per-turn path:
- Pre-bake a library clip.
- Generate a stand-in TTS waveform (sine tone).
- Call ``VideoEngine.generate_speaking_clip`` and get a valid MP4 back.
Writes the resulting clip to ``tests/component/_out/phase5_speaking.mp4``.
Run:
docker compose exec voice-chat python -m tests.component.test_05_musetalk_lipsync
"""
from __future__ import annotations
import sys
from server.video import VideoConfig, VideoEngine
from tests.component._common import (
ensure_sample_avatar,
get_logger,
synth_tone,
write_bytes,
)
log = get_logger("test_05")
def run():
avatar_path = ensure_sample_avatar()
cfg = VideoConfig.from_dict(
{
"enabled": True,
"mode": "library",
"resolution": 480,
"fps": 16,
"library": {"base_clip_count": 1, "base_clip_seconds": 4},
}
)
engine = VideoEngine(cfg)
engine.load_models()
engine.set_avatar(avatar_path)
audio = synth_tone(seconds=3.0, sample_rate=24000, freq=220.0)
log.info("Generating library-mode speaking clip (3s audio)...")
mp4 = engine.generate_speaking_clip(
audio_f32=audio,
sample_rate=24000,
reply_text="Hello, this is a lip-sync test.",
)
assert isinstance(mp4, bytes) and len(mp4) > 0
assert mp4[4:8] == b"ftyp"
out = write_bytes("phase5_speaking.mp4", mp4)
log.info("PASS: speaking clip written to %s (%d bytes)", out, len(mp4))
if __name__ == "__main__":
run()
+69
View File
@@ -0,0 +1,69 @@
"""Phase 6 component test: reflective mode (fresh Wan2.2 clip per turn).
Verifies that with ``mode=reflective``, ``generate_speaking_clip`` runs
the Wan2.2 image-to-video pipeline once per call (so the base frames
differ from turn to turn) and the prompt is derived from the reply text.
Run:
docker compose exec voice-chat python -m tests.component.test_06_reflective
"""
from __future__ import annotations
import numpy as np
from server.video import VideoConfig, VideoEngine
from tests.component._common import (
ensure_sample_avatar,
get_logger,
synth_tone,
write_bytes,
)
log = get_logger("test_06")
def run():
avatar_path = ensure_sample_avatar()
cfg = VideoConfig.from_dict(
{
"enabled": True,
"mode": "reflective",
"resolution": 480,
"fps": 16,
"reflective": {"clip_seconds": 3},
}
)
engine = VideoEngine(cfg)
engine.load_models()
engine.set_avatar(avatar_path)
# Verify prompt derivation includes the reply hint
prompt = engine._derive_prompt(
"The assistant walks along a sunny beach watching seagulls."
)
log.info("derived prompt: %s", prompt)
assert "beach" in prompt, "reply_hint did not survive template interpolation"
audio = synth_tone(seconds=3.0)
log.info("Generating reflective speaking clip #1...")
mp4_a = engine.generate_speaking_clip(
audio, 24000, "The assistant walks along a sunny beach watching seagulls."
)
write_bytes("phase6_reflective_beach.mp4", mp4_a)
log.info("Generating reflective speaking clip #2...")
mp4_b = engine.generate_speaking_clip(
audio, 24000, "Now the character stands in a snow-covered forest at dusk."
)
write_bytes("phase6_reflective_snow.mp4", mp4_b)
# Not a strict assertion (same prompt could yield identical bytes if seeded),
# but with different prompts and random seeds the blobs should differ.
if mp4_a != mp4_b:
log.info("PASS: reflective clips differ as expected")
else:
log.warning("clips are byte-identical — check that seeds are random")
if __name__ == "__main__":
run()
+114
View File
@@ -0,0 +1,114 @@
"""Phase 7 component test: HTTP endpoints (/api/set-avatar, /api/idle-clip,
/api/set-video-mode, /api/reload-loras, WebSocket handshake video_mode msg).
Uses FastAPI's ``TestClient`` so we don't need a running uvicorn server.
Stubs the model manager to avoid loading Wan2.2 — we only care that the
HTTP surface is plumbed correctly.
Run:
docker compose exec voice-chat python -m tests.component.test_07_endpoints
"""
from __future__ import annotations
import io
import json
import sys
from tests.component._common import get_logger
log = get_logger("test_07")
def _stub_video_engine():
class StubCfg:
mode = "reflective"
class StubEngine:
cfg = StubCfg()
avatar_path = None
def __init__(self): self.idle = b"FAKE_MP4"
def is_ready(self): return bool(self.avatar_path)
def get_idle_clip(self): return self.idle
def set_avatar(self, path): self.avatar_path = path
def load_loras(self, specs): self._last_loras = specs
return StubEngine()
def run():
from fastapi.testclient import TestClient
import server.main as main_mod
# Inject a stub engine so we never touch Wan2.2.
main_mod.model_mgr.video_engine = _stub_video_engine()
# Bypass the heavy lifespan (model loading) so TestClient starts fast.
main_mod.app.router.lifespan_context = None # type: ignore[attr-defined]
client = TestClient(main_mod.app)
# --- set-avatar ---
log.info("[case 1] POST /api/set-avatar")
fake_png = b"\x89PNG\r\n\x1a\n" + b"\x00" * 64 # minimal PNG header
resp = client.post(
"/api/set-avatar",
files={"image": ("avatar.png", io.BytesIO(fake_png), "image/png")},
)
assert resp.status_code == 200, f"got {resp.status_code}: {resp.text}"
data = resp.json()
assert data["status"] == "ok"
assert data["idle_clip_url"] == "/api/idle-clip"
log.info(" PASS: %s", data)
# --- idle-clip ---
log.info("[case 2] GET /api/idle-clip")
resp = client.get("/api/idle-clip")
assert resp.status_code == 200
assert resp.content == b"FAKE_MP4"
assert resp.headers["content-type"] == "video/mp4"
log.info(" PASS")
# --- set-video-mode ---
log.info("[case 3] POST /api/set-video-mode")
for mode in ("off", "library", "reflective"):
resp = client.post("/api/set-video-mode", data={"mode": mode})
assert resp.status_code == 200
assert resp.json()["mode"] == mode
resp = client.post("/api/set-video-mode", data={"mode": "bogus"})
assert resp.status_code == 400
log.info(" PASS")
# --- reload-loras ---
log.info("[case 4] POST /api/reload-loras")
body = {
"loras": [
{"path": "/cache/loras/a.safetensors", "weight": 0.8,
"target": "high_noise", "name": "test-a"},
{"path": "/cache/loras/b.safetensors", "weight": 0.4,
"target": "low_noise"},
]
}
resp = client.post("/api/reload-loras", json=body)
assert resp.status_code == 200, resp.text
data = resp.json()
assert data["lora_count"] == 2
log.info(" PASS: %s", data)
# --- WebSocket video_mode handshake ---
log.info("[case 5] WebSocket /ws/chat → video_mode announcement")
with client.websocket_connect("/ws/chat") as websocket:
msgs = []
for _ in range(5):
try:
msg = websocket.receive_json()
msgs.append(msg)
if msg.get("type") == "video_mode":
break
except Exception:
break
assert any(m.get("type") == "video_mode" for m in msgs), msgs
log.info(" PASS")
log.info("ALL PASSED")
if __name__ == "__main__":
run()
+60
View File
@@ -0,0 +1,60 @@
"""Phase 8 component test: /api/reload-loras hot-swap.
Verifies that ``VideoEngine.load_loras`` can be called again after startup
and the idle clip is regenerated to reflect the new style.
This test is the 'real model' version of test_07's reload endpoint stub.
Run:
docker compose exec voice-chat python -m tests.component.test_08_lora_reload
"""
from __future__ import annotations
import hashlib
from server.video import LoRASpec, VideoConfig, VideoEngine
from tests.component._common import ensure_sample_avatar, get_logger, write_bytes
log = get_logger("test_08")
def run():
avatar_path = ensure_sample_avatar()
cfg = VideoConfig.from_dict({"enabled": True, "mode": "reflective"})
engine = VideoEngine(cfg)
engine.load_models()
# Initial state: no LoRAs
engine.set_avatar(avatar_path)
idle_a = engine.get_idle_clip()
assert idle_a is not None
hash_a = hashlib.sha256(idle_a).hexdigest()
write_bytes("phase8_idle_noloras.mp4", idle_a)
log.info("idle (no LoRAs) sha256=%s", hash_a[:16])
# Hot-reload with a distill LoRA
specs = [
LoRASpec(
path="lightx2v/Wan2.2-Distill-Loras:"
"wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step.safetensors",
weight=1.0,
target="high_noise",
name="distill-hi",
),
]
engine.load_loras(specs)
engine.set_avatar(avatar_path)
idle_b = engine.get_idle_clip()
assert idle_b is not None
hash_b = hashlib.sha256(idle_b).hexdigest()
write_bytes("phase8_idle_withlora.mp4", idle_b)
log.info("idle (with LoRA) sha256=%s", hash_b[:16])
if hash_a != hash_b:
log.info("PASS: idle clip changed after LoRA reload")
else:
log.warning("clips identical — LoRA may not be applied; eyeball _out/*.mp4")
if __name__ == "__main__":
run()
View File
+65
View File
@@ -0,0 +1,65 @@
"""Unit tests for the frame-length fitting helper in server.video_models.musetalk.
Pure-python: does not import MuseTalk itself.
"""
import numpy as np
from server.video_models.musetalk import _fit_frames_to_length, _ensure_uint8_rgb
def _make_frames(t, h=2, w=2):
return np.arange(t * h * w * 3, dtype=np.uint8).reshape(t, h, w, 3)
def test_fit_frames_trim():
frames = _make_frames(10)
out = _fit_frames_to_length(frames, 4)
assert out.shape == (4, 2, 2, 3)
np.testing.assert_array_equal(out, frames[:4])
def test_fit_frames_passthrough_when_equal():
frames = _make_frames(5)
out = _fit_frames_to_length(frames, 5)
assert out is frames or np.array_equal(out, frames)
def test_fit_frames_extends_with_pingpong():
frames = _make_frames(3)
out = _fit_frames_to_length(frames, 8)
assert out.shape == (8, 2, 2, 3)
# First 3 frames match the original
np.testing.assert_array_equal(out[:3], frames)
# Next 3 are the reverse (ping-pong)
np.testing.assert_array_equal(out[3:6], frames[::-1])
# Then forward again
np.testing.assert_array_equal(out[6:8], frames[:2])
def test_fit_frames_zero_target_returns_original():
frames = _make_frames(3)
out = _fit_frames_to_length(frames, 0)
np.testing.assert_array_equal(out, frames)
def test_ensure_uint8_rgb_from_float():
arr = np.ones((5, 2, 2, 3), dtype=np.float32) * 0.5
out = _ensure_uint8_rgb(arr)
assert out.dtype == np.uint8
assert out.shape == (5, 2, 2, 3)
assert out[0, 0, 0, 0] == 127
def test_ensure_uint8_rgb_promotes_3d_to_4d():
arr = np.zeros((2, 2, 3), dtype=np.uint8)
out = _ensure_uint8_rgb(arr)
assert out.shape == (1, 2, 2, 3)
def test_ensure_uint8_rgb_clips_float_out_of_range():
arr = np.ones((1, 1, 1, 3), dtype=np.float32) * 2.0 # 2.0 → clipped to 255
out = _ensure_uint8_rgb(arr)
assert out[0, 0, 0, 0] == 255
arr2 = np.ones((1, 1, 1, 3), dtype=np.float32) * -1.0
out2 = _ensure_uint8_rgb(arr2)
assert out2[0, 0, 0, 0] == 0
+67
View File
@@ -0,0 +1,67 @@
"""Unit tests for the ffmpeg muxer.
Requires ``ffmpeg`` on PATH. On Windows, if ffmpeg is not installed these
tests are skipped (they will run inside the Docker image where ffmpeg is
always present).
"""
import os
import shutil
import struct
import numpy as np
import pytest
from server.video_models.muxer import frames_and_audio_to_mp4, frames_to_mp4_loop
pytestmark = pytest.mark.skipif(
shutil.which("ffmpeg") is None,
reason="ffmpeg not installed locally; run these inside Docker",
)
def _rgb_frames(t, h=64, w=64):
"""Coloured checker frames so the encoder has real content."""
frames = np.zeros((t, h, w, 3), dtype=np.uint8)
for i in range(t):
frames[i, :, :, 0] = (i * 20) % 255
frames[i, :h // 2, :, 1] = 255
frames[i, :, :w // 2, 2] = 255
return frames
def test_frames_to_mp4_loop_produces_mp4_bytes():
frames = _rgb_frames(8)
data = frames_to_mp4_loop(frames, fps=16)
assert isinstance(data, bytes)
assert len(data) > 0
# MP4 files start with an ftyp box: 4 bytes size + 'ftyp'
assert data[4:8] == b"ftyp"
def test_frames_and_audio_to_mp4_produces_mp4_bytes():
frames = _rgb_frames(16)
# 1s silent audio at 24kHz
audio = np.zeros(24000, dtype=np.float32)
data = frames_and_audio_to_mp4(frames, audio, sample_rate=24000, fps=16)
assert isinstance(data, bytes)
assert len(data) > 0
assert data[4:8] == b"ftyp"
def test_frames_to_mp4_loop_rejects_empty():
with pytest.raises(ValueError):
frames_to_mp4_loop(np.empty((0, 64, 64, 3), dtype=np.uint8), fps=16)
def test_frames_and_audio_to_mp4_rejects_empty_audio():
frames = _rgb_frames(4)
with pytest.raises(ValueError):
frames_and_audio_to_mp4(
frames, np.empty(0, dtype=np.float32), sample_rate=24000, fps=16
)
def test_frames_to_mp4_loop_rejects_wrong_shape():
with pytest.raises(ValueError):
frames_to_mp4_loop(np.zeros((4, 64, 64), dtype=np.uint8), fps=16)
+144
View File
@@ -0,0 +1,144 @@
"""Unit test for the video-mode branch in ConversationSession.
Stubs every model involved (ASR, LLM, TTS, VideoEngine) so we can verify:
1. When video_engine is not ready, the existing PCM streaming path runs.
2. When video_engine IS ready, the per-chunk PCM sends are skipped and a
single ``speaking_clip`` JSON + MP4 binary is sent instead.
Pure asyncio; no CUDA, no real models.
"""
from __future__ import annotations
import asyncio
import types
from unittest.mock import MagicMock
import numpy as np
import pytest
from server.pipeline import ConversationSession
class _FakeVAD:
is_speaking = False
def process_chunk(self, _): return None
class _FakeASR:
def __init__(self, text="hello"):
self.text = text
def transcribe(self, _): return self.text
class _FakeLLM:
def __init__(self, response="Hi there."):
self.response = response
def generate(self, *_a, **_k):
return self.response, None
def trim_cache(self, state, _): return state
class _FakeTTSIterable:
"""Drop-in replacement for Kokoro's pipeline(..) generator."""
def __init__(self, chunks):
self._chunks = chunks
def __call__(self, segment, voice=None):
for i, audio in enumerate(self._chunks):
yield f"w{i}", None, audio
class _FakeTTSEngine:
def __init__(self, chunks):
self.pipeline = _FakeTTSIterable(chunks)
self.voice = "v"
self.sample_rate = 24000
class _FakeVideoEngineReady:
class _Cfg:
mode = "reflective"
cfg = _Cfg()
def __init__(self):
self.called_with = None
def is_ready(self): return True
def generate_speaking_clip(self, audio, sr, reply_text):
self.called_with = {"len": len(audio), "sr": sr, "reply": reply_text}
return b"FAKE_MP4_BYTES"
class _FakeModelsBase:
def __init__(self, tts_chunks):
self.asr_engine = _FakeASR()
self.llm_engine = _FakeLLM()
self.tts_engine = _FakeTTSEngine(tts_chunks)
def create_vad(self): return _FakeVAD()
class _FakeModelsStreaming(_FakeModelsBase):
video_engine = None
class _FakeModelsVideo(_FakeModelsBase):
def __init__(self, tts_chunks):
super().__init__(tts_chunks)
self.video_engine = _FakeVideoEngineReady()
@pytest.mark.asyncio
async def test_streaming_path_when_video_engine_absent():
json_sent: list = []
bytes_sent: list = []
async def send_json(d): json_sent.append(d)
async def send_bytes(b): bytes_sent.append(b)
chunks = [
np.ones(240, dtype=np.float32),
np.ones(480, dtype=np.float32),
]
models = _FakeModelsStreaming(tts_chunks=chunks)
session = ConversationSession(models, send_json, send_bytes)
await session._process_utterance(np.zeros(16000, dtype=np.float32))
# PCM bytes were sent (one per TTS chunk).
assert len(bytes_sent) == 2
# Per-chunk response_text messages were sent (not video's one-shot).
text_msgs = [m for m in json_sent if m.get("type") == "response_text"]
assert any(not m.get("final") for m in text_msgs)
# No speaking_clip envelope
assert not any(m.get("type") == "speaking_clip" for m in json_sent)
@pytest.mark.asyncio
async def test_video_path_when_engine_ready():
json_sent: list = []
bytes_sent: list = []
async def send_json(d): json_sent.append(d)
async def send_bytes(b): bytes_sent.append(b)
chunks = [
np.full(480, 0.5, dtype=np.float32),
np.full(480, 0.25, dtype=np.float32),
]
models = _FakeModelsVideo(tts_chunks=chunks)
session = ConversationSession(models, send_json, send_bytes)
await session._process_utterance(np.zeros(16000, dtype=np.float32))
# MP4 blob was sent once.
assert bytes_sent == [b"FAKE_MP4_BYTES"]
# speaking_clip envelope was sent exactly once.
envelopes = [m for m in json_sent if m.get("type") == "speaking_clip"]
assert len(envelopes) == 1
assert envelopes[0]["size_bytes"] == len(b"FAKE_MP4_BYTES")
assert envelopes[0]["text"] == "Hi there."
# The video engine received the concatenated audio.
ve = models.video_engine
assert ve.called_with is not None
assert ve.called_with["len"] == 960 # 480 + 480
assert ve.called_with["reply"] == "Hi there."
# No per-chunk PCM bytes were streamed (video path suppresses them).
# Only the MP4 blob is in bytes_sent.
assert len(bytes_sent) == 1
+119
View File
@@ -0,0 +1,119 @@
"""Unit tests for VideoConfig parsing and LoRASpec validation.
Pure-python, no model imports, no CUDA, no ffmpeg. Safe for Windows CI.
"""
import pytest
from server.video import VideoConfig, LoRASpec
def test_defaults_when_raw_is_empty():
cfg = VideoConfig.from_dict({})
assert cfg.enabled is False
assert cfg.backend == "lightx2v"
assert cfg.mode == "reflective"
assert cfg.resolution == 480
assert cfg.fps == 16
assert cfg.library_base_clip_count == 4
assert cfg.reflective_prompt_reply_words == 18
assert cfg.loras == []
def test_defaults_when_raw_is_none():
cfg = VideoConfig.from_dict(None) # type: ignore[arg-type]
assert cfg.enabled is False
def test_library_section_override():
cfg = VideoConfig.from_dict(
{"enabled": True, "mode": "library", "library": {"base_clip_count": 7, "base_clip_seconds": 3}}
)
assert cfg.enabled is True
assert cfg.mode == "library"
assert cfg.library_base_clip_count == 7
assert cfg.library_base_clip_seconds == 3
def test_reflective_section_override():
cfg = VideoConfig.from_dict(
{
"reflective": {
"clip_seconds": 9,
"clip_prompt_template": "my template: {reply_hint}",
"prompt_reply_words": 5,
}
}
)
assert cfg.reflective_clip_seconds == 9
assert cfg.reflective_prompt_template == "my template: {reply_hint}"
assert cfg.reflective_prompt_reply_words == 5
def test_lora_parse_minimal():
cfg = VideoConfig.from_dict({"loras": [{"path": "/tmp/a.safetensors"}]})
assert len(cfg.loras) == 1
lora = cfg.loras[0]
assert lora.path == "/tmp/a.safetensors"
assert lora.weight == 1.0
assert lora.target == "both"
assert lora.name is None
def test_lora_parse_full():
cfg = VideoConfig.from_dict(
{
"loras": [
{
"path": "/tmp/hi.safetensors",
"weight": 0.7,
"target": "high_noise",
"name": "hi-noise-style",
},
{
"path": "/tmp/lo.safetensors",
"weight": 0.4,
"target": "low_noise",
"name": "lo-noise-style",
},
]
}
)
assert len(cfg.loras) == 2
assert cfg.loras[0].target == "high_noise"
assert cfg.loras[0].name == "hi-noise-style"
assert cfg.loras[1].target == "low_noise"
assert cfg.loras[1].weight == 0.4
def test_lora_invalid_target_falls_back_to_both():
cfg = VideoConfig.from_dict(
{"loras": [{"path": "/tmp/x.safetensors", "target": "bogus"}]}
)
assert cfg.loras[0].target == "both"
def test_lora_entries_without_path_are_dropped():
cfg = VideoConfig.from_dict(
{"loras": [{"weight": 0.5}, {"path": "/tmp/ok.safetensors"}, None]}
)
assert len(cfg.loras) == 1
assert cfg.loras[0].path == "/tmp/ok.safetensors"
def test_models_section_override():
cfg = VideoConfig.from_dict(
{
"models": {
"wan22_base_repo": "/local/weights/wan22",
"wan22_fp8_repo": "/local/weights/wan22-fp8",
"wan22_config_json": "/local/cfg/fp8.json",
"wan22_model_cls": "wan2.2_moe",
"musetalk_path": "/local/weights/musetalk",
}
}
)
assert cfg.wan22_base_repo == "/local/weights/wan22"
assert cfg.wan22_fp8_repo == "/local/weights/wan22-fp8"
assert cfg.wan22_config_json == "/local/cfg/fp8.json"
assert cfg.wan22_model_cls == "wan2.2_moe"
assert cfg.musetalk_model_path == "/local/weights/musetalk"
+106
View File
@@ -0,0 +1,106 @@
"""Unit tests for pure-python logic inside VideoEngine.
No models are loaded: we instantiate ``VideoEngine`` and hand-stub its
``_wan22`` / ``_musetalk`` attributes to test prompt derivation, library
round-robin, and frame fitting.
"""
import numpy as np
import pytest
from server.video import VideoConfig, VideoEngine
@pytest.fixture
def engine():
cfg = VideoConfig.from_dict(
{
"enabled": True,
"mode": "reflective",
"fps": 16,
"reflective": {
"clip_prompt_template": "A: {reply_hint} B",
"prompt_reply_words": 5,
},
}
)
return VideoEngine(cfg)
def test_derive_prompt_truncates_to_word_limit(engine):
out = engine._derive_prompt("one two three four five six seven eight")
assert out == "A: one two three four five B"
def test_derive_prompt_handles_empty_reply(engine):
out = engine._derive_prompt("")
assert out == "A: calm and friendly B"
out2 = engine._derive_prompt(None) # type: ignore[arg-type]
assert out2 == "A: calm and friendly B"
def test_derive_prompt_strips_and_passes_through(engine):
out = engine._derive_prompt(" hello world ")
assert out == "A: hello world B"
def test_is_ready_false_without_models(engine):
# Models haven't been loaded — is_ready must be False so the pipeline
# falls back to the PCM streaming path.
assert engine.is_ready() is False
def test_pick_library_frames_round_robin(engine):
engine.cfg.mode = "library"
engine.cfg.fps = 2
# Two base clips, 4 frames each.
a = np.tile(np.array([[[[0, 0, 0]]]], dtype=np.uint8), (4, 1, 1, 1))
b = np.tile(np.array([[[[255, 255, 255]]]], dtype=np.uint8), (4, 1, 1, 1))
engine.speaking_base_frames = [a, b]
# 2s of audio at 16kHz → 4 frames at fps=2
audio = np.zeros(16000 * 2, dtype=np.float32)
f1 = engine._pick_library_frames(audio, 16000)
f2 = engine._pick_library_frames(audio, 16000)
f3 = engine._pick_library_frames(audio, 16000)
assert f1.shape == (4, 1, 1, 3)
assert f1[0, 0, 0, 0] == 0 # first pick = clip A
assert f2[0, 0, 0, 0] == 255 # second pick = clip B
assert f3[0, 0, 0, 0] == 0 # wraps back to A
def test_pick_library_frames_trims_to_audio_duration(engine):
engine.cfg.mode = "library"
engine.cfg.fps = 4
frames = np.zeros((20, 1, 1, 3), dtype=np.uint8)
engine.speaking_base_frames = [frames]
# 1s audio → 4 frames
audio = np.zeros(16000, dtype=np.float32)
out = engine._pick_library_frames(audio, 16000)
assert out.shape == (4, 1, 1, 3)
def test_pick_library_frames_loops_for_long_audio(engine):
engine.cfg.mode = "library"
engine.cfg.fps = 4
frames = np.zeros((4, 1, 1, 3), dtype=np.uint8)
engine.speaking_base_frames = [frames]
# 3s audio → 12 frames, base has only 4
audio = np.zeros(16000 * 3, dtype=np.float32)
out = engine._pick_library_frames(audio, 16000)
assert out.shape == (12, 1, 1, 3)
def test_pick_library_frames_raises_when_empty(engine):
engine.cfg.mode = "library"
engine.speaking_base_frames = []
with pytest.raises(RuntimeError, match="no pre-baked base clips"):
engine._pick_library_frames(np.zeros(100, dtype=np.float32), 16000)
def test_generate_speaking_clip_raises_when_not_ready(engine):
with pytest.raises(RuntimeError, match="not ready"):
engine.generate_speaking_clip(
audio_f32=np.zeros(100, dtype=np.float32),
sample_rate=16000,
reply_text="hi",
)