Compare commits

...

6 Commits

Author SHA1 Message Date
bhetherman 44a10667c2 Enhance video handling and performance optimizations
- Added environment variables to prevent CPU thread pools from busy-waiting.
- Deferred loading of video models until first use to reduce VRAM footprint.
- Implemented streaming of speaking clips for improved responsiveness.
- Introduced a queue for managing speaking clips to handle multiple requests smoothly.
- Updated video playback logic to ensure proper handling of clip generation.
2026-04-24 00:36:18 -04:00
bhetherman 129df7d1fa working ok 2026-04-16 10:00:37 -04:00
bhetherman 9debc56137 Add LightX2V + Wan2.2-TI2V-5B-Turbo GGUF experiment
Benchmarks the dense 5B Turbo model (Q8_0 GGUF + fp8 T5) as a
lower-VRAM alternative to the 14B MoE pipeline. Includes dtype
patches for dense WanModel, Wan 2.2 VAE config (48 channels, 16x
spatial), and Blackwell fp8 workaround.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-16 01:27:45 -04:00
bhetherman 56923ff424 test passing 2026-04-12 16:38:44 -04:00
bhetherman fcf0be38bc t5 encoder fp8 seems to be working 2026-04-12 13:50:34 -04:00
bhetherman 2818b41004 first stab at adding video 2026-04-12 04:11:52 -04:00
58 changed files with 4573 additions and 39 deletions
+3
View File
@@ -1,3 +1,6 @@
.venv .venv
.claude .claude
__pycache__ __pycache__
tests/component/_out/
avatars/
loras/
+3
View File
@@ -0,0 +1,3 @@
[submodule "third_party/MuseTalk"]
path = third_party/MuseTalk
url = https://git.hetherman.cloud/bhetherman/MuseTalk.git
+49
View File
@@ -0,0 +1,49 @@
# Agent guide — voice-chat
Orientation for AI agents making changes in this repo. Keep edits scoped; this pipeline has several sharp edges that a naive refactor will hit.
## What this project is
A local real-time voice (and optionally video-avatar) chat server. FastAPI + WebSocket on the backend, a small vanilla-JS UI on the frontend. All ML runs locally on a single NVIDIA GPU; there are no cloud API calls at runtime.
## Two-tier architecture
1. **Audio pipeline** (always on): mic PCM → [vad.py](server/vad.py) → [asr.py](server/asr.py) → [llm.py](server/llm.py) → [tts.py](server/tts.py) → PCM back to the browser. Orchestrated by [pipeline.py](server/pipeline.py) via `ConversationSession`. Supports barge-in (user speaks → cancel in-flight reply).
2. **Video pipeline** (optional, gated by `config.video.enabled`): per assistant turn, [video.py](server/video.py)'s `VideoEngine` calls [video_models/wan22.py](server/video_models/wan22.py) (LightX2V Wan2.2-I2V) to produce base frames, then [video_models/musetalk.py](server/video_models/musetalk.py) to lip-sync them to the TTS audio, then [video_models/muxer.py](server/video_models/muxer.py) to produce an MP4. The MP4 is sent over the same `/ws/chat` WebSocket as a `speaking_clip` message.
The audio path must keep working when `video.enabled` is false. Don't make video models load-bearing for the audio pipeline.
## Key files
- [server/main.py](server/main.py) — FastAPI app, WebSocket, video/avatar HTTP endpoints
- [server/models.py](server/models.py) — lifetime of all models; `ModelManager.video_engine` is `None` when video is disabled
- [server/pipeline.py](server/pipeline.py) — per-session audio pipeline + video branch
- [server/config.py](server/config.py) — parses [config.yml](config.yml)
- [server/video.py](server/video.py) — `VideoConfig`, `LoRASpec`, `VideoEngine` (library vs reflective modes)
- [server/video_models/wan22.py](server/video_models/wan22.py) — LightX2V wrapper; fp8 + GGUF loading; Blackwell patches
- [configs/lightx2v/](configs/lightx2v/) — LightX2V inference config templates; must match `wan22_dit_quant_scheme`
- [tests/unit/](tests/unit/) — GPU-free tests, runnable on Windows host
- [tests/component/](tests/component/) — end-to-end tests, must run inside the Docker container
## Conventions
- Config: single source of truth is [config.yml](config.yml) → `server/config.py` dataclasses. Don't read env vars for runtime behaviour; if you need a new knob, add it to the dataclass and document it in `config.yml`.
- Logging: `log = logging.getLogger(__name__)` at module top; log level is set once in `server/main.py`. INFO for lifecycle, DEBUG for per-chunk chatter.
- Async: WebSocket handlers and endpoints are async, but heavy model work is sync — wrap via `asyncio.to_thread(...)` at the call site (see `set_avatar` in `main.py`).
- Concurrency: `VideoEngine` serialises generation with `self._lock`. Don't call model methods without holding it from another thread.
- Tests: every non-trivial logic change in `server/video.py` or `server/pipeline.py` should have a corresponding `tests/unit/` test. GPU-dependent behaviour goes in `tests/component/`.
## Gotchas
- **GPU architecture.** This is tuned for RTX 5090 / Blackwell (SM120) with PyTorch 2.8 + Triton 3.4. Several upstream kernels (flashinfer, flash_attn3, sgl_kernel fp8 matmul, Triton-fused scale/shift) are broken or unavailable there. See [server/video_models/AGENT.md](server/video_models/AGENT.md) before touching the Wan2.2 wrapper.
- **First launch is slow.** Hugging Face downloads land in the `huggingface-cache` Docker volume; a cold run pulls >20 GB.
- **Wan-AI/Wan2.2-I2V-A14B** ships bf16 DIT shards we don't want — `BASE_REPO_IGNORE_PATTERNS` in [wan22.py](server/video_models/wan22.py) excludes them. Keep that list in sync if the repo layout changes.
- **LoRA targets matter.** Wan2.2 is a MoE (high_noise + low_noise sub-models). A LoRA with the wrong `target` loads silently and produces subtly wrong output.
- **Don't mix audio+video state.** The audio pipeline must not block on video generation; video is produced for a turn *after* the full reply audio is available, and sent as a separate message.
## When in doubt
- Run `python -m pytest tests/unit -v` — it's fast and catches most regressions.
- For GPU changes, run the lowest-numbered relevant component test first; they're ordered to isolate failure to a single stage.
- Check [memory](../../.claude/projects/c--Users-bheth-Documents-voice-chat/memory/) (auto-loaded) for prior-session findings that aren't in the code.
+49
View File
@@ -4,6 +4,15 @@ ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
# HuggingFace model cache — mounted as a volume so models persist across runs # HuggingFace model cache — mounted as a volume so models persist across runs
ENV HF_HOME=/cache/huggingface ENV HF_HOME=/cache/huggingface
# LoRA directory — users drop .safetensors files here and reference them
# from config.yml::video.loras. Bind-mounted via docker-compose.
ENV LORA_DIR=/cache/loras
# Prevent PyTorch/OpenMP/MKL thread pools from spin-waiting when idle.
# Without this, loading large models (ASR, LLM, Wan2.2) causes all CPU cores
# to busy-loop even with no connected clients, slowing the whole system.
ENV OMP_WAIT_POLICY=PASSIVE
ENV MKL_WAIT_POLICY=PASSIVE
ENV TOKENIZERS_PARALLELISM=false
RUN apt-get update && apt-get install -y \ RUN apt-get update && apt-get install -y \
python3.11 \ python3.11 \
@@ -38,6 +47,46 @@ RUN python3.11 -m pip install --no-cache-dir -r requirements.txt
# Pre-download the spacy model that kokoro needs at runtime # Pre-download the spacy model that kokoro needs at runtime
RUN python3.11 -m spacy download en_core_web_sm RUN python3.11 -m spacy download en_core_web_sm
# --- Optional: avatar video stack -------------------------------------------
# These are heavy installs; keep them after the core deps so rebuilds only
# redo this layer when ONLY the video stack changes. If you don't plan to
# use config.video.enabled=true, you can comment this block out to speed
# up builds and shrink the image.
#
# LightX2V (Wan2.2-Lightning inference framework) — installed from source
# since there is no stable PyPI release yet.
RUN python3.11 -m pip install --no-cache-dir \
"git+https://github.com/ModelTC/LightX2V.git@6db002f2755036b02bd0900bf9b41958bbfb4137" || \
echo "LightX2V install failed — config.video.enabled must stay false until fixed"
# ^ Pinned to 2026-04-14: last commit before WorldMirrorRunner was added to
# pipeline.py (which requires flash_attn + matplotlib) and before the
# dummy_model NameError regression in vae_2_2.py.
#
# sgl-kernel (fp8 T5 encoder acceleration). The PyPI wheel lacks SM120
# (Blackwell) CUTLASS kernels; use SGLang's cu128 wheel index instead.
# Our wan22.py patches fp8_scaled_mm → torch._scaled_mm at runtime for
# Blackwell GPUs, but the sgl_kernel package itself must still be present.
RUN python3.11 -m pip install --no-cache-dir --no-deps \
"sgl-kernel @ https://github.com/sgl-project/whl/releases/download/v0.3.14.post1/sgl_kernel-0.3.14.post1%2Bcu128-cp310-abi3-manylinux2014_x86_64.whl" || \
echo "sgl-kernel install failed — fp8 T5 will fall back to bf16"
#
# MuseTalk (audio-driven lip-sync) — installed from the bhetherman/MuseTalk
# fork checked in as a submodule at third_party/MuseTalk. The upstream repo
# has no setup.py / pyproject.toml; our fork adds them so `pip install .`
# just works. We deliberately do NOT install its requirements.txt (it pins
# numpy==1.23.5, transformers==4.39.2, tensorflow==2.12.0 which conflict
# with the rest of the stack) — instead we install its real runtime deps
# explicitly here.
COPY third_party/MuseTalk /opt/MuseTalk
RUN python3.11 -m pip install --no-cache-dir --no-deps /opt/MuseTalk || \
echo "MuseTalk install failed — config.video.musetalk.enabled must stay false until fixed"
RUN python3.11 -m pip install --no-cache-dir \
librosa einops omegaconf ffmpeg-python || \
echo "MuseTalk runtime deps install failed"
#
# LoRA directory (user drops .safetensors here; bind-mounted in compose).
RUN mkdir -p /cache/loras
COPY . . COPY . .
EXPOSE 8000 EXPOSE 8000
+55 -12
View File
@@ -1,21 +1,29 @@
# Voice Chat # Voice Chat
A real-time voice conversation app powered by local AI models. Speak into your mic and get spoken responses back — all running on your own GPU with no cloud APIs. A real-time voice conversation app powered by local AI models. Speak into your mic and get spoken responses back — and optionally a lip-synced talking-head video of a chosen avatar — all running on your own GPU with no cloud APIs.
## Pipeline ## Pipeline
**Mic input** &rarr; **VAD** (Silero ONNX) &rarr; **ASR** (Qwen3-ASR-0.6B) &rarr; **LLM** (Qwen3.5-0.8B) &rarr; **TTS** (Kokoro) &rarr; **Speaker output** **Mic input** &rarr; **VAD** (Silero ONNX) &rarr; **ASR** (Qwen3-ASR-0.6B) &rarr; **LLM** (Qwen3.5-0.8B) &rarr; **TTS** (Kokoro) &rarr; **Speaker output**
When the optional video stack is enabled, each assistant turn also produces an MP4 via:
**TTS audio + avatar image** &rarr; **Wan2.2-Lightning I2V** (LightX2V, fp8 or GGUF) &rarr; **MuseTalk lip-sync** &rarr; **ffmpeg mux** &rarr; **`speaking_clip` WebSocket message**
- **VAD** — Silero VAD via ONNX Runtime, detects speech/silence boundaries on CPU - **VAD** — Silero VAD via ONNX Runtime, detects speech/silence boundaries on CPU
- **ASR** — Qwen3-ASR-0.6B, bfloat16 on CUDA - **ASR** — Qwen3-ASR-0.6B, bfloat16 on CUDA
- **LLM** — Qwen3.5-0.8B, loaded via transformers - **LLM** — Qwen3.5-0.8B (local) or any model served by LM Studio
- **TTS** — Kokoro, streams sentence-by-sentence audio at 24 kHz - **TTS** — Kokoro, streams sentence-by-sentence audio at 24 kHz
- **Barge-in** — interrupt the assistant mid-response by speaking - **Barge-in** — interrupt the assistant mid-response by speaking
- **Video (optional)** — Wan2.2 I2V 14B MoE with LightX2V distill LoRAs, fp8 or GGUF-quantised DIT, lip-synced by MuseTalk. Two modes:
- `library` — pre-bakes a small set of speaking base clips per avatar, picks round-robin per turn, lip-syncs on the fly
- `reflective` — generates a fresh Wan2.2 clip per turn from a prompt derived from the assistant's reply
## Requirements ## Requirements
- NVIDIA GPU with CUDA 12.8 support - NVIDIA GPU with CUDA 12.8 support (tested on RTX 5090 / SM120 Blackwell)
- Docker + Docker Compose with the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) - Docker + Docker Compose with the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
- ~24 GB VRAM recommended when video is enabled (fp8); ~16 GB with `gguf-Q4_K_M`
## Quick Start ## Quick Start
@@ -27,6 +35,17 @@ Then open [http://localhost:8000](http://localhost:8000) in your browser.
Models are downloaded from Hugging Face on first launch and cached in a Docker volume (`huggingface-cache`) so they persist across rebuilds. Models are downloaded from Hugging Face on first launch and cached in a Docker volume (`huggingface-cache`) so they persist across rebuilds.
## Configuration
All runtime behaviour is driven by [config.yml](config.yml):
- `llm.backend``local` (in-process transformers) or `lmstudio` (talks to a local LM Studio server)
- `llm.system_prompt`, `llm.max_cache_tokens` — prompt and KV-cache limit per session
- `video.enabled` — master toggle for the avatar video stack. When `false`, no video models load and the app behaves exactly like the audio-only pipeline
- `video.mode``library` or `reflective`
- `video.models.wan22_dit_quant_scheme``gguf-Q8_0` (default) or `gguf-Q4_K_M` for lower VRAM; any GGUF level LightX2V supports (dense 5B Turbo is GGUF-only)
- `video.loras` — list of LoRA adapters applied to the dense Wan2.2-TI2V-5B DIT at load time. Each entry has `path`, `weight`, `target` (always `both` — the 5B DIT is not MoE; legacy `high_noise`/`low_noise` values are coerced), and optional `name`. User LoRAs are mounted from `./loras/` into the container at `/cache/loras/`
## Local Development (without Docker) ## Local Development (without Docker)
```bash ```bash
@@ -45,17 +64,41 @@ python run.py
The server starts on port 8000. The server starts on port 8000.
## API
- `GET /` — browser UI
- `WS /ws/chat` — bidirectional audio + control WebSocket
- `POST /api/set-voice` — switch the Kokoro voice
- `POST /api/set-avatar` *(video only)* — upload an avatar image; (re)generates idle + library clips
- `GET /api/idle-clip` *(video only)* — cached idle loop MP4
- `POST /api/set-video-mode` *(video only)* — switch between `off` / `library` / `reflective`
- `POST /api/reload-loras` *(video only)* — hot-swap the LoRA stack; regenerates the idle clip
## Project Structure ## Project Structure
``` ```
server/ server/
main.py — FastAPI app, WebSocket endpoint main.py — FastAPI app, WebSocket + video endpoints
models.py — Model loading and management models.py — Model loading and management (audio + optional video)
pipeline.py — VAD -> ASR -> LLM -> TTS orchestration pipeline.py — VAD -> ASR -> LLM -> TTS orchestration, video branch
vad.py — Silero VAD (ONNX) streaming wrapper config.py — config.yml parsing
asr.py — Speech recognition engine vad.py — Silero VAD (ONNX) streaming wrapper
llm.py — Language model engine asr.py — Speech recognition engine
tts.py — Kokoro TTS engine llm.py — Language model engine (local + LM Studio backends)
audio_utils.py — PCM/float32 conversion helpers tts.py — Kokoro TTS engine
static/ — Browser UI (HTML/JS/CSS) audio_utils.py — PCM/float32 conversion helpers
video.py — VideoEngine orchestrator + VideoConfig + LoRASpec
video_models/
wan22.py — LightX2V Wan2.2 I2V pipeline wrapper (fp8 + GGUF)
musetalk.py — MuseTalk lip-sync wrapper
muxer.py — ffmpeg helpers: frames -> MP4, frames+audio -> MP4
configs/
lightx2v/ — LightX2V inference config templates (fp8 + GGUF variants)
static/ — Browser UI (HTML/JS/CSS)
avatars/ — uploaded avatar images (gitignored)
loras/ — user-supplied Wan2.2 LoRAs, mounted into the container
reference_audio/ — Kokoro voice reference samples
tests/ — unit + component tests (see tests/README.md)
``` ```
+57
View File
@@ -12,3 +12,60 @@ llm:
lmstudio: lmstudio:
url: http://host.docker.internal:1234 # host.docker.internal resolves to your PC from inside Docker url: http://host.docker.internal:1234 # host.docker.internal resolves to your PC from inside Docker
model: "" # leave empty to use whatever model LM Studio has loaded model: "" # leave empty to use whatever model LM Studio has loaded
# Avatar video generation (Wan2.2-TI2V-5B-Turbo GGUF via LightX2V + MuseTalk lip-sync)
video:
enabled: true # master toggle — when false, video models are not loaded
backend: lightx2v # only option for now
mode: reflective # "library" (pre-baked clips) | "reflective" (fresh per turn)
resolution: 480 # 480 or 720
fps: 16 # Wan2.2 native rate; MuseTalk resamples as needed
library:
base_clip_count: 4 # how many speaking base clips to pre-generate per avatar
base_clip_seconds: 6 # duration of each pre-baked clip
# MuseTalk audio-driven lip-sync. When disabled, Wan2.2 base frames are
# used as-is without a lip-sync pass — useful when MuseTalk isn't installed
# or while iterating on the base pipeline.
musetalk:
enabled: false # toggle lip-sync on/off
reflective:
clip_seconds: 5 # target length of each fresh Wan2.2 clip per turn
clip_prompt_template: >-
webcam view of a person speaking, {reply_hint},
casual gestures, natural lighting, soft focus background
prompt_reply_words: 18 # max words lifted from reply to inject as {reply_hint}
# Model sources for the video stack. T5/VAE/tokenizer come from the
# Wan-AI base repo. The single dense DIT comes from wan22_dit_repo as
# GGUF (Turbo 4-step distill). Both repos download on first run into
# HF_HOME=/cache/huggingface.
#
# Supported dit_quant_scheme values (dense 5B Turbo — GGUF only):
# gguf-Q8_0 — 8-bit, ~6 GB DIT, ~6.5 GB VRAM at load (default)
# gguf-Q4_K_M — 4-bit, ~3.5 GB DIT, lower VRAM for tight budgets
# (any gguf-<level> published in hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF)
models:
wan22_base_repo: Wan-AI/Wan2.2-TI2V-5B
wan22_dit_repo: hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF
wan22_dit_quant_scheme: gguf-Q8_0
wan22_t5_quantized: true
wan22_model_cls: wan2.2
wan22_config_json: /app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json
musetalk_path: TMElyralab/MuseTalk
# LoRAs applied to the dense 5B DIT at load time via LightX2V's
# lora_dynamic_apply path (merged during GGUF dequant). Dense has a
# single set of weights so `target` is always `both`.
#
# The old MoE-trained wan22-H-e8 / wan22-L-e8 LoRAs are NOT compatible
# with the 5B DIT and are disabled here. Future 5B-compatible LoRAs
# should follow the shape shown below.
loras: []
# loras:
# - path: /cache/loras/your-5b-lora.safetensors
# weight: 1.0
# target: both
# name: your-5b-lora
@@ -0,0 +1,40 @@
{
"_comment": "LightX2V config for Wan2.2-TI2V-5B-Turbo (dense, GGUF). Single DIT checkpoint (not MoE). dit_quantized_ckpt is filled in at runtime by Wan22Pipeline.",
"infer_steps": 4,
"target_video_length": 81,
"text_len": 512,
"resize_mode": "adaptive",
"resolution": "480p",
"target_height": 480,
"target_width": 480,
"fps": 16,
"vae_stride": [4, 16, 16],
"num_channels_latents": 48,
"self_attn_1_type": "torch_sdpa",
"cross_attn_1_type": "torch_sdpa",
"cross_attn_2_type": "torch_sdpa",
"modulate_type": "torch",
"rope_type": "torch",
"sample_guide_scale": 1.0,
"sample_shift": 5.0,
"enable_cfg": false,
"cpu_offload": false,
"offload_granularity": "model",
"t5_cpu_offload": true,
"vae_cpu_offload": false,
"use_image_encoder": false,
"denoising_step_list": [1000, 750, 500, 250],
"dit_quantized": true,
"dit_quant_scheme": "gguf-Q8_0",
"t5_quantized": true,
"t5_quant_scheme": "fp8-sgl"
}
+11
View File
@@ -0,0 +1,11 @@
"""Pytest configuration.
Ensures the project root is on ``sys.path`` so tests can import ``server.*``
without installing the project as a package.
"""
import os
import sys
_ROOT = os.path.dirname(os.path.abspath(__file__))
if _ROOT not in sys.path:
sys.path.insert(0, _ROOT)
+7
View File
@@ -6,10 +6,17 @@ services:
volumes: volumes:
# Cache models on the host so they survive container rebuilds # Cache models on the host so they survive container rebuilds
- huggingface-cache:/cache/huggingface - huggingface-cache:/cache/huggingface
# LoRA adapters — drop .safetensors files into ./loras on the host,
# reference them from config.yml as /cache/loras/<file>.safetensors
- ./loras:/cache/loras
# Avatar images uploaded via the web UI persist between restarts
- ./avatars:/app/avatars
# Mount source so you can edit code/config without rebuilding the image # Mount source so you can edit code/config without rebuilding the image
- ./config.yml:/app/config.yml:ro - ./config.yml:/app/config.yml:ro
- ./configs:/app/configs:ro
- ./server:/app/server:ro - ./server:/app/server:ro
- ./static:/app/static:ro - ./static:/app/static:ro
- ./tests:/app/tests
- ./run.py:/app/run.py:ro - ./run.py:/app/run.py:ro
deploy: deploy:
resources: resources:
+1
View File
@@ -0,0 +1 @@
_out/
+65
View File
@@ -0,0 +1,65 @@
# LightX2V + Wan2.2-TI2V-5B-Turbo (GGUF) Experiment
Swap the 14B MoE distill for the dense 5B Turbo model, keeping the LightX2V backend.
Hypothesis: half the parameters → lower VRAM footprint (can coexist with the running
server) and faster per-step compute, with the Turbo 4-step distill preserving wall time.
## Config
- **Model**: `hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF` (Q8_0 by default — swap to Q4_K_M via env)
- **Base repo** (configs, T5, VAE): `Wan-AI/Wan2.2-TI2V-5B`
- **model_cls**: `wan2.2` (dense, single DIT — not MoE)
- **Steps**: 4 (Turbo distill)
- **Resolution**: 480×480, 81 frames @ 16 fps
## Key implementation details
- **Dense model (`wan2.2`)**: Uses single DIT checkpoint, not MoE — requires different dtype patching than the 14B pipeline
- **GGUF dequant → fp16**: Requires `DTYPE=FP16` and patches for T5 (bf16→fp16 wrapper), VAE (→fp16), and DIT pre/post weights (fp32→fp16)
- **Wan 2.2 VAE**: 48 latent channels with 16× spatial compression (vs 16 channels / 8× for Wan 2.1) — config must set `vae_stride: [4,16,16]` and `num_channels_latents: 48`
- **fp8 T5**: Uses `lightx2v/Encoders` fp8 checkpoint (~4.9 GB vs ~11.4 GB bf16)
- **Blackwell (SM120)**: Needs `_patch_fp8_scaled_mm_for_blackwell` to replace sgl_kernel's fp8 GEMM
## Why a separate container
Reuses the existing `voice-chat-voice-chat` image (LightX2V already installed) but runs
under its own compose profile so it doesn't interfere with the live server volumes or
startup. Shares the HF cache volume so model downloads are reused.
## Running
```bash
# Ensure main image is built
docker compose build voice-chat
# Stage model (downloads base + Turbo Q8 GGUF, ~6 GB)
docker compose -f experimental/lightx2v_5b/docker-compose.yml --profile experimental \
run --rm lightx2v-5b python /app/experimental/lightx2v_5b/setup_model.py
# Run benchmark
docker compose -f experimental/lightx2v_5b/docker-compose.yml --profile experimental \
run --rm lightx2v-5b python /app/experimental/lightx2v_5b/test_i2v.py
```
Reports peak VRAM and wall time for an 81-frame 480p clip.
## Results
| Metric | Value |
|--------|-------|
| Model load | ~43s |
| VRAM after load | 6.53 GB |
| T5 encode | ~1s |
| VAE encode | ~0.5s |
Awaiting full end-to-end benchmark completion for wall time and peak VRAM.
## Go / no-go criteria
- **Go**: < 45s per 81-frame clip AND peak VRAM < 12 GB (leaves ~20 GB for the server)
- **No-go**: keep the 14B MoE Q4_K_M pipeline
### Baselines
- **vLLM-Omni + fp16 Turbo-5B**: 1663s / 22.5 GB — decisive no-go
- **LightX2V + 14B MoE Q4_K_M**: ~30s/clip, ~14.5 GB VRAM (current production pipeline)
+40
View File
@@ -0,0 +1,40 @@
{
"_comment": "LightX2V config for Wan2.2-TI2V-5B-Turbo (dense, GGUF). Single DIT checkpoint (not MoE). dit_quantized_ckpt is filled in at runtime by setup_model.py / test_i2v.py.",
"infer_steps": 4,
"target_video_length": 81,
"text_len": 512,
"resize_mode": "adaptive",
"resolution": "480p",
"target_height": 480,
"target_width": 480,
"fps": 16,
"vae_stride": [4, 16, 16],
"num_channels_latents": 48,
"self_attn_1_type": "torch_sdpa",
"cross_attn_1_type": "torch_sdpa",
"cross_attn_2_type": "torch_sdpa",
"modulate_type": "torch",
"rope_type": "torch",
"sample_guide_scale": 1.0,
"sample_shift": 5.0,
"enable_cfg": false,
"cpu_offload": false,
"offload_granularity": "model",
"t5_cpu_offload": true,
"vae_cpu_offload": false,
"use_image_encoder": false,
"denoising_step_list": [1000, 750, 500, 250],
"dit_quantized": true,
"dit_quant_scheme": "gguf-Q8_0",
"t5_quantized": true,
"t5_quant_scheme": "fp8-sgl"
}
@@ -0,0 +1,26 @@
services:
lightx2v-5b:
image: voice-chat-voice-chat:latest
volumes:
- huggingface-cache:/cache/huggingface
- ../../:/app
working_dir: /app
environment:
- DTYPE=FP16
- HF_HOME=/cache/huggingface
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
shm_size: "8g"
ipc: host
profiles:
- experimental
volumes:
huggingface-cache:
name: voice-chat_huggingface-cache
external: true
+54
View File
@@ -0,0 +1,54 @@
"""Stage Wan2.2-TI2V-5B-Turbo GGUF pipeline for LightX2V.
Downloads:
1. Base `Wan-AI/Wan2.2-TI2V-5B` snapshot (configs, T5, VAE — skip bf16 DIT shards).
2. Turbo Q8 GGUF DIT from `hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF`.
Quant file can be overridden via GGUF_FILE env (default Q8_0).
Idempotent: huggingface_hub handles caching.
"""
from __future__ import annotations
import os
from huggingface_hub import hf_hub_download, snapshot_download
BASE_REPO = "Wan-AI/Wan2.2-TI2V-5B"
GGUF_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
GGUF_FILE = os.environ.get(
"GGUF_FILE", "Wan2_2-TI2V-5B-Turbo-Q8_0.gguf"
)
T5_FP8_REPO = "lightx2v/Encoders"
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
def main() -> None:
print(f"\n=== 1/2 Snapshot base pipeline {BASE_REPO} ===", flush=True)
# The base repo ships bf16 DIT shards we don't need (we use the Turbo GGUF instead).
base_dir = snapshot_download(
repo_id=BASE_REPO,
ignore_patterns=[
"*.pt",
"diffusion_pytorch_model*.safetensors",
],
)
print(f"Base pipeline at: {base_dir}")
print(f"\n=== 2/3 Download {GGUF_FILE} from {GGUF_REPO} ===", flush=True)
gguf_path = hf_hub_download(repo_id=GGUF_REPO, filename=GGUF_FILE)
print(f"GGUF DIT at: {gguf_path}")
print(f"\n=== 3/3 Download fp8 T5 from {T5_FP8_REPO} ===", flush=True)
t5_path = hf_hub_download(repo_id=T5_FP8_REPO, filename=T5_FP8_FILE)
print(f"fp8 T5 at: {t5_path}")
print(f"\n{'=' * 50}")
print("Ready. Export to test_i2v.py via env:")
print(f" BASE_DIR={base_dir}")
print(f" DIT_GGUF={gguf_path}")
print(f" T5_FP8={t5_path}")
if __name__ == "__main__":
main()
+217
View File
@@ -0,0 +1,217 @@
"""Benchmark Wan2.2-TI2V-5B-Turbo i2v under LightX2V.
Uses the dense `wan2.2` model_cls with a single Q8 GGUF DIT checkpoint.
Applies the same GGUF dtype patches as server/video_models/wan22.py (T5→bf16
wrapper, VAE→fp16, fp32 DIT pre/post weights→fp16).
Measures peak VRAM and wall time for an 81-frame 480p clip from the sample avatar.
Run:
docker compose -f experimental/lightx2v_5b/docker-compose.yml --profile experimental \\
run --rm lightx2v-5b python /app/experimental/lightx2v_5b/test_i2v.py
"""
from __future__ import annotations
import argparse
import json
import os
import tempfile
import time
from pathlib import Path
import torch
from huggingface_hub import hf_hub_download, snapshot_download
BASE_REPO = "Wan-AI/Wan2.2-TI2V-5B"
GGUF_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
GGUF_FILE = os.environ.get("GGUF_FILE", "Wan2_2-TI2V-5B-Turbo-Q8_0.gguf")
T5_FP8_REPO = "lightx2v/Encoders"
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
CONFIG_JSON = Path(__file__).parent / "config.json"
SAMPLE_AVATAR = "/app/tests/component/sample_avatar.png"
OUTPUT = Path("/app/experimental/lightx2v_5b/_out/turbo_5b.mp4")
PROMPT = "a humanoid robot looking at the camera, shaking their head left and right, soft focus background"
SEED = 42
def human_gb(n: int) -> str:
return f"{n / (1024 ** 3):.2f} GB"
def _cast_all_fp32_tensors(obj, visited=None, depth=0) -> int:
"""Recursively find and cast fp32 tensors to fp16 on any object tree."""
if visited is None:
visited = set()
obj_id = id(obj)
if obj_id in visited or depth > 6:
return 0
visited.add(obj_id)
n = 0
for attr_name in dir(obj):
if attr_name.startswith("__"):
continue
try:
val = getattr(obj, attr_name)
except Exception:
continue
if isinstance(val, torch.Tensor) and val.dtype == torch.float32 and val.numel() > 0:
try:
setattr(obj, attr_name, val.to(torch.float16))
n += 1
except Exception:
pass
elif hasattr(val, "__dict__") and not callable(val):
n += _cast_all_fp32_tensors(val, visited, depth + 1)
return n
def build_args(model_path: str, config_json: str) -> argparse.Namespace:
return argparse.Namespace(
seed=SEED,
model_cls="wan2.2",
task="i2v",
support_tasks=[],
model_path=model_path,
sf_model_path=None,
config_json=config_json,
use_prompt_enhancer=False,
prompt="",
negative_prompt="",
image_path="",
last_frame_path="",
audio_path="",
image_strength="1.0",
image_frame_idx="",
src_ref_images=None,
src_video=None,
src_mask=None,
src_pose_path=None,
src_face_path=None,
src_bg_path=None,
src_mask_path=None,
pose=None,
action_path=None,
action_ckpt=None,
save_result_path=None,
return_result_tensor=False,
target_shape=[],
target_video_length=81,
aspect_ratio="",
video_path=None,
sr_ratio=2.0,
)
def main() -> None:
OUTPUT.parent.mkdir(parents=True, exist_ok=True)
print(f"\n=== Stage model ===", flush=True)
base_dir = snapshot_download(
repo_id=BASE_REPO,
ignore_patterns=["*.pt", "diffusion_pytorch_model*.safetensors"],
)
gguf_path = hf_hub_download(repo_id=GGUF_REPO, filename=GGUF_FILE)
t5_fp8_path = hf_hub_download(repo_id=T5_FP8_REPO, filename=T5_FP8_FILE)
print(f" base_dir: {base_dir}")
print(f" dit_gguf: {gguf_path}")
print(f" t5_fp8: {t5_fp8_path}")
with open(CONFIG_JSON, "r", encoding="utf-8") as f:
cfg = json.load(f)
cfg.pop("_comment", None)
cfg["dit_quantized_ckpt"] = gguf_path
cfg["t5_quantized_ckpt"] = t5_fp8_path
tmp = tempfile.NamedTemporaryFile(
prefix="wan22_5b_", suffix=".json",
mode="w", delete=False, encoding="utf-8",
)
json.dump(cfg, tmp, indent=2)
tmp.close()
print(f" runtime config: {tmp.name}")
# Import LightX2V after env is set.
import sys
if "/app" not in sys.path:
sys.path.insert(0, "/app")
from server.video_models.wan22 import _patch_fp8_scaled_mm_for_blackwell
from lightx2v.infer import init_runner
from lightx2v.utils.input_info import (
init_empty_input_info,
update_input_info_from_dict,
)
from lightx2v.utils.set_config import set_config
_patch_fp8_scaled_mm_for_blackwell()
args = build_args(base_dir, tmp.name)
print(f"\n=== set_config + init_runner ===", flush=True)
torch.cuda.reset_peak_memory_stats()
t_load = time.perf_counter()
config = set_config(args)
runner = init_runner(config)
load_time = time.perf_counter() - t_load
vram_load = torch.cuda.memory_allocated()
print(f"[load] time={load_time:.1f}s vram={human_gb(vram_load)}", flush=True)
# GGUF dtype patches: flip DTYPE to FP16, patch T5/VAE/DIT to match.
os.environ["DTYPE"] = "FP16"
from lightx2v.utils.envs import GET_DTYPE, GET_SENSITIVE_DTYPE
GET_DTYPE.cache_clear()
# Reuse the patch methods from Wan22Pipeline as standalone functions
# by constructing a minimal object with ._runner.
from server.video_models.wan22 import Wan22Pipeline
shim = Wan22Pipeline.__new__(Wan22Pipeline)
shim._runner = runner
Wan22Pipeline._patch_t5_dtype_for_gguf(shim)
Wan22Pipeline._patch_vae_dtype_for_gguf(shim)
Wan22Pipeline._patch_dit_fp32_weights_for_gguf(shim)
# Dense model: patch_dit_fp32 above expects MoE wrapper (runner.model.model).
# For dense wan2.2, the model is directly at runner.model — patch it explicitly.
n = Wan22Pipeline._cast_fp32_dit_weights_in_model(runner.model)
print(f"[patch] cast {n} fp32 DIT weights to fp16 (dense model)")
# Sweep all nested objects for any remaining fp32 tensors (conv3d bias, etc).
n_extra = _cast_all_fp32_tensors(runner.model)
print(f"[patch] cast {n_extra} additional fp32 tensors to fp16")
input_info = init_empty_input_info(args.task, args.support_tasks)
print(f"\n=== generate ===", flush=True)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf:
out_path = tf.name
update_input_info_from_dict(
input_info,
{
"seed": SEED,
"prompt": PROMPT,
"negative_prompt": "",
"image_path": SAMPLE_AVATAR,
"save_result_path": out_path,
"target_video_length": 81,
"return_result_tensor": False,
},
)
t_gen = time.perf_counter()
runner.run_pipeline(input_info)
gen_time = time.perf_counter() - t_gen
peak_vram = torch.cuda.max_memory_allocated()
print(f"\n[generate] wall_time={gen_time:.1f}s peak_vram={human_gb(peak_vram)}", flush=True)
if os.path.exists(out_path):
size = os.path.getsize(out_path)
OUTPUT.write_bytes(Path(out_path).read_bytes())
os.remove(out_path)
print(f"[output] wrote {OUTPUT} ({size} bytes)")
else:
print(f"[output] no file at {out_path}")
if __name__ == "__main__":
main()
+12
View File
@@ -14,3 +14,15 @@ soundfile
scipy scipy
python-multipart python-multipart
pyyaml pyyaml
# --- Avatar video (optional, only used when config.video.enabled=true) ---
# Video frame I/O (used by video_models/wan22.py and the muxer).
imageio[ffmpeg]>=2.34
av>=12.0
pyzmq>=25.0
gguf>=0.6.0
# sgl-kernel: installed from SGLang's cu128 wheel index in Dockerfile
# (PyPI version lacks SM120/Blackwell CUDA kernels)
# LightX2V (Wan2.2-Lightning) and MuseTalk are installed from source in the
# Dockerfile because neither ships a stable PyPI release yet. See lines
# "LightX2V from source" / "MuseTalk from source" in Dockerfile.
+10
View File
@@ -1,5 +1,15 @@
import os
import torch
import uvicorn import uvicorn
# Cap CPU thread pools so PyTorch/OpenMP don't spin-wait on every core at idle.
# Models run on GPU; the CPU thread pool is only needed for small ops.
os.environ.setdefault("OMP_WAIT_POLICY", "PASSIVE")
os.environ.setdefault("MKL_WAIT_POLICY", "PASSIVE")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
torch.set_num_threads(2)
torch.set_num_interop_threads(2)
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run( uvicorn.run(
"server.main:app", "server.main:app",
+57
View File
@@ -0,0 +1,57 @@
# Agent guide — server/
The audio pipeline and FastAPI surface. The video stack lives under [video_models/](video_models/) and has its own guide.
## Module map
- [main.py](main.py) — FastAPI app, lifespan, `/ws/chat` WebSocket, video/avatar HTTP endpoints. Keep business logic out of here; this is a transport layer.
- [models.py](models.py) — `ModelManager.load_all()`. All models are loaded once at startup in a fixed order: VAD → ASR → LLM → TTS → (optional) Video. `video_engine` stays `None` when `config.video.enabled` is false — callers MUST tolerate that.
- [config.py](config.py) — thin YAML loader, exposes a single `config` dict. Don't scatter `yaml.safe_load` elsewhere.
- [pipeline.py](pipeline.py) — `ConversationSession`, one instance per WebSocket. Owns per-session state (VAD stream, conversation history, KV cache, cancel event). Orchestrates VAD → ASR → LLM → TTS and optionally the video branch.
- [vad.py](vad.py) — Silero VAD via ONNX Runtime on CPU. `StreamingVAD.process_chunk(pcm_16k) → utterance | None`. Returns a full utterance only on speech→silence transition.
- [asr.py](asr.py) — Qwen3-ASR wrapper. Sync, called under `asyncio.to_thread`.
- [llm.py](llm.py) — two backends behind a common `generate(history, max_new_tokens, kv_cache_state) → (text, KVCacheState)` signature: `LLMEngine` (local transformers) and `LMStudioEngine` (HTTP, no KV cache).
- [tts.py](tts.py) — Kokoro wrapper. The per-segment generator yields `(graphemes, _ps, audio_f32)` tuples at 24 kHz.
- [audio_utils.py](audio_utils.py) — `pcm_bytes_to_float32` / `float32_to_pcm_bytes`. The WebSocket protocol is 16 kHz int16 PCM in, 24 kHz int16 PCM out.
- [video.py](video.py) — `VideoConfig`, `LoRASpec`, `VideoEngine`. Orchestrates Wan2.2 + MuseTalk. Gated by `config.video.enabled`.
- [video_models/](video_models/) — see [video_models/AGENT.md](video_models/AGENT.md) for Blackwell/GGUF gotchas.
## Session lifecycle (pipeline.py)
1. Client connects → `ConversationSession(models, send_json, send_bytes)``start()` emits `{type: "status", state: "listening"}` and a `{type: "video_mode", ...}` hint.
2. Inbound binary frames are 16 kHz int16 PCM → `handle_audio_chunk` → VAD. On speech→silence the session kicks off `_process_utterance` as an `asyncio.Task` so it never blocks the WebSocket receive loop.
3. `_process_utterance` flow: ASR → LLM → TTS stream. Each blocking call wraps in `asyncio.to_thread`.
4. TTS output is split into short segments via `_split_into_segments` before synthesis so streaming chunks stay small.
5. TTS runs on a background `threading.Thread`, feeding a `queue.Queue` that the async loop drains.
## Barge-in
- `cancel_event: threading.Event` is the single stop signal. Checked between every stage and inside the TTS queue drain loop.
- Two ways to trigger: new VAD utterance while `is_responding` (`handle_audio_chunk`), or an explicit `{type: "interrupt", last_chunk_id}` text message (`interrupt`).
- On cancel: set the event, send `{type: "interrupt"}` to the client so it flushes its audio buffer, and discard any pending video clip.
- Don't add work after `cancel_event.is_set()` without checking — that's how zombie audio/video reaches the client after a barge-in.
## Video branch
The audio pipeline changes shape when `video_engine.is_ready()`:
- PCM chunks and `response_text` are **not** streamed during the turn — they're buffered.
- TTS audio is concatenated into one float32 array.
- After TTS completes, `video_engine.generate_speaking_clip(audio, sr, reply_text)` renders the MP4 (blocking, wrapped in `to_thread`).
- The full clip + final text is sent as a single `speaking_clip` message.
If you extend the audio pipeline, preserve this dual-mode behaviour: the client's UX is very different between the two paths, and mixing them (e.g. sending PCM *and* a clip) will double-play the audio.
## Conventions
- Dataclasses for structured config (see `VideoConfig`, `LoRASpec`). Parse once in `*.from_dict`; don't re-read `config.yml` mid-session.
- `asyncio.to_thread` for any sync model call from an async context. Never call `.generate()` / `.transcribe()` / `.pipeline()` directly on the event loop.
- Locks: `VideoEngine._lock` serialises model state mutations; `ConversationSession` is not locked because each WebSocket gets its own instance.
- Logging: one `log = logging.getLogger(__name__)` per module. INFO for lifecycle and per-turn milestones; avoid DEBUG spam in hot loops.
- Keep `main.py` thin. New endpoints should delegate to a method on `ModelManager` or an engine class.
## Testing
- `tests/unit/test_pipeline_video_branch.py` — the video vs. audio path selection. Update it if you change the `use_video` condition.
- `tests/unit/test_video_config.py` / `test_video_engine_logic.py` — config parsing and the pure logic in `video.py`.
- Component tests live in `tests/component/` and require the Docker GPU environment. See [tests/README.md](../tests/README.md).
+29 -3
View File
@@ -80,11 +80,28 @@ class LLMEngine:
f"processing {input_len - cached_len} new tokens" f"processing {input_len - cached_len} new tokens"
) )
with torch.no_grad(): # Guard: if the cache claims to have seen >= input tokens, it's
outputs = self.model.generate( # stale (can happen after barge-in races or tokenizer mismatches).
# An invalid cache causes an empty cache_position in transformers,
# which raises IndexError inside model.generate().
if past_kv is not None:
cache_seq_len = (
past_kv.get_seq_length()
if hasattr(past_kv, "get_seq_length")
else cached_len
)
if cache_seq_len >= input_len:
log.warning(
f"KV-cache stale (cache_seq={cache_seq_len} >= input={input_len}), discarding."
)
past_kv = None
cached_len = 0
def _do_generate(pkv):
return self.model.generate(
input_ids=input_ids, input_ids=input_ids,
attention_mask=inputs.get("attention_mask"), attention_mask=inputs.get("attention_mask"),
past_key_values=past_kv, past_key_values=pkv,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
temperature=0.7, temperature=0.7,
top_p=0.9, top_p=0.9,
@@ -94,6 +111,15 @@ class LLMEngine:
use_cache=True, use_cache=True,
) )
with torch.no_grad():
try:
outputs = _do_generate(past_kv)
except IndexError:
log.warning("KV-cache caused IndexError during generate; retrying without cache.")
past_kv = None
cached_len = 0
outputs = _do_generate(None)
# Decode only the generated tokens (skip prompt) # Decode only the generated tokens (skip prompt)
new_ids = outputs.sequences[0][input_len:] new_ids = outputs.sequences[0][input_len:]
response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip() response = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip()
+121 -2
View File
@@ -1,23 +1,27 @@
import json import json
import logging import logging
import os import os
import tempfile
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import numpy as np import numpy as np
from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.params import Form from fastapi.params import Form
from fastapi.responses import FileResponse from fastapi.responses import FileResponse, Response
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from server.audio_utils import pcm_bytes_to_float32 from server.audio_utils import pcm_bytes_to_float32
from server.models import ModelManager from server.models import ModelManager
from server.pipeline import ConversationSession from server.pipeline import ConversationSession
from server.video import LoRASpec
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
REFERENCE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "reference_audio") REFERENCE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "reference_audio")
STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static") STATIC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static")
AVATAR_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "avatars")
os.makedirs(AVATAR_DIR, exist_ok=True)
model_mgr = ModelManager() model_mgr = ModelManager()
@@ -47,6 +51,110 @@ async def set_voice(voice: str = Form(...), lang: str = Form("a")):
return {"status": "ok", "voice": voice} return {"status": "ok", "voice": voice}
# --- Video / avatar endpoints ---------------------------------------------
def _require_video() -> "object":
"""Return the video engine, or raise 404 if video mode isn't enabled."""
ve = model_mgr.video_engine
if ve is None:
raise HTTPException(
status_code=404,
detail="Video engine disabled. Set config.video.enabled=true and restart.",
)
return ve
@app.post("/api/set-avatar")
async def set_avatar(image: UploadFile):
"""Upload an avatar image and (re)generate cached clips."""
ve = _require_video()
suffix = os.path.splitext(image.filename or "avatar.png")[1] or ".png"
dest = os.path.join(AVATAR_DIR, f"avatar{suffix}")
with open(dest, "wb") as f:
f.write(await image.read())
log.info("Avatar saved to %s", dest)
import asyncio
try:
await asyncio.to_thread(ve.set_avatar, dest)
except Exception as e:
log.exception("set_avatar failed")
raise HTTPException(status_code=500, detail=f"Avatar setup failed: {e}")
return {
"status": "ok",
"avatar_path": dest,
"idle_clip_url": "/api/idle-clip",
"mode": ve.cfg.mode,
}
@app.get("/api/idle-clip")
async def idle_clip():
"""Return the cached idle loop MP4."""
ve = _require_video()
data = ve.get_idle_clip()
if data is None:
raise HTTPException(status_code=404, detail="No idle clip. Upload an avatar first.")
return Response(content=data, media_type="video/mp4")
@app.post("/api/set-video-mode")
async def set_video_mode(mode: str = Form(...)):
"""Switch between 'off', 'library', and 'reflective'.
'off' leaves the video engine loaded but makes the pipeline take the
PCM streaming path on subsequent turns (by marking the engine not-ready
from the client's perspective via a simple flag).
"""
ve = _require_video()
if mode not in ("off", "library", "reflective"):
raise HTTPException(
status_code=400,
detail="mode must be one of: off, library, reflective",
)
# Switching between library/reflective changes how set_avatar prebakes
# clips. Require a fresh avatar upload afterwards to re-bake.
ve.cfg.mode = mode
return {"status": "ok", "mode": mode, "note": "Re-upload avatar to re-bake library clips." if mode == "library" else ""}
@app.post("/api/reload-loras")
async def reload_loras(body: dict):
"""Hot-reload LoRA stack. Body: ``{"loras": [{"path","weight","target","name"}]}``.
Regenerates the idle clip if an avatar is already set, since the new
LoRAs change the base style.
"""
ve = _require_video()
raw = body.get("loras") or []
specs: list[LoRASpec] = []
for entry in raw:
if not entry or "path" not in entry:
continue
target = str(entry.get("target", "both")).lower()
if target != "both":
target = "both"
specs.append(
LoRASpec(
path=str(entry["path"]),
weight=float(entry.get("weight", 1.0)),
target=target, # type: ignore[arg-type]
name=entry.get("name"),
)
)
import asyncio
try:
await asyncio.to_thread(ve.load_loras, specs)
if ve.avatar_path:
log.info("Regenerating idle clip after LoRA reload.")
await asyncio.to_thread(ve.set_avatar, ve.avatar_path)
except Exception as e:
log.exception("reload_loras failed")
raise HTTPException(status_code=500, detail=str(e))
return {"status": "ok", "lora_count": len(specs), "idle_clip_url": "/api/idle-clip"}
@app.websocket("/ws/chat") @app.websocket("/ws/chat")
async def websocket_chat(ws: WebSocket): async def websocket_chat(ws: WebSocket):
await ws.accept() await ws.accept()
@@ -61,6 +169,17 @@ async def websocket_chat(ws: WebSocket):
session = ConversationSession(model_mgr, send_json, send_bytes) session = ConversationSession(model_mgr, send_json, send_bytes)
await session.start() await session.start()
# Tell the client whether video mode is active so it knows whether to
# suppress PCM playback and wait for speaking_clip messages instead.
ve = model_mgr.video_engine
await send_json({
"type": "video_mode",
"enabled": ve is not None,
"ready": ve.is_ready() if ve is not None else False,
"mode": ve.cfg.mode if ve is not None else "off",
"idle_clip_url": "/api/idle-clip" if (ve is not None and ve.get_idle_clip()) else None,
})
try: try:
while True: while True:
message = await ws.receive() message = await ws.receive()
+27 -2
View File
@@ -5,6 +5,7 @@ from server.vad import StreamingVAD
from server.asr import ASREngine from server.asr import ASREngine
from server.llm import LLMEngine from server.llm import LLMEngine
from server.tts import TTSEngine from server.tts import TTSEngine
from server.video import VideoConfig, VideoEngine
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -31,6 +32,7 @@ class ModelManager:
self.asr_engine: ASREngine | None = None self.asr_engine: ASREngine | None = None
self.llm_engine: LLMEngine | None = None self.llm_engine: LLMEngine | None = None
self.tts_engine: TTSEngine | None = None self.tts_engine: TTSEngine | None = None
self.video_engine: VideoEngine | None = None
def load_all(self): def load_all(self):
"""Load all models sequentially. Call from the main process.""" """Load all models sequentially. Call from the main process."""
@@ -38,6 +40,7 @@ class ModelManager:
self._load_asr() self._load_asr()
self._load_llm() self._load_llm()
self._load_tts() self._load_tts()
self._load_video()
log.info("All models loaded successfully.") log.info("All models loaded successfully.")
def _load_vad(self): def _load_vad(self):
@@ -84,8 +87,8 @@ class ModelManager:
log.info("Loading Qwen3-4B (GPTQ 4-bit)...") log.info("Loading Qwen3-4B (GPTQ 4-bit)...")
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "Qwen/Qwen3.5-0.8B" # model_name = "Qwen/Qwen3.5-0.8B"
model_name = "dphn/Dolphin-X1-8B-FP8"
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
device = get_device() device = get_device()
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
@@ -101,6 +104,28 @@ class ModelManager:
self.tts_engine = TTSEngine() self.tts_engine = TTSEngine()
log.info("Kokoro TTS loaded.") log.info("Kokoro TTS loaded.")
def _load_video(self):
"""Load the avatar video stack iff config.video.enabled is true.
Leaves ``video_engine`` as None when disabled so existing voice flow
is untouched. Later phases replace this stub with actual Wan2.2 +
MuseTalk loading inside ``VideoEngine``.
"""
from server.config import config
video_cfg_raw = config.get("video", {}) or {}
if not video_cfg_raw.get("enabled", False):
log.info("Video engine disabled (config.video.enabled=false). Skipping load.")
return
log.info("Video engine configured (models load on first avatar upload).")
cfg = VideoConfig.from_dict(video_cfg_raw)
self.video_engine = VideoEngine(cfg)
# load_models() is intentionally deferred: Wan2.2 + MuseTalk consume
# ~6.5 GB VRAM at idle, which causes WDDM preemption latency on the
# Windows host even with no connected clients. Models are loaded on
# demand when set_avatar() is first called.
def create_vad(self) -> StreamingVAD: def create_vad(self) -> StreamingVAD:
"""Create a new StreamingVAD instance for a client session.""" """Create a new StreamingVAD instance for a client session."""
return StreamingVAD(self.vad_model) return StreamingVAD(self.vad_model)
+100 -14
View File
@@ -157,11 +157,20 @@ class ConversationSession:
# TTS - stream chunks with per-sentence text # TTS - stream chunks with per-sentence text
await self.send_json({"type": "status", "state": "speaking"}) await self.send_json({"type": "status", "state": "speaking"})
# Video-mode branch: if a video engine is loaded AND an avatar is
# set, buffer the full TTS output into a single blob, run MuseTalk
# lip-sync (library or reflective source), mux to MP4, and send the
# full clip + text in one shot. The client plays the MP4 (which
# carries audio) instead of the per-chunk PCM path.
video_engine = getattr(self.models, "video_engine", None)
use_video = video_engine is not None and video_engine.is_ready()
chunk_queue = queue.Queue() chunk_queue = queue.Queue()
self._last_played_chunk_id = None self._last_played_chunk_id = None
segments = _split_into_segments(response) segments = _split_into_segments(response)
log.info(f"TTS: split response into {len(segments)} segments") log.info(f"TTS: split response into {len(segments)} segments (video={use_video})")
def _tts_worker(): def _tts_worker():
try: try:
@@ -187,6 +196,10 @@ class ConversationSession:
chunk_id = 0 chunk_id = 0
# Maps chunk_id -> cumulative text up to and including that chunk # Maps chunk_id -> cumulative text up to and including that chunk
chunk_text_map: dict[int, str] = {} chunk_text_map: dict[int, str] = {}
# Video mode accumulator: we buffer all TTS audio into one float32
# array so MuseTalk can align against the full utterance.
audio_buffer: list[np.ndarray] = []
while True: while True:
try: try:
item = await asyncio.to_thread(chunk_queue.get, timeout=10.0) item = await asyncio.to_thread(chunk_queue.get, timeout=10.0)
@@ -202,23 +215,96 @@ class ConversationSession:
spoken_text += sentence_text spoken_text += sentence_text
chunk_text_map[chunk_id] = spoken_text chunk_text_map[chunk_id] = spoken_text
await self.send_json({ if use_video:
"type": "response_text", audio_buffer.append(audio)
"text": sentence_text, # Don't stream text or PCM during video mode — we'll send
"chunk_id": chunk_id, # everything after the clip renders so the client doesn't
"final": False, # start displaying text before the video is ready.
}) else:
pcm_bytes = float32_to_pcm_bytes(audio) await self.send_json({
try: "type": "response_text",
await self.send_bytes(pcm_bytes) "text": sentence_text,
except Exception: "chunk_id": chunk_id,
log.warning("Failed to send audio, client disconnected.") "final": False,
self.cancel_event.set() })
break pcm_bytes = float32_to_pcm_bytes(audio)
try:
await self.send_bytes(pcm_bytes)
except Exception:
log.warning("Failed to send audio, client disconnected.")
self.cancel_event.set()
break
chunk_id += 1 chunk_id += 1
tts_thread.join(timeout=2.0) tts_thread.join(timeout=2.0)
# Video mode: stream speaking clips as they're generated (one per audio segment).
if use_video and audio_buffer and not self.cancel_event.is_set():
try:
full_audio = np.concatenate(audio_buffer).astype(np.float32)
sample_rate = getattr(self.models.tts_engine, "sample_rate", 24000)
log.info(
"Video: rendering speaking clips (audio=%.1fs, mode=%s)",
len(full_audio) / sample_rate, video_engine.cfg.mode,
)
clip_queue: queue.Queue = queue.Queue()
def _video_worker():
try:
for clip_data in video_engine.generate_speaking_clips_streaming(
full_audio, sample_rate, response
):
if self.cancel_event.is_set():
break
clip_queue.put(clip_data)
except Exception:
log.exception("Video clip generation failed")
finally:
clip_queue.put(_SENTINEL)
video_thread = threading.Thread(target=_video_worker, daemon=True)
video_thread.start()
is_first_clip = True
while not self.cancel_event.is_set():
try:
item = await asyncio.to_thread(clip_queue.get, timeout=120.0)
except Exception:
log.warning("Timed out waiting for video clip.")
break
if item is _SENTINEL:
break
if self.cancel_event.is_set():
break
mp4_bytes, duration_ms = item
try:
await self.send_json({
"type": "speaking_clip",
"chunk_id": 0,
"duration_ms": duration_ms,
"text": response if is_first_clip else "",
"size_bytes": len(mp4_bytes),
})
await self.send_bytes(mp4_bytes)
is_first_clip = False
except Exception:
log.warning("Failed to send video clip, client disconnected.")
self.cancel_event.set()
break
except Exception:
log.exception("Video speaking-clip render failed; falling back silently.")
try:
await self.send_json({
"type": "response_text",
"text": response,
"chunk_id": 0,
"final": True,
})
except Exception:
pass
# Determine what was actually heard by the client # Determine what was actually heard by the client
was_interrupted = spoken_text.strip() != response.strip() was_interrupted = spoken_text.strip() != response.strip()
if was_interrupted and self._last_played_chunk_id is not None: if was_interrupted and self._last_played_chunk_id is not None:
+489
View File
@@ -0,0 +1,489 @@
"""Avatar video generation: Wan2.2-Lightning base + MuseTalk lip-sync.
Top-level orchestrator. The heavy 3rd-party model code is isolated in
``server/video_models/`` so each wrapper can be updated independently.
This module is only imported by ``server/models.py`` when
``config.video.enabled`` is true. When disabled, the existing voice pipeline
is completely untouched.
"""
from __future__ import annotations
import logging
import threading
from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import Literal
import numpy as np
log = logging.getLogger(__name__)
LoRATarget = Literal["both"]
@dataclass
class LoRASpec:
"""One LoRA adapter entry from ``config.video.loras``.
The dense Wan2.2-TI2V-5B DIT has a single set of weights (no MoE
experts), so ``target`` is always ``"both"``. The field is kept for
forward compatibility and config-file compatibility with older MoE
configs — legacy ``"high_noise"`` / ``"low_noise"`` values are coerced
to ``"both"`` in ``VideoConfig.from_dict``.
"""
path: str
weight: float = 1.0
target: LoRATarget = "both"
name: str | None = None
@dataclass
class VideoConfig:
"""Flattened view of the ``video:`` section of config.yml."""
enabled: bool = False
backend: str = "lightx2v"
mode: str = "reflective" # "library" | "reflective"
resolution: int = 480
fps: int = 16
library_base_clip_count: int = 4
library_base_clip_seconds: int = 6
reflective_clip_seconds: int = 5
reflective_prompt_template: str = (
"webcam view of a person speaking, {reply_hint}, casual gestures, "
"natural lighting, soft focus background"
)
reflective_prompt_reply_words: int = 18
loras: list[LoRASpec] = field(default_factory=list)
# Model paths — can be overridden via config.yml.video.models.
# wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer.
# The bf16 DIT shards in this repo are skipped — we
# replace them with a quantised GGUF from wan22_dit_repo.
# wan22_dit_repo : HF repo id (or local dir) providing the single
# dense GGUF DIT checkpoint (5B Turbo).
# wan22_dit_quant_scheme : GGUF quant level, e.g. "gguf-Q8_0" (default)
# or "gguf-Q4_K_M" for lower VRAM.
# wan22_config_json : path to the LightX2V inference config template the
# Wan22Pipeline will fill in with absolute ckpt paths.
wan22_base_repo: str = "Wan-AI/Wan2.2-TI2V-5B"
wan22_dit_repo: str = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
wan22_dit_quant_scheme: str = "gguf-Q8_0"
wan22_t5_quantized: bool = True
wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
wan22_model_cls: str = "wan2.2"
musetalk_enabled: bool = True
musetalk_model_path: str = "TMElyralab/MuseTalk"
@classmethod
def from_dict(cls, raw: dict) -> "VideoConfig":
raw = raw or {}
library = raw.get("library", {}) or {}
reflective = raw.get("reflective", {}) or {}
models_raw = raw.get("models", {}) or {}
loras_raw = raw.get("loras") or []
default_template = (
"webcam view of a person speaking, {reply_hint}, casual gestures, "
"natural lighting, soft focus background"
)
loras: list[LoRASpec] = []
for entry in loras_raw:
if not entry or "path" not in entry:
continue
target = str(entry.get("target", "both")).lower()
if target != "both":
log.warning(
"LoRA %s: target %r is MoE-era; coercing to 'both' "
"(dense 5B has a single DIT).",
entry.get("path"), target,
)
target = "both"
loras.append(
LoRASpec(
path=str(entry["path"]),
weight=float(entry.get("weight", 1.0)),
target=target, # type: ignore[arg-type]
name=entry.get("name"),
)
)
return cls(
enabled=bool(raw.get("enabled", False)),
backend=str(raw.get("backend", "lightx2v")),
mode=str(raw.get("mode", "reflective")),
resolution=int(raw.get("resolution", 480)),
fps=int(raw.get("fps", 16)),
library_base_clip_count=int(library.get("base_clip_count", 4)),
library_base_clip_seconds=int(library.get("base_clip_seconds", 6)),
reflective_clip_seconds=int(reflective.get("clip_seconds", 5)),
reflective_prompt_template=str(
reflective.get("clip_prompt_template", default_template)
),
reflective_prompt_reply_words=int(reflective.get("prompt_reply_words", 18)),
loras=loras,
wan22_base_repo=str(
models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-TI2V-5B")
),
wan22_dit_repo=str(
models_raw.get(
"wan22_dit_repo",
"hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF",
)
),
wan22_dit_quant_scheme=str(
models_raw.get("wan22_dit_quant_scheme", "gguf-Q8_0")
),
wan22_t5_quantized=bool(
models_raw.get("wan22_t5_quantized", True)
),
wan22_config_json=str(
models_raw.get(
"wan22_config_json",
"/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json",
)
),
wan22_model_cls=str(
models_raw.get("wan22_model_cls", "wan2.2")
),
musetalk_enabled=bool(raw.get("musetalk", {}).get("enabled", True))
if isinstance(raw.get("musetalk"), dict)
else bool(raw.get("musetalk_enabled", True)),
musetalk_model_path=str(
models_raw.get("musetalk_path", "TMElyralab/MuseTalk")
),
)
# Library-mode base-clip prompts. Varied gestures so the pre-baked set feels
# less repetitive when replayed. Kept module-level so tests can import them.
LIBRARY_BASE_PROMPTS = [
"webcam view of a person speaking, subtle head nods, casual expression, "
"natural lighting, soft focus background",
"webcam view of a person speaking, slight smile, gentle hand gesture, "
"natural lighting, soft focus background",
"webcam view of a person speaking, looking thoughtful, small head tilt, "
"natural lighting, soft focus background",
"webcam view of a person speaking, engaged and attentive, minor shoulder "
"movement, natural lighting, soft focus background",
"webcam view of a person speaking, relaxed posture, blinking naturally, "
"natural lighting, soft focus background",
]
IDLE_PROMPT = (
"webcam view of a person listening quietly, mouth closed, subtle "
"breathing, occasional blinks, calm expression, natural lighting, "
"soft focus background"
)
class VideoEngine:
"""Top-level video generation orchestrator.
Holds the Wan2.2 and MuseTalk model wrappers, plus the current avatar's
pre-rendered clips. Exposed to ``ConversationSession`` via
``ModelManager.video_engine``.
"""
def __init__(self, cfg: VideoConfig):
self.cfg = cfg
self._lock = threading.Lock()
# Avatar state
self.avatar_path: str | None = None
self.idle_clip_mp4: bytes | None = None
# Pre-baked speaking base clips for library mode. Each entry is a
# contiguous ``np.ndarray`` of shape ``[T, H, W, 3]`` uint8.
self.speaking_base_frames: list[np.ndarray] = []
# Round-robin pointer for picking a library clip per turn
self._library_cursor = 0
# Model wrappers — instantiated lazily by ``load_models()`` so unit
# tests can exercise VideoEngine without touching CUDA at all.
self._wan22 = None # server.video_models.wan22.Wan22Pipeline
self._musetalk = None # server.video_models.musetalk.MuseTalkEngine
log.info(
"VideoEngine initialised (mode=%s, resolution=%d, fps=%d, loras=%d).",
cfg.mode, cfg.resolution, cfg.fps, len(cfg.loras),
)
# --- Model loading --------------------------------------------------
def load_models(self) -> None:
"""Instantiate the underlying model wrappers.
Separated from ``__init__`` so tests can mock ``_wan22``/``_musetalk``
without triggering Wan2.2's ~12-16GB VRAM allocation.
"""
from server.video_models.wan22 import Wan22Pipeline
from server.video_models.musetalk import MuseTalkEngine
log.info(
"Loading Wan2.2 pipeline (base=%s, dit=%s, quant=%s)...",
self.cfg.wan22_base_repo, self.cfg.wan22_dit_repo,
self.cfg.wan22_dit_quant_scheme,
)
self._wan22 = Wan22Pipeline(
base_repo=self.cfg.wan22_base_repo,
dit_repo=self.cfg.wan22_dit_repo,
config_json=self.cfg.wan22_config_json,
model_cls=self.cfg.wan22_model_cls,
resolution=self.cfg.resolution,
fps=self.cfg.fps,
dit_quant_scheme=self.cfg.wan22_dit_quant_scheme,
t5_quantized=self.cfg.wan22_t5_quantized,
)
if self.cfg.loras:
self._wan22.load_loras(self.cfg.loras)
log.info("Wan2.2 pipeline ready.")
if self.cfg.musetalk_enabled:
log.info("Loading MuseTalk engine (%s)...", self.cfg.musetalk_model_path)
self._musetalk = MuseTalkEngine(model_path=self.cfg.musetalk_model_path)
log.info("MuseTalk engine ready.")
else:
log.info("MuseTalk disabled via config — skipping lip-sync pass.")
self._musetalk = None
# --- Readiness ------------------------------------------------------
def is_ready(self) -> bool:
"""True when an avatar is set and a speaking clip can be produced."""
musetalk_ok = (not self.cfg.musetalk_enabled) or self._musetalk is not None
return (
self._wan22 is not None
and musetalk_ok
and self.avatar_path is not None
and self.idle_clip_mp4 is not None
)
# --- LoRA management ------------------------------------------------
def load_loras(self, specs: list[LoRASpec]) -> None:
"""Apply a list of LoRA adapters to the Wan2.2 base.
Replaces any previously applied LoRAs. Safe to call after init for
hot-reload via ``POST /api/reload-loras``.
"""
if self._wan22 is None:
raise RuntimeError("load_loras called before load_models()")
with self._lock:
self._wan22.unload_loras()
self._wan22.load_loras(specs)
self.cfg.loras = list(specs)
log.info("Applied %d LoRA(s): %s",
len(specs),
", ".join(s.name or s.path for s in specs) or "<none>")
# --- Avatar lifecycle ----------------------------------------------
def set_avatar(self, image_path: str) -> None:
"""Register an avatar image and pre-generate cached clips.
- Always: generate the idle loop.
- Library mode: also pre-generate ``library.base_clip_count``
speaking base clips.
- Reflective mode: idle loop only.
Lazily calls load_models() on first invocation so that Wan2.2's VRAM
footprint doesn't exist until video is actually used.
"""
if self._wan22 is None:
self.load_models()
with self._lock:
log.info("Setting avatar: %s", image_path)
self.avatar_path = image_path
# Drop any previously cached clips so the new avatar's library
# doesn't mix with the old.
self.speaking_base_frames = []
self.idle_clip_mp4 = None
# Idle clip: short loop, neutral/listening prompt.
log.info("Generating idle clip...")
idle_frames = self._wan22.generate_i2v(
image_path=image_path,
prompt=IDLE_PROMPT,
seconds=self.cfg.library_base_clip_seconds,
seed=0,
)
from server.video_models.muxer import frames_to_mp4_loop
self.idle_clip_mp4 = frames_to_mp4_loop(idle_frames, fps=self.cfg.fps)
log.info("Idle clip ready (%d bytes).", len(self.idle_clip_mp4))
# Library mode: pre-bake N speaking base clips.
if self.cfg.mode == "library":
n = self.cfg.library_base_clip_count
log.info("Pre-baking %d speaking base clip(s) for library mode.", n)
for i in range(n):
prompt = LIBRARY_BASE_PROMPTS[i % len(LIBRARY_BASE_PROMPTS)]
frames = self._wan22.generate_i2v(
image_path=image_path,
prompt=prompt,
seconds=self.cfg.library_base_clip_seconds,
seed=i + 1,
)
self.speaking_base_frames.append(frames)
log.info(" base clip %d/%d rendered", i + 1, n)
self._library_cursor = 0
def get_idle_clip(self) -> bytes | None:
return self.idle_clip_mp4
# --- Per-turn generation -------------------------------------------
def generate_speaking_clip(
self,
audio_f32: np.ndarray,
sample_rate: int,
reply_text: str,
) -> bytes:
"""Produce a lip-synced MP4 for one assistant turn."""
if not self.is_ready():
raise RuntimeError(
"generate_speaking_clip: engine not ready "
"(avatar set? models loaded?)"
)
assert self._wan22 is not None
# 1. Source base frames.
if self.cfg.mode == "library":
base_frames = self._pick_library_frames(audio_f32, sample_rate)
else: # reflective
prompt = self._derive_prompt(reply_text)
log.info("Reflective prompt: %s", prompt[:120])
base_frames = self._wan22.generate_i2v(
image_path=self.avatar_path or "",
prompt=prompt,
seconds=self.cfg.reflective_clip_seconds,
seed=None, # random each turn
)
# 2. Lip-sync the base frames to the given audio (if enabled).
if self._musetalk is not None:
synced_frames = self._musetalk.lip_sync(
frames=base_frames,
audio=audio_f32,
sample_rate=sample_rate,
fps=self.cfg.fps,
)
else:
synced_frames = base_frames
# 3. Mux frames + audio into an MP4.
from server.video_models.muxer import frames_and_audio_to_mp4
return frames_and_audio_to_mp4(
frames=synced_frames,
audio=audio_f32,
sample_rate=sample_rate,
fps=self.cfg.fps,
)
def _pick_library_frames(
self, audio_f32: np.ndarray, sample_rate: int
) -> np.ndarray:
"""Round-robin pick from the pre-baked library, clipped to the segment duration.
Does not loop frames — callers that need longer coverage should split
the audio into segments and call this once per segment.
"""
if not self.speaking_base_frames:
raise RuntimeError(
"Library mode has no pre-baked base clips. "
"Was set_avatar called with mode=library?"
)
frames = self.speaking_base_frames[
self._library_cursor % len(self.speaking_base_frames)
]
self._library_cursor += 1
target_frames = int(round(len(audio_f32) / sample_rate * self.cfg.fps))
if target_frames <= 0:
return frames
return frames[:min(target_frames, len(frames))]
def generate_speaking_clips_streaming(
self,
audio_f32: np.ndarray,
sample_rate: int,
reply_text: str,
) -> Iterator[tuple[bytes, int]]:
"""Generate one MP4 per clip-length audio segment, yielding each when ready.
Splits ``audio_f32`` into segments of ``reflective_clip_seconds`` (or
``library_base_clip_seconds`` for library mode) and generates + lip-syncs
one clip per segment. Yields ``(mp4_bytes, duration_ms)`` tuples so the
caller can stream each clip to the client as soon as it's ready rather
than waiting for the full response.
"""
if not self.is_ready():
raise RuntimeError(
"generate_speaking_clips_streaming: engine not ready "
"(avatar set? models loaded?)"
)
assert self._wan22 is not None
if len(audio_f32) == 0:
return
clip_sec = (
self.cfg.library_base_clip_seconds
if self.cfg.mode == "library"
else self.cfg.reflective_clip_seconds
)
clip_samples = int(clip_sec * sample_rate)
segments = [
audio_f32[i : i + clip_samples]
for i in range(0, len(audio_f32), clip_samples)
]
for seg_audio in segments:
if self.cfg.mode == "library":
base_frames = self._pick_library_frames(seg_audio, sample_rate)
else:
prompt = self._derive_prompt(reply_text)
log.info("Reflective prompt (clip segment): %s", prompt[:80])
base_frames = self._wan22.generate_i2v(
image_path=self.avatar_path or "",
prompt=prompt,
seconds=self.cfg.reflective_clip_seconds,
seed=None,
)
if self._musetalk is not None:
synced_frames = self._musetalk.lip_sync(
frames=base_frames,
audio=seg_audio,
sample_rate=sample_rate,
fps=self.cfg.fps,
)
else:
synced_frames = base_frames
from server.video_models.muxer import frames_and_audio_to_mp4
mp4_bytes = frames_and_audio_to_mp4(
frames=synced_frames,
audio=seg_audio,
sample_rate=sample_rate,
fps=self.cfg.fps,
)
duration_ms = int(len(seg_audio) / sample_rate * 1000)
yield mp4_bytes, duration_ms
def _derive_prompt(self, reply_text: str) -> str:
"""Template-based prompt builder for reflective mode.
Takes up to ``prompt_reply_words`` words from the start of the reply
and interpolates them into the configured template. Cheap,
deterministic, no extra LLM call.
"""
words = (reply_text or "").split()
hint = " ".join(words[: self.cfg.reflective_prompt_reply_words]).strip()
if not hint:
hint = "calm and friendly"
return self.cfg.reflective_prompt_template.format(reply_hint=hint)
+78
View File
@@ -0,0 +1,78 @@
# Agent guide — server/video_models/
Wrappers around 3rd-party video models. These are the trickiest files in the repo: LightX2V's internals move quickly upstream, and the Blackwell (RTX 5090 / SM120) GPU path requires several non-obvious patches layered on top. Read this before editing.
## Scope
- [wan22.py](wan22.py) — LightX2V Wan2.2-I2V A14B MoE pipeline. Supports fp8 safetensors and GGUF DIT checkpoints. Loaded once at startup, held resident; per-turn calls go through `generate_i2v` and `switch_lora`.
- [musetalk.py](musetalk.py) — MuseTalk lip-sync over base frames + TTS audio.
- [muxer.py](muxer.py) — thin ffmpeg wrappers: frames → MP4 loop, frames + audio → MP4.
Nothing here is imported unless `config.video.enabled` is true.
## LightX2V entry points (upstream API)
Use these symbols, not internal/private ones:
```python
from lightx2v.utils.set_config import set_config
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
from lightx2v.infer import init_runner
config = set_config(args) # args is an argparse.Namespace
input_info = init_empty_input_info(args.task, args.support_tasks)
runner = init_runner(config) # loads all weights — ONCE
update_input_info_from_dict(input_info, {...}) # per-turn inputs
runner.run_pipeline(input_info) # MP4 written to save_result_path
runner.switch_lora(lora_path, strength) # hot-swap
```
Keep model load out of the per-turn path — `init_runner` is expensive.
## Blackwell (SM120) patches — do not remove without testing
The GGUF pipeline works on a 5090 only because of layered patches in `wan22.py` and tuning in the LightX2V JSON configs under [configs/lightx2v/](../../configs/lightx2v/). Each patch exists because a stock upstream path segfaults or silently miscomputes on SM120.
**Dtype plumbing (GGUF path):**
- Default `DTYPE` must be `BF16` at `init_runner()` time — T5 offload buffers break if FP16 at init.
- Flip `BF16 → FP16` *after* `init_runner()`.
- Wrap T5 encoder so it runs under BF16 internally, then cast outputs `bf16 → fp16` before handing to the DIT. See `_patch_t5_dtype_for_gguf`.
- Cast VAE **both** layers: the inner `.model` via `.to(fp16)` **and** the outer `WanVAE` wrapper's `mean` / `inv_std` / `scale` tensors. Missing the wrapper tensors upcasts the latent during decode's `z/inv_std + mean`.
- DIT `pre_weight.patch_embedding.pin_weight` loads as fp32 (only `pin_bias` is fp16). Cast **and** re-pin via `.pin_memory()` — skipping re-pin segfaults during `to_cuda` H2D copy.
- `sgl_kernel`'s fp8 scaled matmul is patched to `torch._scaled_mm` in `_patch_fp8_scaled_mm_for_blackwell`.
**LightX2V JSON config (see `wan22_i2v_gguf_distill.json`):**
- `modulate_type: "torch"` — Triton `fuse_scale_shift_kernel` segfaults in `ast_to_ttir` on Triton 3.4 + SM120.
- `rope_type: "torch"` — flashinfer isn't installed.
- `self_attn_1_type` / `cross_attn_*_type`: `"torch_sdpa"` — flash_attn3 unavailable; `sageattention==1.0.6` from PyPI segfaults on Blackwell (newer requires source build).
If you add a new quant scheme or a new model_cls, create its own JSON under `configs/lightx2v/` mirroring these choices, and exercise it end-to-end via a new `tests/component/test_NN_*.py` before wiring it into the default config.
## HF download layout
`Wan-AI/Wan2.2-I2V-A14B` ships ~28 GB of bf16 DIT shards we replace with the quantised `dit_repo`. `BASE_REPO_IGNORE_PATTERNS` in [wan22.py](wan22.py) excludes them but **keeps** `high_noise_model/*.json` and `low_noise_model/*.json``set_config` parses architecture params (`dim`, etc.) from those. Don't broaden the ignore pattern without checking.
Supported quant schemes live in `wan22_dit_quant_scheme`:
- `fp8-sgl``lightx2v/Wan2.2-Distill-Models`, two `.safetensors` files
- `gguf-Q4_K_M`, `gguf-Q8_0`, … — `QuantStack/Wan2.2-I2V-A14B-GGUF`, layout `HighNoise/…` and `LowNoise/…`
Filenames are templated at the top of `wan22.py`; update those if the upstream repos rename files.
## Testing
- `tests/component/test_02_wan22_loras.py` — full pipeline load + LoRA apply
- `tests/component/test_09_gguf_generate.py` — GGUF end-to-end I2V
- `tests/component/test_10_t5_encode.py` — T5 encoder dtype path
- `tests/component/test_11_image_encode.py` — image → VAE latent
- `tests/component/test_12_dit_single_step.py` — one DIT step per expert
- `tests/component/test_13_vae_decode.py` — VAE decode → RGB
When diagnosing a Blackwell regression, run 10 → 11 → 12 → 13 in that order; the failure localises to the first failing stage.
## LoRAs
`switch_lora(path, strength)` applies; `switch_lora("", 0.0)` removes. `load_loras`/`unload_loras` in this wrapper iterate over `LoRASpec`s from config and call `switch_lora` per `target` sub-model (`high_noise`, `low_noise`, or `both`). Wrong target = silently wrong output.
+10
View File
@@ -0,0 +1,10 @@
"""Thin wrappers around 3rd-party video generation models.
Each submodule isolates one external dependency so the real API surface
can be updated in a single file without touching the pipeline.
Submodules:
- ``wan22``: Wan2.2-Lightning image-to-video via LightX2V
- ``musetalk``: MuseTalk audio-driven lip-sync
- ``muxer``: ffmpeg-based frame/audio → MP4 encoding
"""
+151
View File
@@ -0,0 +1,151 @@
"""MuseTalk audio-driven lip-sync wrapper.
MuseTalk takes a sequence of face frames + driving audio and returns a new
sequence of frames where the mouth region is animated to match the audio.
This module isolates MuseTalk's real API behind a single ``lip_sync()``
method. MuseTalk's upstream Python surface varies between forks — if the
real import path or call signature differs, update this file only.
"""
from __future__ import annotations
import logging
import os
import numpy as np
log = logging.getLogger(__name__)
class MuseTalkEngine:
"""Thin wrapper over MuseTalk inference."""
def __init__(self, model_path: str = "TMElyralab/MuseTalk"):
self.model_path = model_path
# MuseTalk's canonical entry point is ``musetalk.inference`` or a
# similar ``MuseTalkInfer`` class. Try the most common imports.
self._infer = self._load_impl(model_path)
log.info("MuseTalk engine loaded from %s", model_path)
@staticmethod
def _load_impl(model_path: str):
"""Load the MuseTalk inference implementation.
Upstream MuseTalk has no library-style entry point — it's a bundle
of training/inference CLI scripts. The bhetherman/MuseTalk fork at
``third_party/MuseTalk`` adds package metadata but the low-level
API is still the raw ``musetalk.utils.*`` and ``musetalk.models.*``
modules. We import them here to verify the install succeeded; the
actual pipeline (VAE, UNet, Whisper, face detection, blending)
is wired up inside ``MuseTalkEngine.lip_sync``.
"""
resolved = model_path
if not os.path.isdir(model_path) and "/" in model_path:
try:
from huggingface_hub import snapshot_download
resolved = snapshot_download(repo_id=model_path)
except Exception as e: # pragma: no cover
log.warning("Could not snapshot_download MuseTalk repo: %s", e)
try:
from musetalk.utils.utils import load_all_model # type: ignore[import-not-found] # noqa: F401
from musetalk.utils.audio_processor import AudioProcessor # type: ignore[import-not-found] # noqa: F401
except ImportError as e:
raise RuntimeError(
"MuseTalk Python package is not importable. "
"Check that third_party/MuseTalk was installed via "
"`pip install /opt/MuseTalk` in the Dockerfile."
) from e
# Return the resolved weight path; lip_sync loads models lazily on
# first call so import-time failures don't block voice-only startup.
return {"model_path": resolved, "loaded": False}
# --- Inference ---------------------------------------------------------
def lip_sync(
self,
frames: np.ndarray,
audio: np.ndarray,
sample_rate: int,
fps: int,
) -> np.ndarray:
"""Return new frames with lip-sync applied to match ``audio``.
Args:
frames: uint8 ``[T, H, W, 3]`` RGB base frames.
audio: float32 mono 1D audio.
sample_rate: sample rate of ``audio``.
fps: frame rate of ``frames``.
Returns:
uint8 ``[T', H, W, 3]`` RGB frames. ``T'`` is trimmed or padded
to match audio duration at ``fps``.
"""
if frames.ndim != 4 or frames.shape[-1] != 3:
raise ValueError(
f"frames must be [T, H, W, 3] uint8, got {frames.shape}"
)
# Normalise frame count to audio duration so the caller doesn't have
# to do the arithmetic.
target_t = int(round(len(audio) / sample_rate * fps))
if target_t > 0 and len(frames) != target_t:
frames = _fit_frames_to_length(frames, target_t)
# MuseTalk's real inference path (see third_party/MuseTalk/scripts/
# realtime_inference.py::Avatar.inference) needs:
# - mmpose + mmcv + mmengine (dwpose keypoint detection)
# - face_alignment (bbox)
# - MuseTalk UNet + VAE weights (TMElyralab/MuseTalk HF repo)
# - Whisper encoder (openai/whisper-tiny)
# - face_parsing weights
# Plus its preprocessing module has import-time side effects that
# load dwpose weights from a CWD-relative path. Turn the full
# pipeline on by extending this method once those deps are
# installed and weights are resolved — until then, callers should
# keep ``config.video.musetalk.enabled: false`` and VideoEngine
# will skip the lip-sync pass.
raise NotImplementedError(
"MuseTalk lip-sync pipeline is not wired up yet. "
"Set config.video.musetalk.enabled=false to bypass."
)
def _fit_frames_to_length(frames: np.ndarray, target_t: int) -> np.ndarray:
"""Trim or repeat ``frames`` (contiguous T axis) to exactly ``target_t``.
Repeats with a ping-pong / boomerang tail so the seam between loops is
less jarring than a hard cut back to frame 0.
"""
if target_t <= 0:
return frames
t = len(frames)
if t == target_t:
return frames
if t > target_t:
return frames[:target_t]
# Extend via ping-pong looping
extended = [frames]
total = t
flip = True
while total < target_t:
seg = frames[::-1] if flip else frames
extended.append(seg)
total += t
flip = not flip
return np.concatenate(extended, axis=0)[:target_t]
def _ensure_uint8_rgb(arr) -> np.ndarray:
"""Coerce the MuseTalk output to uint8 [T, H, W, 3] RGB."""
result = np.asarray(arr)
if result.dtype != np.uint8:
if result.dtype in (np.float32, np.float64):
result = np.clip(result * 255.0, 0, 255).astype(np.uint8)
else:
result = result.astype(np.uint8)
if result.ndim == 3:
result = result[None, ...]
return result
+146
View File
@@ -0,0 +1,146 @@
"""ffmpeg-based frame + audio → MP4 muxing.
Uses the system ``ffmpeg`` binary already installed in the Dockerfile.
No extra python dependencies beyond ``numpy``.
"""
from __future__ import annotations
import logging
import os
import shutil
import subprocess
import tempfile
import numpy as np
log = logging.getLogger(__name__)
def _ffmpeg_bin() -> str:
bin_path = shutil.which("ffmpeg")
if bin_path is None:
raise RuntimeError(
"ffmpeg binary not found on PATH. It should be installed by "
"the Dockerfile (line 13). Ensure you're running inside the "
"docker image or install ffmpeg locally."
)
return bin_path
def _write_raw_frames(frames: np.ndarray, path: str) -> tuple[int, int]:
"""Write uint8 RGB frames to ``path`` as raw rgb24 bytes. Returns (h, w)."""
if frames.ndim != 4 or frames.shape[-1] != 3:
raise ValueError(
f"frames must be [T, H, W, 3] uint8, got {frames.shape}"
)
if frames.dtype != np.uint8:
frames = frames.astype(np.uint8)
with open(path, "wb") as f:
f.write(frames.tobytes())
_, h, w, _ = frames.shape
return h, w
def _write_wav(audio: np.ndarray, sample_rate: int, path: str) -> None:
"""Write a float32 mono audio array to a 16-bit PCM WAV at ``path``."""
from scipy.io import wavfile # type: ignore[import-not-found]
audio = np.asarray(audio, dtype=np.float32).reshape(-1)
int16 = np.clip(audio * 32767.0, -32768, 32767).astype(np.int16)
wavfile.write(path, sample_rate, int16)
def frames_to_mp4_loop(frames: np.ndarray, fps: int) -> bytes:
"""Encode ``frames`` to a silent MP4 suitable for looping playback.
Used for the idle clip: no audio track, loopable on an HTMLMediaElement
without audible seams.
"""
if frames.size == 0:
raise ValueError("frames_to_mp4_loop: empty frames")
ffmpeg = _ffmpeg_bin()
with tempfile.TemporaryDirectory() as td:
raw_path = os.path.join(td, "frames.raw")
out_path = os.path.join(td, "out.mp4")
h, w = _write_raw_frames(frames, raw_path)
cmd = [
ffmpeg, "-y",
"-f", "rawvideo",
"-pix_fmt", "rgb24",
"-s", f"{w}x{h}",
"-r", str(fps),
"-i", raw_path,
"-an",
"-c:v", "libx264",
"-preset", "veryfast",
"-pix_fmt", "yuv420p",
"-movflags", "+faststart",
out_path,
]
log.debug("muxer idle clip: %s", " ".join(cmd))
_run_ffmpeg(cmd)
with open(out_path, "rb") as f:
return f.read()
def frames_and_audio_to_mp4(
frames: np.ndarray,
audio: np.ndarray,
sample_rate: int,
fps: int,
) -> bytes:
"""Encode ``frames`` + ``audio`` to an MP4 with H.264 video + AAC audio.
Used for per-turn speaking clips.
"""
if frames.size == 0:
raise ValueError("frames_and_audio_to_mp4: empty frames")
if audio.size == 0:
raise ValueError("frames_and_audio_to_mp4: empty audio")
ffmpeg = _ffmpeg_bin()
with tempfile.TemporaryDirectory() as td:
raw_path = os.path.join(td, "frames.raw")
wav_path = os.path.join(td, "audio.wav")
out_path = os.path.join(td, "out.mp4")
h, w = _write_raw_frames(frames, raw_path)
_write_wav(audio, sample_rate, wav_path)
cmd = [
ffmpeg, "-y",
"-f", "rawvideo",
"-pix_fmt", "rgb24",
"-s", f"{w}x{h}",
"-r", str(fps),
"-i", raw_path,
"-i", wav_path,
"-c:v", "libx264",
"-preset", "veryfast",
"-pix_fmt", "yuv420p",
"-c:a", "aac",
"-b:a", "128k",
"-shortest",
"-movflags", "+faststart",
out_path,
]
log.debug("muxer speaking clip: %s", " ".join(cmd))
_run_ffmpeg(cmd)
with open(out_path, "rb") as f:
return f.read()
def _run_ffmpeg(cmd: list[str]) -> None:
try:
proc = subprocess.run(
cmd,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
except subprocess.CalledProcessError as e:
log.error("ffmpeg failed (exit %d): %s", e.returncode, e.stderr.decode(errors="replace"))
raise
if proc.returncode != 0: # pragma: no cover
raise RuntimeError(f"ffmpeg returned {proc.returncode}")
+643
View File
@@ -0,0 +1,643 @@
"""Wan2.2-TI2V-5B-Turbo (dense) image-to-video wrapper via LightX2V.
This wrapper targets LightX2V's actual Python entry points (verified against
the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main):
from lightx2v.utils.set_config import set_config
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
from lightx2v.infer import init_runner
args = argparse.Namespace(model_cls=..., task="i2v", model_path=..., config_json=..., ...)
config = set_config(args)
input_info = init_empty_input_info(args.task, args.support_tasks)
runner = init_runner(config) # loads all weights — done ONCE
update_input_info_from_dict(input_info, {"seed": ..., "prompt": ..., "image_path": ..., "save_result_path": ...})
runner.run_pipeline(input_info) # per-turn; MP4 written to save_result_path
# LoRA hot-swap:
runner.switch_lora(lora_path, strength) # swap in
runner.switch_lora("", 0.0) # remove
Model weights are loaded once at construction and held resident across turns
so reflective mode doesn't re-pay the load cost each reply.
Two HuggingFace repos are consumed on first run (cached under HF_HOME):
- Wan-AI/Wan2.2-TI2V-5B — T5 encoder, VAE, tokenizer/config only.
The bf16 DIT shards are SKIPPED via
ignore_patterns — replaced by the GGUF
checkpoint from dit_repo.
- dit_repo (configurable) — single dense GGUF DIT checkpoint, e.g.
hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF.
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import random
import tempfile
from typing import TYPE_CHECKING
import numpy as np
if TYPE_CHECKING:
from server.video import LoRASpec
log = logging.getLogger(__name__)
# --- GGUF filename for the dense 5B Turbo repo ------------------------------
# hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF ships flat: Wan2_2-TI2V-5B-Turbo-{quant}.gguf
GGUF_TURBO_5B_FILE = "Wan2_2-TI2V-5B-Turbo-{quant}.gguf"
# --- fp8 T5 encoder (lightx2v/Encoders repo) --------------------------------
T5_FP8_REPO = "lightx2v/Encoders"
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
# The Wan-AI base repo ships bf16 DIT weight shards alongside the T5/VAE/
# tokenizer support files. We only need the latter — the GGUF from dit_repo
# replaces the DIT weights entirely. Keep config.json / tokenizer files.
BASE_REPO_IGNORE_PATTERNS = [
"*.pt",
"diffusion_pytorch_model*.safetensors",
"assets/*",
"examples/*",
"nohup.out",
"*.md",
]
def _cast_all_fp32_tensors(obj, visited=None, depth=0) -> int:
"""Recursively find fp32 tensors reachable from ``obj`` and cast to fp16.
The dense ``wan2.2`` DIT isn't a standard ``nn.Module`` — some fp32
tensors (conv3d bias etc.) live outside ``pre_weight``/``post_weight``
and are missed by the structured sweep. This generic traversal catches
them. Bounded depth + visited-set to avoid cycles.
"""
import torch
if visited is None:
visited = set()
obj_id = id(obj)
if obj_id in visited or depth > 6:
return 0
visited.add(obj_id)
n = 0
for attr_name in dir(obj):
if attr_name.startswith("__"):
continue
try:
val = getattr(obj, attr_name)
except Exception:
continue
if isinstance(val, torch.Tensor) and val.dtype == torch.float32 and val.numel() > 0:
try:
setattr(obj, attr_name, val.to(torch.float16))
n += 1
except Exception:
pass
elif hasattr(val, "__dict__") and not callable(val):
n += _cast_all_fp32_tensors(val, visited, depth + 1)
return n
def _patch_fp8_scaled_mm_for_blackwell() -> None:
"""Replace sgl_kernel.fp8_scaled_mm with torch._scaled_mm on Blackwell.
sgl_kernel's CUTLASS-based fp8 GEMM doesn't ship SM120 kernels yet.
PyTorch 2.8+'s native ``_scaled_mm`` works on all architectures
including Blackwell. This patch is idempotent.
"""
try:
import sgl_kernel # type: ignore[import-not-found]
except ImportError:
return # no sgl_kernel → fp8 T5 not in use
if getattr(sgl_kernel, "_fp8_patched_for_blackwell", False):
return
import torch
if not torch.cuda.is_available():
return
cap = torch.cuda.get_device_capability()
if cap[0] < 12:
return # only patch on Blackwell+
_orig = sgl_kernel.fp8_scaled_mm
def _torch_fp8_scaled_mm(
a: torch.Tensor,
b: torch.Tensor,
a_scale: torch.Tensor,
b_scale: torch.Tensor,
out_dtype: torch.dtype,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# torch._scaled_mm expects (M,K) @ (N,K).t() with:
# scale_a: scalar or (M,1)
# scale_b: scalar or (1,N)
# sgl_kernel provides scale_b as (N,1) — transpose it.
if b_scale.dim() == 2 and b_scale.shape[1] == 1:
b_scale = b_scale.t()
# _scaled_mm requires B to be column-major (stride(0)==1).
bt = b.t().contiguous().t()
out = torch._scaled_mm(
a, bt,
scale_a=a_scale, scale_b=b_scale,
out_dtype=out_dtype, bias=bias,
)
return out
sgl_kernel.fp8_scaled_mm = _torch_fp8_scaled_mm
sgl_kernel._fp8_patched_for_blackwell = True
log.info("Patched sgl_kernel.fp8_scaled_mm → torch._scaled_mm for Blackwell (SM%d%d).", *cap)
class Wan22Pipeline:
"""Wrapper around LightX2V's dense Wan2.2-TI2V-5B-Turbo runner.
The 5B Turbo repo ships a single dense DIT checkpoint (not MoE) as GGUF.
``dit_quant_scheme`` must be a GGUF variant (``gguf-Q8_0`` default,
``gguf-Q4_K_M`` for lower VRAM); no fp8 5B Turbo weights exist.
Constructor downloads (if needed) both HF repos, writes a runtime JSON
config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``.
``generate_i2v`` runs one inference turn against the already-loaded runner.
"""
def __init__(
self,
base_repo: str,
dit_repo: str,
config_json: str,
model_cls: str = "wan2.2",
resolution: int = 480,
fps: int = 16,
dit_quant_scheme: str = "gguf-Q8_0",
t5_quantized: bool = True,
):
self.base_repo = base_repo
self.dit_repo = dit_repo
self.config_json_template = config_json
self.model_cls = model_cls
self.resolution = resolution
self.fps = fps
self.dit_quant_scheme = dit_quant_scheme
self.t5_quantized = t5_quantized
self._applied_loras: list[LoRASpec] = []
self._is_gguf = dit_quant_scheme.startswith("gguf-")
if not self._is_gguf:
raise ValueError(
f"dit_quant_scheme must be a GGUF variant for dense 5B Turbo "
f"(got {dit_quant_scheme!r}); no fp8 5B Turbo weights exist."
)
# 1. Resolve / download base repo (T5/VAE/config) and DIT ckpt.
self._model_root = self._ensure_base_repo(base_repo)
self._dit_ckpt = self._ensure_dit_checkpoint(
dit_repo, dit_quant_scheme,
)
self._t5_fp8_ckpt = (
self._ensure_t5_fp8() if t5_quantized else None
)
# 2. Materialize a runtime JSON config with absolute ckpt paths.
self._runtime_json_path = self._build_runtime_config()
# 3. Build the argparse-like namespace LightX2V.set_config() expects.
args = self._build_args(
model_cls=model_cls,
model_path=self._model_root,
config_json=self._runtime_json_path,
)
# 4. Import LightX2V (scoped here so ``import server.video_models.wan22``
# never pulls in lightx2v — tests can import this module on CPU).
from lightx2v.utils.set_config import set_config # type: ignore[import-not-found]
from lightx2v.utils.input_info import init_empty_input_info # type: ignore[import-not-found]
from lightx2v.infer import init_runner # type: ignore[import-not-found]
_patch_fp8_scaled_mm_for_blackwell()
# 5. Load all models under default DTYPE=BF16 so T5 (which is
# hardcoded to bf16 weights) initialises its offload buffers
# correctly. We flip to FP16 *after* init_runner completes.
log.info("LightX2V set_config (model_cls=%s, model_path=%s)",
model_cls, self._model_root)
self._config = set_config(args)
self._input_info_template = init_empty_input_info(
args.task, args.support_tasks
)
log.info("LightX2V init_runner — loading weights (this takes a while)...")
self._runner = init_runner(self._config)
log.info("LightX2V runner loaded; weights resident.")
# 6. GGUF: switch global DTYPE to FP16 for inference. GGUF DIT
# dequantises to fp16, and many intermediate tensors inside the
# DIT forward pass are allocated via GET_DTYPE(). The T5 encoder
# is wrapped to temporarily restore BF16 during its forward.
if self._is_gguf:
os.environ["DTYPE"] = "FP16"
from lightx2v.utils.envs import GET_DTYPE # type: ignore[import-not-found]
GET_DTYPE.cache_clear()
log.info("Set DTYPE=FP16 for GGUF (GET_DTYPE()=%s)", GET_DTYPE())
self._patch_t5_dtype_for_gguf()
self._patch_vae_dtype_for_gguf()
self._patch_dit_fp32_weights_for_gguf()
# --- GGUF dtype compatibility patch ----------------------------------------
def _patch_t5_dtype_for_gguf(self) -> None:
"""Wrap the T5 encoder so it temporarily restores DTYPE=BF16.
The T5 encoder is hardcoded to bfloat16 weights (wan_runner.py). When
the global DTYPE is FP16 (required for GGUF DIT), the T5's CPU-offload
path breaks because intermediate tensor dtypes no longer match the bf16
weights. We wrap ``run_text_encoder`` to temporarily flip GET_DTYPE()
back to bf16, then restore fp16 before the DIT runs.
"""
import os
import types
from lightx2v.utils.envs import GET_DTYPE, GET_SENSITIVE_DTYPE # type: ignore[import-not-found]
runner = self._runner
orig_run_text_encoder = runner.run_text_encoder.__func__
def bf16_text_encoder(self_runner, *args, **kwargs):
import torch
# Flip DTYPE to BF16 so the T5 encoder works with its bf16 weights.
os.environ["DTYPE"] = "BF16"
GET_DTYPE.cache_clear()
GET_SENSITIVE_DTYPE.cache_clear()
try:
result = orig_run_text_encoder(self_runner, *args, **kwargs)
finally:
# Restore FP16 for the DIT / rest of the pipeline.
os.environ["DTYPE"] = "FP16"
GET_DTYPE.cache_clear()
GET_SENSITIVE_DTYPE.cache_clear()
# Cast bf16 T5 outputs to fp16 so they match the GGUF DIT dtype.
def _to_fp16(x):
if isinstance(x, torch.Tensor) and x.dtype == torch.bfloat16:
return x.to(torch.float16)
if isinstance(x, list):
return [_to_fp16(v) for v in x]
if isinstance(x, tuple):
return tuple(_to_fp16(v) for v in x)
if isinstance(x, dict):
return {k: _to_fp16(v) for k, v in x.items()}
return x
return _to_fp16(result)
runner.run_text_encoder = types.MethodType(bf16_text_encoder, runner)
log.info("Patched T5 encoder to use BF16 under GGUF FP16 pipeline.")
def _patch_vae_dtype_for_gguf(self) -> None:
"""Cast VAE encoder/decoder weights to fp16 to match GGUF DIT dtype.
The VAE weights load as bf16 (the default). Under GGUF the DIT runs in
fp16 and the runner casts VAE inputs via ``.to(GET_DTYPE())`` — which
under DTYPE=FP16 collides with bf16 VAE weights in Conv3d. Since the
VAE is a plain float model (not quantized), simply converting its
weights to fp16 avoids both input-vs-weight mismatches and the need
for any runtime dtype juggling.
"""
import torch
runner = self._runner
for name in ("vae_encoder", "vae_decoder"):
mod = getattr(runner, name, None)
if mod is None:
continue
inner = getattr(mod, "model", mod)
if hasattr(inner, "to"):
inner.to(dtype=torch.float16)
# The outer WanVAE wrapper also holds mean/inv_std/scale tensors
# used by encode/decode (z = z/inv_std + mean). Cast them too, or
# the first op upcasts fp16 latents back to fp32/bf16.
for attr in ("mean", "inv_std"):
t = getattr(mod, attr, None)
if isinstance(t, torch.Tensor):
setattr(mod, attr, t.to(torch.float16))
scale = getattr(mod, "scale", None)
if isinstance(scale, list):
mod.scale = [
t.to(torch.float16) if isinstance(t, torch.Tensor) else t
for t in scale
]
log.info("Cast VAE encoder/decoder weights + scale to fp16 for GGUF FP16 pipeline.")
def _patch_dit_fp32_weights_for_gguf(self) -> None:
"""Cast leftover fp32 DIT weights to fp16 (dense model).
GGUF dequantises the transformer blocks to fp16, but a handful of
non-quantised weights (notably ``patch_embedding.pin_weight``) end up
loaded as fp32. That breaks the first conv in the DIT forward pass
(fp16 input vs fp32 weight). Dense ``wan2.2`` exposes the model
directly at ``runner.model`` (no MoE wrapper). After the structured
pre/post weight sweep, we also run a recursive traversal to catch
fp32 conv3d biases etc. that live outside pre/post_weight.
"""
runner = self._runner
n_struct = self._cast_fp32_dit_weights_in_model(runner.model)
n_extra = _cast_all_fp32_tensors(runner.model)
log.info(
"Cast %d (structured) + %d (recursive) fp32 DIT tensors to fp16 for GGUF pipeline.",
n_struct, n_extra,
)
@staticmethod
def _cast_fp32_dit_weights_in_model(m) -> int:
import torch
n_cast = 0
for weights_attr in ("pre_weight", "post_weight"):
w = getattr(m, weights_attr, None)
if w is None:
continue
for sub_name in dir(w):
if sub_name.startswith("_"):
continue
try:
sub = getattr(w, sub_name)
except Exception:
continue
if sub is None:
continue
for t_name in ("weight", "bias", "pin_weight", "pin_bias"):
t = getattr(sub, t_name, None)
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
casted = t.to(torch.float16)
if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned():
try:
casted = casted.pin_memory()
except RuntimeError:
pass
setattr(sub, t_name, casted)
n_cast += 1
return n_cast
# --- Weight provisioning -------------------------------------------------
@staticmethod
def _ensure_base_repo(base_repo: str) -> str:
"""Return a local directory containing the Wan2.2 base support files.
If ``base_repo`` is already a local directory, use it as-is. Otherwise
snapshot_download the HF repo into HF_HOME, skipping the bf16 DIT
shards (they're replaced by the quantised files).
"""
if os.path.isdir(base_repo):
return base_repo
from huggingface_hub import snapshot_download
log.info("Downloading Wan2.2 base support files from %s "
"(skipping bf16 DIT shards)...", base_repo)
return snapshot_download(
repo_id=base_repo,
ignore_patterns=BASE_REPO_IGNORE_PATTERNS,
)
@staticmethod
def _ensure_dit_checkpoint(
dit_repo: str,
dit_quant_scheme: str,
) -> str:
"""Return the local path to the single dense GGUF DIT checkpoint."""
if not dit_repo:
raise ValueError("dit_repo must be a HF repo id or local directory.")
if not dit_quant_scheme.startswith("gguf-"):
raise ValueError(
f"Only GGUF quant schemes are supported for dense 5B Turbo "
f"(got {dit_quant_scheme!r})."
)
quant = dit_quant_scheme.replace("gguf-", "")
filename = GGUF_TURBO_5B_FILE.format(quant=quant)
if os.path.isdir(dit_repo):
path = os.path.join(dit_repo, filename)
if not os.path.isfile(path):
raise FileNotFoundError(
f"DIT checkpoint not found in {dit_repo}: expected {filename}"
)
return path
from huggingface_hub import hf_hub_download
log.info("Downloading %s DIT checkpoint from %s ...",
dit_quant_scheme, dit_repo)
return hf_hub_download(repo_id=dit_repo, filename=filename)
@staticmethod
def _ensure_t5_fp8() -> str:
"""Download the fp8 T5 encoder from lightx2v/Encoders (if not cached).
Returns the local path to the safetensors file (~6 GB).
"""
from huggingface_hub import hf_hub_download
log.info("Downloading fp8 T5 encoder from %s ...", T5_FP8_REPO)
return hf_hub_download(repo_id=T5_FP8_REPO, filename=T5_FP8_FILE)
def _build_runtime_config(self) -> str:
"""Load the template JSON, inject absolute ckpt paths, persist to temp."""
with open(self.config_json_template, "r", encoding="utf-8") as f:
cfg = json.load(f)
# Drop editorial comments before passing to LightX2V.
cfg.pop("_comment", None)
cfg["dit_quantized_ckpt"] = self._dit_ckpt
cfg.setdefault("fps", self.fps)
# T5 fp8 quantization.
if self._t5_fp8_ckpt:
cfg["t5_quantized"] = True
cfg["t5_quant_scheme"] = "fp8-sgl"
cfg["t5_quantized_ckpt"] = self._t5_fp8_ckpt
tmp = tempfile.NamedTemporaryFile(
prefix="wan22_dit_", suffix=".json",
mode="w", delete=False, encoding="utf-8",
)
json.dump(cfg, tmp, indent=2)
tmp.close()
log.info("Runtime LightX2V config: %s", tmp.name)
return tmp.name
@staticmethod
def _build_args(
*, model_cls: str, model_path: str, config_json: str
) -> argparse.Namespace:
"""Mirror every field from ``lightx2v.infer.main``'s argparse so
``set_config`` finds the attributes it expects. We only customize the
model/task/path fields; everything else stays at the CLI defaults.
"""
return argparse.Namespace(
seed=42,
model_cls=model_cls,
task="i2v",
support_tasks=[],
model_path=model_path,
sf_model_path=None,
config_json=config_json,
use_prompt_enhancer=False,
prompt="",
negative_prompt="",
image_path="",
last_frame_path="",
audio_path="",
image_strength="1.0",
image_frame_idx="",
src_ref_images=None,
src_video=None,
src_mask=None,
src_pose_path=None,
src_face_path=None,
src_bg_path=None,
src_mask_path=None,
pose=None,
action_path=None,
action_ckpt=None,
save_result_path=None,
return_result_tensor=False,
target_shape=[],
target_video_length=81,
aspect_ratio="",
video_path=None,
sr_ratio=2.0,
)
# --- LoRA --------------------------------------------------------------
def load_loras(self, specs: list["LoRASpec"]) -> None:
"""Apply LoRAs to the dense Wan2.2-TI2V-5B pipeline.
Dense has a single DIT (no MoE experts), so ``target`` must be
``"both"``. GGUF DIT weights don't expose a ``lora_down`` buffer,
so ``switch_lora`` would crash — we use the dynamic-apply path that
merges LoRAs during GGUF dequant.
"""
if not specs:
return
resolved: list[tuple["LoRASpec", str]] = []
for spec in specs:
if spec.target != "both":
raise ValueError(
f"Dense 5B Turbo has a single DIT; LoRA target must be "
f"'both' (got {spec.target!r})."
)
local_path = self._resolve_lora_path(spec.path)
log.info(" LoRA %s → strength=%.2f (%s)",
spec.name or spec.path, spec.weight, local_path)
resolved.append((spec, local_path))
lora_cfgs = [
{"path": local_path, "strength": spec.weight}
for spec, local_path in resolved
]
self._runner.set_config({
"lora_configs": lora_cfgs,
"lora_dynamic_apply": True,
})
self._applied_loras = list(specs)
def unload_loras(self) -> None:
"""Remove all currently applied LoRAs."""
if not self._applied_loras:
return
self._runner.set_config({
"lora_configs": None,
"lora_dynamic_apply": False,
})
self._applied_loras = []
@staticmethod
def _resolve_lora_path(path: str) -> str:
"""Resolve a LoRA path. Supports:
- Absolute/relative local paths (returned as-is if the file exists)
- ``repo_id:filename`` HuggingFace references
"""
if os.path.isfile(path):
return path
if ":" in path and not path.startswith(("/", "./")):
repo_id, filename = path.split(":", 1)
from huggingface_hub import hf_hub_download
return hf_hub_download(repo_id=repo_id, filename=filename)
return path
# --- Inference ---------------------------------------------------------
def generate_i2v(
self,
image_path: str,
prompt: str,
seconds: int,
seed: int | None = None,
negative_prompt: str = "",
) -> np.ndarray:
"""Run image-to-video inference and return decoded frames.
Returns ``np.ndarray`` shape ``[T, H, W, 3]`` dtype uint8 in RGB.
"""
if seed is None:
seed = random.randint(0, 2**31 - 1)
# Wan2.2 target_video_length is "frames including the conditioning
# frame", so N seconds → N*fps + 1.
target_frames = seconds * self.fps + 1
from lightx2v.utils.input_info import update_input_info_from_dict # type: ignore[import-not-found]
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf:
out_path = tf.name
try:
log.info("Wan2.2 generate: prompt=%r seconds=%d seed=%d%s",
prompt[:80], seconds, seed, out_path)
update_input_info_from_dict(
self._input_info_template,
{
"seed": seed,
"prompt": prompt,
"negative_prompt": negative_prompt,
"image_path": image_path,
"save_result_path": out_path,
"target_video_length": target_frames,
"return_result_tensor": False,
},
)
self._runner.run_pipeline(self._input_info_template)
return _read_mp4_to_frames(out_path)
finally:
try:
os.remove(out_path)
except OSError:
pass
# --- MP4 decoding helper ------------------------------------------------------
def _read_mp4_to_frames(path: str) -> np.ndarray:
"""Decode an MP4 into an RGB uint8 frame array ``[T, H, W, 3]``."""
try:
import imageio.v3 as iio # type: ignore[import-not-found]
frames = iio.imread(path, plugin="pyav")
arr = np.asarray(frames)
if arr.ndim == 3:
arr = arr[None, ...]
return arr.astype(np.uint8)
except Exception as e: # pragma: no cover - fallback path
log.warning("imageio decode failed (%s); falling back to cv2", e)
import cv2 # type: ignore[import-not-found]
cap = cv2.VideoCapture(path)
frames: list[np.ndarray] = []
while True:
ok, frame = cap.read()
if not ok:
break
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
cap.release()
if not frames:
raise RuntimeError(f"Failed to decode any frames from {path}")
return np.stack(frames, axis=0).astype(np.uint8)
+181 -1
View File
@@ -18,9 +18,20 @@ let pendingTextChunks = []; // [{chunkId, text}] - text waiting for its audio to
let scheduledTextTimers = []; // timer IDs for text display scheduled to match audio playback let scheduledTextTimers = []; // timer IDs for text display scheduled to match audio playback
let lastDisplayedChunkId = -1; // last chunk whose text was actually shown to the user let lastDisplayedChunkId = -1; // last chunk whose text was actually shown to the user
// --- Video mode state ---
let videoModeEnabled = false; // true when server has video engine active AND ready
let videoModeName = "off"; // "off" | "library" | "reflective"
let idleClipUrl = null; // URL string (server-served) or null
let pendingSpeakingClipMeta = null; // {chunk_id, duration_ms, text} waiting for MP4 binary
let currentSpeakingClipBlobUrl = null;
let speakingClipQueue = []; // [{blobUrl, meta}] clips waiting to play
let currentClipGeneration = 0; // incremented each clip start; guards stale onended handlers
const chatArea = document.getElementById("chat-area"); const chatArea = document.getElementById("chat-area");
const statusBadge = document.getElementById("status-badge"); const statusBadge = document.getElementById("status-badge");
const micBtn = document.getElementById("mic-btn"); const micBtn = document.getElementById("mic-btn");
const avatarVideo = document.getElementById("avatar-video");
const stageEl = document.getElementById("stage");
// --- WebSocket --- // --- WebSocket ---
@@ -44,7 +55,18 @@ function connectWS() {
ws.onmessage = (event) => { ws.onmessage = (event) => {
if (event.data instanceof ArrayBuffer) { if (event.data instanceof ArrayBuffer) {
playAudioChunk(event.data); // In video mode, the next binary frame after a "speaking_clip"
// envelope is an MP4 blob; otherwise it's a PCM audio chunk.
if (pendingSpeakingClipMeta) {
const meta = pendingSpeakingClipMeta;
pendingSpeakingClipMeta = null;
playSpeakingClip(event.data, meta);
} else if (videoModeEnabled) {
// Video mode is active but we didn't get a speaking_clip envelope
// first — ignore raw PCM so we don't double-play audio.
} else {
playAudioChunk(event.data);
}
} else { } else {
handleJSON(JSON.parse(event.data)); handleJSON(JSON.parse(event.data));
} }
@@ -59,6 +81,7 @@ function handleJSON(msg) {
case "interrupt": case "interrupt":
stopPlayback(); stopPlayback();
stopSpeakingClip();
// Finalize with interrupted marker — text already reflects only what was heard // Finalize with interrupted marker — text already reflects only what was heard
finalizeAssistantMessage(true); finalizeAssistantMessage(true);
break; break;
@@ -80,6 +103,160 @@ function handleJSON(msg) {
pendingTextChunks.push({ chunkId: msg.chunk_id, text: msg.text }); pendingTextChunks.push({ chunkId: msg.chunk_id, text: msg.text });
} }
break; break;
case "video_mode":
// Sent once on WS open. Toggles the video element + speaking-clip path.
applyVideoModeState(msg);
break;
case "speaking_clip":
// Envelope preceding an MP4 binary frame with the full turn.
pendingSpeakingClipMeta = {
chunk_id: msg.chunk_id,
duration_ms: msg.duration_ms,
text: msg.text,
};
break;
}
}
// --- Video mode ------------------------------------------------------------
function applyVideoModeState(msg) {
videoModeEnabled = !!msg.enabled && !!msg.ready;
videoModeName = msg.mode || "off";
idleClipUrl = msg.idle_clip_url || null;
refreshStage();
}
function refreshStage() {
if (videoModeEnabled && idleClipUrl) {
stageEl.classList.add("active");
if (avatarVideo.src !== location.origin + idleClipUrl) {
_returnToIdle();
}
} else {
stageEl.classList.remove("active");
}
}
function _returnToIdle() {
if (!idleClipUrl) return;
avatarVideo.onended = null;
avatarVideo.loop = false;
avatarVideo.muted = true;
avatarVideo.src = idleClipUrl;
avatarVideo.play().catch(() => {});
}
function playSpeakingClip(arrayBuffer, meta) {
const blob = new Blob([arrayBuffer], { type: "video/mp4" });
const blobUrl = URL.createObjectURL(blob);
if (currentSpeakingClipBlobUrl !== null) {
// A clip is already playing — queue this one.
speakingClipQueue.push({ blobUrl, meta });
} else {
_startSpeakingClip(blobUrl, meta);
}
}
function _startSpeakingClip(blobUrl, meta) {
const gen = ++currentClipGeneration;
if (currentSpeakingClipBlobUrl) {
URL.revokeObjectURL(currentSpeakingClipBlobUrl);
}
currentSpeakingClipBlobUrl = blobUrl;
avatarVideo.loop = false;
avatarVideo.muted = false;
avatarVideo.src = blobUrl;
if (meta && meta.text) {
appendAssistantText(meta.text);
}
isPlaying = true;
avatarVideo.onended = () => {
if (currentClipGeneration !== gen) return; // stale handler from a replaced clip
URL.revokeObjectURL(currentSpeakingClipBlobUrl);
currentSpeakingClipBlobUrl = null;
const next = speakingClipQueue.shift();
if (next) {
_startSpeakingClip(next.blobUrl, next.meta);
} else {
isPlaying = false;
finalizeAssistantMessage(false);
_returnToIdle();
}
};
avatarVideo.play().catch((e) => {
console.error("speaking clip play failed:", e);
});
}
function stopSpeakingClip() {
// Discard any queued clips.
for (const { blobUrl } of speakingClipQueue) {
URL.revokeObjectURL(blobUrl);
}
speakingClipQueue = [];
currentClipGeneration++; // invalidate any in-flight onended handlers
if (!currentSpeakingClipBlobUrl) return;
try { avatarVideo.pause(); } catch (_) {}
avatarVideo.onended = null;
URL.revokeObjectURL(currentSpeakingClipBlobUrl);
currentSpeakingClipBlobUrl = null;
isPlaying = false;
_returnToIdle();
}
async function uploadAvatar() {
const fileInput = document.getElementById("avatar-file");
const status = document.getElementById("avatar-status");
if (!fileInput.files || !fileInput.files[0]) {
status.textContent = "Pick an image first.";
return;
}
status.textContent = "Uploading and rendering idle clip (this takes a while)...";
const fd = new FormData();
fd.append("image", fileInput.files[0]);
try {
const resp = await fetch("/api/set-avatar", { method: "POST", body: fd });
if (!resp.ok) throw new Error(await resp.text());
const data = await resp.json();
idleClipUrl = data.idle_clip_url + "?t=" + Date.now(); // cache-bust
videoModeEnabled = true;
videoModeName = data.mode || videoModeName;
refreshStage();
status.textContent = "Avatar ready (" + data.mode + ")";
} catch (err) {
console.error(err);
status.textContent = "Failed: " + err.message;
}
}
async function applyVideoMode() {
const sel = document.getElementById("video-mode-select");
const status = document.getElementById("avatar-status");
const fd = new FormData();
fd.append("mode", sel.value);
try {
const resp = await fetch("/api/set-video-mode", { method: "POST", body: fd });
if (!resp.ok) throw new Error(await resp.text());
const data = await resp.json();
videoModeName = data.mode;
if (data.mode === "off") {
videoModeEnabled = false;
stageEl.classList.remove("active");
}
status.textContent = "Mode: " + data.mode + (data.note ? " — " + data.note : "");
} catch (err) {
status.textContent = "Failed: " + err.message;
} }
} }
@@ -275,6 +452,7 @@ async function startMic() {
if (bargeInCount >= BARGE_IN_FRAMES) { if (bargeInCount >= BARGE_IN_FRAMES) {
// User is speaking over the assistant - interrupt // User is speaking over the assistant - interrupt
stopPlayback(); stopPlayback();
stopSpeakingClip();
const msg = { type: "interrupt" }; const msg = { type: "interrupt" };
if (lastDisplayedChunkId >= 0) { if (lastDisplayedChunkId >= 0) {
msg.last_chunk_id = lastDisplayedChunkId; msg.last_chunk_id = lastDisplayedChunkId;
@@ -353,3 +531,5 @@ async function applyVoice() {
// Expose to HTML onclick // Expose to HTML onclick
window.toggleMic = toggleMic; window.toggleMic = toggleMic;
window.applyVoice = applyVoice; window.applyVoice = applyVoice;
window.uploadAvatar = uploadAvatar;
window.applyVideoMode = applyVideoMode;
+32
View File
@@ -12,6 +12,17 @@
<span id="status-badge">Disconnected</span> <span id="status-badge">Disconnected</span>
</header> </header>
<div id="stage">
<video
id="avatar-video"
autoplay
muted
loop
playsinline
preload="auto"
></video>
</div>
<div id="chat-area"></div> <div id="chat-area"></div>
<details id="voice-panel"> <details id="voice-panel">
@@ -40,6 +51,27 @@
</div> </div>
</details> </details>
<details id="avatar-panel">
<summary>Avatar / Video</summary>
<div class="panel-content">
<label>
Avatar image
<input type="file" id="avatar-file" accept="image/*" />
</label>
<button id="upload-avatar-btn" onclick="uploadAvatar()">Upload</button>
<label>
Mode
<select id="video-mode-select">
<option value="off">Off</option>
<option value="library">Library (pre-baked)</option>
<option value="reflective" selected>Reflective (per-turn)</option>
</select>
</label>
<button id="apply-mode-btn" onclick="applyVideoMode()">Apply mode</button>
<span id="avatar-status"></span>
</div>
</details>
<div id="controls"> <div id="controls">
<button id="mic-btn" onclick="toggleMic()">&#x1F3A4;</button> <button id="mic-btn" onclick="toggleMic()">&#x1F3A4;</button>
</div> </div>
+39 -4
View File
@@ -52,6 +52,28 @@ header h1 {
color: #a78bfa; color: #a78bfa;
} }
#stage {
display: none; /* toggled on when video mode is enabled */
align-items: center;
justify-content: center;
padding: 16px 24px 0;
background: #0a0a0a;
}
#stage.active {
display: flex;
}
#avatar-video {
width: 100%;
max-width: 480px;
aspect-ratio: 16 / 9;
background: #000;
border-radius: 12px;
object-fit: cover;
box-shadow: 0 8px 24px rgba(0, 0, 0, 0.4);
}
#chat-area { #chat-area {
flex: 1; flex: 1;
overflow-y: auto; overflow-y: auto;
@@ -130,21 +152,34 @@ header h1 {
50% { box-shadow: 0 0 0 12px rgba(239, 68, 68, 0); } 50% { box-shadow: 0 0 0 12px rgba(239, 68, 68, 0); }
} }
/* Voice clone panel */ /* Voice + avatar panels */
#voice-panel { #voice-panel,
#avatar-panel {
padding: 12px 24px; padding: 12px 24px;
border-top: 1px solid #222; border-top: 1px solid #222;
background: #0a0a0a; background: #0a0a0a;
} }
#voice-panel summary { #voice-panel select,
#avatar-panel select {
background: #1a1a1a;
border: 1px solid #333;
border-radius: 6px;
padding: 6px 10px;
color: #e0e0e0;
font-size: 13px;
}
#voice-panel summary,
#avatar-panel summary {
cursor: pointer; cursor: pointer;
font-size: 13px; font-size: 13px;
color: #888; color: #888;
user-select: none; user-select: none;
} }
#voice-panel .panel-content { #voice-panel .panel-content,
#avatar-panel .panel-content {
margin-top: 12px; margin-top: 12px;
display: flex; display: flex;
gap: 12px; gap: 12px;
+67
View File
@@ -0,0 +1,67 @@
# Voice-chat tests
Two tiers.
## Unit tests — fast, GPU-free
```
python -m pytest tests/unit -v
```
These exercise pure logic: config parsing, prompt derivation, LoRA spec
parsing, frame-length fitting, library round-robin selection, the
pipeline's video branch, and ffmpeg mux argument shaping. They do not
touch CUDA, Wan2.2, MuseTalk, or a real ffmpeg binary. Safe to run on
Windows, outside Docker, without any models installed.
Current unit files:
- `test_video_config.py``VideoConfig.from_dict` round-trip, LoRA target validation
- `test_video_engine_logic.py` — prompt derivation, library cursor, frame fitting
- `test_pipeline_video_branch.py` — pipeline takes the video path iff engine is ready
- `test_musetalk_fit_frames.py` — frame-length adjustment to match audio duration
- `test_muxer_ffmpeg.py` — ffmpeg command construction
## Component tests — slow, GPU-required, run inside Docker
Each script in `tests/component/` exercises one subsystem end-to-end
against the real models. The numbered prefix reflects the implementation
phase each script gates, and also serves as a reasonable run order when
debugging a fresh environment:
| Script | Phase | Tests |
|---|---|---|
| `test_01_video_skeleton.py` | 1 | VideoEngine loads, config gate respected |
| `test_02_wan22_loras.py` | 2 | Wan2.2 pipeline loads, LoRA stack applies |
| `test_03_idle_clip.py` | 3 | `set_avatar` → idle MP4, written to disk for eyeballing |
| `test_04_library_prebake.py` | 4 | library mode pre-bakes N base clips |
| `test_05_musetalk_lipsync.py` | 5 | MuseTalk lip-sync on library frames + ffmpeg mux |
| `test_06_reflective.py` | 6 | reflective mode: fresh Wan2.2 per reply |
| `test_07_endpoints.py` | 7 | HTTP endpoints return sane responses |
| `test_08_lora_reload.py` | 8 | `/api/reload-loras` swaps LoRAs live |
| `test_09_gguf_generate.py` | 9 | GGUF-quantised DIT end-to-end I2V generation |
| `test_10_t5_encode.py` | 10 | T5 encoder (optionally fp8-quantised) on CUDA |
| `test_11_image_encode.py` | 11 | Avatar image → VAE latent path |
| `test_12_dit_single_step.py` | 12 | Single DIT step on the loaded expert(s) |
| `test_13_vae_decode.py` | 13 | VAE decode back to RGB frames |
Tests 09-13 are focused on the GGUF + Blackwell (SM120) path and are how
new quant schemes / attention backends get validated before wiring them
into the full pipeline.
Run one:
```
# Inside the container:
docker compose exec voice-chat python -m tests.component.test_03_idle_clip
```
Run all (slow, ~20+ minutes on a 5090):
```
docker compose exec voice-chat python -m tests.component.run_all
```
Each component script writes its artifacts (MP4s, PNG frame dumps, logs)
to `tests/component/_out/` so you can visually inspect results. That
directory is gitignored.
View File
View File
+72
View File
@@ -0,0 +1,72 @@
"""Shared utilities for component tests.
Component tests run inside the Docker image against real GPU models. They
write their output artefacts (MP4s, PNGs, logs) to ``_out/`` so you can
visually inspect results.
"""
from __future__ import annotations
import logging
import os
import sys
import numpy as np
OUT_DIR = os.path.join(os.path.dirname(__file__), "_out")
os.makedirs(OUT_DIR, exist_ok=True)
# A tiny 256x256 portrait PNG lives next to the component tests so tests
# don't need a user-supplied file. If it's missing we synthesise one on
# the fly.
SAMPLE_AVATAR = os.path.join(os.path.dirname(__file__), "sample_avatar.png")
def get_logger(name: str) -> logging.Logger:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(name)s %(levelname)s %(message)s",
stream=sys.stdout,
)
return logging.getLogger(name)
def ensure_sample_avatar() -> str:
"""Guarantee a usable avatar image exists. Returns its path."""
if os.path.isfile(SAMPLE_AVATAR):
return SAMPLE_AVATAR
# Synthesise a simple gradient PNG as a last resort (won't look like a
# person but is valid input for Wan2.2 so the pipeline doesn't fail).
try:
from PIL import Image # type: ignore[import-not-found]
except ImportError:
import imageio.v3 as iio # type: ignore[import-not-found]
arr = np.zeros((256, 256, 3), dtype=np.uint8)
for y in range(256):
arr[y, :, 0] = y
arr[y, :, 1] = 255 - y
arr[y, :, 2] = 128
iio.imwrite(SAMPLE_AVATAR, arr)
return SAMPLE_AVATAR
arr = np.zeros((256, 256, 3), dtype=np.uint8)
for y in range(256):
arr[y, :, 0] = y
arr[y, :, 1] = 255 - y
arr[y, :, 2] = 128
Image.fromarray(arr).save(SAMPLE_AVATAR)
return SAMPLE_AVATAR
def write_bytes(name: str, data: bytes) -> str:
"""Write an artefact to _out/<name> and return the full path."""
path = os.path.join(OUT_DIR, name)
with open(path, "wb") as f:
f.write(data)
return path
def synth_tone(seconds: float, sample_rate: int = 24000, freq: float = 220.0) -> np.ndarray:
"""Return a float32 sine tone usable as stand-in TTS audio."""
t = np.arange(int(seconds * sample_rate), dtype=np.float32) / sample_rate
return (0.2 * np.sin(2 * np.pi * freq * t)).astype(np.float32)
+46
View File
@@ -0,0 +1,46 @@
"""Run every component test in order. Stops at first failure.
docker compose exec voice-chat python -m tests.component.run_all
"""
import importlib
import sys
import traceback
SCRIPTS = [
"tests.component.test_01_video_skeleton",
"tests.component.test_02_wan22_loras",
"tests.component.test_03_idle_clip",
"tests.component.test_04_library_prebake",
"tests.component.test_05_musetalk_lipsync",
"tests.component.test_06_reflective",
"tests.component.test_07_endpoints",
"tests.component.test_08_lora_reload",
]
def main() -> int:
failed: list[str] = []
for name in SCRIPTS:
print(f"\n{'=' * 70}\nRUNNING: {name}\n{'=' * 70}")
try:
mod = importlib.import_module(name)
mod.run()
except SystemExit as e:
if e.code:
print(f"FAILED: {name} (exit {e.code})")
failed.append(name)
break # hard-stop on failure
except Exception:
traceback.print_exc()
failed.append(name)
break
if failed:
print(f"\n{len(failed)} failed: {failed}")
return 1
print("\nALL COMPONENT TESTS PASSED")
return 0
if __name__ == "__main__":
sys.exit(main())
Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

+69
View File
@@ -0,0 +1,69 @@
"""Phase 1 component test: VideoEngine skeleton + config gate.
Verifies:
- ``ModelManager`` can be imported and constructed.
- When ``config.video.enabled=false``, ``_load_video`` skips and leaves
``video_engine=None`` (existing voice path unaffected).
- When ``config.video.enabled=true``, a ``VideoEngine`` instance is created
and ``is_ready()`` returns False (no models loaded yet).
Does NOT load Wan2.2 or MuseTalk — this test is safe to run on any machine
with the python deps installed (no GPU needed).
Run inside Docker:
docker compose exec voice-chat python -m tests.component.test_01_video_skeleton
"""
from __future__ import annotations
import sys
from server.models import ModelManager
from server.video import VideoConfig, VideoEngine
from tests.component._common import get_logger
log = get_logger("test_01")
def run():
# --- disabled path ---
log.info("[case 1] config.video.enabled=False → engine skipped")
mgr = ModelManager()
# Monkey-patch the config module to simulate disabled
import server.config as cfgmod
original = cfgmod.config
cfgmod.config = {"video": {"enabled": False}, **{k: v for k, v in original.items() if k != "video"}}
try:
mgr._load_video()
assert mgr.video_engine is None, "video_engine should be None when disabled"
log.info(" PASS: video_engine is None")
finally:
cfgmod.config = original
# --- enabled path (no models loaded) ---
log.info("[case 2] config.video.enabled=True → engine created, not ready")
mgr2 = ModelManager()
cfgmod.config = {
**original,
"video": {"enabled": True, "mode": "reflective", "loras": []},
}
try:
mgr2._load_video()
assert mgr2.video_engine is not None, "video_engine should be created"
assert isinstance(mgr2.video_engine, VideoEngine)
assert mgr2.video_engine.is_ready() is False
log.info(" PASS: engine=%s, ready=%s",
type(mgr2.video_engine).__name__, mgr2.video_engine.is_ready())
finally:
cfgmod.config = original
log.info("ALL PASSED")
if __name__ == "__main__":
try:
run()
sys.exit(0)
except AssertionError as e:
log.error("FAILED: %s", e)
sys.exit(1)
+108
View File
@@ -0,0 +1,108 @@
"""Phase 2 component test: dense Wan2.2-TI2V-5B-Turbo pipeline + LoRA stacking.
Verifies:
- ``Wan22Pipeline`` loads successfully (exercises the real LightX2V
set_config -> init_runner flow).
- ``load_loras`` / ``unload_loras`` survive with any user LoRAs at
``/cache/loras/*.safetensors`` (target='both', dense single DIT).
Supports any GGUF quant published in hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF.
Set ``DIT_QUANT`` to switch (default: ``gguf-Q8_0``).
DIT_QUANT=gguf-Q4_K_M docker compose exec voice-chat \
python -m tests.component.test_02_wan22_loras
Requires GPU and a first-run download of the base repo + GGUF DIT.
If LightX2V isn't installed the test is skipped.
Run:
docker compose exec voice-chat python -m tests.component.test_02_wan22_loras
"""
from __future__ import annotations
import glob
import os
import sys
from tests.component._common import get_logger
log = get_logger("test_02")
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
def run():
try:
from server.video_models.wan22 import Wan22Pipeline
except ImportError as e:
log.error("Wan22Pipeline import failed: %s", e)
log.warning("SKIP: phase 2 deps not installed")
sys.exit(0)
from server.video import LoRASpec
log.info("[case 1] Instantiate Wan22Pipeline "
"(quant=%s, dit_repo=%s)...", DIT_QUANT, DIT_REPO)
try:
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-TI2V-5B",
dit_repo=DIT_REPO,
config_json=CONFIG_JSON,
model_cls="wan2.2",
resolution=480,
fps=16,
dit_quant_scheme=DIT_QUANT,
t5_quantized=True,
)
except Exception as e:
log.error("FAIL: Wan22Pipeline construction raised: %s", e)
log.error("Check: LightX2V install, HF cache at /cache/huggingface, "
"VRAM headroom, and that %s exists inside the container.",
CONFIG_JSON)
sys.exit(2)
log.info(" PASS: pipeline constructed")
# --- LoRAs ---
log.info("[case 2] load_loras with empty list -> no-op")
pipe.load_loras([])
log.info(" PASS")
lora_files = sorted(glob.glob("/cache/loras/*.safetensors"))
if not lora_files:
log.warning("SKIP: no LoRA files found in /cache/loras/")
log.info("ALL PASSED (partial — LoRA cases skipped)")
return
lora_path = lora_files[0]
log.info("[case 3] load_loras with one 5B-compatible LoRA (%s)", lora_path)
specs = [
LoRASpec(
path=lora_path,
weight=1.0,
target="both",
name=os.path.basename(lora_path),
),
]
try:
pipe.load_loras(specs)
except Exception as e:
log.error("FAIL: load_loras raised: %s", e)
log.error("Check: LoRA checkpoint shape matches dense 5B DIT.")
sys.exit(3)
log.info(" PASS: LoRAs applied")
log.info("[case 4] unload_loras")
try:
pipe.unload_loras()
except Exception as e:
log.error("FAIL: unload_loras raised: %s", e)
sys.exit(4)
log.info(" PASS")
log.info("ALL PASSED")
if __name__ == "__main__":
run()
+66
View File
@@ -0,0 +1,66 @@
"""Phase 3 component test: avatar upload → idle clip generation.
Verifies:
- ``VideoEngine.load_models()`` + ``set_avatar(image)`` produces a non-empty
idle MP4 blob.
- The blob decodes as a valid MP4 (ftyp header).
Writes the idle clip to ``tests/component/_out/phase3_idle.mp4`` so you can
inspect it visually.
Run:
docker compose exec voice-chat python -m tests.component.test_03_idle_clip
"""
from __future__ import annotations
import sys
from server.video import VideoConfig, VideoEngine
from tests.component._common import ensure_sample_avatar, get_logger, write_bytes
log = get_logger("test_03")
def run():
avatar_path = ensure_sample_avatar()
log.info("Using avatar: %s", avatar_path)
cfg = VideoConfig.from_dict(
{
"enabled": True,
"mode": "reflective", # reflective skips the library prebake
"resolution": 480,
"fps": 16,
"library": {"base_clip_count": 0, "base_clip_seconds": 3},
}
)
engine = VideoEngine(cfg)
log.info("Loading models (Wan2.2 + MuseTalk)...")
try:
engine.load_models()
except Exception as e:
log.error("FAIL: load_models raised: %s", e)
sys.exit(2)
log.info("Models loaded.")
log.info("Generating idle clip for avatar...")
try:
engine.set_avatar(avatar_path)
except Exception as e:
log.error("FAIL: set_avatar raised: %s", e)
sys.exit(3)
idle = engine.get_idle_clip()
assert idle is not None and len(idle) > 0, "idle clip is empty"
assert idle[4:8] == b"ftyp", "idle clip is not a valid MP4"
out_path = write_bytes("phase3_idle.mp4", idle)
log.info("PASS: idle clip written to %s (%d bytes)", out_path, len(idle))
assert engine.is_ready() is True
log.info(" engine.is_ready() = True (avatar + models present)")
if __name__ == "__main__":
run()
@@ -0,0 +1,55 @@
"""Phase 4 component test: library mode pre-bake of speaking-base clips.
Verifies:
- ``set_avatar`` under ``mode=library`` populates ``speaking_base_frames``
with ``library_base_clip_count`` entries.
- Each cached entry has shape ``[T, H, W, 3]`` uint8.
Run:
docker compose exec voice-chat python -m tests.component.test_04_library_prebake
"""
from __future__ import annotations
import sys
import numpy as np
from server.video import VideoConfig, VideoEngine
from tests.component._common import ensure_sample_avatar, get_logger
log = get_logger("test_04")
def run():
avatar_path = ensure_sample_avatar()
cfg = VideoConfig.from_dict(
{
"enabled": True,
"mode": "library",
"resolution": 480,
"fps": 16,
"library": {"base_clip_count": 2, "base_clip_seconds": 3},
}
)
engine = VideoEngine(cfg)
log.info("Loading models...")
engine.load_models()
log.info("Pre-baking 2 library clips...")
engine.set_avatar(avatar_path)
assert len(engine.speaking_base_frames) == 2, \
f"expected 2 base clips, got {len(engine.speaking_base_frames)}"
for i, frames in enumerate(engine.speaking_base_frames):
assert isinstance(frames, np.ndarray)
assert frames.ndim == 4 and frames.shape[-1] == 3
assert frames.dtype == np.uint8
log.info(" clip %d: shape=%s", i, frames.shape)
assert engine.get_idle_clip() is not None
log.info("PASS: library pre-bake complete")
if __name__ == "__main__":
run()
@@ -0,0 +1,57 @@
"""Phase 5 component test: MuseTalk lip-sync + ffmpeg mux.
Verifies the full library-mode per-turn path:
- Pre-bake a library clip.
- Generate a stand-in TTS waveform (sine tone).
- Call ``VideoEngine.generate_speaking_clip`` and get a valid MP4 back.
Writes the resulting clip to ``tests/component/_out/phase5_speaking.mp4``.
Run:
docker compose exec voice-chat python -m tests.component.test_05_musetalk_lipsync
"""
from __future__ import annotations
import sys
from server.video import VideoConfig, VideoEngine
from tests.component._common import (
ensure_sample_avatar,
get_logger,
synth_tone,
write_bytes,
)
log = get_logger("test_05")
def run():
avatar_path = ensure_sample_avatar()
cfg = VideoConfig.from_dict(
{
"enabled": True,
"mode": "library",
"resolution": 480,
"fps": 16,
"library": {"base_clip_count": 1, "base_clip_seconds": 4},
}
)
engine = VideoEngine(cfg)
engine.load_models()
engine.set_avatar(avatar_path)
audio = synth_tone(seconds=3.0, sample_rate=24000, freq=220.0)
log.info("Generating library-mode speaking clip (3s audio)...")
mp4 = engine.generate_speaking_clip(
audio_f32=audio,
sample_rate=24000,
reply_text="Hello, this is a lip-sync test.",
)
assert isinstance(mp4, bytes) and len(mp4) > 0
assert mp4[4:8] == b"ftyp"
out = write_bytes("phase5_speaking.mp4", mp4)
log.info("PASS: speaking clip written to %s (%d bytes)", out, len(mp4))
if __name__ == "__main__":
run()
+69
View File
@@ -0,0 +1,69 @@
"""Phase 6 component test: reflective mode (fresh Wan2.2 clip per turn).
Verifies that with ``mode=reflective``, ``generate_speaking_clip`` runs
the Wan2.2 image-to-video pipeline once per call (so the base frames
differ from turn to turn) and the prompt is derived from the reply text.
Run:
docker compose exec voice-chat python -m tests.component.test_06_reflective
"""
from __future__ import annotations
import numpy as np
from server.video import VideoConfig, VideoEngine
from tests.component._common import (
ensure_sample_avatar,
get_logger,
synth_tone,
write_bytes,
)
log = get_logger("test_06")
def run():
avatar_path = ensure_sample_avatar()
cfg = VideoConfig.from_dict(
{
"enabled": True,
"mode": "reflective",
"resolution": 480,
"fps": 16,
"reflective": {"clip_seconds": 3},
}
)
engine = VideoEngine(cfg)
engine.load_models()
engine.set_avatar(avatar_path)
# Verify prompt derivation includes the reply hint
prompt = engine._derive_prompt(
"The assistant walks along a sunny beach watching seagulls."
)
log.info("derived prompt: %s", prompt)
assert "beach" in prompt, "reply_hint did not survive template interpolation"
audio = synth_tone(seconds=3.0)
log.info("Generating reflective speaking clip #1...")
mp4_a = engine.generate_speaking_clip(
audio, 24000, "The assistant walks along a sunny beach watching seagulls."
)
write_bytes("phase6_reflective_beach.mp4", mp4_a)
log.info("Generating reflective speaking clip #2...")
mp4_b = engine.generate_speaking_clip(
audio, 24000, "Now the character stands in a snow-covered forest at dusk."
)
write_bytes("phase6_reflective_snow.mp4", mp4_b)
# Not a strict assertion (same prompt could yield identical bytes if seeded),
# but with different prompts and random seeds the blobs should differ.
if mp4_a != mp4_b:
log.info("PASS: reflective clips differ as expected")
else:
log.warning("clips are byte-identical — check that seeds are random")
if __name__ == "__main__":
run()
+114
View File
@@ -0,0 +1,114 @@
"""Phase 7 component test: HTTP endpoints (/api/set-avatar, /api/idle-clip,
/api/set-video-mode, /api/reload-loras, WebSocket handshake video_mode msg).
Uses FastAPI's ``TestClient`` so we don't need a running uvicorn server.
Stubs the model manager to avoid loading Wan2.2 — we only care that the
HTTP surface is plumbed correctly.
Run:
docker compose exec voice-chat python -m tests.component.test_07_endpoints
"""
from __future__ import annotations
import io
import json
import sys
from tests.component._common import get_logger
log = get_logger("test_07")
def _stub_video_engine():
class StubCfg:
mode = "reflective"
class StubEngine:
cfg = StubCfg()
avatar_path = None
def __init__(self): self.idle = b"FAKE_MP4"
def is_ready(self): return bool(self.avatar_path)
def get_idle_clip(self): return self.idle
def set_avatar(self, path): self.avatar_path = path
def load_loras(self, specs): self._last_loras = specs
return StubEngine()
def run():
from fastapi.testclient import TestClient
import server.main as main_mod
# Inject a stub engine so we never touch Wan2.2.
main_mod.model_mgr.video_engine = _stub_video_engine()
# Bypass the heavy lifespan (model loading) so TestClient starts fast.
main_mod.app.router.lifespan_context = None # type: ignore[attr-defined]
client = TestClient(main_mod.app)
# --- set-avatar ---
log.info("[case 1] POST /api/set-avatar")
fake_png = b"\x89PNG\r\n\x1a\n" + b"\x00" * 64 # minimal PNG header
resp = client.post(
"/api/set-avatar",
files={"image": ("avatar.png", io.BytesIO(fake_png), "image/png")},
)
assert resp.status_code == 200, f"got {resp.status_code}: {resp.text}"
data = resp.json()
assert data["status"] == "ok"
assert data["idle_clip_url"] == "/api/idle-clip"
log.info(" PASS: %s", data)
# --- idle-clip ---
log.info("[case 2] GET /api/idle-clip")
resp = client.get("/api/idle-clip")
assert resp.status_code == 200
assert resp.content == b"FAKE_MP4"
assert resp.headers["content-type"] == "video/mp4"
log.info(" PASS")
# --- set-video-mode ---
log.info("[case 3] POST /api/set-video-mode")
for mode in ("off", "library", "reflective"):
resp = client.post("/api/set-video-mode", data={"mode": mode})
assert resp.status_code == 200
assert resp.json()["mode"] == mode
resp = client.post("/api/set-video-mode", data={"mode": "bogus"})
assert resp.status_code == 400
log.info(" PASS")
# --- reload-loras ---
log.info("[case 4] POST /api/reload-loras")
body = {
"loras": [
{"path": "/cache/loras/a.safetensors", "weight": 0.8,
"target": "both", "name": "test-a"},
{"path": "/cache/loras/b.safetensors", "weight": 0.4,
"target": "both"},
]
}
resp = client.post("/api/reload-loras", json=body)
assert resp.status_code == 200, resp.text
data = resp.json()
assert data["lora_count"] == 2
log.info(" PASS: %s", data)
# --- WebSocket video_mode handshake ---
log.info("[case 5] WebSocket /ws/chat → video_mode announcement")
with client.websocket_connect("/ws/chat") as websocket:
msgs = []
for _ in range(5):
try:
msg = websocket.receive_json()
msgs.append(msg)
if msg.get("type") == "video_mode":
break
except Exception:
break
assert any(m.get("type") == "video_mode" for m in msgs), msgs
log.info(" PASS")
log.info("ALL PASSED")
if __name__ == "__main__":
run()
+52
View File
@@ -0,0 +1,52 @@
"""Phase 8 component test: /api/reload-loras hot-swap.
Verifies that ``VideoEngine.load_loras`` can be called again after startup
and the idle clip is regenerated to reflect the new style.
This test is the 'real model' version of test_07's reload endpoint stub.
Run:
docker compose exec voice-chat python -m tests.component.test_08_lora_reload
"""
from __future__ import annotations
import hashlib
from server.video import LoRASpec, VideoConfig, VideoEngine
from tests.component._common import ensure_sample_avatar, get_logger, write_bytes
log = get_logger("test_08")
def run():
avatar_path = ensure_sample_avatar()
cfg = VideoConfig.from_dict({"enabled": True, "mode": "reflective"})
engine = VideoEngine(cfg)
engine.load_models()
# Initial state: no LoRAs
engine.set_avatar(avatar_path)
idle_a = engine.get_idle_clip()
assert idle_a is not None
hash_a = hashlib.sha256(idle_a).hexdigest()
write_bytes("phase8_idle_noloras.mp4", idle_a)
log.info("idle (no LoRAs) sha256=%s", hash_a[:16])
# Hot-reload flow: unload (no-op), reload empty list, verify clip still generates.
# There are no published 5B-Turbo-compatible LoRAs yet; when one exists,
# construct a LoRASpec(path=..., target="both", weight=1.0) and compare hashes.
engine.load_loras([])
engine.set_avatar(avatar_path)
idle_b = engine.get_idle_clip()
assert idle_b is not None
hash_b = hashlib.sha256(idle_b).hexdigest()
write_bytes("phase8_idle_reloaded.mp4", idle_b)
log.info("idle (post-reload) sha256=%s", hash_b[:16])
log.info("PASS: hot-reload round-trip completed "
"(hash match=%s — expected without a real LoRA applied).",
hash_a == hash_b)
if __name__ == "__main__":
run()
+78
View File
@@ -0,0 +1,78 @@
"""Quick smoke test: generate a video clip with the dense 5B Turbo GGUF pipeline.
Calls Wan22Pipeline.generate_i2v directly (no MuseTalk, no VideoEngine)
and writes the result to tests/component/_out/phase9_gguf.mp4.
Run:
docker compose exec -e DIT_QUANT=gguf-Q8_0 voice-chat \
python -m tests.component.test_09_gguf_generate
"""
from __future__ import annotations
import os
import sys
from tests.component._common import ensure_sample_avatar, get_logger, write_bytes
log = get_logger("test_09")
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
def run():
try:
from server.video_models.wan22 import Wan22Pipeline
except ImportError as e:
log.error("Import failed: %s", e)
sys.exit(0)
avatar = ensure_sample_avatar()
log.info("Avatar: %s", avatar)
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-TI2V-5B",
dit_repo=DIT_REPO,
config_json=CONFIG_JSON,
model_cls="wan2.2",
resolution=480,
fps=16,
dit_quant_scheme=DIT_QUANT,
t5_quantized=True,
)
log.info("Pipeline ready.")
# Debug: verify DTYPE is set correctly for GGUF
from lightx2v.utils.envs import GET_DTYPE
log.info("GET_DTYPE() = %s (DTYPE env = %s)", GET_DTYPE(), os.environ.get("DTYPE"))
log.info("Generating 3-second i2v clip...")
frames = pipe.generate_i2v(
image_path=avatar,
prompt="a person looking at the camera, natural lighting, soft focus background",
seconds=3,
seed=42,
)
log.info("Got frames: shape=%s dtype=%s", frames.shape, frames.dtype)
# Encode to MP4
import imageio.v3 as iio
import tempfile
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tf:
tmp = tf.name
try:
iio.imwrite(tmp, frames, fps=16, codec="libx264")
with open(tmp, "rb") as f:
mp4_bytes = f.read()
finally:
os.remove(tmp)
out = write_bytes("phase9_gguf.mp4", mp4_bytes)
log.info("PASS: video written to %s (%d bytes, %d frames)", out, len(mp4_bytes), frames.shape[0])
if __name__ == "__main__":
run()
+83
View File
@@ -0,0 +1,83 @@
"""Smoke test: T5 text encoding under GGUF pipeline.
Builds the Wan22Pipeline (loads all weights including DIT) but only
exercises the T5 encoder — no image-to-video generation. Validates that
the DTYPE=FP16 ↔ BF16 patching lets T5 encode a prompt successfully.
Run:
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
python -m tests.component.test_10_t5_encode
"""
from __future__ import annotations
import os
import sys
from tests.component._common import get_logger
log = get_logger("test_10")
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
def run():
try:
from server.video_models.wan22 import Wan22Pipeline
except ImportError as e:
log.error("Import failed: %s", e)
sys.exit(0)
log.info("Building pipeline (quant=%s) — this loads T5 + DIT weights...", DIT_QUANT)
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-TI2V-5B",
dit_repo=DIT_REPO,
config_json=CONFIG_JSON,
model_cls="wan2.2",
resolution=480,
fps=16,
dit_quant_scheme=DIT_QUANT,
t5_quantized=True,
)
log.info("Pipeline ready.")
# Check DTYPE state after init
from lightx2v.utils.envs import GET_DTYPE
log.info("GET_DTYPE() = %s (DTYPE env = %s)", GET_DTYPE(), os.environ.get("DTYPE"))
# Run only the T5 text encoder
runner = pipe._runner
prompt = "a person looking at the camera, natural lighting"
log.info("Running T5 text encoder on prompt: %r", prompt)
import copy
input_info = copy.deepcopy(pipe._input_info_template)
input_info.prompt = prompt
runner.run_text_encoder(input_info)
log.info("T5 encode complete.")
# Inspect output — check all dataclass fields for tensor results
import torch
for attr in vars(input_info):
val = getattr(input_info, attr)
if isinstance(val, torch.Tensor):
log.info(" %s: shape=%s dtype=%s device=%s", attr, val.shape, val.dtype, val.device)
# Verify DTYPE is back to FP16 after T5 runs (if GGUF)
if DIT_QUANT.startswith("gguf-"):
current = GET_DTYPE()
import torch
expected = torch.float16
if current == expected:
log.info("PASS: DTYPE correctly restored to FP16 after T5 encode.")
else:
log.error("FAIL: DTYPE is %s after T5, expected %s", current, expected)
sys.exit(1)
log.info("PASS: T5 encoding succeeded under %s pipeline.", DIT_QUANT)
if __name__ == "__main__":
run()
+107
View File
@@ -0,0 +1,107 @@
"""Smoke test: image reading + VAE encoder (+ CLIP if enabled) under GGUF pipeline.
Builds the Wan22Pipeline, loads a sample avatar, reads the image input,
runs the CLIP image encoder (if use_image_encoder is true in the config),
and runs the VAE encoder. Validates outputs under DTYPE=FP16.
Note: The GGUF distill config sets use_image_encoder=false, so CLIP is
skipped by default. The VAE encoder is always exercised.
Run:
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
python -m tests.component.test_11_image_encode
"""
from __future__ import annotations
import os
import sys
import torch
from tests.component._common import ensure_sample_avatar, get_logger
log = get_logger("test_11")
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
def run():
try:
from server.video_models.wan22 import Wan22Pipeline
except ImportError as e:
log.error("Import failed: %s", e)
sys.exit(0)
avatar = ensure_sample_avatar()
log.info("Avatar: %s", avatar)
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-TI2V-5B",
dit_repo=DIT_REPO,
config_json=CONFIG_JSON,
model_cls="wan2.2",
resolution=480,
fps=16,
dit_quant_scheme=DIT_QUANT,
t5_quantized=True,
)
log.info("Pipeline ready.")
runner = pipe._runner
# Set up input_info so runner methods can access it
from lightx2v.utils.input_info import update_input_info_from_dict
update_input_info_from_dict(
pipe._input_info_template,
{
"seed": 42,
"prompt": "a person looking at the camera, natural lighting",
"negative_prompt": "",
"image_path": avatar,
"target_video_length": 17,
},
)
runner.input_info = pipe._input_info_template
# 1. Load image
log.info("Reading image input...")
img, img_ori = runner.read_image_input(avatar)
log.info("img: shape=%s dtype=%s device=%s", img.shape, img.dtype, img.device)
# 2. CLIP image encoder (only if enabled in config)
use_clip = runner.config.get("use_image_encoder", True)
if use_clip:
log.info("Running CLIP image encoder...")
clip_out = runner.run_image_encoder(img)
log.info("clip_out: shape=%s dtype=%s device=%s", clip_out.shape, clip_out.dtype, clip_out.device)
assert isinstance(clip_out, torch.Tensor), f"Expected tensor, got {type(clip_out)}"
log.info("PASS: CLIP image encoder succeeded.")
else:
log.info("CLIP image encoder disabled (use_image_encoder=false) — skipping.")
# 3. VAE encoder
vae_input = img_ori if runner.vae_encoder_need_img_original else img
log.info("Running VAE encoder (using %s)...",
"img_ori" if runner.vae_encoder_need_img_original else "img tensor")
vae_out, latent_shape = runner.run_vae_encoder(vae_input)
log.info("latent_shape: %s", latent_shape)
if isinstance(vae_out, torch.Tensor):
log.info("vae_out: shape=%s dtype=%s device=%s", vae_out.shape, vae_out.dtype, vae_out.device)
elif isinstance(vae_out, dict):
for k, v in vae_out.items():
if isinstance(v, torch.Tensor):
log.info("vae_out[%s]: shape=%s dtype=%s", k, v.shape, v.dtype)
else:
log.info("vae_out[%s]: type=%s", k, type(v))
else:
log.info("vae_out: type=%s", type(vae_out))
log.info("PASS: VAE encoder succeeded.")
log.info("PASS: All image encoding stages completed under %s pipeline.", DIT_QUANT)
if __name__ == "__main__":
run()
+102
View File
@@ -0,0 +1,102 @@
"""Smoke test: single DIT denoising step with GGUF weights.
Builds the pipeline, runs all encoders, initializes the scheduler, then
executes exactly one DIT forward pass (step_pre → infer → step_post).
This isolates the GGUF fp16 DIT from the rest of the pipeline.
Run:
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
python -m tests.component.test_12_dit_single_step
"""
from __future__ import annotations
import copy
import os
import sys
import torch
from tests.component._common import ensure_sample_avatar, get_logger
log = get_logger("test_12")
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
def run():
try:
from server.video_models.wan22 import Wan22Pipeline
except ImportError as e:
log.error("Import failed: %s", e)
sys.exit(0)
avatar = ensure_sample_avatar()
log.info("Avatar: %s", avatar)
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-TI2V-5B",
dit_repo=DIT_REPO,
config_json=CONFIG_JSON,
model_cls="wan2.2",
resolution=480,
fps=16,
dit_quant_scheme=DIT_QUANT,
t5_quantized=True,
)
log.info("Pipeline ready.")
runner = pipe._runner
# Set up input_info for a short clip
from lightx2v.utils.input_info import update_input_info_from_dict
update_input_info_from_dict(
pipe._input_info_template,
{
"seed": 42,
"prompt": "a person looking at the camera, natural lighting",
"negative_prompt": "",
"image_path": avatar,
"target_video_length": 17, # 1 second at 16fps + 1
},
)
runner.input_info = pipe._input_info_template
# 1. Run all encoders (T5 + CLIP + VAE)
log.info("Running all input encoders (T5 + CLIP + VAE)...")
runner.inputs = runner.run_input_encoder()
log.info("Encoder outputs ready.")
for k, v in runner.inputs.items():
if isinstance(v, torch.Tensor):
log.info(" inputs[%s]: shape=%s dtype=%s", k, v.shape, v.dtype)
elif isinstance(v, dict):
for k2, v2 in v.items():
if isinstance(v2, torch.Tensor):
log.info(" inputs[%s][%s]: shape=%s dtype=%s", k, k2, v2.shape, v2.dtype)
# 2. Initialize run (sets up scheduler, creates noise latents)
log.info("Initializing run (scheduler.prepare)...")
runner.init_run()
latents = runner.model.scheduler.latents
log.info("Initial latents: shape=%s dtype=%s", latents.shape, latents.dtype)
# 3. Single DIT step
log.info("Running single DIT step (step_pre → infer → step_post)...")
runner.model.scheduler.step_pre(step_index=0)
runner.model.infer(runner.inputs)
runner.model.scheduler.step_post()
latents_after = runner.model.scheduler.latents
log.info("Latents after step: shape=%s dtype=%s", latents_after.shape, latents_after.dtype)
# Verify latents changed (denoising did something)
assert not torch.equal(latents, latents_after), "Latents unchanged after DIT step"
log.info("PASS: DIT single step completed, latents updated.")
log.info("PASS: DIT forward pass succeeded under %s pipeline.", DIT_QUANT)
if __name__ == "__main__":
run()
+104
View File
@@ -0,0 +1,104 @@
"""Smoke test: VAE decoder under GGUF pipeline.
Builds the pipeline, runs all encoders, initializes the scheduler, executes
one DIT denoising step, then decodes the resulting latents back to pixel
frames via the VAE decoder. Validates the full encode→denoise→decode path.
Run:
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
python -m tests.component.test_13_vae_decode
"""
from __future__ import annotations
import os
import sys
import torch
from tests.component._common import ensure_sample_avatar, get_logger, write_bytes
log = get_logger("test_13")
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
def run():
try:
from server.video_models.wan22 import Wan22Pipeline
except ImportError as e:
log.error("Import failed: %s", e)
sys.exit(0)
avatar = ensure_sample_avatar()
log.info("Avatar: %s", avatar)
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-TI2V-5B",
dit_repo=DIT_REPO,
config_json=CONFIG_JSON,
model_cls="wan2.2",
resolution=480,
fps=16,
dit_quant_scheme=DIT_QUANT,
t5_quantized=True,
)
log.info("Pipeline ready.")
runner = pipe._runner
# Set up input_info for a short clip
from lightx2v.utils.input_info import update_input_info_from_dict
update_input_info_from_dict(
pipe._input_info_template,
{
"seed": 42,
"prompt": "a person looking at the camera, natural lighting",
"negative_prompt": "",
"image_path": avatar,
"target_video_length": 17, # 1 second at 16fps + 1
},
)
runner.input_info = pipe._input_info_template
# 1. Run all encoders (T5 + CLIP + VAE)
log.info("Running all input encoders (T5 + CLIP + VAE)...")
runner.inputs = runner.run_input_encoder()
log.info("Encoder outputs ready.")
# 2. Initialize run (sets up scheduler, creates noise latents)
log.info("Initializing run (scheduler.prepare)...")
runner.init_run()
log.info("Initial latents: shape=%s dtype=%s",
runner.model.scheduler.latents.shape,
runner.model.scheduler.latents.dtype)
# 3. Single DIT step (so we have realistic latents to decode)
log.info("Running single DIT step...")
runner.model.scheduler.step_pre(step_index=0)
runner.model.infer(runner.inputs)
runner.model.scheduler.step_post()
latents = runner.model.scheduler.latents
log.info("Latents after step: shape=%s dtype=%s", latents.shape, latents.dtype)
# 4. VAE decode
log.info("Running VAE decoder...")
video_out = runner.run_vae_decoder(latents)
log.info("VAE decoder output type: %s", type(video_out))
if isinstance(video_out, torch.Tensor):
log.info("video_out: shape=%s dtype=%s device=%s",
video_out.shape, video_out.dtype, video_out.device)
elif isinstance(video_out, list):
log.info("video_out: list of %d items", len(video_out))
if len(video_out) > 0 and isinstance(video_out[0], torch.Tensor):
log.info(" first item: shape=%s dtype=%s", video_out[0].shape, video_out[0].dtype)
else:
log.info("video_out: %s", video_out)
log.info("PASS: VAE decoder succeeded under %s pipeline.", DIT_QUANT)
if __name__ == "__main__":
run()
View File
+65
View File
@@ -0,0 +1,65 @@
"""Unit tests for the frame-length fitting helper in server.video_models.musetalk.
Pure-python: does not import MuseTalk itself.
"""
import numpy as np
from server.video_models.musetalk import _fit_frames_to_length, _ensure_uint8_rgb
def _make_frames(t, h=2, w=2):
return np.arange(t * h * w * 3, dtype=np.uint8).reshape(t, h, w, 3)
def test_fit_frames_trim():
frames = _make_frames(10)
out = _fit_frames_to_length(frames, 4)
assert out.shape == (4, 2, 2, 3)
np.testing.assert_array_equal(out, frames[:4])
def test_fit_frames_passthrough_when_equal():
frames = _make_frames(5)
out = _fit_frames_to_length(frames, 5)
assert out is frames or np.array_equal(out, frames)
def test_fit_frames_extends_with_pingpong():
frames = _make_frames(3)
out = _fit_frames_to_length(frames, 8)
assert out.shape == (8, 2, 2, 3)
# First 3 frames match the original
np.testing.assert_array_equal(out[:3], frames)
# Next 3 are the reverse (ping-pong)
np.testing.assert_array_equal(out[3:6], frames[::-1])
# Then forward again
np.testing.assert_array_equal(out[6:8], frames[:2])
def test_fit_frames_zero_target_returns_original():
frames = _make_frames(3)
out = _fit_frames_to_length(frames, 0)
np.testing.assert_array_equal(out, frames)
def test_ensure_uint8_rgb_from_float():
arr = np.ones((5, 2, 2, 3), dtype=np.float32) * 0.5
out = _ensure_uint8_rgb(arr)
assert out.dtype == np.uint8
assert out.shape == (5, 2, 2, 3)
assert out[0, 0, 0, 0] == 127
def test_ensure_uint8_rgb_promotes_3d_to_4d():
arr = np.zeros((2, 2, 3), dtype=np.uint8)
out = _ensure_uint8_rgb(arr)
assert out.shape == (1, 2, 2, 3)
def test_ensure_uint8_rgb_clips_float_out_of_range():
arr = np.ones((1, 1, 1, 3), dtype=np.float32) * 2.0 # 2.0 → clipped to 255
out = _ensure_uint8_rgb(arr)
assert out[0, 0, 0, 0] == 255
arr2 = np.ones((1, 1, 1, 3), dtype=np.float32) * -1.0
out2 = _ensure_uint8_rgb(arr2)
assert out2[0, 0, 0, 0] == 0
+67
View File
@@ -0,0 +1,67 @@
"""Unit tests for the ffmpeg muxer.
Requires ``ffmpeg`` on PATH. On Windows, if ffmpeg is not installed these
tests are skipped (they will run inside the Docker image where ffmpeg is
always present).
"""
import os
import shutil
import struct
import numpy as np
import pytest
from server.video_models.muxer import frames_and_audio_to_mp4, frames_to_mp4_loop
pytestmark = pytest.mark.skipif(
shutil.which("ffmpeg") is None,
reason="ffmpeg not installed locally; run these inside Docker",
)
def _rgb_frames(t, h=64, w=64):
"""Coloured checker frames so the encoder has real content."""
frames = np.zeros((t, h, w, 3), dtype=np.uint8)
for i in range(t):
frames[i, :, :, 0] = (i * 20) % 255
frames[i, :h // 2, :, 1] = 255
frames[i, :, :w // 2, 2] = 255
return frames
def test_frames_to_mp4_loop_produces_mp4_bytes():
frames = _rgb_frames(8)
data = frames_to_mp4_loop(frames, fps=16)
assert isinstance(data, bytes)
assert len(data) > 0
# MP4 files start with an ftyp box: 4 bytes size + 'ftyp'
assert data[4:8] == b"ftyp"
def test_frames_and_audio_to_mp4_produces_mp4_bytes():
frames = _rgb_frames(16)
# 1s silent audio at 24kHz
audio = np.zeros(24000, dtype=np.float32)
data = frames_and_audio_to_mp4(frames, audio, sample_rate=24000, fps=16)
assert isinstance(data, bytes)
assert len(data) > 0
assert data[4:8] == b"ftyp"
def test_frames_to_mp4_loop_rejects_empty():
with pytest.raises(ValueError):
frames_to_mp4_loop(np.empty((0, 64, 64, 3), dtype=np.uint8), fps=16)
def test_frames_and_audio_to_mp4_rejects_empty_audio():
frames = _rgb_frames(4)
with pytest.raises(ValueError):
frames_and_audio_to_mp4(
frames, np.empty(0, dtype=np.float32), sample_rate=24000, fps=16
)
def test_frames_to_mp4_loop_rejects_wrong_shape():
with pytest.raises(ValueError):
frames_to_mp4_loop(np.zeros((4, 64, 64), dtype=np.uint8), fps=16)
+144
View File
@@ -0,0 +1,144 @@
"""Unit test for the video-mode branch in ConversationSession.
Stubs every model involved (ASR, LLM, TTS, VideoEngine) so we can verify:
1. When video_engine is not ready, the existing PCM streaming path runs.
2. When video_engine IS ready, the per-chunk PCM sends are skipped and a
single ``speaking_clip`` JSON + MP4 binary is sent instead.
Pure asyncio; no CUDA, no real models.
"""
from __future__ import annotations
import asyncio
import types
from unittest.mock import MagicMock
import numpy as np
import pytest
from server.pipeline import ConversationSession
class _FakeVAD:
is_speaking = False
def process_chunk(self, _): return None
class _FakeASR:
def __init__(self, text="hello"):
self.text = text
def transcribe(self, _): return self.text
class _FakeLLM:
def __init__(self, response="Hi there."):
self.response = response
def generate(self, *_a, **_k):
return self.response, None
def trim_cache(self, state, _): return state
class _FakeTTSIterable:
"""Drop-in replacement for Kokoro's pipeline(..) generator."""
def __init__(self, chunks):
self._chunks = chunks
def __call__(self, segment, voice=None):
for i, audio in enumerate(self._chunks):
yield f"w{i}", None, audio
class _FakeTTSEngine:
def __init__(self, chunks):
self.pipeline = _FakeTTSIterable(chunks)
self.voice = "v"
self.sample_rate = 24000
class _FakeVideoEngineReady:
class _Cfg:
mode = "reflective"
cfg = _Cfg()
def __init__(self):
self.called_with = None
def is_ready(self): return True
def generate_speaking_clip(self, audio, sr, reply_text):
self.called_with = {"len": len(audio), "sr": sr, "reply": reply_text}
return b"FAKE_MP4_BYTES"
class _FakeModelsBase:
def __init__(self, tts_chunks):
self.asr_engine = _FakeASR()
self.llm_engine = _FakeLLM()
self.tts_engine = _FakeTTSEngine(tts_chunks)
def create_vad(self): return _FakeVAD()
class _FakeModelsStreaming(_FakeModelsBase):
video_engine = None
class _FakeModelsVideo(_FakeModelsBase):
def __init__(self, tts_chunks):
super().__init__(tts_chunks)
self.video_engine = _FakeVideoEngineReady()
@pytest.mark.asyncio
async def test_streaming_path_when_video_engine_absent():
json_sent: list = []
bytes_sent: list = []
async def send_json(d): json_sent.append(d)
async def send_bytes(b): bytes_sent.append(b)
chunks = [
np.ones(240, dtype=np.float32),
np.ones(480, dtype=np.float32),
]
models = _FakeModelsStreaming(tts_chunks=chunks)
session = ConversationSession(models, send_json, send_bytes)
await session._process_utterance(np.zeros(16000, dtype=np.float32))
# PCM bytes were sent (one per TTS chunk).
assert len(bytes_sent) == 2
# Per-chunk response_text messages were sent (not video's one-shot).
text_msgs = [m for m in json_sent if m.get("type") == "response_text"]
assert any(not m.get("final") for m in text_msgs)
# No speaking_clip envelope
assert not any(m.get("type") == "speaking_clip" for m in json_sent)
@pytest.mark.asyncio
async def test_video_path_when_engine_ready():
json_sent: list = []
bytes_sent: list = []
async def send_json(d): json_sent.append(d)
async def send_bytes(b): bytes_sent.append(b)
chunks = [
np.full(480, 0.5, dtype=np.float32),
np.full(480, 0.25, dtype=np.float32),
]
models = _FakeModelsVideo(tts_chunks=chunks)
session = ConversationSession(models, send_json, send_bytes)
await session._process_utterance(np.zeros(16000, dtype=np.float32))
# MP4 blob was sent once.
assert bytes_sent == [b"FAKE_MP4_BYTES"]
# speaking_clip envelope was sent exactly once.
envelopes = [m for m in json_sent if m.get("type") == "speaking_clip"]
assert len(envelopes) == 1
assert envelopes[0]["size_bytes"] == len(b"FAKE_MP4_BYTES")
assert envelopes[0]["text"] == "Hi there."
# The video engine received the concatenated audio.
ve = models.video_engine
assert ve.called_with is not None
assert ve.called_with["len"] == 960 # 480 + 480
assert ve.called_with["reply"] == "Hi there."
# No per-chunk PCM bytes were streamed (video path suppresses them).
# Only the MP4 blob is in bytes_sent.
assert len(bytes_sent) == 1
+138
View File
@@ -0,0 +1,138 @@
"""Unit tests for VideoConfig parsing and LoRASpec validation.
Pure-python, no model imports, no CUDA, no ffmpeg. Safe for Windows CI.
"""
import pytest
from server.video import VideoConfig, LoRASpec
def test_defaults_when_raw_is_empty():
cfg = VideoConfig.from_dict({})
assert cfg.enabled is False
assert cfg.backend == "lightx2v"
assert cfg.mode == "reflective"
assert cfg.resolution == 480
assert cfg.fps == 16
assert cfg.library_base_clip_count == 4
assert cfg.reflective_prompt_reply_words == 18
assert cfg.loras == []
def test_defaults_when_raw_is_none():
cfg = VideoConfig.from_dict(None) # type: ignore[arg-type]
assert cfg.enabled is False
def test_library_section_override():
cfg = VideoConfig.from_dict(
{"enabled": True, "mode": "library", "library": {"base_clip_count": 7, "base_clip_seconds": 3}}
)
assert cfg.enabled is True
assert cfg.mode == "library"
assert cfg.library_base_clip_count == 7
assert cfg.library_base_clip_seconds == 3
def test_reflective_section_override():
cfg = VideoConfig.from_dict(
{
"reflective": {
"clip_seconds": 9,
"clip_prompt_template": "my template: {reply_hint}",
"prompt_reply_words": 5,
}
}
)
assert cfg.reflective_clip_seconds == 9
assert cfg.reflective_prompt_template == "my template: {reply_hint}"
assert cfg.reflective_prompt_reply_words == 5
def test_lora_parse_minimal():
cfg = VideoConfig.from_dict({"loras": [{"path": "/tmp/a.safetensors"}]})
assert len(cfg.loras) == 1
lora = cfg.loras[0]
assert lora.path == "/tmp/a.safetensors"
assert lora.weight == 1.0
assert lora.target == "both"
assert lora.name is None
def test_lora_parse_full():
cfg = VideoConfig.from_dict(
{
"loras": [
{
"path": "/tmp/a.safetensors",
"weight": 0.7,
"target": "both",
"name": "style-a",
},
{
"path": "/tmp/b.safetensors",
"weight": 0.4,
"target": "both",
"name": "style-b",
},
]
}
)
assert len(cfg.loras) == 2
assert cfg.loras[0].target == "both"
assert cfg.loras[0].name == "style-a"
assert cfg.loras[1].target == "both"
assert cfg.loras[1].weight == 0.4
def test_lora_legacy_moe_target_coerced_to_both():
"""Legacy MoE configs with target='high_noise'/'low_noise' get coerced."""
cfg = VideoConfig.from_dict(
{
"loras": [
{"path": "/tmp/hi.safetensors", "target": "high_noise"},
{"path": "/tmp/lo.safetensors", "target": "low_noise"},
{"path": "/tmp/x.safetensors", "target": "bogus"},
]
}
)
assert all(l.target == "both" for l in cfg.loras)
def test_lora_entries_without_path_are_dropped():
cfg = VideoConfig.from_dict(
{"loras": [{"weight": 0.5}, {"path": "/tmp/ok.safetensors"}, None]}
)
assert len(cfg.loras) == 1
assert cfg.loras[0].path == "/tmp/ok.safetensors"
def test_models_section_override():
cfg = VideoConfig.from_dict(
{
"models": {
"wan22_base_repo": "/local/weights/wan22",
"wan22_dit_repo": "/local/weights/wan22-dit",
"wan22_dit_quant_scheme": "gguf-Q4_K_M",
"wan22_config_json": "/local/cfg/turbo.json",
"wan22_model_cls": "wan2.2",
"musetalk_path": "/local/weights/musetalk",
}
}
)
assert cfg.wan22_base_repo == "/local/weights/wan22"
assert cfg.wan22_dit_repo == "/local/weights/wan22-dit"
assert cfg.wan22_dit_quant_scheme == "gguf-Q4_K_M"
assert cfg.wan22_config_json == "/local/cfg/turbo.json"
assert cfg.wan22_model_cls == "wan2.2"
assert cfg.musetalk_model_path == "/local/weights/musetalk"
def test_models_section_defaults_to_5b_turbo():
cfg = VideoConfig.from_dict({})
assert cfg.wan22_base_repo == "Wan-AI/Wan2.2-TI2V-5B"
assert cfg.wan22_dit_repo == "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
assert cfg.wan22_dit_quant_scheme == "gguf-Q8_0"
assert cfg.wan22_t5_quantized is True
assert cfg.wan22_model_cls == "wan2.2"
assert cfg.wan22_config_json == "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
+106
View File
@@ -0,0 +1,106 @@
"""Unit tests for pure-python logic inside VideoEngine.
No models are loaded: we instantiate ``VideoEngine`` and hand-stub its
``_wan22`` / ``_musetalk`` attributes to test prompt derivation, library
round-robin, and frame fitting.
"""
import numpy as np
import pytest
from server.video import VideoConfig, VideoEngine
@pytest.fixture
def engine():
cfg = VideoConfig.from_dict(
{
"enabled": True,
"mode": "reflective",
"fps": 16,
"reflective": {
"clip_prompt_template": "A: {reply_hint} B",
"prompt_reply_words": 5,
},
}
)
return VideoEngine(cfg)
def test_derive_prompt_truncates_to_word_limit(engine):
out = engine._derive_prompt("one two three four five six seven eight")
assert out == "A: one two three four five B"
def test_derive_prompt_handles_empty_reply(engine):
out = engine._derive_prompt("")
assert out == "A: calm and friendly B"
out2 = engine._derive_prompt(None) # type: ignore[arg-type]
assert out2 == "A: calm and friendly B"
def test_derive_prompt_strips_and_passes_through(engine):
out = engine._derive_prompt(" hello world ")
assert out == "A: hello world B"
def test_is_ready_false_without_models(engine):
# Models haven't been loaded — is_ready must be False so the pipeline
# falls back to the PCM streaming path.
assert engine.is_ready() is False
def test_pick_library_frames_round_robin(engine):
engine.cfg.mode = "library"
engine.cfg.fps = 2
# Two base clips, 4 frames each.
a = np.tile(np.array([[[[0, 0, 0]]]], dtype=np.uint8), (4, 1, 1, 1))
b = np.tile(np.array([[[[255, 255, 255]]]], dtype=np.uint8), (4, 1, 1, 1))
engine.speaking_base_frames = [a, b]
# 2s of audio at 16kHz → 4 frames at fps=2
audio = np.zeros(16000 * 2, dtype=np.float32)
f1 = engine._pick_library_frames(audio, 16000)
f2 = engine._pick_library_frames(audio, 16000)
f3 = engine._pick_library_frames(audio, 16000)
assert f1.shape == (4, 1, 1, 3)
assert f1[0, 0, 0, 0] == 0 # first pick = clip A
assert f2[0, 0, 0, 0] == 255 # second pick = clip B
assert f3[0, 0, 0, 0] == 0 # wraps back to A
def test_pick_library_frames_trims_to_audio_duration(engine):
engine.cfg.mode = "library"
engine.cfg.fps = 4
frames = np.zeros((20, 1, 1, 3), dtype=np.uint8)
engine.speaking_base_frames = [frames]
# 1s audio → 4 frames
audio = np.zeros(16000, dtype=np.float32)
out = engine._pick_library_frames(audio, 16000)
assert out.shape == (4, 1, 1, 3)
def test_pick_library_frames_loops_for_long_audio(engine):
engine.cfg.mode = "library"
engine.cfg.fps = 4
frames = np.zeros((4, 1, 1, 3), dtype=np.uint8)
engine.speaking_base_frames = [frames]
# 3s audio → 12 frames, base has only 4
audio = np.zeros(16000 * 3, dtype=np.float32)
out = engine._pick_library_frames(audio, 16000)
assert out.shape == (12, 1, 1, 3)
def test_pick_library_frames_raises_when_empty(engine):
engine.cfg.mode = "library"
engine.speaking_base_frames = []
with pytest.raises(RuntimeError, match="no pre-baked base clips"):
engine._pick_library_frames(np.zeros(100, dtype=np.float32), 16000)
def test_generate_speaking_clip_raises_when_not_ready(engine):
with pytest.raises(RuntimeError, match="not ready"):
engine.generate_speaking_clip(
audio_f32=np.zeros(100, dtype=np.float32),
sample_rate=16000,
reply_text="hi",
)
Vendored Submodule
+1
Submodule third_party/MuseTalk added at ca5b7a8f28