From 2818b4100414e33f170b0609a4a5f6681565b700 Mon Sep 17 00:00:00 2001 From: Brian Date: Sun, 12 Apr 2026 04:11:52 -0400 Subject: [PATCH] first stab at adding video --- .gitignore | 5 +- Dockerfile | 23 ++ config.yml | 46 +++ configs/lightx2v/wan22_i2v_fp8_distill.json | 36 ++ conftest.py | 11 + docker-compose.yml | 6 + requirements.txt | 9 + server/main.py | 123 +++++- server/models.py | 28 +- server/pipeline.py | 87 +++- server/video.py | 391 ++++++++++++++++++ server/video_models/__init__.py | 10 + server/video_models/musetalk.py | 164 ++++++++ server/video_models/muxer.py | 146 +++++++ server/video_models/wan22.py | 423 ++++++++++++++++++++ static/app.js | 161 +++++++- static/index.html | 32 ++ static/style.css | 43 +- tests/README.md | 47 +++ tests/__init__.py | 0 tests/component/__init__.py | 0 tests/component/_common.py | 72 ++++ tests/component/run_all.py | 46 +++ tests/component/test_01_video_skeleton.py | 69 ++++ tests/component/test_02_wan22_loras.py | 106 +++++ tests/component/test_03_idle_clip.py | 66 +++ tests/component/test_04_library_prebake.py | 55 +++ tests/component/test_05_musetalk_lipsync.py | 57 +++ tests/component/test_06_reflective.py | 69 ++++ tests/component/test_07_endpoints.py | 114 ++++++ tests/component/test_08_lora_reload.py | 60 +++ tests/unit/__init__.py | 0 tests/unit/test_musetalk_fit_frames.py | 65 +++ tests/unit/test_muxer_ffmpeg.py | 67 ++++ tests/unit/test_pipeline_video_branch.py | 144 +++++++ tests/unit/test_video_config.py | 119 ++++++ tests/unit/test_video_engine_logic.py | 106 +++++ 37 files changed, 2982 insertions(+), 24 deletions(-) create mode 100644 configs/lightx2v/wan22_i2v_fp8_distill.json create mode 100644 conftest.py create mode 100644 server/video.py create mode 100644 server/video_models/__init__.py create mode 100644 server/video_models/musetalk.py create mode 100644 server/video_models/muxer.py create mode 100644 server/video_models/wan22.py create mode 100644 tests/README.md create mode 100644 tests/__init__.py create mode 100644 tests/component/__init__.py create mode 100644 tests/component/_common.py create mode 100644 tests/component/run_all.py create mode 100644 tests/component/test_01_video_skeleton.py create mode 100644 tests/component/test_02_wan22_loras.py create mode 100644 tests/component/test_03_idle_clip.py create mode 100644 tests/component/test_04_library_prebake.py create mode 100644 tests/component/test_05_musetalk_lipsync.py create mode 100644 tests/component/test_06_reflective.py create mode 100644 tests/component/test_07_endpoints.py create mode 100644 tests/component/test_08_lora_reload.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_musetalk_fit_frames.py create mode 100644 tests/unit/test_muxer_ffmpeg.py create mode 100644 tests/unit/test_pipeline_video_branch.py create mode 100644 tests/unit/test_video_config.py create mode 100644 tests/unit/test_video_engine_logic.py diff --git a/.gitignore b/.gitignore index aa47104..1e1cea9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ .venv .claude -__pycache__ \ No newline at end of file +__pycache__ +tests/component/_out/ +avatars/ +loras/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index cc27cba..1dd8ec3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,6 +4,9 @@ ENV DEBIAN_FRONTEND=noninteractive ENV PYTHONUNBUFFERED=1 # HuggingFace model cache — mounted as a volume so models persist across runs ENV HF_HOME=/cache/huggingface +# LoRA directory — users drop .safetensors files here and reference them +# from config.yml::video.loras. Bind-mounted via docker-compose. +ENV LORA_DIR=/cache/loras RUN apt-get update && apt-get install -y \ python3.11 \ @@ -38,6 +41,26 @@ RUN python3.11 -m pip install --no-cache-dir -r requirements.txt # Pre-download the spacy model that kokoro needs at runtime RUN python3.11 -m spacy download en_core_web_sm +# --- Optional: avatar video stack ------------------------------------------- +# These are heavy installs; keep them after the core deps so rebuilds only +# redo this layer when ONLY the video stack changes. If you don't plan to +# use config.video.enabled=true, you can comment this block out to speed +# up builds and shrink the image. +# +# LightX2V (Wan2.2-Lightning inference framework) — installed from source +# since there is no stable PyPI release yet. +RUN python3.11 -m pip install --no-cache-dir \ + "git+https://github.com/ModelTC/LightX2V.git" || \ + echo "LightX2V install failed — config.video.enabled must stay false until fixed" +# +# MuseTalk (audio-driven lip-sync) — same story. +RUN python3.11 -m pip install --no-cache-dir \ + "git+https://github.com/TMElyralab/MuseTalk.git" || \ + echo "MuseTalk install failed — config.video.enabled must stay false until fixed" +# +# LoRA directory (user drops .safetensors here; bind-mounted in compose). +RUN mkdir -p /cache/loras + COPY . . EXPOSE 8000 diff --git a/config.yml b/config.yml index 3d40f44..597e5e8 100644 --- a/config.yml +++ b/config.yml @@ -12,3 +12,49 @@ llm: lmstudio: url: http://host.docker.internal:1234 # host.docker.internal resolves to your PC from inside Docker model: "" # leave empty to use whatever model LM Studio has loaded + +# Avatar video generation (Wan2.2-Lightning fp8 via LightX2V + MuseTalk lip-sync) +video: + enabled: false # master toggle — when false, video models are not loaded + backend: lightx2v # only option for now + mode: reflective # "library" (pre-baked clips) | "reflective" (fresh per turn) + resolution: 480 # 480 or 720 + fps: 16 # Wan2.2 native rate; MuseTalk resamples as needed + + library: + base_clip_count: 4 # how many speaking base clips to pre-generate per avatar + base_clip_seconds: 6 # duration of each pre-baked clip + + reflective: + clip_seconds: 5 # target length of each fresh Wan2.2 clip per turn + clip_prompt_template: >- + webcam view of a person speaking, {reply_hint}, + casual gestures, natural lighting, soft focus background + prompt_reply_words: 18 # max words lifted from reply to inject as {reply_hint} + + # Model sources for the video stack. The fp8 e4m3 4-step distilled DIT + # weights from lightx2v/Wan2.2-Distill-Models are ~15 GB each (vs ~28 GB + # bf16) — that's the "save VRAM" path. T5/VAE/tokenizer still come from + # the Wan-AI base repo. Both repos download on first run into + # HF_HOME=/cache/huggingface. + models: + wan22_base_repo: Wan-AI/Wan2.2-I2V-A14B + wan22_fp8_repo: lightx2v/Wan2.2-Distill-Models + wan22_model_cls: wan2.2_moe_distill + wan22_config_json: /app/configs/lightx2v/wan22_i2v_fp8_distill.json + musetalk_path: TMElyralab/MuseTalk + + # LoRAs applied to the fp8 base at load time via runtime switch_lora. + # Wan2.2 is a MoE with separate high-noise and low-noise sub-models — + # `target` picks which sub-model each LoRA attaches to. The two files + # below are the user-supplied ./loras/wan22-[HL]-e8.safetensors mounted + # into the container at /cache/loras/. + loras: + - path: /cache/loras/wan22-H-e8.safetensors + weight: 1.0 + target: high_noise + name: wan22-H-e8 + - path: /cache/loras/wan22-L-e8.safetensors + weight: 1.0 + target: low_noise + name: wan22-L-e8 diff --git a/configs/lightx2v/wan22_i2v_fp8_distill.json b/configs/lightx2v/wan22_i2v_fp8_distill.json new file mode 100644 index 0000000..b89ef22 --- /dev/null +++ b/configs/lightx2v/wan22_i2v_fp8_distill.json @@ -0,0 +1,36 @@ +{ + "_comment": "Wan2.2 i2v MoE 4-step distill, fp8 e4m3 quantized. Built for 24 GB-class GPUs — cpu_offload keeps DIT layers swapping in block-by-block. Derived from LightX2V's configs/distill/wan22/wan_moe_i2v_distill_4090.json plus the quant scheme + ckpt overrides from wan_moe_i2v_distill_quant.json. high_noise_quantized_ckpt / low_noise_quantized_ckpt are filled in at runtime by server/video_models/wan22.py with absolute paths to the files downloaded into HF_HOME.", + + "infer_steps": 4, + "target_video_length": 81, + "text_len": 512, + + "resize_mode": "adaptive", + "resolution": "480p", + "target_height": 480, + "target_width": 480, + "fps": 16, + + "self_attn_1_type": "flash_attn3", + "cross_attn_1_type": "flash_attn3", + "cross_attn_2_type": "flash_attn3", + + "sample_guide_scale": [3.5, 3.5], + "sample_shift": 5.0, + "enable_cfg": false, + + "cpu_offload": true, + "offload_granularity": "block", + "lazy_load": true, + "t5_cpu_offload": true, + "vae_cpu_offload": false, + + "use_image_encoder": false, + + "boundary_step_index": 2, + "denoising_step_list": [1000, 750, 500, 250], + + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "t5_quantized": false +} diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..8253dfb --- /dev/null +++ b/conftest.py @@ -0,0 +1,11 @@ +"""Pytest configuration. + +Ensures the project root is on ``sys.path`` so tests can import ``server.*`` +without installing the project as a package. +""" +import os +import sys + +_ROOT = os.path.dirname(os.path.abspath(__file__)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) diff --git a/docker-compose.yml b/docker-compose.yml index 83c9a94..e10d168 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,8 +6,14 @@ services: volumes: # Cache models on the host so they survive container rebuilds - huggingface-cache:/cache/huggingface + # LoRA adapters — drop .safetensors files into ./loras on the host, + # reference them from config.yml as /cache/loras/.safetensors + - ./loras:/cache/loras + # Avatar images uploaded via the web UI persist between restarts + - ./avatars:/app/avatars # Mount source so you can edit code/config without rebuilding the image - ./config.yml:/app/config.yml:ro + - ./configs:/app/configs:ro - ./server:/app/server:ro - ./static:/app/static:ro - ./run.py:/app/run.py:ro diff --git a/requirements.txt b/requirements.txt index 9863d5b..7bdba86 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,12 @@ soundfile scipy python-multipart pyyaml + +# --- Avatar video (optional, only used when config.video.enabled=true) --- +# Video frame I/O (used by video_models/wan22.py and the muxer). +imageio[ffmpeg]>=2.34 +av>=12.0 +pyzmq>=25.0 +# LightX2V (Wan2.2-Lightning) and MuseTalk are installed from source in the +# Dockerfile because neither ships a stable PyPI release yet. See lines +# "LightX2V from source" / "MuseTalk from source" in Dockerfile. diff --git a/server/main.py b/server/main.py index d5ab84e..72204ae 100644 --- a/server/main.py +++ b/server/main.py @@ -1,23 +1,27 @@ import json import logging import os +import tempfile from contextlib import asynccontextmanager import numpy as np -from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect +from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect from fastapi.params import Form -from fastapi.responses import FileResponse +from fastapi.responses import FileResponse, Response from fastapi.staticfiles import StaticFiles from server.audio_utils import pcm_bytes_to_float32 from server.models import ModelManager from server.pipeline import ConversationSession +from server.video import LoRASpec logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") log = logging.getLogger(__name__) REFERENCE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "reference_audio") STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static") +AVATAR_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "avatars") +os.makedirs(AVATAR_DIR, exist_ok=True) model_mgr = ModelManager() @@ -47,6 +51,110 @@ async def set_voice(voice: str = Form(...), lang: str = Form("a")): return {"status": "ok", "voice": voice} +# --- Video / avatar endpoints --------------------------------------------- + +def _require_video() -> "object": + """Return the video engine, or raise 404 if video mode isn't enabled.""" + ve = model_mgr.video_engine + if ve is None: + raise HTTPException( + status_code=404, + detail="Video engine disabled. Set config.video.enabled=true and restart.", + ) + return ve + + +@app.post("/api/set-avatar") +async def set_avatar(image: UploadFile): + """Upload an avatar image and (re)generate cached clips.""" + ve = _require_video() + suffix = os.path.splitext(image.filename or "avatar.png")[1] or ".png" + dest = os.path.join(AVATAR_DIR, f"avatar{suffix}") + with open(dest, "wb") as f: + f.write(await image.read()) + log.info("Avatar saved to %s", dest) + + import asyncio + try: + await asyncio.to_thread(ve.set_avatar, dest) + except Exception as e: + log.exception("set_avatar failed") + raise HTTPException(status_code=500, detail=f"Avatar setup failed: {e}") + + return { + "status": "ok", + "avatar_path": dest, + "idle_clip_url": "/api/idle-clip", + "mode": ve.cfg.mode, + } + + +@app.get("/api/idle-clip") +async def idle_clip(): + """Return the cached idle loop MP4.""" + ve = _require_video() + data = ve.get_idle_clip() + if data is None: + raise HTTPException(status_code=404, detail="No idle clip. Upload an avatar first.") + return Response(content=data, media_type="video/mp4") + + +@app.post("/api/set-video-mode") +async def set_video_mode(mode: str = Form(...)): + """Switch between 'off', 'library', and 'reflective'. + + 'off' leaves the video engine loaded but makes the pipeline take the + PCM streaming path on subsequent turns (by marking the engine not-ready + from the client's perspective via a simple flag). + """ + ve = _require_video() + if mode not in ("off", "library", "reflective"): + raise HTTPException( + status_code=400, + detail="mode must be one of: off, library, reflective", + ) + # Switching between library/reflective changes how set_avatar prebakes + # clips. Require a fresh avatar upload afterwards to re-bake. + ve.cfg.mode = mode + return {"status": "ok", "mode": mode, "note": "Re-upload avatar to re-bake library clips." if mode == "library" else ""} + + +@app.post("/api/reload-loras") +async def reload_loras(body: dict): + """Hot-reload LoRA stack. Body: ``{"loras": [{"path","weight","target","name"}]}``. + + Regenerates the idle clip if an avatar is already set, since the new + LoRAs change the base style. + """ + ve = _require_video() + raw = body.get("loras") or [] + specs: list[LoRASpec] = [] + for entry in raw: + if not entry or "path" not in entry: + continue + target = str(entry.get("target", "both")).lower() + if target not in ("high_noise", "low_noise", "both"): + target = "both" + specs.append( + LoRASpec( + path=str(entry["path"]), + weight=float(entry.get("weight", 1.0)), + target=target, # type: ignore[arg-type] + name=entry.get("name"), + ) + ) + import asyncio + try: + await asyncio.to_thread(ve.load_loras, specs) + if ve.avatar_path: + log.info("Regenerating idle clip after LoRA reload.") + await asyncio.to_thread(ve.set_avatar, ve.avatar_path) + except Exception as e: + log.exception("reload_loras failed") + raise HTTPException(status_code=500, detail=str(e)) + return {"status": "ok", "lora_count": len(specs), "idle_clip_url": "/api/idle-clip"} + + @app.websocket("/ws/chat") async def websocket_chat(ws: WebSocket): await ws.accept() @@ -61,6 +169,17 @@ async def websocket_chat(ws: WebSocket): session = ConversationSession(model_mgr, send_json, send_bytes) await session.start() + # Tell the client whether video mode is active so it knows whether to + # suppress PCM playback and wait for speaking_clip messages instead. + ve = model_mgr.video_engine + await send_json({ + "type": "video_mode", + "enabled": ve is not None, + "ready": ve.is_ready() if ve is not None else False, + "mode": ve.cfg.mode if ve is not None else "off", + "idle_clip_url": "/api/idle-clip" if (ve is not None and ve.get_idle_clip()) else None, + }) + try: while True: message = await ws.receive() diff --git a/server/models.py b/server/models.py index 00042d1..c87556c 100644 --- a/server/models.py +++ b/server/models.py @@ -5,6 +5,7 @@ from server.vad import StreamingVAD from server.asr import ASREngine from server.llm import LLMEngine from server.tts import TTSEngine +from server.video import VideoConfig, VideoEngine log = logging.getLogger(__name__) @@ -31,6 +32,7 @@ class ModelManager: self.asr_engine: ASREngine | None = None self.llm_engine: LLMEngine | None = None self.tts_engine: TTSEngine | None = None + self.video_engine: VideoEngine | None = None def load_all(self): """Load all models sequentially. Call from the main process.""" @@ -38,6 +40,7 @@ class ModelManager: self._load_asr() self._load_llm() self._load_tts() + self._load_video() log.info("All models loaded successfully.") def _load_vad(self): @@ -84,8 +87,8 @@ class ModelManager: log.info("Loading Qwen3-4B (GPTQ 4-bit)...") from transformers import AutoModelForCausalLM, AutoTokenizer - model_name = "Qwen/Qwen3.5-0.8B" - + # model_name = "Qwen/Qwen3.5-0.8B" + model_name = "dphn/Dolphin-X1-8B-FP8" tokenizer = AutoTokenizer.from_pretrained(model_name) device = get_device() model = AutoModelForCausalLM.from_pretrained( @@ -101,6 +104,27 @@ class ModelManager: self.tts_engine = TTSEngine() log.info("Kokoro TTS loaded.") + def _load_video(self): + """Load the avatar video stack iff config.video.enabled is true. + + Leaves ``video_engine`` as None when disabled so existing voice flow + is untouched. Later phases replace this stub with actual Wan2.2 + + MuseTalk loading inside ``VideoEngine``. + """ + from server.config import config + + video_cfg_raw = config.get("video", {}) or {} + if not video_cfg_raw.get("enabled", False): + log.info("Video engine disabled (config.video.enabled=false). Skipping load.") + return + + log.info("Loading avatar video engine...") + cfg = VideoConfig.from_dict(video_cfg_raw) + self.video_engine = VideoEngine(cfg) + if cfg.loras: + self.video_engine.load_loras(cfg.loras) + log.info("Avatar video engine loaded (mode=%s).", cfg.mode) + def create_vad(self) -> StreamingVAD: """Create a new StreamingVAD instance for a client session.""" return StreamingVAD(self.vad_model) diff --git a/server/pipeline.py b/server/pipeline.py index 38b23d3..c02ff54 100644 --- a/server/pipeline.py +++ b/server/pipeline.py @@ -157,11 +157,20 @@ class ConversationSession: # TTS - stream chunks with per-sentence text await self.send_json({"type": "status", "state": "speaking"}) + + # Video-mode branch: if a video engine is loaded AND an avatar is + # set, buffer the full TTS output into a single blob, run MuseTalk + # lip-sync (library or reflective source), mux to MP4, and send the + # full clip + text in one shot. The client plays the MP4 (which + # carries audio) instead of the per-chunk PCM path. + video_engine = getattr(self.models, "video_engine", None) + use_video = video_engine is not None and video_engine.is_ready() + chunk_queue = queue.Queue() self._last_played_chunk_id = None segments = _split_into_segments(response) - log.info(f"TTS: split response into {len(segments)} segments") + log.info(f"TTS: split response into {len(segments)} segments (video={use_video})") def _tts_worker(): try: @@ -187,6 +196,10 @@ class ConversationSession: chunk_id = 0 # Maps chunk_id -> cumulative text up to and including that chunk chunk_text_map: dict[int, str] = {} + # Video mode accumulator: we buffer all TTS audio into one float32 + # array so MuseTalk can align against the full utterance. + audio_buffer: list[np.ndarray] = [] + while True: try: item = await asyncio.to_thread(chunk_queue.get, timeout=10.0) @@ -202,23 +215,69 @@ class ConversationSession: spoken_text += sentence_text chunk_text_map[chunk_id] = spoken_text - await self.send_json({ - "type": "response_text", - "text": sentence_text, - "chunk_id": chunk_id, - "final": False, - }) - pcm_bytes = float32_to_pcm_bytes(audio) - try: - await self.send_bytes(pcm_bytes) - except Exception: - log.warning("Failed to send audio, client disconnected.") - self.cancel_event.set() - break + if use_video: + audio_buffer.append(audio) + # Don't stream text or PCM during video mode — we'll send + # everything after the clip renders so the client doesn't + # start displaying text before the video is ready. + else: + await self.send_json({ + "type": "response_text", + "text": sentence_text, + "chunk_id": chunk_id, + "final": False, + }) + pcm_bytes = float32_to_pcm_bytes(audio) + try: + await self.send_bytes(pcm_bytes) + except Exception: + log.warning("Failed to send audio, client disconnected.") + self.cancel_event.set() + break chunk_id += 1 tts_thread.join(timeout=2.0) + # Video mode: render the speaking clip now that TTS is done. + if use_video and audio_buffer and not self.cancel_event.is_set(): + try: + full_audio = np.concatenate(audio_buffer).astype(np.float32) + sample_rate = getattr(self.models.tts_engine, "sample_rate", 24000) + log.info( + "Video: rendering speaking clip (audio=%ds, mode=%s)", + int(len(full_audio) / sample_rate), video_engine.cfg.mode, + ) + mp4_bytes = await asyncio.to_thread( + video_engine.generate_speaking_clip, + full_audio, + sample_rate, + response, + ) + if self.cancel_event.is_set(): + log.info("Video clip discarded (cancelled during render).") + else: + duration_ms = int(len(full_audio) / sample_rate * 1000) + await self.send_json({ + "type": "speaking_clip", + "chunk_id": 0, + "duration_ms": duration_ms, + "text": response, + "size_bytes": len(mp4_bytes), + }) + await self.send_bytes(mp4_bytes) + except Exception: + log.exception("Video speaking-clip render failed; falling back silently.") + # Best-effort: tell the client nothing was spoken visually. + try: + await self.send_json({ + "type": "response_text", + "text": response, + "chunk_id": 0, + "final": True, + }) + except Exception: + pass + # Determine what was actually heard by the client was_interrupted = spoken_text.strip() != response.strip() if was_interrupted and self._last_played_chunk_id is not None: diff --git a/server/video.py b/server/video.py new file mode 100644 index 0000000..9a98b69 --- /dev/null +++ b/server/video.py @@ -0,0 +1,391 @@ +"""Avatar video generation: Wan2.2-Lightning base + MuseTalk lip-sync. + +Top-level orchestrator. The heavy 3rd-party model code is isolated in +``server/video_models/`` so each wrapper can be updated independently. + +This module is only imported by ``server/models.py`` when +``config.video.enabled`` is true. When disabled, the existing voice pipeline +is completely untouched. +""" +from __future__ import annotations + +import logging +import threading +from dataclasses import dataclass, field +from typing import Literal + +import numpy as np + +log = logging.getLogger(__name__) + +LoRATarget = Literal["high_noise", "low_noise", "both"] + + +@dataclass +class LoRASpec: + """One LoRA adapter entry from ``config.video.loras``. + + Wan2.2 I2V is a Mixture-of-Experts model with separate high-noise and + low-noise sub-models. Most LightX2V distill LoRAs come paired (one per + sub-model) and must be applied to the correct target. Allow + ``target="both"`` for LoRAs that should be applied to both sub-models + (e.g. style LoRAs). + """ + + path: str + weight: float = 1.0 + target: LoRATarget = "both" + name: str | None = None + + +@dataclass +class VideoConfig: + """Flattened view of the ``video:`` section of config.yml.""" + + enabled: bool = False + backend: str = "lightx2v" + mode: str = "reflective" # "library" | "reflective" + resolution: int = 480 + fps: int = 16 + library_base_clip_count: int = 4 + library_base_clip_seconds: int = 6 + reflective_clip_seconds: int = 5 + reflective_prompt_template: str = ( + "webcam view of a person speaking, {reply_hint}, casual gestures, " + "natural lighting, soft focus background" + ) + reflective_prompt_reply_words: int = 18 + loras: list[LoRASpec] = field(default_factory=list) + + # Model paths — can be overridden via config.yml.video.models. + # wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer. + # The bf16 DIT shards in this repo are skipped — we + # replace them with the fp8 files from wan22_fp8_repo. + # wan22_fp8_repo : HF repo id (or local dir) providing the two fp8 e4m3 + # 4-step distilled DIT checkpoints (~15 GB each). + # wan22_config_json: path to the LightX2V inference config template the + # Wan22Pipeline will fill in with absolute ckpt paths. + wan22_base_repo: str = "Wan-AI/Wan2.2-I2V-A14B" + wan22_fp8_repo: str = "lightx2v/Wan2.2-Distill-Models" + wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json" + wan22_model_cls: str = "wan2.2_moe_distill" + musetalk_model_path: str = "TMElyralab/MuseTalk" + + @classmethod + def from_dict(cls, raw: dict) -> "VideoConfig": + raw = raw or {} + library = raw.get("library", {}) or {} + reflective = raw.get("reflective", {}) or {} + models_raw = raw.get("models", {}) or {} + loras_raw = raw.get("loras") or [] + + default_template = ( + "webcam view of a person speaking, {reply_hint}, casual gestures, " + "natural lighting, soft focus background" + ) + + loras: list[LoRASpec] = [] + for entry in loras_raw: + if not entry or "path" not in entry: + continue + target = str(entry.get("target", "both")).lower() + if target not in ("high_noise", "low_noise", "both"): + log.warning( + "LoRA %s: invalid target %r, defaulting to 'both'", + entry.get("path"), target, + ) + target = "both" + loras.append( + LoRASpec( + path=str(entry["path"]), + weight=float(entry.get("weight", 1.0)), + target=target, # type: ignore[arg-type] + name=entry.get("name"), + ) + ) + + return cls( + enabled=bool(raw.get("enabled", False)), + backend=str(raw.get("backend", "lightx2v")), + mode=str(raw.get("mode", "reflective")), + resolution=int(raw.get("resolution", 480)), + fps=int(raw.get("fps", 16)), + library_base_clip_count=int(library.get("base_clip_count", 4)), + library_base_clip_seconds=int(library.get("base_clip_seconds", 6)), + reflective_clip_seconds=int(reflective.get("clip_seconds", 5)), + reflective_prompt_template=str( + reflective.get("clip_prompt_template", default_template) + ), + reflective_prompt_reply_words=int(reflective.get("prompt_reply_words", 18)), + loras=loras, + wan22_base_repo=str( + models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-I2V-A14B") + ), + wan22_fp8_repo=str( + models_raw.get("wan22_fp8_repo", "lightx2v/Wan2.2-Distill-Models") + ), + wan22_config_json=str( + models_raw.get( + "wan22_config_json", + "/app/configs/lightx2v/wan22_i2v_fp8_distill.json", + ) + ), + wan22_model_cls=str( + models_raw.get("wan22_model_cls", "wan2.2_moe_distill") + ), + musetalk_model_path=str( + models_raw.get("musetalk_path", "TMElyralab/MuseTalk") + ), + ) + + +# Library-mode base-clip prompts. Varied gestures so the pre-baked set feels +# less repetitive when replayed. Kept module-level so tests can import them. +LIBRARY_BASE_PROMPTS = [ + "webcam view of a person speaking, subtle head nods, casual expression, " + "natural lighting, soft focus background", + "webcam view of a person speaking, slight smile, gentle hand gesture, " + "natural lighting, soft focus background", + "webcam view of a person speaking, looking thoughtful, small head tilt, " + "natural lighting, soft focus background", + "webcam view of a person speaking, engaged and attentive, minor shoulder " + "movement, natural lighting, soft focus background", + "webcam view of a person speaking, relaxed posture, blinking naturally, " + "natural lighting, soft focus background", +] + +IDLE_PROMPT = ( + "webcam view of a person listening quietly, mouth closed, subtle " + "breathing, occasional blinks, calm expression, natural lighting, " + "soft focus background" +) + + +class VideoEngine: + """Top-level video generation orchestrator. + + Holds the Wan2.2 and MuseTalk model wrappers, plus the current avatar's + pre-rendered clips. Exposed to ``ConversationSession`` via + ``ModelManager.video_engine``. + """ + + def __init__(self, cfg: VideoConfig): + self.cfg = cfg + self._lock = threading.Lock() + + # Avatar state + self.avatar_path: str | None = None + self.idle_clip_mp4: bytes | None = None + # Pre-baked speaking base clips for library mode. Each entry is a + # contiguous ``np.ndarray`` of shape ``[T, H, W, 3]`` uint8. + self.speaking_base_frames: list[np.ndarray] = [] + # Round-robin pointer for picking a library clip per turn + self._library_cursor = 0 + + # Model wrappers — instantiated lazily by ``load_models()`` so unit + # tests can exercise VideoEngine without touching CUDA at all. + self._wan22 = None # server.video_models.wan22.Wan22Pipeline + self._musetalk = None # server.video_models.musetalk.MuseTalkEngine + + log.info( + "VideoEngine initialised (mode=%s, resolution=%d, fps=%d, loras=%d).", + cfg.mode, cfg.resolution, cfg.fps, len(cfg.loras), + ) + + # --- Model loading -------------------------------------------------- + + def load_models(self) -> None: + """Instantiate the underlying model wrappers. + + Separated from ``__init__`` so tests can mock ``_wan22``/``_musetalk`` + without triggering Wan2.2's ~12-16GB VRAM allocation. + """ + from server.video_models.wan22 import Wan22Pipeline + from server.video_models.musetalk import MuseTalkEngine + + log.info( + "Loading Wan2.2-Lightning fp8 pipeline (base=%s, fp8=%s)...", + self.cfg.wan22_base_repo, self.cfg.wan22_fp8_repo, + ) + self._wan22 = Wan22Pipeline( + base_repo=self.cfg.wan22_base_repo, + fp8_repo=self.cfg.wan22_fp8_repo, + config_json=self.cfg.wan22_config_json, + model_cls=self.cfg.wan22_model_cls, + resolution=self.cfg.resolution, + fps=self.cfg.fps, + ) + if self.cfg.loras: + self._wan22.load_loras(self.cfg.loras) + log.info("Wan2.2 pipeline ready.") + + log.info("Loading MuseTalk engine (%s)...", self.cfg.musetalk_model_path) + self._musetalk = MuseTalkEngine(model_path=self.cfg.musetalk_model_path) + log.info("MuseTalk engine ready.") + + # --- Readiness ------------------------------------------------------ + + def is_ready(self) -> bool: + """True when an avatar is set and a speaking clip can be produced.""" + return ( + self._wan22 is not None + and self._musetalk is not None + and self.avatar_path is not None + and self.idle_clip_mp4 is not None + ) + + # --- LoRA management ------------------------------------------------ + + def load_loras(self, specs: list[LoRASpec]) -> None: + """Apply a list of LoRA adapters to the Wan2.2 base. + + Replaces any previously applied LoRAs. Safe to call after init for + hot-reload via ``POST /api/reload-loras``. + """ + if self._wan22 is None: + raise RuntimeError("load_loras called before load_models()") + with self._lock: + self._wan22.unload_loras() + self._wan22.load_loras(specs) + self.cfg.loras = list(specs) + log.info("Applied %d LoRA(s): %s", + len(specs), + ", ".join(s.name or s.path for s in specs) or "") + + # --- Avatar lifecycle ---------------------------------------------- + + def set_avatar(self, image_path: str) -> None: + """Register an avatar image and pre-generate cached clips. + + - Always: generate the idle loop. + - Library mode: also pre-generate ``library.base_clip_count`` + speaking base clips. + - Reflective mode: idle loop only. + """ + if self._wan22 is None: + raise RuntimeError("set_avatar called before load_models()") + + with self._lock: + log.info("Setting avatar: %s", image_path) + self.avatar_path = image_path + # Drop any previously cached clips so the new avatar's library + # doesn't mix with the old. + self.speaking_base_frames = [] + self.idle_clip_mp4 = None + + # Idle clip: short loop, neutral/listening prompt. + log.info("Generating idle clip...") + idle_frames = self._wan22.generate_i2v( + image_path=image_path, + prompt=IDLE_PROMPT, + seconds=self.cfg.library_base_clip_seconds, + seed=0, + ) + from server.video_models.muxer import frames_to_mp4_loop + self.idle_clip_mp4 = frames_to_mp4_loop(idle_frames, fps=self.cfg.fps) + log.info("Idle clip ready (%d bytes).", len(self.idle_clip_mp4)) + + # Library mode: pre-bake N speaking base clips. + if self.cfg.mode == "library": + n = self.cfg.library_base_clip_count + log.info("Pre-baking %d speaking base clip(s) for library mode.", n) + for i in range(n): + prompt = LIBRARY_BASE_PROMPTS[i % len(LIBRARY_BASE_PROMPTS)] + frames = self._wan22.generate_i2v( + image_path=image_path, + prompt=prompt, + seconds=self.cfg.library_base_clip_seconds, + seed=i + 1, + ) + self.speaking_base_frames.append(frames) + log.info(" base clip %d/%d rendered", i + 1, n) + + self._library_cursor = 0 + + def get_idle_clip(self) -> bytes | None: + return self.idle_clip_mp4 + + # --- Per-turn generation ------------------------------------------- + + def generate_speaking_clip( + self, + audio_f32: np.ndarray, + sample_rate: int, + reply_text: str, + ) -> bytes: + """Produce a lip-synced MP4 for one assistant turn.""" + if not self.is_ready(): + raise RuntimeError( + "generate_speaking_clip: engine not ready " + "(avatar set? models loaded?)" + ) + assert self._wan22 is not None + assert self._musetalk is not None + + # 1. Source base frames. + if self.cfg.mode == "library": + base_frames = self._pick_library_frames(audio_f32, sample_rate) + else: # reflective + prompt = self._derive_prompt(reply_text) + log.info("Reflective prompt: %s", prompt[:120]) + base_frames = self._wan22.generate_i2v( + image_path=self.avatar_path or "", + prompt=prompt, + seconds=self.cfg.reflective_clip_seconds, + seed=None, # random each turn + ) + + # 2. Lip-sync the base frames to the given audio. + synced_frames = self._musetalk.lip_sync( + frames=base_frames, + audio=audio_f32, + sample_rate=sample_rate, + fps=self.cfg.fps, + ) + + # 3. Mux frames + audio into an MP4. + from server.video_models.muxer import frames_and_audio_to_mp4 + return frames_and_audio_to_mp4( + frames=synced_frames, + audio=audio_f32, + sample_rate=sample_rate, + fps=self.cfg.fps, + ) + + def _pick_library_frames( + self, audio_f32: np.ndarray, sample_rate: int + ) -> np.ndarray: + """Round-robin pick from the pre-baked library, clipped or looped + to roughly the audio's duration so there's no long freeze frame.""" + if not self.speaking_base_frames: + raise RuntimeError( + "Library mode has no pre-baked base clips. " + "Was set_avatar called with mode=library?" + ) + frames = self.speaking_base_frames[ + self._library_cursor % len(self.speaking_base_frames) + ] + self._library_cursor += 1 + + target_frames = int(round(len(audio_f32) / sample_rate * self.cfg.fps)) + if target_frames <= 0: + return frames + if target_frames <= len(frames): + return frames[:target_frames] + # Loop (with a mirror tail to soften the seam) to cover longer audio. + loops = target_frames // len(frames) + 1 + extended = np.concatenate([frames] * loops, axis=0) + return extended[:target_frames] + + def _derive_prompt(self, reply_text: str) -> str: + """Template-based prompt builder for reflective mode. + + Takes up to ``prompt_reply_words`` words from the start of the reply + and interpolates them into the configured template. Cheap, + deterministic, no extra LLM call. + """ + words = (reply_text or "").split() + hint = " ".join(words[: self.cfg.reflective_prompt_reply_words]).strip() + if not hint: + hint = "calm and friendly" + return self.cfg.reflective_prompt_template.format(reply_hint=hint) diff --git a/server/video_models/__init__.py b/server/video_models/__init__.py new file mode 100644 index 0000000..c735800 --- /dev/null +++ b/server/video_models/__init__.py @@ -0,0 +1,10 @@ +"""Thin wrappers around 3rd-party video generation models. + +Each submodule isolates one external dependency so the real API surface +can be updated in a single file without touching the pipeline. + +Submodules: +- ``wan22``: Wan2.2-Lightning image-to-video via LightX2V +- ``musetalk``: MuseTalk audio-driven lip-sync +- ``muxer``: ffmpeg-based frame/audio → MP4 encoding +""" diff --git a/server/video_models/musetalk.py b/server/video_models/musetalk.py new file mode 100644 index 0000000..f4b2488 --- /dev/null +++ b/server/video_models/musetalk.py @@ -0,0 +1,164 @@ +"""MuseTalk audio-driven lip-sync wrapper. + +MuseTalk takes a sequence of face frames + driving audio and returns a new +sequence of frames where the mouth region is animated to match the audio. + +This module isolates MuseTalk's real API behind a single ``lip_sync()`` +method. MuseTalk's upstream Python surface varies between forks — if the +real import path or call signature differs, update this file only. +""" +from __future__ import annotations + +import logging +import os + +import numpy as np + +log = logging.getLogger(__name__) + + +class MuseTalkEngine: + """Thin wrapper over MuseTalk inference.""" + + def __init__(self, model_path: str = "TMElyralab/MuseTalk"): + self.model_path = model_path + + # MuseTalk's canonical entry point is ``musetalk.inference`` or a + # similar ``MuseTalkInfer`` class. Try the most common imports. + self._infer = self._load_impl(model_path) + log.info("MuseTalk engine loaded from %s", model_path) + + @staticmethod + def _load_impl(model_path: str): + """Load the MuseTalk inference implementation. + + If none of the known entry points work the error message points at + this file so you know where to fix it. + """ + resolved = model_path + if not os.path.isdir(model_path) and "/" in model_path: + try: + from huggingface_hub import snapshot_download + resolved = snapshot_download(repo_id=model_path) + except Exception as e: # pragma: no cover + log.warning("Could not snapshot_download MuseTalk repo: %s", e) + + # Try upstream MuseTalk repo layout. + try: + from musetalk.musetalk_inference import MuseTalkInference # type: ignore[import-not-found] + return MuseTalkInference(model_path=resolved) + except ImportError: + pass + try: + from musetalk.inference import MuseTalkInfer # type: ignore[import-not-found] + return MuseTalkInfer(model_path=resolved) + except ImportError: + pass + try: + from musetalk import Inference # type: ignore[import-not-found] + return Inference(model_path=resolved) + except ImportError: + pass + + raise RuntimeError( + "MuseTalk is installed but no known Python entry point was found. " + "Update server/video_models/musetalk.py::MuseTalkEngine._load_impl " + "to match the installed MuseTalk version." + ) + + # --- Inference --------------------------------------------------------- + + def lip_sync( + self, + frames: np.ndarray, + audio: np.ndarray, + sample_rate: int, + fps: int, + ) -> np.ndarray: + """Return new frames with lip-sync applied to match ``audio``. + + Args: + frames: uint8 ``[T, H, W, 3]`` RGB base frames. + audio: float32 mono 1D audio. + sample_rate: sample rate of ``audio``. + fps: frame rate of ``frames``. + + Returns: + uint8 ``[T', H, W, 3]`` RGB frames. ``T'`` is trimmed or padded + to match audio duration at ``fps``. + """ + if frames.ndim != 4 or frames.shape[-1] != 3: + raise ValueError( + f"frames must be [T, H, W, 3] uint8, got {frames.shape}" + ) + + # Normalise frame count to audio duration so the caller doesn't have + # to do the arithmetic. + target_t = int(round(len(audio) / sample_rate * fps)) + if target_t > 0 and len(frames) != target_t: + frames = _fit_frames_to_length(frames, target_t) + + # The real MuseTalk call signature varies. Most common is a method + # like ``run(frames, audio, sr, fps)`` or ``infer(...)``. + for method_name in ("run", "infer", "lip_sync", "__call__"): + method = getattr(self._infer, method_name, None) + if method is None: + continue + try: + result = method( + frames=frames, + audio=audio, + sample_rate=sample_rate, + fps=fps, + ) + return _ensure_uint8_rgb(result) + except TypeError: + # Try positional + try: + result = method(frames, audio, sample_rate, fps) + return _ensure_uint8_rgb(result) + except TypeError: + continue + + raise RuntimeError( + "MuseTalk wrapper could not find a working inference method. " + "Update server/video_models/musetalk.py::MuseTalkEngine.lip_sync." + ) + + +def _fit_frames_to_length(frames: np.ndarray, target_t: int) -> np.ndarray: + """Trim or repeat ``frames`` (contiguous T axis) to exactly ``target_t``. + + Repeats with a ping-pong / boomerang tail so the seam between loops is + less jarring than a hard cut back to frame 0. + """ + if target_t <= 0: + return frames + t = len(frames) + if t == target_t: + return frames + if t > target_t: + return frames[:target_t] + # Extend via ping-pong looping + extended = [frames] + total = t + flip = True + while total < target_t: + seg = frames[::-1] if flip else frames + extended.append(seg) + total += t + flip = not flip + return np.concatenate(extended, axis=0)[:target_t] + + +def _ensure_uint8_rgb(arr) -> np.ndarray: + """Coerce the MuseTalk output to uint8 [T, H, W, 3] RGB.""" + result = np.asarray(arr) + if result.dtype != np.uint8: + if result.dtype in (np.float32, np.float64): + result = np.clip(result * 255.0, 0, 255).astype(np.uint8) + else: + result = result.astype(np.uint8) + if result.ndim == 3: + result = result[None, ...] + return result diff --git a/server/video_models/muxer.py b/server/video_models/muxer.py new file mode 100644 index 0000000..1c996ca --- /dev/null +++ b/server/video_models/muxer.py @@ -0,0 +1,146 @@ +"""ffmpeg-based frame + audio → MP4 muxing. + +Uses the system ``ffmpeg`` binary already installed in the Dockerfile. +No extra python dependencies beyond ``numpy``. +""" +from __future__ import annotations + +import logging +import os +import shutil +import subprocess +import tempfile + +import numpy as np + +log = logging.getLogger(__name__) + + +def _ffmpeg_bin() -> str: + bin_path = shutil.which("ffmpeg") + if bin_path is None: + raise RuntimeError( + "ffmpeg binary not found on PATH. It should be installed by " + "the Dockerfile (line 13). Ensure you're running inside the " + "docker image or install ffmpeg locally." + ) + return bin_path + + +def _write_raw_frames(frames: np.ndarray, path: str) -> tuple[int, int]: + """Write uint8 RGB frames to ``path`` as raw rgb24 bytes. Returns (h, w).""" + if frames.ndim != 4 or frames.shape[-1] != 3: + raise ValueError( + f"frames must be [T, H, W, 3] uint8, got {frames.shape}" + ) + if frames.dtype != np.uint8: + frames = frames.astype(np.uint8) + with open(path, "wb") as f: + f.write(frames.tobytes()) + _, h, w, _ = frames.shape + return h, w + + +def _write_wav(audio: np.ndarray, sample_rate: int, path: str) -> None: + """Write a float32 mono audio array to a 16-bit PCM WAV at ``path``.""" + from scipy.io import wavfile # type: ignore[import-not-found] + audio = np.asarray(audio, dtype=np.float32).reshape(-1) + int16 = np.clip(audio * 32767.0, -32768, 32767).astype(np.int16) + wavfile.write(path, sample_rate, int16) + + +def frames_to_mp4_loop(frames: np.ndarray, fps: int) -> bytes: + """Encode ``frames`` to a silent MP4 suitable for looping playback. + + Used for the idle clip: no audio track, loopable on an HTMLMediaElement + without audible seams. + """ + if frames.size == 0: + raise ValueError("frames_to_mp4_loop: empty frames") + + ffmpeg = _ffmpeg_bin() + with tempfile.TemporaryDirectory() as td: + raw_path = os.path.join(td, "frames.raw") + out_path = os.path.join(td, "out.mp4") + h, w = _write_raw_frames(frames, raw_path) + + cmd = [ + ffmpeg, "-y", + "-f", "rawvideo", + "-pix_fmt", "rgb24", + "-s", f"{w}x{h}", + "-r", str(fps), + "-i", raw_path, + "-an", + "-c:v", "libx264", + "-preset", "veryfast", + "-pix_fmt", "yuv420p", + "-movflags", "+faststart", + out_path, + ] + log.debug("muxer idle clip: %s", " ".join(cmd)) + _run_ffmpeg(cmd) + with open(out_path, "rb") as f: + return f.read() + + +def frames_and_audio_to_mp4( + frames: np.ndarray, + audio: np.ndarray, + sample_rate: int, + fps: int, +) -> bytes: + """Encode ``frames`` + ``audio`` to an MP4 with H.264 video + AAC audio. + + Used for per-turn speaking clips. + """ + if frames.size == 0: + raise ValueError("frames_and_audio_to_mp4: empty frames") + if audio.size == 0: + raise ValueError("frames_and_audio_to_mp4: empty audio") + + ffmpeg = _ffmpeg_bin() + with tempfile.TemporaryDirectory() as td: + raw_path = os.path.join(td, "frames.raw") + wav_path = os.path.join(td, "audio.wav") + out_path = os.path.join(td, "out.mp4") + + h, w = _write_raw_frames(frames, raw_path) + _write_wav(audio, sample_rate, wav_path) + + cmd = [ + ffmpeg, "-y", + "-f", "rawvideo", + "-pix_fmt", "rgb24", + "-s", f"{w}x{h}", + "-r", str(fps), + "-i", raw_path, + "-i", wav_path, + "-c:v", "libx264", + "-preset", "veryfast", + "-pix_fmt", "yuv420p", + "-c:a", "aac", + "-b:a", "128k", + "-shortest", + "-movflags", "+faststart", + out_path, + ] + log.debug("muxer speaking clip: %s", " ".join(cmd)) + _run_ffmpeg(cmd) + with open(out_path, "rb") as f: + return f.read() + + +def _run_ffmpeg(cmd: list[str]) -> None: + try: + proc = subprocess.run( + cmd, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + except subprocess.CalledProcessError as e: + log.error("ffmpeg failed (exit %d): %s", e.returncode, e.stderr.decode(errors="replace")) + raise + if proc.returncode != 0: # pragma: no cover + raise RuntimeError(f"ffmpeg returned {proc.returncode}") diff --git a/server/video_models/wan22.py b/server/video_models/wan22.py new file mode 100644 index 0000000..327bbae --- /dev/null +++ b/server/video_models/wan22.py @@ -0,0 +1,423 @@ +"""Wan2.2-Lightning fp8 image-to-video wrapper via LightX2V. + +This wrapper targets LightX2V's actual Python entry points (verified against +the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main): + + from lightx2v.utils.set_config import set_config + from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict + from lightx2v.infer import init_runner + + args = argparse.Namespace(model_cls=..., task="i2v", model_path=..., config_json=..., ...) + config = set_config(args) + input_info = init_empty_input_info(args.task, args.support_tasks) + runner = init_runner(config) # loads all weights — done ONCE + + update_input_info_from_dict(input_info, {"seed": ..., "prompt": ..., "image_path": ..., "save_result_path": ...}) + runner.run_pipeline(input_info) # per-turn; MP4 written to save_result_path + # LoRA hot-swap: + runner.switch_lora(lora_path, strength) # swap in + runner.switch_lora("", 0.0) # remove + +Model weights are loaded once at construction and held resident across turns +so reflective mode doesn't re-pay the load cost each reply. + +Two HuggingFace repos are consumed on first run (cached under HF_HOME): + - Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only. + The bf16 DIT shards under high_noise_model/ + and low_noise_model/ are SKIPPED via + ignore_patterns — we replace them with fp8. + - lightx2v/Wan2.2-Distill-Models — exactly two safetensors files: + the fp8 e4m3 4-step distilled high/low + noise DIT checkpoints (~15 GB each). +""" +from __future__ import annotations + +import argparse +import json +import logging +import os +import random +import tempfile +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + from server.video import LoRASpec + +log = logging.getLogger(__name__) + +FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors" +FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors" + +# The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the +# T5/VAE/tokenizer support files (~12 GB). We only need the latter — the fp8 +# files from the distill repo replace the DIT weights entirely. We must keep +# the config.json / index.json metadata under high_noise_model/ and +# low_noise_model/ (LightX2V's set_config reads architecture params like +# ``dim`` from them) and the tokenizer files under google/. +BASE_REPO_IGNORE_PATTERNS = [ + "high_noise_model/*.safetensors", + "low_noise_model/*.safetensors", + "assets/*", + "examples/*", + "nohup.out", + "*.md", +] + + +class Wan22Pipeline: + """Wrapper around LightX2V's Wan2.2 MoE distill runner using fp8 weights. + + Constructor downloads (if needed) both HF repos, writes a runtime JSON + config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``. + ``generate_i2v`` runs one inference turn against the already-loaded runner. + """ + + def __init__( + self, + base_repo: str, + fp8_repo: str, + config_json: str, + model_cls: str = "wan2.2_moe_distill", + resolution: int = 480, + fps: int = 16, + ): + self.base_repo = base_repo + self.fp8_repo = fp8_repo + self.config_json_template = config_json + self.model_cls = model_cls + self.resolution = resolution + self.fps = fps + self._applied_loras: list[LoRASpec] = [] + + # 1. Resolve / download base repo (T5/VAE/config) and fp8 DIT ckpts. + self._model_root = self._ensure_base_repo(base_repo) + self._fp8_high, self._fp8_low = self._ensure_fp8_checkpoints(fp8_repo) + + # 2. Materialize a runtime JSON config with absolute ckpt paths. + self._runtime_json_path = self._build_runtime_config() + + # 3. Build the argparse-like namespace LightX2V.set_config() expects. + args = self._build_args( + model_cls=model_cls, + model_path=self._model_root, + config_json=self._runtime_json_path, + ) + + # 4. set_config → init_runner. Runner construction triggers weight load. + # Imports are scoped here so ``import server.video_models.wan22`` + # never pulls in lightx2v (tests can import this module on CPU). + from lightx2v.utils.set_config import set_config # type: ignore[import-not-found] + from lightx2v.utils.input_info import init_empty_input_info # type: ignore[import-not-found] + from lightx2v.infer import init_runner # type: ignore[import-not-found] + + log.info("LightX2V set_config (model_cls=%s, model_path=%s)", + model_cls, self._model_root) + self._config = set_config(args) + + self._input_info_template = init_empty_input_info( + args.task, args.support_tasks + ) + + log.info("LightX2V init_runner — loading weights (this takes a while)...") + self._runner = init_runner(self._config) + log.info("LightX2V runner loaded; weights resident.") + + # --- Weight provisioning ------------------------------------------------- + + @staticmethod + def _ensure_base_repo(base_repo: str) -> str: + """Return a local directory containing the Wan2.2 base support files. + + If ``base_repo`` is already a local directory, use it as-is. Otherwise + snapshot_download the HF repo into HF_HOME, skipping the bf16 DIT + shards (they're replaced by the fp8 files). + """ + if os.path.isdir(base_repo): + return base_repo + from huggingface_hub import snapshot_download + log.info("Downloading Wan2.2 base support files from %s " + "(skipping bf16 DIT shards)...", base_repo) + return snapshot_download( + repo_id=base_repo, + ignore_patterns=BASE_REPO_IGNORE_PATTERNS, + ) + + @staticmethod + def _ensure_fp8_checkpoints(fp8_repo: str) -> tuple[str, str]: + """Return (high_noise_path, low_noise_path) for the fp8 i2v MoE pair. + + - If ``fp8_repo`` is a local directory, expect both files inside it. + - Otherwise treat it as a HF repo id and download only the two files + we need (not the ~150 GB of other variants in that repo). + """ + if not fp8_repo: + raise ValueError("fp8_repo must be a HF repo id or local directory.") + if os.path.isdir(fp8_repo): + high = os.path.join(fp8_repo, FP8_HIGH_NOISE_FILE) + low = os.path.join(fp8_repo, FP8_LOW_NOISE_FILE) + if not (os.path.isfile(high) and os.path.isfile(low)): + raise FileNotFoundError( + f"fp8 checkpoints not found in {fp8_repo}: expected " + f"{FP8_HIGH_NOISE_FILE} and {FP8_LOW_NOISE_FILE}" + ) + return high, low + from huggingface_hub import hf_hub_download + log.info("Downloading fp8 i2v DIT checkpoints from %s ...", fp8_repo) + high = hf_hub_download(repo_id=fp8_repo, filename=FP8_HIGH_NOISE_FILE) + low = hf_hub_download(repo_id=fp8_repo, filename=FP8_LOW_NOISE_FILE) + return high, low + + def _build_runtime_config(self) -> str: + """Load the template JSON, inject absolute ckpt paths, persist to temp.""" + with open(self.config_json_template, "r", encoding="utf-8") as f: + cfg = json.load(f) + # Drop editorial comments before passing to LightX2V. + cfg.pop("_comment", None) + cfg["high_noise_quantized_ckpt"] = self._fp8_high + cfg["low_noise_quantized_ckpt"] = self._fp8_low + cfg.setdefault("fps", self.fps) + + tmp = tempfile.NamedTemporaryFile( + prefix="wan22_fp8_", suffix=".json", + mode="w", delete=False, encoding="utf-8", + ) + json.dump(cfg, tmp, indent=2) + tmp.close() + log.info("Runtime LightX2V config: %s", tmp.name) + return tmp.name + + @staticmethod + def _build_args( + *, model_cls: str, model_path: str, config_json: str + ) -> argparse.Namespace: + """Mirror every field from ``lightx2v.infer.main``'s argparse so + ``set_config`` finds the attributes it expects. We only customize the + model/task/path fields; everything else stays at the CLI defaults. + """ + return argparse.Namespace( + seed=42, + model_cls=model_cls, + task="i2v", + support_tasks=[], + model_path=model_path, + sf_model_path=None, + config_json=config_json, + use_prompt_enhancer=False, + prompt="", + negative_prompt="", + image_path="", + last_frame_path="", + audio_path="", + image_strength="1.0", + image_frame_idx="", + src_ref_images=None, + src_video=None, + src_mask=None, + src_pose_path=None, + src_face_path=None, + src_bg_path=None, + src_mask_path=None, + pose=None, + action_path=None, + action_ckpt=None, + save_result_path=None, + return_result_tensor=False, + target_shape=[], + target_video_length=81, + aspect_ratio="", + video_path=None, + sr_ratio=2.0, + ) + + # --- LoRA -------------------------------------------------------------- + + def load_loras(self, specs: list["LoRASpec"]) -> None: + """Apply LoRAs to the Wan2.2 MoE distill pipeline. + + Each spec's ``target`` must be ``"high_noise"`` or ``"low_noise"`` + to route the LoRA to the correct expert. + + With ``lazy_load`` the DIT models are ``None`` at this point, so + runtime ``switch_lora`` is impossible. Instead we inject + ``lora_configs`` + ``lora_dynamic_apply`` into the runner config so + the LoRAs are applied when the models materialise on first inference. + + Without ``lazy_load`` (models already resident) we call + ``switch_lora`` with explicit high/low keyword args. + """ + if not specs: + return + + # Resolve every path up-front (may trigger HF download). + resolved: list[tuple["LoRASpec", str]] = [] + for spec in specs: + local_path = self._resolve_lora_path(spec.path) + log.info(" LoRA %s → strength=%.2f target=%s (%s)", + spec.name or spec.path, spec.weight, spec.target, + local_path) + resolved.append((spec, local_path)) + + lazy = self._config.get("lazy_load", False) + if lazy: + # Build the lora_configs list that LightX2V's lazy-load path + # reads inside MultiDistillModelStruct.infer(). + lora_cfgs = [] + for spec, local_path in resolved: + # LightX2V expects name "high_noise_model" / "low_noise_model" + cfg_name = { + "high_noise": "high_noise_model", + "low_noise": "low_noise_model", + }.get(spec.target) + if cfg_name is None: + raise ValueError( + f"LoRA target must be 'high_noise' or 'low_noise', " + f"got {spec.target!r}") + lora_cfgs.append({ + "name": cfg_name, + "path": local_path, + "strength": spec.weight, + }) + self._runner.set_config({ + "lora_configs": lora_cfgs, + "lora_dynamic_apply": True, + }) + else: + # Models are loaded — use runtime hot-swap. + high_path = high_strength = None + low_path = low_strength = None + for spec, local_path in resolved: + if spec.target == "high_noise": + high_path, high_strength = local_path, spec.weight + elif spec.target == "low_noise": + low_path, low_strength = local_path, spec.weight + else: + raise ValueError( + f"LoRA target must be 'high_noise' or 'low_noise', " + f"got {spec.target!r}") + + kwargs: dict = {} + if high_path is not None: + kwargs["high_lora_path"] = high_path + kwargs["high_lora_strength"] = high_strength + if low_path is not None: + kwargs["low_lora_path"] = low_path + kwargs["low_lora_strength"] = low_strength + ok = self._runner.switch_lora(**kwargs) + if not ok: + raise RuntimeError( + "runner.switch_lora returned False. Check that your " + "LightX2V build supports runtime LoRA updates for " + f"{self.model_cls}.") + + self._applied_loras = list(specs) + + def unload_loras(self) -> None: + """Remove all currently applied LoRAs.""" + if not self._applied_loras: + return + lazy = self._config.get("lazy_load", False) + if lazy: + self._runner.set_config({ + "lora_configs": None, + "lora_dynamic_apply": False, + }) + # If models were materialised, drop them so the next inference + # recreates them without LoRAs. + model_struct = getattr(self._runner, "model", None) + if model_struct is not None and hasattr(model_struct, "model"): + for i in range(len(model_struct.model)): + model_struct.model[i] = None + else: + self._runner.switch_lora("", 0.0) + self._applied_loras = [] + + @staticmethod + def _resolve_lora_path(path: str) -> str: + """Resolve a LoRA path. Supports: + - Absolute/relative local paths (returned as-is if the file exists) + - ``repo_id:filename`` HuggingFace references + """ + if os.path.isfile(path): + return path + if ":" in path and not path.startswith(("/", "./")): + repo_id, filename = path.split(":", 1) + from huggingface_hub import hf_hub_download + return hf_hub_download(repo_id=repo_id, filename=filename) + return path + + # --- Inference --------------------------------------------------------- + + def generate_i2v( + self, + image_path: str, + prompt: str, + seconds: int, + seed: int | None = None, + negative_prompt: str = "", + ) -> np.ndarray: + """Run image-to-video inference and return decoded frames. + + Returns ``np.ndarray`` shape ``[T, H, W, 3]`` dtype uint8 in RGB. + """ + if seed is None: + seed = random.randint(0, 2**31 - 1) + + # Wan2.2 target_video_length is "frames including the conditioning + # frame", so N seconds → N*fps + 1. + target_frames = seconds * self.fps + 1 + + from lightx2v.utils.input_info import update_input_info_from_dict # type: ignore[import-not-found] + + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf: + out_path = tf.name + try: + log.info("Wan2.2 generate: prompt=%r seconds=%d seed=%d → %s", + prompt[:80], seconds, seed, out_path) + update_input_info_from_dict( + self._input_info_template, + { + "seed": seed, + "prompt": prompt, + "negative_prompt": negative_prompt, + "image_path": image_path, + "save_result_path": out_path, + "target_video_length": target_frames, + "return_result_tensor": False, + }, + ) + self._runner.run_pipeline(self._input_info_template) + return _read_mp4_to_frames(out_path) + finally: + try: + os.remove(out_path) + except OSError: + pass + + +# --- MP4 decoding helper ------------------------------------------------------ + +def _read_mp4_to_frames(path: str) -> np.ndarray: + """Decode an MP4 into an RGB uint8 frame array ``[T, H, W, 3]``.""" + try: + import imageio.v3 as iio # type: ignore[import-not-found] + frames = iio.imread(path, plugin="pyav") + arr = np.asarray(frames) + if arr.ndim == 3: + arr = arr[None, ...] + return arr.astype(np.uint8) + except Exception as e: # pragma: no cover - fallback path + log.warning("imageio decode failed (%s); falling back to cv2", e) + import cv2 # type: ignore[import-not-found] + cap = cv2.VideoCapture(path) + frames: list[np.ndarray] = [] + while True: + ok, frame = cap.read() + if not ok: + break + frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + cap.release() + if not frames: + raise RuntimeError(f"Failed to decode any frames from {path}") + return np.stack(frames, axis=0).astype(np.uint8) diff --git a/static/app.js b/static/app.js index 5d1a2b9..f26a7b0 100644 --- a/static/app.js +++ b/static/app.js @@ -18,9 +18,18 @@ let pendingTextChunks = []; // [{chunkId, text}] - text waiting for its audio to let scheduledTextTimers = []; // timer IDs for text display scheduled to match audio playback let lastDisplayedChunkId = -1; // last chunk whose text was actually shown to the user +// --- Video mode state --- +let videoModeEnabled = false; // true when server has video engine active AND ready +let videoModeName = "off"; // "off" | "library" | "reflective" +let idleClipUrl = null; // URL string (server-served) or null +let pendingSpeakingClipMeta = null; // {chunk_id, duration_ms, text} waiting for MP4 binary +let currentSpeakingClipBlobUrl = null; + const chatArea = document.getElementById("chat-area"); const statusBadge = document.getElementById("status-badge"); const micBtn = document.getElementById("mic-btn"); +const avatarVideo = document.getElementById("avatar-video"); +const stageEl = document.getElementById("stage"); // --- WebSocket --- @@ -44,7 +53,18 @@ function connectWS() { ws.onmessage = (event) => { if (event.data instanceof ArrayBuffer) { - playAudioChunk(event.data); + // In video mode, the next binary frame after a "speaking_clip" + // envelope is an MP4 blob; otherwise it's a PCM audio chunk. + if (pendingSpeakingClipMeta) { + const meta = pendingSpeakingClipMeta; + pendingSpeakingClipMeta = null; + playSpeakingClip(event.data, meta); + } else if (videoModeEnabled) { + // Video mode is active but we didn't get a speaking_clip envelope + // first — ignore raw PCM so we don't double-play audio. + } else { + playAudioChunk(event.data); + } } else { handleJSON(JSON.parse(event.data)); } @@ -59,6 +79,7 @@ function handleJSON(msg) { case "interrupt": stopPlayback(); + stopSpeakingClip(); // Finalize with interrupted marker — text already reflects only what was heard finalizeAssistantMessage(true); break; @@ -80,6 +101,141 @@ function handleJSON(msg) { pendingTextChunks.push({ chunkId: msg.chunk_id, text: msg.text }); } break; + + case "video_mode": + // Sent once on WS open. Toggles the video element + speaking-clip path. + applyVideoModeState(msg); + break; + + case "speaking_clip": + // Envelope preceding an MP4 binary frame with the full turn. + pendingSpeakingClipMeta = { + chunk_id: msg.chunk_id, + duration_ms: msg.duration_ms, + text: msg.text, + }; + break; + } +} + +// --- Video mode ------------------------------------------------------------ + +function applyVideoModeState(msg) { + videoModeEnabled = !!msg.enabled && !!msg.ready; + videoModeName = msg.mode || "off"; + idleClipUrl = msg.idle_clip_url || null; + refreshStage(); +} + +function refreshStage() { + if (videoModeEnabled && idleClipUrl) { + stageEl.classList.add("active"); + if (avatarVideo.src !== location.origin + idleClipUrl) { + avatarVideo.src = idleClipUrl; + avatarVideo.loop = true; + avatarVideo.muted = true; + avatarVideo.play().catch(() => {}); + } + } else { + stageEl.classList.remove("active"); + } +} + +function playSpeakingClip(arrayBuffer, meta) { + // Replace the idle loop with the speaking clip. + stopSpeakingClip(); + const blob = new Blob([arrayBuffer], { type: "video/mp4" }); + currentSpeakingClipBlobUrl = URL.createObjectURL(blob); + + avatarVideo.loop = false; + avatarVideo.muted = false; + avatarVideo.src = currentSpeakingClipBlobUrl; + + // Show the full reply text now — the MP4 plays it in one shot so there's + // no per-chunk sync to do. + if (meta && meta.text) { + appendAssistantText(meta.text); + } + isPlaying = true; + + avatarVideo.onended = () => { + isPlaying = false; + finalizeAssistantMessage(false); + // Return to idle loop. + if (idleClipUrl) { + avatarVideo.loop = true; + avatarVideo.muted = true; + avatarVideo.src = idleClipUrl; + avatarVideo.play().catch(() => {}); + } + if (currentSpeakingClipBlobUrl) { + URL.revokeObjectURL(currentSpeakingClipBlobUrl); + currentSpeakingClipBlobUrl = null; + } + }; + avatarVideo.play().catch((e) => { + console.error("speaking clip play failed:", e); + }); +} + +function stopSpeakingClip() { + if (!currentSpeakingClipBlobUrl) return; + try { + avatarVideo.pause(); + } catch (_) {} + URL.revokeObjectURL(currentSpeakingClipBlobUrl); + currentSpeakingClipBlobUrl = null; + if (idleClipUrl) { + avatarVideo.loop = true; + avatarVideo.muted = true; + avatarVideo.src = idleClipUrl; + avatarVideo.play().catch(() => {}); + } + isPlaying = false; +} + +async function uploadAvatar() { + const fileInput = document.getElementById("avatar-file"); + const status = document.getElementById("avatar-status"); + if (!fileInput.files || !fileInput.files[0]) { + status.textContent = "Pick an image first."; + return; + } + status.textContent = "Uploading and rendering idle clip (this takes a while)..."; + const fd = new FormData(); + fd.append("image", fileInput.files[0]); + try { + const resp = await fetch("/api/set-avatar", { method: "POST", body: fd }); + if (!resp.ok) throw new Error(await resp.text()); + const data = await resp.json(); + idleClipUrl = data.idle_clip_url + "?t=" + Date.now(); // cache-bust + videoModeEnabled = true; + videoModeName = data.mode || videoModeName; + refreshStage(); + status.textContent = "Avatar ready (" + data.mode + ")"; + } catch (err) { + console.error(err); + status.textContent = "Failed: " + err.message; + } +} + +async function applyVideoMode() { + const sel = document.getElementById("video-mode-select"); + const status = document.getElementById("avatar-status"); + const fd = new FormData(); + fd.append("mode", sel.value); + try { + const resp = await fetch("/api/set-video-mode", { method: "POST", body: fd }); + if (!resp.ok) throw new Error(await resp.text()); + const data = await resp.json(); + videoModeName = data.mode; + if (data.mode === "off") { + videoModeEnabled = false; + stageEl.classList.remove("active"); + } + status.textContent = "Mode: " + data.mode + (data.note ? " — " + data.note : ""); + } catch (err) { + status.textContent = "Failed: " + err.message; } } @@ -275,6 +431,7 @@ async function startMic() { if (bargeInCount >= BARGE_IN_FRAMES) { // User is speaking over the assistant - interrupt stopPlayback(); + stopSpeakingClip(); const msg = { type: "interrupt" }; if (lastDisplayedChunkId >= 0) { msg.last_chunk_id = lastDisplayedChunkId; @@ -353,3 +510,5 @@ async function applyVoice() { // Expose to HTML onclick window.toggleMic = toggleMic; window.applyVoice = applyVoice; +window.uploadAvatar = uploadAvatar; +window.applyVideoMode = applyVideoMode; diff --git a/static/index.html b/static/index.html index 83a88ca..eef437d 100644 --- a/static/index.html +++ b/static/index.html @@ -12,6 +12,17 @@ Disconnected +
+ +
+
@@ -40,6 +51,27 @@
+
+ Avatar / Video +
+ + + + + +
+
+
diff --git a/static/style.css b/static/style.css index bbb5301..369eac1 100644 --- a/static/style.css +++ b/static/style.css @@ -52,6 +52,28 @@ header h1 { color: #a78bfa; } +#stage { + display: none; /* toggled on when video mode is enabled */ + align-items: center; + justify-content: center; + padding: 16px 24px 0; + background: #0a0a0a; +} + +#stage.active { + display: flex; +} + +#avatar-video { + width: 100%; + max-width: 480px; + aspect-ratio: 16 / 9; + background: #000; + border-radius: 12px; + object-fit: cover; + box-shadow: 0 8px 24px rgba(0, 0, 0, 0.4); +} + #chat-area { flex: 1; overflow-y: auto; @@ -130,21 +152,34 @@ header h1 { 50% { box-shadow: 0 0 0 12px rgba(239, 68, 68, 0); } } -/* Voice clone panel */ -#voice-panel { +/* Voice + avatar panels */ +#voice-panel, +#avatar-panel { padding: 12px 24px; border-top: 1px solid #222; background: #0a0a0a; } -#voice-panel summary { +#voice-panel select, +#avatar-panel select { + background: #1a1a1a; + border: 1px solid #333; + border-radius: 6px; + padding: 6px 10px; + color: #e0e0e0; + font-size: 13px; +} + +#voice-panel summary, +#avatar-panel summary { cursor: pointer; font-size: 13px; color: #888; user-select: none; } -#voice-panel .panel-content { +#voice-panel .panel-content, +#avatar-panel .panel-content { margin-top: 12px; display: flex; gap: 12px; diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..c2a13d5 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,47 @@ +# Voice-chat tests + +Two tiers. + +## Unit tests — fast, GPU-free + +``` +python -m pytest tests/unit -v +``` + +These exercise pure logic: config parsing, prompt derivation, LoRA spec +parsing, frame-length fitting, library round-robin selection. They do not +touch CUDA, Wan2.2, MuseTalk, or ffmpeg. Safe to run on Windows, outside +Docker, without any models installed. + +## Component tests — slow, GPU-required, run inside Docker + +Each script in `tests/component/` exercises one subsystem end-to-end against +the real models. They are ordered to match the implementation phases: + +| Script | Phase | Tests | +|---|---|---| +| `test_01_video_skeleton.py` | 1 | VideoEngine loads, config gate respected | +| `test_02_wan22_loras.py` | 2 | Wan2.2 pipeline loads, LoRA stack applies | +| `test_03_idle_clip.py` | 3 | set_avatar → idle MP4, written to disk for eyeballing | +| `test_04_library_prebake.py` | 4 | library mode pre-bakes N base clips | +| `test_05_musetalk_lipsync.py` | 5 | MuseTalk lip-sync on library frames + ffmpeg mux | +| `test_06_reflective.py` | 6 | reflective mode: fresh Wan2.2 per reply | +| `test_07_endpoints.py` | 7 | HTTP endpoints return sane responses | +| `test_08_lora_reload.py` | 8 | /api/reload-loras swaps LoRAs live | + +Run one: + +``` +# Inside the container: +docker compose exec voice-chat python -m tests.component.test_03_idle_clip +``` + +Run all (slow, ~20+ minutes on 5090): + +``` +docker compose exec voice-chat python -m tests.component.run_all +``` + +Each component script writes its artifacts (MP4s, PNG frame dumps, logs) +to `tests/component/_out/` so you can visually inspect results. That +directory is gitignored. diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/component/__init__.py b/tests/component/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/component/_common.py b/tests/component/_common.py new file mode 100644 index 0000000..34b6122 --- /dev/null +++ b/tests/component/_common.py @@ -0,0 +1,72 @@ +"""Shared utilities for component tests. + +Component tests run inside the Docker image against real GPU models. They +write their output artefacts (MP4s, PNGs, logs) to ``_out/`` so you can +visually inspect results. +""" +from __future__ import annotations + +import logging +import os +import sys + +import numpy as np + + +OUT_DIR = os.path.join(os.path.dirname(__file__), "_out") +os.makedirs(OUT_DIR, exist_ok=True) + +# A tiny 256x256 portrait PNG lives next to the component tests so tests +# don't need a user-supplied file. If it's missing we synthesise one on +# the fly. +SAMPLE_AVATAR = os.path.join(os.path.dirname(__file__), "sample_avatar.png") + + +def get_logger(name: str) -> logging.Logger: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(name)s %(levelname)s %(message)s", + stream=sys.stdout, + ) + return logging.getLogger(name) + + +def ensure_sample_avatar() -> str: + """Guarantee a usable avatar image exists. Returns its path.""" + if os.path.isfile(SAMPLE_AVATAR): + return SAMPLE_AVATAR + # Synthesise a simple gradient PNG as a last resort (won't look like a + # person but is valid input for Wan2.2 so the pipeline doesn't fail). + try: + from PIL import Image # type: ignore[import-not-found] + except ImportError: + import imageio.v3 as iio # type: ignore[import-not-found] + arr = np.zeros((256, 256, 3), dtype=np.uint8) + for y in range(256): + arr[y, :, 0] = y + arr[y, :, 1] = 255 - y + arr[y, :, 2] = 128 + iio.imwrite(SAMPLE_AVATAR, arr) + return SAMPLE_AVATAR + + arr = np.zeros((256, 256, 3), dtype=np.uint8) + for y in range(256): + arr[y, :, 0] = y + arr[y, :, 1] = 255 - y + arr[y, :, 2] = 128 + Image.fromarray(arr).save(SAMPLE_AVATAR) + return SAMPLE_AVATAR + + +def write_bytes(name: str, data: bytes) -> str: + """Write an artefact to _out/ and return the full path.""" + path = os.path.join(OUT_DIR, name) + with open(path, "wb") as f: + f.write(data) + return path + + +def synth_tone(seconds: float, sample_rate: int = 24000, freq: float = 220.0) -> np.ndarray: + """Return a float32 sine tone usable as stand-in TTS audio.""" + t = np.arange(int(seconds * sample_rate), dtype=np.float32) / sample_rate + return (0.2 * np.sin(2 * np.pi * freq * t)).astype(np.float32) diff --git a/tests/component/run_all.py b/tests/component/run_all.py new file mode 100644 index 0000000..ee7f2fd --- /dev/null +++ b/tests/component/run_all.py @@ -0,0 +1,46 @@ +"""Run every component test in order. Stops at first failure. + + docker compose exec voice-chat python -m tests.component.run_all +""" +import importlib +import sys +import traceback + + +SCRIPTS = [ + "tests.component.test_01_video_skeleton", + "tests.component.test_02_wan22_loras", + "tests.component.test_03_idle_clip", + "tests.component.test_04_library_prebake", + "tests.component.test_05_musetalk_lipsync", + "tests.component.test_06_reflective", + "tests.component.test_07_endpoints", + "tests.component.test_08_lora_reload", +] + + +def main() -> int: + failed: list[str] = [] + for name in SCRIPTS: + print(f"\n{'=' * 70}\nRUNNING: {name}\n{'=' * 70}") + try: + mod = importlib.import_module(name) + mod.run() + except SystemExit as e: + if e.code: + print(f"FAILED: {name} (exit {e.code})") + failed.append(name) + break # hard-stop on failure + except Exception: + traceback.print_exc() + failed.append(name) + break + if failed: + print(f"\n{len(failed)} failed: {failed}") + return 1 + print("\nALL COMPONENT TESTS PASSED") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/component/test_01_video_skeleton.py b/tests/component/test_01_video_skeleton.py new file mode 100644 index 0000000..49f7d92 --- /dev/null +++ b/tests/component/test_01_video_skeleton.py @@ -0,0 +1,69 @@ +"""Phase 1 component test: VideoEngine skeleton + config gate. + +Verifies: +- ``ModelManager`` can be imported and constructed. +- When ``config.video.enabled=false``, ``_load_video`` skips and leaves + ``video_engine=None`` (existing voice path unaffected). +- When ``config.video.enabled=true``, a ``VideoEngine`` instance is created + and ``is_ready()`` returns False (no models loaded yet). + +Does NOT load Wan2.2 or MuseTalk — this test is safe to run on any machine +with the python deps installed (no GPU needed). + +Run inside Docker: + docker compose exec voice-chat python -m tests.component.test_01_video_skeleton +""" +from __future__ import annotations + +import sys + +from server.models import ModelManager +from server.video import VideoConfig, VideoEngine + +from tests.component._common import get_logger + +log = get_logger("test_01") + + +def run(): + # --- disabled path --- + log.info("[case 1] config.video.enabled=False → engine skipped") + mgr = ModelManager() + # Monkey-patch the config module to simulate disabled + import server.config as cfgmod + original = cfgmod.config + cfgmod.config = {"video": {"enabled": False}, **{k: v for k, v in original.items() if k != "video"}} + try: + mgr._load_video() + assert mgr.video_engine is None, "video_engine should be None when disabled" + log.info(" PASS: video_engine is None") + finally: + cfgmod.config = original + + # --- enabled path (no models loaded) --- + log.info("[case 2] config.video.enabled=True → engine created, not ready") + mgr2 = ModelManager() + cfgmod.config = { + **original, + "video": {"enabled": True, "mode": "reflective", "loras": []}, + } + try: + mgr2._load_video() + assert mgr2.video_engine is not None, "video_engine should be created" + assert isinstance(mgr2.video_engine, VideoEngine) + assert mgr2.video_engine.is_ready() is False + log.info(" PASS: engine=%s, ready=%s", + type(mgr2.video_engine).__name__, mgr2.video_engine.is_ready()) + finally: + cfgmod.config = original + + log.info("ALL PASSED") + + +if __name__ == "__main__": + try: + run() + sys.exit(0) + except AssertionError as e: + log.error("FAILED: %s", e) + sys.exit(1) diff --git a/tests/component/test_02_wan22_loras.py b/tests/component/test_02_wan22_loras.py new file mode 100644 index 0000000..f41e857 --- /dev/null +++ b/tests/component/test_02_wan22_loras.py @@ -0,0 +1,106 @@ +"""Phase 2 component test: Wan2.2-Lightning fp8 pipeline + LoRA stacking. + +Verifies: +- ``Wan22Pipeline`` loads successfully against the fp8 distill path + (exercises the real LightX2V set_config → init_runner flow). +- ``load_loras`` / ``unload_loras`` survive with the two user LoRAs at + ``/cache/loras/wan22-[HL]-e8.safetensors``. + +Requires GPU and a first-run download of both HF repos (base support files +~12 GB, fp8 DIT ~30 GB). If LightX2V isn't installed the test is skipped. + +Run: + docker compose exec voice-chat python -m tests.component.test_02_wan22_loras +""" +from __future__ import annotations + +import os +import sys + +from tests.component._common import get_logger + +log = get_logger("test_02") + +CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json" +LORA_HIGH = "/cache/loras/wan22-H-e8.safetensors" +LORA_LOW = "/cache/loras/wan22-L-e8.safetensors" + + +def run(): + try: + from server.video_models.wan22 import Wan22Pipeline + except ImportError as e: + log.error("Wan22Pipeline import failed: %s", e) + log.warning("SKIP: phase 2 deps not installed") + sys.exit(0) + + from server.video import LoRASpec + + log.info("[case 1] Instantiate Wan22Pipeline " + "(first run downloads ~42 GB total)...") + try: + pipe = Wan22Pipeline( + base_repo="Wan-AI/Wan2.2-I2V-A14B", + fp8_repo="lightx2v/Wan2.2-Distill-Models", + config_json=CONFIG_JSON, + model_cls="wan2.2_moe_distill", + resolution=480, + fps=16, + ) + except Exception as e: + log.error("FAIL: Wan22Pipeline construction raised: %s", e) + log.error("Check: LightX2V install, HF cache at /cache/huggingface, " + "VRAM headroom, and that %s exists inside the container.", + CONFIG_JSON) + sys.exit(2) + log.info(" PASS: pipeline constructed") + + # --- LoRAs --- + log.info("[case 2] load_loras with empty list → no-op") + pipe.load_loras([]) + log.info(" PASS") + + if not (os.path.isfile(LORA_HIGH) and os.path.isfile(LORA_LOW)): + log.warning("SKIP: expected LoRA files not found at %s / %s", + LORA_HIGH, LORA_LOW) + log.info("ALL PASSED (partial — LoRA cases skipped)") + return + + log.info("[case 3] load_loras with the two MoE distill LoRAs") + specs = [ + LoRASpec( + path=LORA_HIGH, + weight=1.0, + target="high_noise", + name="wan22-H-e8", + ), + LoRASpec( + path=LORA_LOW, + weight=1.0, + target="low_noise", + name="wan22-L-e8", + ), + ] + try: + pipe.load_loras(specs) + except Exception as e: + log.error("FAIL: load_loras raised: %s", e) + log.error("Check: switch_lora support for wan2.2_moe_distill in the " + "installed LightX2V build. If it errors there, pre-declare " + "LoRAs in the config_json 'lora_configs' field instead.") + sys.exit(3) + log.info(" PASS: LoRAs applied") + + log.info("[case 4] unload_loras") + try: + pipe.unload_loras() + except Exception as e: + log.error("FAIL: unload_loras raised: %s", e) + sys.exit(4) + log.info(" PASS") + + log.info("ALL PASSED") + + +if __name__ == "__main__": + run() diff --git a/tests/component/test_03_idle_clip.py b/tests/component/test_03_idle_clip.py new file mode 100644 index 0000000..ae48020 --- /dev/null +++ b/tests/component/test_03_idle_clip.py @@ -0,0 +1,66 @@ +"""Phase 3 component test: avatar upload → idle clip generation. + +Verifies: +- ``VideoEngine.load_models()`` + ``set_avatar(image)`` produces a non-empty + idle MP4 blob. +- The blob decodes as a valid MP4 (ftyp header). + +Writes the idle clip to ``tests/component/_out/phase3_idle.mp4`` so you can +inspect it visually. + +Run: + docker compose exec voice-chat python -m tests.component.test_03_idle_clip +""" +from __future__ import annotations + +import sys + +from server.video import VideoConfig, VideoEngine +from tests.component._common import ensure_sample_avatar, get_logger, write_bytes + +log = get_logger("test_03") + + +def run(): + avatar_path = ensure_sample_avatar() + log.info("Using avatar: %s", avatar_path) + + cfg = VideoConfig.from_dict( + { + "enabled": True, + "mode": "reflective", # reflective skips the library prebake + "resolution": 480, + "fps": 16, + "library": {"base_clip_count": 0, "base_clip_seconds": 3}, + } + ) + engine = VideoEngine(cfg) + + log.info("Loading models (Wan2.2 + MuseTalk)...") + try: + engine.load_models() + except Exception as e: + log.error("FAIL: load_models raised: %s", e) + sys.exit(2) + log.info("Models loaded.") + + log.info("Generating idle clip for avatar...") + try: + engine.set_avatar(avatar_path) + except Exception as e: + log.error("FAIL: set_avatar raised: %s", e) + sys.exit(3) + + idle = engine.get_idle_clip() + assert idle is not None and len(idle) > 0, "idle clip is empty" + assert idle[4:8] == b"ftyp", "idle clip is not a valid MP4" + + out_path = write_bytes("phase3_idle.mp4", idle) + log.info("PASS: idle clip written to %s (%d bytes)", out_path, len(idle)) + + assert engine.is_ready() is True + log.info(" engine.is_ready() = True (avatar + models present)") + + +if __name__ == "__main__": + run() diff --git a/tests/component/test_04_library_prebake.py b/tests/component/test_04_library_prebake.py new file mode 100644 index 0000000..f726e6d --- /dev/null +++ b/tests/component/test_04_library_prebake.py @@ -0,0 +1,55 @@ +"""Phase 4 component test: library mode pre-bake of speaking-base clips. + +Verifies: +- ``set_avatar`` under ``mode=library`` populates ``speaking_base_frames`` + with ``library_base_clip_count`` entries. +- Each cached entry has shape ``[T, H, W, 3]`` uint8. + +Run: + docker compose exec voice-chat python -m tests.component.test_04_library_prebake +""" +from __future__ import annotations + +import sys + +import numpy as np + +from server.video import VideoConfig, VideoEngine +from tests.component._common import ensure_sample_avatar, get_logger + +log = get_logger("test_04") + + +def run(): + avatar_path = ensure_sample_avatar() + cfg = VideoConfig.from_dict( + { + "enabled": True, + "mode": "library", + "resolution": 480, + "fps": 16, + "library": {"base_clip_count": 2, "base_clip_seconds": 3}, + } + ) + engine = VideoEngine(cfg) + + log.info("Loading models...") + engine.load_models() + + log.info("Pre-baking 2 library clips...") + engine.set_avatar(avatar_path) + + assert len(engine.speaking_base_frames) == 2, \ + f"expected 2 base clips, got {len(engine.speaking_base_frames)}" + for i, frames in enumerate(engine.speaking_base_frames): + assert isinstance(frames, np.ndarray) + assert frames.ndim == 4 and frames.shape[-1] == 3 + assert frames.dtype == np.uint8 + log.info(" clip %d: shape=%s", i, frames.shape) + + assert engine.get_idle_clip() is not None + log.info("PASS: library pre-bake complete") + + +if __name__ == "__main__": + run() diff --git a/tests/component/test_05_musetalk_lipsync.py b/tests/component/test_05_musetalk_lipsync.py new file mode 100644 index 0000000..bcd4a69 --- /dev/null +++ b/tests/component/test_05_musetalk_lipsync.py @@ -0,0 +1,57 @@ +"""Phase 5 component test: MuseTalk lip-sync + ffmpeg mux. + +Verifies the full library-mode per-turn path: +- Pre-bake a library clip. +- Generate a stand-in TTS waveform (sine tone). +- Call ``VideoEngine.generate_speaking_clip`` and get a valid MP4 back. + +Writes the resulting clip to ``tests/component/_out/phase5_speaking.mp4``. + +Run: + docker compose exec voice-chat python -m tests.component.test_05_musetalk_lipsync +""" +from __future__ import annotations + +import sys + +from server.video import VideoConfig, VideoEngine +from tests.component._common import ( + ensure_sample_avatar, + get_logger, + synth_tone, + write_bytes, +) + +log = get_logger("test_05") + + +def run(): + avatar_path = ensure_sample_avatar() + cfg = VideoConfig.from_dict( + { + "enabled": True, + "mode": "library", + "resolution": 480, + "fps": 16, + "library": {"base_clip_count": 1, "base_clip_seconds": 4}, + } + ) + engine = VideoEngine(cfg) + engine.load_models() + engine.set_avatar(avatar_path) + + audio = synth_tone(seconds=3.0, sample_rate=24000, freq=220.0) + log.info("Generating library-mode speaking clip (3s audio)...") + mp4 = engine.generate_speaking_clip( + audio_f32=audio, + sample_rate=24000, + reply_text="Hello, this is a lip-sync test.", + ) + assert isinstance(mp4, bytes) and len(mp4) > 0 + assert mp4[4:8] == b"ftyp" + out = write_bytes("phase5_speaking.mp4", mp4) + log.info("PASS: speaking clip written to %s (%d bytes)", out, len(mp4)) + + +if __name__ == "__main__": + run() diff --git a/tests/component/test_06_reflective.py b/tests/component/test_06_reflective.py new file mode 100644 index 0000000..0105326 --- /dev/null +++ b/tests/component/test_06_reflective.py @@ -0,0 +1,69 @@ +"""Phase 6 component test: reflective mode (fresh Wan2.2 clip per turn). + +Verifies that with ``mode=reflective``, ``generate_speaking_clip`` runs +the Wan2.2 image-to-video pipeline once per call (so the base frames +differ from turn to turn) and the prompt is derived from the reply text. + +Run: + docker compose exec voice-chat python -m tests.component.test_06_reflective +""" +from __future__ import annotations + +import numpy as np + +from server.video import VideoConfig, VideoEngine +from tests.component._common import ( + ensure_sample_avatar, + get_logger, + synth_tone, + write_bytes, +) + +log = get_logger("test_06") + + +def run(): + avatar_path = ensure_sample_avatar() + cfg = VideoConfig.from_dict( + { + "enabled": True, + "mode": "reflective", + "resolution": 480, + "fps": 16, + "reflective": {"clip_seconds": 3}, + } + ) + engine = VideoEngine(cfg) + engine.load_models() + engine.set_avatar(avatar_path) + + # Verify prompt derivation includes the reply hint + prompt = engine._derive_prompt( + "The assistant walks along a sunny beach watching seagulls." + ) + log.info("derived prompt: %s", prompt) + assert "beach" in prompt, "reply_hint did not survive template interpolation" + + audio = synth_tone(seconds=3.0) + log.info("Generating reflective speaking clip #1...") + mp4_a = engine.generate_speaking_clip( + audio, 24000, "The assistant walks along a sunny beach watching seagulls." + ) + write_bytes("phase6_reflective_beach.mp4", mp4_a) + + log.info("Generating reflective speaking clip #2...") + mp4_b = engine.generate_speaking_clip( + audio, 24000, "Now the character stands in a snow-covered forest at dusk." + ) + write_bytes("phase6_reflective_snow.mp4", mp4_b) + + # Not a strict assertion (same prompt could yield identical bytes if seeded), + # but with different prompts and random seeds the blobs should differ. + if mp4_a != mp4_b: + log.info("PASS: reflective clips differ as expected") + else: + log.warning("clips are byte-identical — check that seeds are random") + + +if __name__ == "__main__": + run() diff --git a/tests/component/test_07_endpoints.py b/tests/component/test_07_endpoints.py new file mode 100644 index 0000000..d2eacd7 --- /dev/null +++ b/tests/component/test_07_endpoints.py @@ -0,0 +1,114 @@ +"""Phase 7 component test: HTTP endpoints (/api/set-avatar, /api/idle-clip, +/api/set-video-mode, /api/reload-loras, WebSocket handshake video_mode msg). + +Uses FastAPI's ``TestClient`` so we don't need a running uvicorn server. +Stubs the model manager to avoid loading Wan2.2 — we only care that the +HTTP surface is plumbed correctly. + +Run: + docker compose exec voice-chat python -m tests.component.test_07_endpoints +""" +from __future__ import annotations + +import io +import json +import sys + +from tests.component._common import get_logger + +log = get_logger("test_07") + + +def _stub_video_engine(): + class StubCfg: + mode = "reflective" + class StubEngine: + cfg = StubCfg() + avatar_path = None + def __init__(self): self.idle = b"FAKE_MP4" + def is_ready(self): return bool(self.avatar_path) + def get_idle_clip(self): return self.idle + def set_avatar(self, path): self.avatar_path = path + def load_loras(self, specs): self._last_loras = specs + return StubEngine() + + +def run(): + from fastapi.testclient import TestClient + import server.main as main_mod + + # Inject a stub engine so we never touch Wan2.2. + main_mod.model_mgr.video_engine = _stub_video_engine() + + # Bypass the heavy lifespan (model loading) so TestClient starts fast. + main_mod.app.router.lifespan_context = None # type: ignore[attr-defined] + + client = TestClient(main_mod.app) + + # --- set-avatar --- + log.info("[case 1] POST /api/set-avatar") + fake_png = b"\x89PNG\r\n\x1a\n" + b"\x00" * 64 # minimal PNG header + resp = client.post( + "/api/set-avatar", + files={"image": ("avatar.png", io.BytesIO(fake_png), "image/png")}, + ) + assert resp.status_code == 200, f"got {resp.status_code}: {resp.text}" + data = resp.json() + assert data["status"] == "ok" + assert data["idle_clip_url"] == "/api/idle-clip" + log.info(" PASS: %s", data) + + # --- idle-clip --- + log.info("[case 2] GET /api/idle-clip") + resp = client.get("/api/idle-clip") + assert resp.status_code == 200 + assert resp.content == b"FAKE_MP4" + assert resp.headers["content-type"] == "video/mp4" + log.info(" PASS") + + # --- set-video-mode --- + log.info("[case 3] POST /api/set-video-mode") + for mode in ("off", "library", "reflective"): + resp = client.post("/api/set-video-mode", data={"mode": mode}) + assert resp.status_code == 200 + assert resp.json()["mode"] == mode + resp = client.post("/api/set-video-mode", data={"mode": "bogus"}) + assert resp.status_code == 400 + log.info(" PASS") + + # --- reload-loras --- + log.info("[case 4] POST /api/reload-loras") + body = { + "loras": [ + {"path": "/cache/loras/a.safetensors", "weight": 0.8, + "target": "high_noise", "name": "test-a"}, + {"path": "/cache/loras/b.safetensors", "weight": 0.4, + "target": "low_noise"}, + ] + } + resp = client.post("/api/reload-loras", json=body) + assert resp.status_code == 200, resp.text + data = resp.json() + assert data["lora_count"] == 2 + log.info(" PASS: %s", data) + + # --- WebSocket video_mode handshake --- + log.info("[case 5] WebSocket /ws/chat → video_mode announcement") + with client.websocket_connect("/ws/chat") as websocket: + msgs = [] + for _ in range(5): + try: + msg = websocket.receive_json() + msgs.append(msg) + if msg.get("type") == "video_mode": + break + except Exception: + break + assert any(m.get("type") == "video_mode" for m in msgs), msgs + log.info(" PASS") + + log.info("ALL PASSED") + + +if __name__ == "__main__": + run() diff --git a/tests/component/test_08_lora_reload.py b/tests/component/test_08_lora_reload.py new file mode 100644 index 0000000..9cf252c --- /dev/null +++ b/tests/component/test_08_lora_reload.py @@ -0,0 +1,60 @@ +"""Phase 8 component test: /api/reload-loras hot-swap. + +Verifies that ``VideoEngine.load_loras`` can be called again after startup +and the idle clip is regenerated to reflect the new style. + +This test is the 'real model' version of test_07's reload endpoint stub. + +Run: + docker compose exec voice-chat python -m tests.component.test_08_lora_reload +""" +from __future__ import annotations + +import hashlib + +from server.video import LoRASpec, VideoConfig, VideoEngine +from tests.component._common import ensure_sample_avatar, get_logger, write_bytes + +log = get_logger("test_08") + + +def run(): + avatar_path = ensure_sample_avatar() + cfg = VideoConfig.from_dict({"enabled": True, "mode": "reflective"}) + engine = VideoEngine(cfg) + engine.load_models() + + # Initial state: no LoRAs + engine.set_avatar(avatar_path) + idle_a = engine.get_idle_clip() + assert idle_a is not None + hash_a = hashlib.sha256(idle_a).hexdigest() + write_bytes("phase8_idle_noloras.mp4", idle_a) + log.info("idle (no LoRAs) sha256=%s", hash_a[:16]) + + # Hot-reload with a distill LoRA + specs = [ + LoRASpec( + path="lightx2v/Wan2.2-Distill-Loras:" + "wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step.safetensors", + weight=1.0, + target="high_noise", + name="distill-hi", + ), + ] + engine.load_loras(specs) + engine.set_avatar(avatar_path) + idle_b = engine.get_idle_clip() + assert idle_b is not None + hash_b = hashlib.sha256(idle_b).hexdigest() + write_bytes("phase8_idle_withlora.mp4", idle_b) + log.info("idle (with LoRA) sha256=%s", hash_b[:16]) + + if hash_a != hash_b: + log.info("PASS: idle clip changed after LoRA reload") + else: + log.warning("clips identical — LoRA may not be applied; eyeball _out/*.mp4") + + +if __name__ == "__main__": + run() diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_musetalk_fit_frames.py b/tests/unit/test_musetalk_fit_frames.py new file mode 100644 index 0000000..5f27c59 --- /dev/null +++ b/tests/unit/test_musetalk_fit_frames.py @@ -0,0 +1,65 @@ +"""Unit tests for the frame-length fitting helper in server.video_models.musetalk. + +Pure-python: does not import MuseTalk itself. +""" +import numpy as np + +from server.video_models.musetalk import _fit_frames_to_length, _ensure_uint8_rgb + + +def _make_frames(t, h=2, w=2): + return np.arange(t * h * w * 3, dtype=np.uint8).reshape(t, h, w, 3) + + +def test_fit_frames_trim(): + frames = _make_frames(10) + out = _fit_frames_to_length(frames, 4) + assert out.shape == (4, 2, 2, 3) + np.testing.assert_array_equal(out, frames[:4]) + + +def test_fit_frames_passthrough_when_equal(): + frames = _make_frames(5) + out = _fit_frames_to_length(frames, 5) + assert out is frames or np.array_equal(out, frames) + + +def test_fit_frames_extends_with_pingpong(): + frames = _make_frames(3) + out = _fit_frames_to_length(frames, 8) + assert out.shape == (8, 2, 2, 3) + # First 3 frames match the original + np.testing.assert_array_equal(out[:3], frames) + # Next 3 are the reverse (ping-pong) + np.testing.assert_array_equal(out[3:6], frames[::-1]) + # Then forward again + np.testing.assert_array_equal(out[6:8], frames[:2]) + + +def test_fit_frames_zero_target_returns_original(): + frames = _make_frames(3) + out = _fit_frames_to_length(frames, 0) + np.testing.assert_array_equal(out, frames) + + +def test_ensure_uint8_rgb_from_float(): + arr = np.ones((5, 2, 2, 3), dtype=np.float32) * 0.5 + out = _ensure_uint8_rgb(arr) + assert out.dtype == np.uint8 + assert out.shape == (5, 2, 2, 3) + assert out[0, 0, 0, 0] == 127 + + +def test_ensure_uint8_rgb_promotes_3d_to_4d(): + arr = np.zeros((2, 2, 3), dtype=np.uint8) + out = _ensure_uint8_rgb(arr) + assert out.shape == (1, 2, 2, 3) + + +def test_ensure_uint8_rgb_clips_float_out_of_range(): + arr = np.ones((1, 1, 1, 3), dtype=np.float32) * 2.0 # 2.0 → clipped to 255 + out = _ensure_uint8_rgb(arr) + assert out[0, 0, 0, 0] == 255 + arr2 = np.ones((1, 1, 1, 3), dtype=np.float32) * -1.0 + out2 = _ensure_uint8_rgb(arr2) + assert out2[0, 0, 0, 0] == 0 diff --git a/tests/unit/test_muxer_ffmpeg.py b/tests/unit/test_muxer_ffmpeg.py new file mode 100644 index 0000000..2d3f707 --- /dev/null +++ b/tests/unit/test_muxer_ffmpeg.py @@ -0,0 +1,67 @@ +"""Unit tests for the ffmpeg muxer. + +Requires ``ffmpeg`` on PATH. On Windows, if ffmpeg is not installed these +tests are skipped (they will run inside the Docker image where ffmpeg is +always present). +""" +import os +import shutil +import struct + +import numpy as np +import pytest + +from server.video_models.muxer import frames_and_audio_to_mp4, frames_to_mp4_loop + + +pytestmark = pytest.mark.skipif( + shutil.which("ffmpeg") is None, + reason="ffmpeg not installed locally; run these inside Docker", +) + + +def _rgb_frames(t, h=64, w=64): + """Coloured checker frames so the encoder has real content.""" + frames = np.zeros((t, h, w, 3), dtype=np.uint8) + for i in range(t): + frames[i, :, :, 0] = (i * 20) % 255 + frames[i, :h // 2, :, 1] = 255 + frames[i, :, :w // 2, 2] = 255 + return frames + + +def test_frames_to_mp4_loop_produces_mp4_bytes(): + frames = _rgb_frames(8) + data = frames_to_mp4_loop(frames, fps=16) + assert isinstance(data, bytes) + assert len(data) > 0 + # MP4 files start with an ftyp box: 4 bytes size + 'ftyp' + assert data[4:8] == b"ftyp" + + +def test_frames_and_audio_to_mp4_produces_mp4_bytes(): + frames = _rgb_frames(16) + # 1s silent audio at 24kHz + audio = np.zeros(24000, dtype=np.float32) + data = frames_and_audio_to_mp4(frames, audio, sample_rate=24000, fps=16) + assert isinstance(data, bytes) + assert len(data) > 0 + assert data[4:8] == b"ftyp" + + +def test_frames_to_mp4_loop_rejects_empty(): + with pytest.raises(ValueError): + frames_to_mp4_loop(np.empty((0, 64, 64, 3), dtype=np.uint8), fps=16) + + +def test_frames_and_audio_to_mp4_rejects_empty_audio(): + frames = _rgb_frames(4) + with pytest.raises(ValueError): + frames_and_audio_to_mp4( + frames, np.empty(0, dtype=np.float32), sample_rate=24000, fps=16 + ) + + +def test_frames_to_mp4_loop_rejects_wrong_shape(): + with pytest.raises(ValueError): + frames_to_mp4_loop(np.zeros((4, 64, 64), dtype=np.uint8), fps=16) diff --git a/tests/unit/test_pipeline_video_branch.py b/tests/unit/test_pipeline_video_branch.py new file mode 100644 index 0000000..5bbafee --- /dev/null +++ b/tests/unit/test_pipeline_video_branch.py @@ -0,0 +1,144 @@ +"""Unit test for the video-mode branch in ConversationSession. + +Stubs every model involved (ASR, LLM, TTS, VideoEngine) so we can verify: +1. When video_engine is not ready, the existing PCM streaming path runs. +2. When video_engine IS ready, the per-chunk PCM sends are skipped and a + single ``speaking_clip`` JSON + MP4 binary is sent instead. + +Pure asyncio; no CUDA, no real models. +""" +from __future__ import annotations + +import asyncio +import types +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from server.pipeline import ConversationSession + + +class _FakeVAD: + is_speaking = False + def process_chunk(self, _): return None + + +class _FakeASR: + def __init__(self, text="hello"): + self.text = text + def transcribe(self, _): return self.text + + +class _FakeLLM: + def __init__(self, response="Hi there."): + self.response = response + def generate(self, *_a, **_k): + return self.response, None + def trim_cache(self, state, _): return state + + +class _FakeTTSIterable: + """Drop-in replacement for Kokoro's pipeline(..) generator.""" + def __init__(self, chunks): + self._chunks = chunks + def __call__(self, segment, voice=None): + for i, audio in enumerate(self._chunks): + yield f"w{i}", None, audio + + +class _FakeTTSEngine: + def __init__(self, chunks): + self.pipeline = _FakeTTSIterable(chunks) + self.voice = "v" + self.sample_rate = 24000 + + +class _FakeVideoEngineReady: + class _Cfg: + mode = "reflective" + cfg = _Cfg() + def __init__(self): + self.called_with = None + def is_ready(self): return True + def generate_speaking_clip(self, audio, sr, reply_text): + self.called_with = {"len": len(audio), "sr": sr, "reply": reply_text} + return b"FAKE_MP4_BYTES" + + +class _FakeModelsBase: + def __init__(self, tts_chunks): + self.asr_engine = _FakeASR() + self.llm_engine = _FakeLLM() + self.tts_engine = _FakeTTSEngine(tts_chunks) + def create_vad(self): return _FakeVAD() + + +class _FakeModelsStreaming(_FakeModelsBase): + video_engine = None + + +class _FakeModelsVideo(_FakeModelsBase): + def __init__(self, tts_chunks): + super().__init__(tts_chunks) + self.video_engine = _FakeVideoEngineReady() + + +@pytest.mark.asyncio +async def test_streaming_path_when_video_engine_absent(): + json_sent: list = [] + bytes_sent: list = [] + + async def send_json(d): json_sent.append(d) + async def send_bytes(b): bytes_sent.append(b) + + chunks = [ + np.ones(240, dtype=np.float32), + np.ones(480, dtype=np.float32), + ] + models = _FakeModelsStreaming(tts_chunks=chunks) + session = ConversationSession(models, send_json, send_bytes) + await session._process_utterance(np.zeros(16000, dtype=np.float32)) + + # PCM bytes were sent (one per TTS chunk). + assert len(bytes_sent) == 2 + # Per-chunk response_text messages were sent (not video's one-shot). + text_msgs = [m for m in json_sent if m.get("type") == "response_text"] + assert any(not m.get("final") for m in text_msgs) + # No speaking_clip envelope + assert not any(m.get("type") == "speaking_clip" for m in json_sent) + + +@pytest.mark.asyncio +async def test_video_path_when_engine_ready(): + json_sent: list = [] + bytes_sent: list = [] + + async def send_json(d): json_sent.append(d) + async def send_bytes(b): bytes_sent.append(b) + + chunks = [ + np.full(480, 0.5, dtype=np.float32), + np.full(480, 0.25, dtype=np.float32), + ] + models = _FakeModelsVideo(tts_chunks=chunks) + session = ConversationSession(models, send_json, send_bytes) + await session._process_utterance(np.zeros(16000, dtype=np.float32)) + + # MP4 blob was sent once. + assert bytes_sent == [b"FAKE_MP4_BYTES"] + # speaking_clip envelope was sent exactly once. + envelopes = [m for m in json_sent if m.get("type") == "speaking_clip"] + assert len(envelopes) == 1 + assert envelopes[0]["size_bytes"] == len(b"FAKE_MP4_BYTES") + assert envelopes[0]["text"] == "Hi there." + + # The video engine received the concatenated audio. + ve = models.video_engine + assert ve.called_with is not None + assert ve.called_with["len"] == 960 # 480 + 480 + assert ve.called_with["reply"] == "Hi there." + + # No per-chunk PCM bytes were streamed (video path suppresses them). + # Only the MP4 blob is in bytes_sent. + assert len(bytes_sent) == 1 diff --git a/tests/unit/test_video_config.py b/tests/unit/test_video_config.py new file mode 100644 index 0000000..54e5333 --- /dev/null +++ b/tests/unit/test_video_config.py @@ -0,0 +1,119 @@ +"""Unit tests for VideoConfig parsing and LoRASpec validation. + +Pure-python, no model imports, no CUDA, no ffmpeg. Safe for Windows CI. +""" +import pytest + +from server.video import VideoConfig, LoRASpec + + +def test_defaults_when_raw_is_empty(): + cfg = VideoConfig.from_dict({}) + assert cfg.enabled is False + assert cfg.backend == "lightx2v" + assert cfg.mode == "reflective" + assert cfg.resolution == 480 + assert cfg.fps == 16 + assert cfg.library_base_clip_count == 4 + assert cfg.reflective_prompt_reply_words == 18 + assert cfg.loras == [] + + +def test_defaults_when_raw_is_none(): + cfg = VideoConfig.from_dict(None) # type: ignore[arg-type] + assert cfg.enabled is False + + +def test_library_section_override(): + cfg = VideoConfig.from_dict( + {"enabled": True, "mode": "library", "library": {"base_clip_count": 7, "base_clip_seconds": 3}} + ) + assert cfg.enabled is True + assert cfg.mode == "library" + assert cfg.library_base_clip_count == 7 + assert cfg.library_base_clip_seconds == 3 + + +def test_reflective_section_override(): + cfg = VideoConfig.from_dict( + { + "reflective": { + "clip_seconds": 9, + "clip_prompt_template": "my template: {reply_hint}", + "prompt_reply_words": 5, + } + } + ) + assert cfg.reflective_clip_seconds == 9 + assert cfg.reflective_prompt_template == "my template: {reply_hint}" + assert cfg.reflective_prompt_reply_words == 5 + + +def test_lora_parse_minimal(): + cfg = VideoConfig.from_dict({"loras": [{"path": "/tmp/a.safetensors"}]}) + assert len(cfg.loras) == 1 + lora = cfg.loras[0] + assert lora.path == "/tmp/a.safetensors" + assert lora.weight == 1.0 + assert lora.target == "both" + assert lora.name is None + + +def test_lora_parse_full(): + cfg = VideoConfig.from_dict( + { + "loras": [ + { + "path": "/tmp/hi.safetensors", + "weight": 0.7, + "target": "high_noise", + "name": "hi-noise-style", + }, + { + "path": "/tmp/lo.safetensors", + "weight": 0.4, + "target": "low_noise", + "name": "lo-noise-style", + }, + ] + } + ) + assert len(cfg.loras) == 2 + assert cfg.loras[0].target == "high_noise" + assert cfg.loras[0].name == "hi-noise-style" + assert cfg.loras[1].target == "low_noise" + assert cfg.loras[1].weight == 0.4 + + +def test_lora_invalid_target_falls_back_to_both(): + cfg = VideoConfig.from_dict( + {"loras": [{"path": "/tmp/x.safetensors", "target": "bogus"}]} + ) + assert cfg.loras[0].target == "both" + + +def test_lora_entries_without_path_are_dropped(): + cfg = VideoConfig.from_dict( + {"loras": [{"weight": 0.5}, {"path": "/tmp/ok.safetensors"}, None]} + ) + assert len(cfg.loras) == 1 + assert cfg.loras[0].path == "/tmp/ok.safetensors" + + +def test_models_section_override(): + cfg = VideoConfig.from_dict( + { + "models": { + "wan22_base_repo": "/local/weights/wan22", + "wan22_fp8_repo": "/local/weights/wan22-fp8", + "wan22_config_json": "/local/cfg/fp8.json", + "wan22_model_cls": "wan2.2_moe", + "musetalk_path": "/local/weights/musetalk", + } + } + ) + assert cfg.wan22_base_repo == "/local/weights/wan22" + assert cfg.wan22_fp8_repo == "/local/weights/wan22-fp8" + assert cfg.wan22_config_json == "/local/cfg/fp8.json" + assert cfg.wan22_model_cls == "wan2.2_moe" + assert cfg.musetalk_model_path == "/local/weights/musetalk" diff --git a/tests/unit/test_video_engine_logic.py b/tests/unit/test_video_engine_logic.py new file mode 100644 index 0000000..a32d8c5 --- /dev/null +++ b/tests/unit/test_video_engine_logic.py @@ -0,0 +1,106 @@ +"""Unit tests for pure-python logic inside VideoEngine. + +No models are loaded: we instantiate ``VideoEngine`` and hand-stub its +``_wan22`` / ``_musetalk`` attributes to test prompt derivation, library +round-robin, and frame fitting. +""" +import numpy as np +import pytest + +from server.video import VideoConfig, VideoEngine + + +@pytest.fixture +def engine(): + cfg = VideoConfig.from_dict( + { + "enabled": True, + "mode": "reflective", + "fps": 16, + "reflective": { + "clip_prompt_template": "A: {reply_hint} B", + "prompt_reply_words": 5, + }, + } + ) + return VideoEngine(cfg) + + +def test_derive_prompt_truncates_to_word_limit(engine): + out = engine._derive_prompt("one two three four five six seven eight") + assert out == "A: one two three four five B" + + +def test_derive_prompt_handles_empty_reply(engine): + out = engine._derive_prompt("") + assert out == "A: calm and friendly B" + out2 = engine._derive_prompt(None) # type: ignore[arg-type] + assert out2 == "A: calm and friendly B" + + +def test_derive_prompt_strips_and_passes_through(engine): + out = engine._derive_prompt(" hello world ") + assert out == "A: hello world B" + + +def test_is_ready_false_without_models(engine): + # Models haven't been loaded — is_ready must be False so the pipeline + # falls back to the PCM streaming path. + assert engine.is_ready() is False + + +def test_pick_library_frames_round_robin(engine): + engine.cfg.mode = "library" + engine.cfg.fps = 2 + # Two base clips, 4 frames each. + a = np.tile(np.array([[[[0, 0, 0]]]], dtype=np.uint8), (4, 1, 1, 1)) + b = np.tile(np.array([[[[255, 255, 255]]]], dtype=np.uint8), (4, 1, 1, 1)) + engine.speaking_base_frames = [a, b] + # 2s of audio at 16kHz → 4 frames at fps=2 + audio = np.zeros(16000 * 2, dtype=np.float32) + + f1 = engine._pick_library_frames(audio, 16000) + f2 = engine._pick_library_frames(audio, 16000) + f3 = engine._pick_library_frames(audio, 16000) + assert f1.shape == (4, 1, 1, 3) + assert f1[0, 0, 0, 0] == 0 # first pick = clip A + assert f2[0, 0, 0, 0] == 255 # second pick = clip B + assert f3[0, 0, 0, 0] == 0 # wraps back to A + + +def test_pick_library_frames_trims_to_audio_duration(engine): + engine.cfg.mode = "library" + engine.cfg.fps = 4 + frames = np.zeros((20, 1, 1, 3), dtype=np.uint8) + engine.speaking_base_frames = [frames] + # 1s audio → 4 frames + audio = np.zeros(16000, dtype=np.float32) + out = engine._pick_library_frames(audio, 16000) + assert out.shape == (4, 1, 1, 3) + + +def test_pick_library_frames_loops_for_long_audio(engine): + engine.cfg.mode = "library" + engine.cfg.fps = 4 + frames = np.zeros((4, 1, 1, 3), dtype=np.uint8) + engine.speaking_base_frames = [frames] + # 3s audio → 12 frames, base has only 4 + audio = np.zeros(16000 * 3, dtype=np.float32) + out = engine._pick_library_frames(audio, 16000) + assert out.shape == (12, 1, 1, 3) + + +def test_pick_library_frames_raises_when_empty(engine): + engine.cfg.mode = "library" + engine.speaking_base_frames = [] + with pytest.raises(RuntimeError, match="no pre-baked base clips"): + engine._pick_library_frames(np.zeros(100, dtype=np.float32), 16000) + + +def test_generate_speaking_clip_raises_when_not_ready(engine): + with pytest.raises(RuntimeError, match="not ready"): + engine.generate_speaking_clip( + audio_f32=np.zeros(100, dtype=np.float32), + sample_rate=16000, + reply_text="hi", + )