working ok

This commit is contained in:
2026-04-16 10:00:37 -04:00
parent 9debc56137
commit 129df7d1fa
24 changed files with 674 additions and 539 deletions
+49
View File
@@ -0,0 +1,49 @@
# Agent guide — voice-chat
Orientation for AI agents making changes in this repo. Keep edits scoped; this pipeline has several sharp edges that a naive refactor will hit.
## What this project is
A local real-time voice (and optionally video-avatar) chat server. FastAPI + WebSocket on the backend, a small vanilla-JS UI on the frontend. All ML runs locally on a single NVIDIA GPU; there are no cloud API calls at runtime.
## Two-tier architecture
1. **Audio pipeline** (always on): mic PCM → [vad.py](server/vad.py) → [asr.py](server/asr.py) → [llm.py](server/llm.py) → [tts.py](server/tts.py) → PCM back to the browser. Orchestrated by [pipeline.py](server/pipeline.py) via `ConversationSession`. Supports barge-in (user speaks → cancel in-flight reply).
2. **Video pipeline** (optional, gated by `config.video.enabled`): per assistant turn, [video.py](server/video.py)'s `VideoEngine` calls [video_models/wan22.py](server/video_models/wan22.py) (LightX2V Wan2.2-I2V) to produce base frames, then [video_models/musetalk.py](server/video_models/musetalk.py) to lip-sync them to the TTS audio, then [video_models/muxer.py](server/video_models/muxer.py) to produce an MP4. The MP4 is sent over the same `/ws/chat` WebSocket as a `speaking_clip` message.
The audio path must keep working when `video.enabled` is false. Don't make video models load-bearing for the audio pipeline.
## Key files
- [server/main.py](server/main.py) — FastAPI app, WebSocket, video/avatar HTTP endpoints
- [server/models.py](server/models.py) — lifetime of all models; `ModelManager.video_engine` is `None` when video is disabled
- [server/pipeline.py](server/pipeline.py) — per-session audio pipeline + video branch
- [server/config.py](server/config.py) — parses [config.yml](config.yml)
- [server/video.py](server/video.py) — `VideoConfig`, `LoRASpec`, `VideoEngine` (library vs reflective modes)
- [server/video_models/wan22.py](server/video_models/wan22.py) — LightX2V wrapper; fp8 + GGUF loading; Blackwell patches
- [configs/lightx2v/](configs/lightx2v/) — LightX2V inference config templates; must match `wan22_dit_quant_scheme`
- [tests/unit/](tests/unit/) — GPU-free tests, runnable on Windows host
- [tests/component/](tests/component/) — end-to-end tests, must run inside the Docker container
## Conventions
- Config: single source of truth is [config.yml](config.yml) → `server/config.py` dataclasses. Don't read env vars for runtime behaviour; if you need a new knob, add it to the dataclass and document it in `config.yml`.
- Logging: `log = logging.getLogger(__name__)` at module top; log level is set once in `server/main.py`. INFO for lifecycle, DEBUG for per-chunk chatter.
- Async: WebSocket handlers and endpoints are async, but heavy model work is sync — wrap via `asyncio.to_thread(...)` at the call site (see `set_avatar` in `main.py`).
- Concurrency: `VideoEngine` serialises generation with `self._lock`. Don't call model methods without holding it from another thread.
- Tests: every non-trivial logic change in `server/video.py` or `server/pipeline.py` should have a corresponding `tests/unit/` test. GPU-dependent behaviour goes in `tests/component/`.
## Gotchas
- **GPU architecture.** This is tuned for RTX 5090 / Blackwell (SM120) with PyTorch 2.8 + Triton 3.4. Several upstream kernels (flashinfer, flash_attn3, sgl_kernel fp8 matmul, Triton-fused scale/shift) are broken or unavailable there. See [server/video_models/AGENT.md](server/video_models/AGENT.md) before touching the Wan2.2 wrapper.
- **First launch is slow.** Hugging Face downloads land in the `huggingface-cache` Docker volume; a cold run pulls >20 GB.
- **Wan-AI/Wan2.2-I2V-A14B** ships bf16 DIT shards we don't want — `BASE_REPO_IGNORE_PATTERNS` in [wan22.py](server/video_models/wan22.py) excludes them. Keep that list in sync if the repo layout changes.
- **LoRA targets matter.** Wan2.2 is a MoE (high_noise + low_noise sub-models). A LoRA with the wrong `target` loads silently and produces subtly wrong output.
- **Don't mix audio+video state.** The audio pipeline must not block on video generation; video is produced for a turn *after* the full reply audio is available, and sent as a separate message.
## When in doubt
- Run `python -m pytest tests/unit -v` — it's fast and catches most regressions.
- For GPU changes, run the lowest-numbered relevant component test first; they're ordered to isolate failure to a single stage.
- Check [memory](../../.claude/projects/c--Users-bheth-Documents-voice-chat/memory/) (auto-loaded) for prior-session findings that aren't in the code.
+12 -3
View File
@@ -61,10 +61,19 @@ RUN python3.11 -m pip install --no-cache-dir --no-deps \
"sgl-kernel @ https://github.com/sgl-project/whl/releases/download/v0.3.14.post1/sgl_kernel-0.3.14.post1%2Bcu128-cp310-abi3-manylinux2014_x86_64.whl" || \
echo "sgl-kernel install failed — fp8 T5 will fall back to bf16"
#
# MuseTalk (audio-driven lip-sync) — same story.
# MuseTalk (audio-driven lip-sync) — installed from the bhetherman/MuseTalk
# fork checked in as a submodule at third_party/MuseTalk. The upstream repo
# has no setup.py / pyproject.toml; our fork adds them so `pip install .`
# just works. We deliberately do NOT install its requirements.txt (it pins
# numpy==1.23.5, transformers==4.39.2, tensorflow==2.12.0 which conflict
# with the rest of the stack) — instead we install its real runtime deps
# explicitly here.
COPY third_party/MuseTalk /opt/MuseTalk
RUN python3.11 -m pip install --no-cache-dir --no-deps /opt/MuseTalk || \
echo "MuseTalk install failed — config.video.musetalk.enabled must stay false until fixed"
RUN python3.11 -m pip install --no-cache-dir \
"git+https://github.com/TMElyralab/MuseTalk.git" || \
echo "MuseTalk install failed — config.video.enabled must stay false until fixed"
librosa einops omegaconf ffmpeg-python || \
echo "MuseTalk runtime deps install failed"
#
# LoRA directory (user drops .safetensors here; bind-mounted in compose).
RUN mkdir -p /cache/loras
+55 -12
View File
@@ -1,21 +1,29 @@
# Voice Chat
A real-time voice conversation app powered by local AI models. Speak into your mic and get spoken responses back — all running on your own GPU with no cloud APIs.
A real-time voice conversation app powered by local AI models. Speak into your mic and get spoken responses back — and optionally a lip-synced talking-head video of a chosen avatar — all running on your own GPU with no cloud APIs.
## Pipeline
**Mic input** → **VAD** (Silero ONNX) → **ASR** (Qwen3-ASR-0.6B) → **LLM** (Qwen3.5-0.8B) → **TTS** (Kokoro) → **Speaker output**
When the optional video stack is enabled, each assistant turn also produces an MP4 via:
**TTS audio + avatar image** → **Wan2.2-Lightning I2V** (LightX2V, fp8 or GGUF) → **MuseTalk lip-sync** → **ffmpeg mux** → **`speaking_clip` WebSocket message**
- **VAD** — Silero VAD via ONNX Runtime, detects speech/silence boundaries on CPU
- **ASR** — Qwen3-ASR-0.6B, bfloat16 on CUDA
- **LLM** — Qwen3.5-0.8B, loaded via transformers
- **LLM** — Qwen3.5-0.8B (local) or any model served by LM Studio
- **TTS** — Kokoro, streams sentence-by-sentence audio at 24 kHz
- **Barge-in** — interrupt the assistant mid-response by speaking
- **Video (optional)** — Wan2.2 I2V 14B MoE with LightX2V distill LoRAs, fp8 or GGUF-quantised DIT, lip-synced by MuseTalk. Two modes:
- `library` — pre-bakes a small set of speaking base clips per avatar, picks round-robin per turn, lip-syncs on the fly
- `reflective` — generates a fresh Wan2.2 clip per turn from a prompt derived from the assistant's reply
## Requirements
- NVIDIA GPU with CUDA 12.8 support
- NVIDIA GPU with CUDA 12.8 support (tested on RTX 5090 / SM120 Blackwell)
- Docker + Docker Compose with the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
- ~24 GB VRAM recommended when video is enabled (fp8); ~16 GB with `gguf-Q4_K_M`
## Quick Start
@@ -27,6 +35,17 @@ Then open [http://localhost:8000](http://localhost:8000) in your browser.
Models are downloaded from Hugging Face on first launch and cached in a Docker volume (`huggingface-cache`) so they persist across rebuilds.
## Configuration
All runtime behaviour is driven by [config.yml](config.yml):
- `llm.backend``local` (in-process transformers) or `lmstudio` (talks to a local LM Studio server)
- `llm.system_prompt`, `llm.max_cache_tokens` — prompt and KV-cache limit per session
- `video.enabled` — master toggle for the avatar video stack. When `false`, no video models load and the app behaves exactly like the audio-only pipeline
- `video.mode``library` or `reflective`
- `video.models.wan22_dit_quant_scheme``gguf-Q8_0` (default) or `gguf-Q4_K_M` for lower VRAM; any GGUF level LightX2V supports (dense 5B Turbo is GGUF-only)
- `video.loras` — list of LoRA adapters applied to the dense Wan2.2-TI2V-5B DIT at load time. Each entry has `path`, `weight`, `target` (always `both` — the 5B DIT is not MoE; legacy `high_noise`/`low_noise` values are coerced), and optional `name`. User LoRAs are mounted from `./loras/` into the container at `/cache/loras/`
## Local Development (without Docker)
```bash
@@ -45,17 +64,41 @@ python run.py
The server starts on port 8000.
## API
- `GET /` — browser UI
- `WS /ws/chat` — bidirectional audio + control WebSocket
- `POST /api/set-voice` — switch the Kokoro voice
- `POST /api/set-avatar` *(video only)* — upload an avatar image; (re)generates idle + library clips
- `GET /api/idle-clip` *(video only)* — cached idle loop MP4
- `POST /api/set-video-mode` *(video only)* — switch between `off` / `library` / `reflective`
- `POST /api/reload-loras` *(video only)* — hot-swap the LoRA stack; regenerates the idle clip
## Project Structure
```
server/
main.py — FastAPI app, WebSocket endpoint
models.py — Model loading and management
pipeline.py — VAD -> ASR -> LLM -> TTS orchestration
vad.py — Silero VAD (ONNX) streaming wrapper
asr.py — Speech recognition engine
llm.py — Language model engine
tts.py — Kokoro TTS engine
audio_utils.py — PCM/float32 conversion helpers
static/ — Browser UI (HTML/JS/CSS)
main.py — FastAPI app, WebSocket + video endpoints
models.py — Model loading and management (audio + optional video)
pipeline.py — VAD -> ASR -> LLM -> TTS orchestration, video branch
config.py — config.yml parsing
vad.py — Silero VAD (ONNX) streaming wrapper
asr.py — Speech recognition engine
llm.py — Language model engine (local + LM Studio backends)
tts.py — Kokoro TTS engine
audio_utils.py — PCM/float32 conversion helpers
video.py — VideoEngine orchestrator + VideoConfig + LoRASpec
video_models/
wan22.py — LightX2V Wan2.2 I2V pipeline wrapper (fp8 + GGUF)
musetalk.py — MuseTalk lip-sync wrapper
muxer.py — ffmpeg helpers: frames -> MP4, frames+audio -> MP4
configs/
lightx2v/ — LightX2V inference config templates (fp8 + GGUF variants)
static/ — Browser UI (HTML/JS/CSS)
avatars/ — uploaded avatar images (gitignored)
loras/ — user-supplied Wan2.2 LoRAs, mounted into the container
reference_audio/ — Kokoro voice reference samples
tests/ — unit + component tests (see tests/README.md)
```
+33 -29
View File
@@ -13,9 +13,9 @@ llm:
url: http://host.docker.internal:1234 # host.docker.internal resolves to your PC from inside Docker
model: "" # leave empty to use whatever model LM Studio has loaded
# Avatar video generation (Wan2.2-Lightning fp8 via LightX2V + MuseTalk lip-sync)
# Avatar video generation (Wan2.2-TI2V-5B-Turbo GGUF via LightX2V + MuseTalk lip-sync)
video:
enabled: false # master toggle — when false, video models are not loaded
enabled: true # master toggle — when false, video models are not loaded
backend: lightx2v # only option for now
mode: reflective # "library" (pre-baked clips) | "reflective" (fresh per turn)
resolution: 480 # 480 or 720
@@ -25,6 +25,12 @@ video:
base_clip_count: 4 # how many speaking base clips to pre-generate per avatar
base_clip_seconds: 6 # duration of each pre-baked clip
# MuseTalk audio-driven lip-sync. When disabled, Wan2.2 base frames are
# used as-is without a lip-sync pass — useful when MuseTalk isn't installed
# or while iterating on the base pipeline.
musetalk:
enabled: false # toggle lip-sync on/off
reflective:
clip_seconds: 5 # target length of each fresh Wan2.2 clip per turn
clip_prompt_template: >-
@@ -33,35 +39,33 @@ video:
prompt_reply_words: 18 # max words lifted from reply to inject as {reply_hint}
# Model sources for the video stack. T5/VAE/tokenizer come from the
# Wan-AI base repo. DIT weights come from wan22_dit_repo in the format
# specified by wan22_dit_quant_scheme. Both repos download on first run
# into HF_HOME=/cache/huggingface.
# Wan-AI base repo. The single dense DIT comes from wan22_dit_repo as
# GGUF (Turbo 4-step distill). Both repos download on first run into
# HF_HOME=/cache/huggingface.
#
# Supported dit_quant_scheme values:
# fp8-sgl — fp8 e4m3 safetensors (~15 GB/expert, from lightx2v/Wan2.2-Distill-Models)
# gguf-Q4_K_M — GGUF 4-bit (~9.65 GB/expert, from QuantStack/Wan2.2-I2V-A14B-GGUF)
# gguf-Q8_0 — GGUF 8-bit (~15.4 GB/expert)
# (any gguf-<level> supported by LightX2V — see base_model.py MM_WEIGHT_REGISTER)
# Supported dit_quant_scheme values (dense 5B Turbo — GGUF only):
# gguf-Q8_0 — 8-bit, ~6 GB DIT, ~6.5 GB VRAM at load (default)
# gguf-Q4_K_M — 4-bit, ~3.5 GB DIT, lower VRAM for tight budgets
# (any gguf-<level> published in hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF)
models:
wan22_base_repo: Wan-AI/Wan2.2-I2V-A14B
wan22_dit_repo: QuantStack/Wan2.2-I2V-A14B-GGUF
wan22_dit_quant_scheme: gguf-Q4_K_M
wan22_base_repo: Wan-AI/Wan2.2-TI2V-5B
wan22_dit_repo: hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF
wan22_dit_quant_scheme: gguf-Q8_0
wan22_t5_quantized: true
wan22_model_cls: wan2.2_moe_distill
wan22_config_json: /app/configs/lightx2v/wan22_i2v_gguf_distill.json
wan22_model_cls: wan2.2
wan22_config_json: /app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json
musetalk_path: TMElyralab/MuseTalk
# LoRAs applied to the fp8 base at load time via runtime switch_lora.
# Wan2.2 is a MoE with separate high-noise and low-noise sub-models
# `target` picks which sub-model each LoRA attaches to. The two files
# below are the user-supplied ./loras/wan22-[HL]-e8.safetensors mounted
# into the container at /cache/loras/.
loras:
- path: /cache/loras/wan22-H-e8.safetensors
weight: 1.0
target: high_noise
name: wan22-H-e8
- path: /cache/loras/wan22-L-e8.safetensors
weight: 1.0
target: low_noise
name: wan22-L-e8
# LoRAs applied to the dense 5B DIT at load time via LightX2V's
# lora_dynamic_apply path (merged during GGUF dequant). Dense has a
# single set of weights so `target` is always `both`.
#
# The old MoE-trained wan22-H-e8 / wan22-L-e8 LoRAs are NOT compatible
# with the 5B DIT and are disabled here. Future 5B-compatible LoRAs
# should follow the shape shown below.
loras: []
# loras:
# - path: /cache/loras/your-5b-lora.safetensors
# weight: 1.0
# target: both
# name: your-5b-lora
@@ -1,36 +0,0 @@
{
"_comment": "Wan2.2 i2v MoE 4-step distill, fp8 e4m3 quantized. Built for 24 GB-class GPUs — cpu_offload keeps DIT layers swapping in block-by-block. Derived from LightX2V's configs/distill/wan22/wan_moe_i2v_distill_4090.json plus the quant scheme + ckpt overrides from wan_moe_i2v_distill_quant.json. high_noise_quantized_ckpt / low_noise_quantized_ckpt are filled in at runtime by server/video_models/wan22.py with absolute paths to the files downloaded into HF_HOME.",
"infer_steps": 4,
"target_video_length": 81,
"text_len": 512,
"resize_mode": "adaptive",
"resolution": "480p",
"target_height": 480,
"target_width": 480,
"fps": 16,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"sample_guide_scale": [3.5, 3.5],
"sample_shift": 5.0,
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "block",
"lazy_load": true,
"t5_cpu_offload": true,
"vae_cpu_offload": false,
"use_image_encoder": false,
"boundary_step_index": 2,
"denoising_step_list": [1000, 750, 500, 250],
"dit_quantized": true,
"dit_quant_scheme": "fp8-sgl",
"t5_quantized": false
}
@@ -0,0 +1,40 @@
{
"_comment": "LightX2V config for Wan2.2-TI2V-5B-Turbo (dense, GGUF). Single DIT checkpoint (not MoE). dit_quantized_ckpt is filled in at runtime by Wan22Pipeline.",
"infer_steps": 4,
"target_video_length": 81,
"text_len": 512,
"resize_mode": "adaptive",
"resolution": "480p",
"target_height": 480,
"target_width": 480,
"fps": 16,
"vae_stride": [4, 16, 16],
"num_channels_latents": 48,
"self_attn_1_type": "torch_sdpa",
"cross_attn_1_type": "torch_sdpa",
"cross_attn_2_type": "torch_sdpa",
"modulate_type": "torch",
"rope_type": "torch",
"sample_guide_scale": 1.0,
"sample_shift": 5.0,
"enable_cfg": false,
"cpu_offload": false,
"offload_granularity": "model",
"t5_cpu_offload": true,
"vae_cpu_offload": false,
"use_image_encoder": false,
"denoising_step_list": [1000, 750, 500, 250],
"dit_quantized": true,
"dit_quant_scheme": "gguf-Q8_0",
"t5_quantized": true,
"t5_quant_scheme": "fp8-sgl"
}
@@ -1,41 +0,0 @@
{
"_comment": "Wan2.2 i2v MoE 4-step distill, GGUF quantized. Uses QuantStack/Wan2.2-I2V-A14B-GGUF checkpoints instead of fp8 safetensors. GGUF does not support block-level offload so offload_granularity is set to 'model' — the entire DIT is moved to GPU when active. With Q4_K_M (~9.65 GB per expert) this fits comfortably in 24+ GB VRAM. high_noise_quantized_ckpt / low_noise_quantized_ckpt are filled in at runtime by server/video_models/wan22.py. IMPORTANT: GGUF dequantizes to fp16, so you must set DTYPE=FP16 in the container environment.",
"infer_steps": 4,
"target_video_length": 81,
"text_len": 512,
"resize_mode": "adaptive",
"resolution": "480p",
"target_height": 480,
"target_width": 480,
"fps": 16,
"_comment_attn": "flash_attn3/sageattn3 aren't installed (no Blackwell-ready pre-built wheels). Use PyTorch SDPA which works on SM120.",
"self_attn_1_type": "torch_sdpa",
"cross_attn_1_type": "torch_sdpa",
"cross_attn_2_type": "torch_sdpa",
"_comment_modulate": "Triton fuse_scale_shift_kernel segfaults during JIT compile on Blackwell SM120 (triton 3.4 + cu128). Use the PyTorch modulate fallback until the Triton issue is resolved.",
"modulate_type": "torch",
"_comment_rope": "flashinfer not installed; fall back to PyTorch rope.",
"rope_type": "torch",
"sample_guide_scale": [3.5, 3.5],
"sample_shift": 5.0,
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "model",
"t5_cpu_offload": true,
"vae_cpu_offload": false,
"use_image_encoder": false,
"boundary_step_index": 2,
"denoising_step_list": [1000, 750, 500, 250],
"dit_quantized": true,
"dit_quant_scheme": "gguf-Q4_K_M",
"t5_quantized": false
}
+57
View File
@@ -0,0 +1,57 @@
# Agent guide — server/
The audio pipeline and FastAPI surface. The video stack lives under [video_models/](video_models/) and has its own guide.
## Module map
- [main.py](main.py) — FastAPI app, lifespan, `/ws/chat` WebSocket, video/avatar HTTP endpoints. Keep business logic out of here; this is a transport layer.
- [models.py](models.py) — `ModelManager.load_all()`. All models are loaded once at startup in a fixed order: VAD → ASR → LLM → TTS → (optional) Video. `video_engine` stays `None` when `config.video.enabled` is false — callers MUST tolerate that.
- [config.py](config.py) — thin YAML loader, exposes a single `config` dict. Don't scatter `yaml.safe_load` elsewhere.
- [pipeline.py](pipeline.py) — `ConversationSession`, one instance per WebSocket. Owns per-session state (VAD stream, conversation history, KV cache, cancel event). Orchestrates VAD → ASR → LLM → TTS and optionally the video branch.
- [vad.py](vad.py) — Silero VAD via ONNX Runtime on CPU. `StreamingVAD.process_chunk(pcm_16k) → utterance | None`. Returns a full utterance only on speech→silence transition.
- [asr.py](asr.py) — Qwen3-ASR wrapper. Sync, called under `asyncio.to_thread`.
- [llm.py](llm.py) — two backends behind a common `generate(history, max_new_tokens, kv_cache_state) → (text, KVCacheState)` signature: `LLMEngine` (local transformers) and `LMStudioEngine` (HTTP, no KV cache).
- [tts.py](tts.py) — Kokoro wrapper. The per-segment generator yields `(graphemes, _ps, audio_f32)` tuples at 24 kHz.
- [audio_utils.py](audio_utils.py) — `pcm_bytes_to_float32` / `float32_to_pcm_bytes`. The WebSocket protocol is 16 kHz int16 PCM in, 24 kHz int16 PCM out.
- [video.py](video.py) — `VideoConfig`, `LoRASpec`, `VideoEngine`. Orchestrates Wan2.2 + MuseTalk. Gated by `config.video.enabled`.
- [video_models/](video_models/) — see [video_models/AGENT.md](video_models/AGENT.md) for Blackwell/GGUF gotchas.
## Session lifecycle (pipeline.py)
1. Client connects → `ConversationSession(models, send_json, send_bytes)``start()` emits `{type: "status", state: "listening"}` and a `{type: "video_mode", ...}` hint.
2. Inbound binary frames are 16 kHz int16 PCM → `handle_audio_chunk` → VAD. On speech→silence the session kicks off `_process_utterance` as an `asyncio.Task` so it never blocks the WebSocket receive loop.
3. `_process_utterance` flow: ASR → LLM → TTS stream. Each blocking call wraps in `asyncio.to_thread`.
4. TTS output is split into short segments via `_split_into_segments` before synthesis so streaming chunks stay small.
5. TTS runs on a background `threading.Thread`, feeding a `queue.Queue` that the async loop drains.
## Barge-in
- `cancel_event: threading.Event` is the single stop signal. Checked between every stage and inside the TTS queue drain loop.
- Two ways to trigger: new VAD utterance while `is_responding` (`handle_audio_chunk`), or an explicit `{type: "interrupt", last_chunk_id}` text message (`interrupt`).
- On cancel: set the event, send `{type: "interrupt"}` to the client so it flushes its audio buffer, and discard any pending video clip.
- Don't add work after `cancel_event.is_set()` without checking — that's how zombie audio/video reaches the client after a barge-in.
## Video branch
The audio pipeline changes shape when `video_engine.is_ready()`:
- PCM chunks and `response_text` are **not** streamed during the turn — they're buffered.
- TTS audio is concatenated into one float32 array.
- After TTS completes, `video_engine.generate_speaking_clip(audio, sr, reply_text)` renders the MP4 (blocking, wrapped in `to_thread`).
- The full clip + final text is sent as a single `speaking_clip` message.
If you extend the audio pipeline, preserve this dual-mode behaviour: the client's UX is very different between the two paths, and mixing them (e.g. sending PCM *and* a clip) will double-play the audio.
## Conventions
- Dataclasses for structured config (see `VideoConfig`, `LoRASpec`). Parse once in `*.from_dict`; don't re-read `config.yml` mid-session.
- `asyncio.to_thread` for any sync model call from an async context. Never call `.generate()` / `.transcribe()` / `.pipeline()` directly on the event loop.
- Locks: `VideoEngine._lock` serialises model state mutations; `ConversationSession` is not locked because each WebSocket gets its own instance.
- Logging: one `log = logging.getLogger(__name__)` per module. INFO for lifecycle and per-turn milestones; avoid DEBUG spam in hot loops.
- Keep `main.py` thin. New endpoints should delegate to a method on `ModelManager` or an engine class.
## Testing
- `tests/unit/test_pipeline_video_branch.py` — the video vs. audio path selection. Update it if you change the `use_video` condition.
- `tests/unit/test_video_config.py` / `test_video_engine_logic.py` — config parsing and the pure logic in `video.py`.
- Component tests live in `tests/component/` and require the Docker GPU environment. See [tests/README.md](../tests/README.md).
+1 -1
View File
@@ -133,7 +133,7 @@ async def reload_loras(body: dict):
if not entry or "path" not in entry:
continue
target = str(entry.get("target", "both")).lower()
if target not in ("high_noise", "low_noise", "both"):
if target != "both":
target = "both"
specs.append(
LoRASpec(
+1 -2
View File
@@ -121,8 +121,7 @@ class ModelManager:
log.info("Loading avatar video engine...")
cfg = VideoConfig.from_dict(video_cfg_raw)
self.video_engine = VideoEngine(cfg)
if cfg.loras:
self.video_engine.load_loras(cfg.loras)
self.video_engine.load_models()
log.info("Avatar video engine loaded (mode=%s).", cfg.mode)
def create_vad(self) -> StreamingVAD:
+49 -37
View File
@@ -18,18 +18,18 @@ import numpy as np
log = logging.getLogger(__name__)
LoRATarget = Literal["high_noise", "low_noise", "both"]
LoRATarget = Literal["both"]
@dataclass
class LoRASpec:
"""One LoRA adapter entry from ``config.video.loras``.
Wan2.2 I2V is a Mixture-of-Experts model with separate high-noise and
low-noise sub-models. Most LightX2V distill LoRAs come paired (one per
sub-model) and must be applied to the correct target. Allow
``target="both"`` for LoRAs that should be applied to both sub-models
(e.g. style LoRAs).
The dense Wan2.2-TI2V-5B DIT has a single set of weights (no MoE
experts), so ``target`` is always ``"both"``. The field is kept for
forward compatibility and config-file compatibility with older MoE
configs — legacy ``"high_noise"`` / ``"low_noise"`` values are coerced
to ``"both"`` in ``VideoConfig.from_dict``.
"""
path: str
@@ -60,18 +60,20 @@ class VideoConfig:
# Model paths — can be overridden via config.yml.video.models.
# wan22_base_repo : HF repo id (or local dir) providing T5/VAE/tokenizer.
# The bf16 DIT shards in this repo are skipped — we
# replace them with quantised files from wan22_dit_repo.
# wan22_dit_repo : HF repo id (or local dir) providing the quantised
# DIT checkpoints (fp8 or GGUF).
# wan22_dit_quant_scheme : quantisation format, e.g. "fp8-sgl" or "gguf-Q4_K_M".
# replace them with a quantised GGUF from wan22_dit_repo.
# wan22_dit_repo : HF repo id (or local dir) providing the single
# dense GGUF DIT checkpoint (5B Turbo).
# wan22_dit_quant_scheme : GGUF quant level, e.g. "gguf-Q8_0" (default)
# or "gguf-Q4_K_M" for lower VRAM.
# wan22_config_json : path to the LightX2V inference config template the
# Wan22Pipeline will fill in with absolute ckpt paths.
wan22_base_repo: str = "Wan-AI/Wan2.2-I2V-A14B"
wan22_dit_repo: str = "lightx2v/Wan2.2-Distill-Models"
wan22_dit_quant_scheme: str = "fp8-sgl"
wan22_t5_quantized: bool = False
wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
wan22_model_cls: str = "wan2.2_moe_distill"
wan22_base_repo: str = "Wan-AI/Wan2.2-TI2V-5B"
wan22_dit_repo: str = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
wan22_dit_quant_scheme: str = "gguf-Q8_0"
wan22_t5_quantized: bool = True
wan22_config_json: str = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
wan22_model_cls: str = "wan2.2"
musetalk_enabled: bool = True
musetalk_model_path: str = "TMElyralab/MuseTalk"
@classmethod
@@ -92,9 +94,10 @@ class VideoConfig:
if not entry or "path" not in entry:
continue
target = str(entry.get("target", "both")).lower()
if target not in ("high_noise", "low_noise", "both"):
if target != "both":
log.warning(
"LoRA %s: invalid target %r, defaulting to 'both'",
"LoRA %s: target %r is MoE-era; coercing to 'both' "
"(dense 5B has a single DIT).",
entry.get("path"), target,
)
target = "both"
@@ -122,30 +125,32 @@ class VideoConfig:
reflective_prompt_reply_words=int(reflective.get("prompt_reply_words", 18)),
loras=loras,
wan22_base_repo=str(
models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-I2V-A14B")
models_raw.get("wan22_base_repo", "Wan-AI/Wan2.2-TI2V-5B")
),
wan22_dit_repo=str(
models_raw.get(
"wan22_dit_repo",
# Backwards compat: fall back to old key name.
models_raw.get("wan22_fp8_repo", "lightx2v/Wan2.2-Distill-Models"),
"hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF",
)
),
wan22_dit_quant_scheme=str(
models_raw.get("wan22_dit_quant_scheme", "fp8-sgl")
models_raw.get("wan22_dit_quant_scheme", "gguf-Q8_0")
),
wan22_t5_quantized=bool(
models_raw.get("wan22_t5_quantized", False)
models_raw.get("wan22_t5_quantized", True)
),
wan22_config_json=str(
models_raw.get(
"wan22_config_json",
"/app/configs/lightx2v/wan22_i2v_fp8_distill.json",
"/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json",
)
),
wan22_model_cls=str(
models_raw.get("wan22_model_cls", "wan2.2_moe_distill")
models_raw.get("wan22_model_cls", "wan2.2")
),
musetalk_enabled=bool(raw.get("musetalk", {}).get("enabled", True))
if isinstance(raw.get("musetalk"), dict)
else bool(raw.get("musetalk_enabled", True)),
musetalk_model_path=str(
models_raw.get("musetalk_path", "TMElyralab/MuseTalk")
),
@@ -235,17 +240,22 @@ class VideoEngine:
self._wan22.load_loras(self.cfg.loras)
log.info("Wan2.2 pipeline ready.")
log.info("Loading MuseTalk engine (%s)...", self.cfg.musetalk_model_path)
self._musetalk = MuseTalkEngine(model_path=self.cfg.musetalk_model_path)
log.info("MuseTalk engine ready.")
if self.cfg.musetalk_enabled:
log.info("Loading MuseTalk engine (%s)...", self.cfg.musetalk_model_path)
self._musetalk = MuseTalkEngine(model_path=self.cfg.musetalk_model_path)
log.info("MuseTalk engine ready.")
else:
log.info("MuseTalk disabled via config — skipping lip-sync pass.")
self._musetalk = None
# --- Readiness ------------------------------------------------------
def is_ready(self) -> bool:
"""True when an avatar is set and a speaking clip can be produced."""
musetalk_ok = (not self.cfg.musetalk_enabled) or self._musetalk is not None
return (
self._wan22 is not None
and self._musetalk is not None
and musetalk_ok
and self.avatar_path is not None
and self.idle_clip_mp4 is not None
)
@@ -336,7 +346,6 @@ class VideoEngine:
"(avatar set? models loaded?)"
)
assert self._wan22 is not None
assert self._musetalk is not None
# 1. Source base frames.
if self.cfg.mode == "library":
@@ -351,13 +360,16 @@ class VideoEngine:
seed=None, # random each turn
)
# 2. Lip-sync the base frames to the given audio.
synced_frames = self._musetalk.lip_sync(
frames=base_frames,
audio=audio_f32,
sample_rate=sample_rate,
fps=self.cfg.fps,
)
# 2. Lip-sync the base frames to the given audio (if enabled).
if self._musetalk is not None:
synced_frames = self._musetalk.lip_sync(
frames=base_frames,
audio=audio_f32,
sample_rate=sample_rate,
fps=self.cfg.fps,
)
else:
synced_frames = base_frames
# 3. Mux frames + audio into an MP4.
from server.video_models.muxer import frames_and_audio_to_mp4
+78
View File
@@ -0,0 +1,78 @@
# Agent guide — server/video_models/
Wrappers around 3rd-party video models. These are the trickiest files in the repo: LightX2V's internals move quickly upstream, and the Blackwell (RTX 5090 / SM120) GPU path requires several non-obvious patches layered on top. Read this before editing.
## Scope
- [wan22.py](wan22.py) — LightX2V Wan2.2-I2V A14B MoE pipeline. Supports fp8 safetensors and GGUF DIT checkpoints. Loaded once at startup, held resident; per-turn calls go through `generate_i2v` and `switch_lora`.
- [musetalk.py](musetalk.py) — MuseTalk lip-sync over base frames + TTS audio.
- [muxer.py](muxer.py) — thin ffmpeg wrappers: frames → MP4 loop, frames + audio → MP4.
Nothing here is imported unless `config.video.enabled` is true.
## LightX2V entry points (upstream API)
Use these symbols, not internal/private ones:
```python
from lightx2v.utils.set_config import set_config
from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
from lightx2v.infer import init_runner
config = set_config(args) # args is an argparse.Namespace
input_info = init_empty_input_info(args.task, args.support_tasks)
runner = init_runner(config) # loads all weights — ONCE
update_input_info_from_dict(input_info, {...}) # per-turn inputs
runner.run_pipeline(input_info) # MP4 written to save_result_path
runner.switch_lora(lora_path, strength) # hot-swap
```
Keep model load out of the per-turn path — `init_runner` is expensive.
## Blackwell (SM120) patches — do not remove without testing
The GGUF pipeline works on a 5090 only because of layered patches in `wan22.py` and tuning in the LightX2V JSON configs under [configs/lightx2v/](../../configs/lightx2v/). Each patch exists because a stock upstream path segfaults or silently miscomputes on SM120.
**Dtype plumbing (GGUF path):**
- Default `DTYPE` must be `BF16` at `init_runner()` time — T5 offload buffers break if FP16 at init.
- Flip `BF16 → FP16` *after* `init_runner()`.
- Wrap T5 encoder so it runs under BF16 internally, then cast outputs `bf16 → fp16` before handing to the DIT. See `_patch_t5_dtype_for_gguf`.
- Cast VAE **both** layers: the inner `.model` via `.to(fp16)` **and** the outer `WanVAE` wrapper's `mean` / `inv_std` / `scale` tensors. Missing the wrapper tensors upcasts the latent during decode's `z/inv_std + mean`.
- DIT `pre_weight.patch_embedding.pin_weight` loads as fp32 (only `pin_bias` is fp16). Cast **and** re-pin via `.pin_memory()` — skipping re-pin segfaults during `to_cuda` H2D copy.
- `sgl_kernel`'s fp8 scaled matmul is patched to `torch._scaled_mm` in `_patch_fp8_scaled_mm_for_blackwell`.
**LightX2V JSON config (see `wan22_i2v_gguf_distill.json`):**
- `modulate_type: "torch"` — Triton `fuse_scale_shift_kernel` segfaults in `ast_to_ttir` on Triton 3.4 + SM120.
- `rope_type: "torch"` — flashinfer isn't installed.
- `self_attn_1_type` / `cross_attn_*_type`: `"torch_sdpa"` — flash_attn3 unavailable; `sageattention==1.0.6` from PyPI segfaults on Blackwell (newer requires source build).
If you add a new quant scheme or a new model_cls, create its own JSON under `configs/lightx2v/` mirroring these choices, and exercise it end-to-end via a new `tests/component/test_NN_*.py` before wiring it into the default config.
## HF download layout
`Wan-AI/Wan2.2-I2V-A14B` ships ~28 GB of bf16 DIT shards we replace with the quantised `dit_repo`. `BASE_REPO_IGNORE_PATTERNS` in [wan22.py](wan22.py) excludes them but **keeps** `high_noise_model/*.json` and `low_noise_model/*.json``set_config` parses architecture params (`dim`, etc.) from those. Don't broaden the ignore pattern without checking.
Supported quant schemes live in `wan22_dit_quant_scheme`:
- `fp8-sgl``lightx2v/Wan2.2-Distill-Models`, two `.safetensors` files
- `gguf-Q4_K_M`, `gguf-Q8_0`, … — `QuantStack/Wan2.2-I2V-A14B-GGUF`, layout `HighNoise/…` and `LowNoise/…`
Filenames are templated at the top of `wan22.py`; update those if the upstream repos rename files.
## Testing
- `tests/component/test_02_wan22_loras.py` — full pipeline load + LoRA apply
- `tests/component/test_09_gguf_generate.py` — GGUF end-to-end I2V
- `tests/component/test_10_t5_encode.py` — T5 encoder dtype path
- `tests/component/test_11_image_encode.py` — image → VAE latent
- `tests/component/test_12_dit_single_step.py` — one DIT step per expert
- `tests/component/test_13_vae_decode.py` — VAE decode → RGB
When diagnosing a Blackwell regression, run 10 → 11 → 12 → 13 in that order; the failure localises to the first failing stage.
## LoRAs
`switch_lora(path, strength)` applies; `switch_lora("", 0.0)` removes. `load_loras`/`unload_loras` in this wrapper iterate over `LoRASpec`s from config and call `switch_lora` per `target` sub-model (`high_noise`, `low_noise`, or `both`). Wrong target = silently wrong output.
+34 -47
View File
@@ -32,8 +32,13 @@ class MuseTalkEngine:
def _load_impl(model_path: str):
"""Load the MuseTalk inference implementation.
If none of the known entry points work the error message points at
this file so you know where to fix it.
Upstream MuseTalk has no library-style entry point — it's a bundle
of training/inference CLI scripts. The bhetherman/MuseTalk fork at
``third_party/MuseTalk`` adds package metadata but the low-level
API is still the raw ``musetalk.utils.*`` and ``musetalk.models.*``
modules. We import them here to verify the install succeeded; the
actual pipeline (VAE, UNet, Whisper, face detection, blending)
is wired up inside ``MuseTalkEngine.lip_sync``.
"""
resolved = model_path
if not os.path.isdir(model_path) and "/" in model_path:
@@ -43,28 +48,19 @@ class MuseTalkEngine:
except Exception as e: # pragma: no cover
log.warning("Could not snapshot_download MuseTalk repo: %s", e)
# Try upstream MuseTalk repo layout.
try:
from musetalk.musetalk_inference import MuseTalkInference # type: ignore[import-not-found]
return MuseTalkInference(model_path=resolved)
except ImportError:
pass
try:
from musetalk.inference import MuseTalkInfer # type: ignore[import-not-found]
return MuseTalkInfer(model_path=resolved)
except ImportError:
pass
try:
from musetalk import Inference # type: ignore[import-not-found]
return Inference(model_path=resolved)
except ImportError:
pass
from musetalk.utils.utils import load_all_model # type: ignore[import-not-found] # noqa: F401
from musetalk.utils.audio_processor import AudioProcessor # type: ignore[import-not-found] # noqa: F401
except ImportError as e:
raise RuntimeError(
"MuseTalk Python package is not importable. "
"Check that third_party/MuseTalk was installed via "
"`pip install /opt/MuseTalk` in the Dockerfile."
) from e
raise RuntimeError(
"MuseTalk is installed but no known Python entry point was found. "
"Update server/video_models/musetalk.py::MuseTalkEngine._load_impl "
"to match the installed MuseTalk version."
)
# Return the resolved weight path; lip_sync loads models lazily on
# first call so import-time failures don't block voice-only startup.
return {"model_path": resolved, "loaded": False}
# --- Inference ---------------------------------------------------------
@@ -98,31 +94,22 @@ class MuseTalkEngine:
if target_t > 0 and len(frames) != target_t:
frames = _fit_frames_to_length(frames, target_t)
# The real MuseTalk call signature varies. Most common is a method
# like ``run(frames, audio, sr, fps)`` or ``infer(...)``.
for method_name in ("run", "infer", "lip_sync", "__call__"):
method = getattr(self._infer, method_name, None)
if method is None:
continue
try:
result = method(
frames=frames,
audio=audio,
sample_rate=sample_rate,
fps=fps,
)
return _ensure_uint8_rgb(result)
except TypeError:
# Try positional
try:
result = method(frames, audio, sample_rate, fps)
return _ensure_uint8_rgb(result)
except TypeError:
continue
raise RuntimeError(
"MuseTalk wrapper could not find a working inference method. "
"Update server/video_models/musetalk.py::MuseTalkEngine.lip_sync."
# MuseTalk's real inference path (see third_party/MuseTalk/scripts/
# realtime_inference.py::Avatar.inference) needs:
# - mmpose + mmcv + mmengine (dwpose keypoint detection)
# - face_alignment (bbox)
# - MuseTalk UNet + VAE weights (TMElyralab/MuseTalk HF repo)
# - Whisper encoder (openai/whisper-tiny)
# - face_parsing weights
# Plus its preprocessing module has import-time side effects that
# load dwpose weights from a CWD-relative path. Turn the full
# pipeline on by extending this method once those deps are
# installed and weights are resolved — until then, callers should
# keep ``config.video.musetalk.enabled: false`` and VideoEngine
# will skip the lip-sync pass.
raise NotImplementedError(
"MuseTalk lip-sync pipeline is not wired up yet. "
"Set config.video.musetalk.enabled=false to bypass."
)
+145 -187
View File
@@ -1,4 +1,4 @@
"""Wan2.2-Lightning image-to-video wrapper via LightX2V.
"""Wan2.2-TI2V-5B-Turbo (dense) image-to-video wrapper via LightX2V.
This wrapper targets LightX2V's actual Python entry points (verified against
the upstream ``lightx2v.infer.main`` in ModelTC/LightX2V@main):
@@ -22,15 +22,12 @@ Model weights are loaded once at construction and held resident across turns
so reflective mode doesn't re-pay the load cost each reply.
Two HuggingFace repos are consumed on first run (cached under HF_HOME):
- Wan-AI/Wan2.2-I2V-A14B — T5 encoder, VAE, tokenizer/config only.
The bf16 DIT shards under high_noise_model/
and low_noise_model/ are SKIPPED via
ignore_patterns — we replace them with
quantised checkpoints from dit_repo.
- dit_repo (configurable) — quantised DIT checkpoints. Supported
formats:
* fp8 safetensors (lightx2v/Wan2.2-Distill-Models)
* GGUF (QuantStack/Wan2.2-I2V-A14B-GGUF)
- Wan-AI/Wan2.2-TI2V-5B — T5 encoder, VAE, tokenizer/config only.
The bf16 DIT shards are SKIPPED via
ignore_patterns — replaced by the GGUF
checkpoint from dit_repo.
- dit_repo (configurable) — single dense GGUF DIT checkpoint, e.g.
hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF.
"""
from __future__ import annotations
@@ -49,27 +46,20 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
# --- fp8 distill filenames --------------------------------------------------
FP8_HIGH_NOISE_FILE = "wan2.2_i2v_A14b_high_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
FP8_LOW_NOISE_FILE = "wan2.2_i2v_A14b_low_noise_scaled_fp8_e4m3_lightx2v_4step.safetensors"
# --- GGUF filenames (QuantStack layout: HighNoise/<name>.gguf) ---------------
GGUF_HIGH_NOISE_TEMPLATE = "HighNoise/Wan2.2-I2V-A14B-HighNoise-{quant}.gguf"
GGUF_LOW_NOISE_TEMPLATE = "LowNoise/Wan2.2-I2V-A14B-LowNoise-{quant}.gguf"
# --- GGUF filename for the dense 5B Turbo repo ------------------------------
# hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF ships flat: Wan2_2-TI2V-5B-Turbo-{quant}.gguf
GGUF_TURBO_5B_FILE = "Wan2_2-TI2V-5B-Turbo-{quant}.gguf"
# --- fp8 T5 encoder (lightx2v/Encoders repo) --------------------------------
T5_FP8_REPO = "lightx2v/Encoders"
T5_FP8_FILE = "models_t5_umt5-xxl-enc-fp8.safetensors"
# The Wan-AI base repo ships bf16 DIT weight shards (~28 GB) alongside the
# T5/VAE/tokenizer support files (~12 GB). We only need the latter — the
# quantised files from dit_repo replace the DIT weights entirely. We must
# keep the config.json / index.json metadata under high_noise_model/ and
# low_noise_model/ (LightX2V's set_config reads architecture params like
# ``dim`` from them) and the tokenizer files under google/.
# The Wan-AI base repo ships bf16 DIT weight shards alongside the T5/VAE/
# tokenizer support files. We only need the latter — the GGUF from dit_repo
# replaces the DIT weights entirely. Keep config.json / tokenizer files.
BASE_REPO_IGNORE_PATTERNS = [
"high_noise_model/*.safetensors",
"low_noise_model/*.safetensors",
"*.pt",
"diffusion_pytorch_model*.safetensors",
"assets/*",
"examples/*",
"nohup.out",
@@ -77,6 +67,40 @@ BASE_REPO_IGNORE_PATTERNS = [
]
def _cast_all_fp32_tensors(obj, visited=None, depth=0) -> int:
"""Recursively find fp32 tensors reachable from ``obj`` and cast to fp16.
The dense ``wan2.2`` DIT isn't a standard ``nn.Module`` — some fp32
tensors (conv3d bias etc.) live outside ``pre_weight``/``post_weight``
and are missed by the structured sweep. This generic traversal catches
them. Bounded depth + visited-set to avoid cycles.
"""
import torch
if visited is None:
visited = set()
obj_id = id(obj)
if obj_id in visited or depth > 6:
return 0
visited.add(obj_id)
n = 0
for attr_name in dir(obj):
if attr_name.startswith("__"):
continue
try:
val = getattr(obj, attr_name)
except Exception:
continue
if isinstance(val, torch.Tensor) and val.dtype == torch.float32 and val.numel() > 0:
try:
setattr(obj, attr_name, val.to(torch.float16))
n += 1
except Exception:
pass
elif hasattr(val, "__dict__") and not callable(val):
n += _cast_all_fp32_tensors(val, visited, depth + 1)
return n
def _patch_fp8_scaled_mm_for_blackwell() -> None:
"""Replace sgl_kernel.fp8_scaled_mm with torch._scaled_mm on Blackwell.
@@ -132,13 +156,11 @@ def _patch_fp8_scaled_mm_for_blackwell() -> None:
class Wan22Pipeline:
"""Wrapper around LightX2V's Wan2.2 MoE distill runner.
"""Wrapper around LightX2V's dense Wan2.2-TI2V-5B-Turbo runner.
Supports two DIT quantisation formats:
* **fp8** — ``dit_quant_scheme="fp8-sgl"`` (default, from
``lightx2v/Wan2.2-Distill-Models``)
* **GGUF** — ``dit_quant_scheme="gguf-Q4_K_M"`` (or any quant level,
from ``QuantStack/Wan2.2-I2V-A14B-GGUF``)
The 5B Turbo repo ships a single dense DIT checkpoint (not MoE) as GGUF.
``dit_quant_scheme`` must be a GGUF variant (``gguf-Q8_0`` default,
``gguf-Q4_K_M`` for lower VRAM); no fp8 5B Turbo weights exist.
Constructor downloads (if needed) both HF repos, writes a runtime JSON
config with absolute ckpt paths, then drives ``lightx2v.infer.init_runner``.
@@ -150,11 +172,11 @@ class Wan22Pipeline:
base_repo: str,
dit_repo: str,
config_json: str,
model_cls: str = "wan2.2_moe_distill",
model_cls: str = "wan2.2",
resolution: int = 480,
fps: int = 16,
dit_quant_scheme: str = "fp8-sgl",
t5_quantized: bool = False,
dit_quant_scheme: str = "gguf-Q8_0",
t5_quantized: bool = True,
):
self.base_repo = base_repo
self.dit_repo = dit_repo
@@ -167,10 +189,15 @@ class Wan22Pipeline:
self._applied_loras: list[LoRASpec] = []
self._is_gguf = dit_quant_scheme.startswith("gguf-")
if not self._is_gguf:
raise ValueError(
f"dit_quant_scheme must be a GGUF variant for dense 5B Turbo "
f"(got {dit_quant_scheme!r}); no fp8 5B Turbo weights exist."
)
# 1. Resolve / download base repo (T5/VAE/config) and DIT ckpts.
# 1. Resolve / download base repo (T5/VAE/config) and DIT ckpt.
self._model_root = self._ensure_base_repo(base_repo)
self._dit_high, self._dit_low = self._ensure_dit_checkpoints(
self._dit_ckpt = self._ensure_dit_checkpoint(
dit_repo, dit_quant_scheme,
)
self._t5_fp8_ckpt = (
@@ -306,52 +333,53 @@ class Wan22Pipeline:
log.info("Cast VAE encoder/decoder weights + scale to fp16 for GGUF FP16 pipeline.")
def _patch_dit_fp32_weights_for_gguf(self) -> None:
"""Cast leftover fp32 DIT pre/post weights to fp16.
"""Cast leftover fp32 DIT weights to fp16 (dense model).
GGUF Q4_K_M dequantises the transformer blocks to fp16, but a handful
of non-quantised weights (notably ``patch_embedding.pin_weight``) end
up loaded as fp32. That breaks the first conv in the DIT forward pass
(fp16 input vs fp32 weight). Cast any such tensors to fp16 so the DIT
runs uniformly in fp16.
GGUF dequantises the transformer blocks to fp16, but a handful of
non-quantised weights (notably ``patch_embedding.pin_weight``) end up
loaded as fp32. That breaks the first conv in the DIT forward pass
(fp16 input vs fp32 weight). Dense ``wan2.2`` exposes the model
directly at ``runner.model`` (no MoE wrapper). After the structured
pre/post weight sweep, we also run a recursive traversal to catch
fp32 conv3d biases etc. that live outside pre/post_weight.
"""
import torch
runner = self._runner
models = getattr(runner.model, "model", None)
if models is None:
return
if not isinstance(models, (list, tuple)):
models = [models]
n_struct = self._cast_fp32_dit_weights_in_model(runner.model)
n_extra = _cast_all_fp32_tensors(runner.model)
log.info(
"Cast %d (structured) + %d (recursive) fp32 DIT tensors to fp16 for GGUF pipeline.",
n_struct, n_extra,
)
@staticmethod
def _cast_fp32_dit_weights_in_model(m) -> int:
import torch
n_cast = 0
for m in models:
for weights_attr in ("pre_weight", "post_weight"):
w = getattr(m, weights_attr, None)
if w is None:
for weights_attr in ("pre_weight", "post_weight"):
w = getattr(m, weights_attr, None)
if w is None:
continue
for sub_name in dir(w):
if sub_name.startswith("_"):
continue
for sub_name in dir(w):
if sub_name.startswith("_"):
continue
try:
sub = getattr(w, sub_name)
except Exception:
continue
if sub is None:
continue
for t_name in ("weight", "bias", "pin_weight", "pin_bias"):
t = getattr(sub, t_name, None)
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
casted = t.to(torch.float16)
# Preserve pinned-memory status on pin_* tensors so
# move_attr_to_cuda's non-blocking H2D copy is safe.
if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned():
try:
casted = casted.pin_memory()
except RuntimeError:
pass
setattr(sub, t_name, casted)
n_cast += 1
log.info("Cast %d fp32 DIT weight tensors to fp16 for GGUF pipeline.", n_cast)
try:
sub = getattr(w, sub_name)
except Exception:
continue
if sub is None:
continue
for t_name in ("weight", "bias", "pin_weight", "pin_bias"):
t = getattr(sub, t_name, None)
if isinstance(t, torch.Tensor) and t.dtype == torch.float32:
casted = t.to(torch.float16)
if t_name.startswith("pin_") and t.is_pinned() and not casted.is_pinned():
try:
casted = casted.pin_memory()
except RuntimeError:
pass
setattr(sub, t_name, casted)
n_cast += 1
return n_cast
# --- Weight provisioning -------------------------------------------------
@@ -374,46 +402,34 @@ class Wan22Pipeline:
)
@staticmethod
def _ensure_dit_checkpoints(
def _ensure_dit_checkpoint(
dit_repo: str,
dit_quant_scheme: str,
) -> tuple[str, str]:
"""Return (high_noise_path, low_noise_path) for the DIT pair.
Supports both fp8 safetensors and GGUF formats.
"""
) -> str:
"""Return the local path to the single dense GGUF DIT checkpoint."""
if not dit_repo:
raise ValueError("dit_repo must be a HF repo id or local directory.")
if not dit_quant_scheme.startswith("gguf-"):
raise ValueError(
f"Only GGUF quant schemes are supported for dense 5B Turbo "
f"(got {dit_quant_scheme!r})."
)
is_gguf = dit_quant_scheme.startswith("gguf-")
quant = dit_quant_scheme.replace("gguf-", "")
filename = GGUF_TURBO_5B_FILE.format(quant=quant)
if is_gguf:
# Extract quant level, e.g. "gguf-Q4_K_M" → "Q4_K_M"
quant = dit_quant_scheme.replace("gguf-", "")
high_file = GGUF_HIGH_NOISE_TEMPLATE.format(quant=quant)
low_file = GGUF_LOW_NOISE_TEMPLATE.format(quant=quant)
else:
high_file = FP8_HIGH_NOISE_FILE
low_file = FP8_LOW_NOISE_FILE
# Local directory?
if os.path.isdir(dit_repo):
high = os.path.join(dit_repo, high_file)
low = os.path.join(dit_repo, low_file)
if not (os.path.isfile(high) and os.path.isfile(low)):
path = os.path.join(dit_repo, filename)
if not os.path.isfile(path):
raise FileNotFoundError(
f"DIT checkpoints not found in {dit_repo}: expected "
f"{high_file} and {low_file}"
f"DIT checkpoint not found in {dit_repo}: expected {filename}"
)
return high, low
return path
# HuggingFace download.
from huggingface_hub import hf_hub_download
log.info("Downloading %s DIT checkpoints from %s ...",
log.info("Downloading %s DIT checkpoint from %s ...",
dit_quant_scheme, dit_repo)
high = hf_hub_download(repo_id=dit_repo, filename=high_file)
low = hf_hub_download(repo_id=dit_repo, filename=low_file)
return high, low
return hf_hub_download(repo_id=dit_repo, filename=filename)
@staticmethod
def _ensure_t5_fp8() -> str:
@@ -431,8 +447,7 @@ class Wan22Pipeline:
cfg = json.load(f)
# Drop editorial comments before passing to LightX2V.
cfg.pop("_comment", None)
cfg["high_noise_quantized_ckpt"] = self._dit_high
cfg["low_noise_quantized_ckpt"] = self._dit_low
cfg["dit_quantized_ckpt"] = self._dit_ckpt
cfg.setdefault("fps", self.fps)
# T5 fp8 quantization.
@@ -496,103 +511,46 @@ class Wan22Pipeline:
# --- LoRA --------------------------------------------------------------
def load_loras(self, specs: list["LoRASpec"]) -> None:
"""Apply LoRAs to the Wan2.2 MoE distill pipeline.
"""Apply LoRAs to the dense Wan2.2-TI2V-5B pipeline.
Each spec's ``target`` must be ``"high_noise"`` or ``"low_noise"``
to route the LoRA to the correct expert.
With ``lazy_load`` the DIT models are ``None`` at this point, so
runtime ``switch_lora`` is impossible. Instead we inject
``lora_configs`` + ``lora_dynamic_apply`` into the runner config so
the LoRAs are applied when the models materialise on first inference.
Without ``lazy_load`` (models already resident) we call
``switch_lora`` with explicit high/low keyword args.
Dense has a single DIT (no MoE experts), so ``target`` must be
``"both"``. GGUF DIT weights don't expose a ``lora_down`` buffer,
so ``switch_lora`` would crash — we use the dynamic-apply path that
merges LoRAs during GGUF dequant.
"""
if not specs:
return
# Resolve every path up-front (may trigger HF download).
resolved: list[tuple["LoRASpec", str]] = []
for spec in specs:
if spec.target != "both":
raise ValueError(
f"Dense 5B Turbo has a single DIT; LoRA target must be "
f"'both' (got {spec.target!r})."
)
local_path = self._resolve_lora_path(spec.path)
log.info(" LoRA %s → strength=%.2f target=%s (%s)",
spec.name or spec.path, spec.weight, spec.target,
local_path)
log.info(" LoRA %s → strength=%.2f (%s)",
spec.name or spec.path, spec.weight, local_path)
resolved.append((spec, local_path))
lazy = self._config.get("lazy_load", False)
if lazy:
# Build the lora_configs list that LightX2V's lazy-load path
# reads inside MultiDistillModelStruct.infer().
lora_cfgs = []
for spec, local_path in resolved:
# LightX2V expects name "high_noise_model" / "low_noise_model"
cfg_name = {
"high_noise": "high_noise_model",
"low_noise": "low_noise_model",
}.get(spec.target)
if cfg_name is None:
raise ValueError(
f"LoRA target must be 'high_noise' or 'low_noise', "
f"got {spec.target!r}")
lora_cfgs.append({
"name": cfg_name,
"path": local_path,
"strength": spec.weight,
})
self._runner.set_config({
"lora_configs": lora_cfgs,
"lora_dynamic_apply": True,
})
else:
# Models are loaded — use runtime hot-swap.
high_path = high_strength = None
low_path = low_strength = None
for spec, local_path in resolved:
if spec.target == "high_noise":
high_path, high_strength = local_path, spec.weight
elif spec.target == "low_noise":
low_path, low_strength = local_path, spec.weight
else:
raise ValueError(
f"LoRA target must be 'high_noise' or 'low_noise', "
f"got {spec.target!r}")
kwargs: dict = {}
if high_path is not None:
kwargs["high_lora_path"] = high_path
kwargs["high_lora_strength"] = high_strength
if low_path is not None:
kwargs["low_lora_path"] = low_path
kwargs["low_lora_strength"] = low_strength
ok = self._runner.switch_lora(**kwargs)
if not ok:
raise RuntimeError(
"runner.switch_lora returned False. Check that your "
"LightX2V build supports runtime LoRA updates for "
f"{self.model_cls}.")
lora_cfgs = [
{"path": local_path, "strength": spec.weight}
for spec, local_path in resolved
]
self._runner.set_config({
"lora_configs": lora_cfgs,
"lora_dynamic_apply": True,
})
self._applied_loras = list(specs)
def unload_loras(self) -> None:
"""Remove all currently applied LoRAs."""
if not self._applied_loras:
return
lazy = self._config.get("lazy_load", False)
if lazy:
self._runner.set_config({
"lora_configs": None,
"lora_dynamic_apply": False,
})
# If models were materialised, drop them so the next inference
# recreates them without LoRAs.
model_struct = getattr(self._runner, "model", None)
if model_struct is not None and hasattr(model_struct, "model"):
for i in range(len(model_struct.model)):
model_struct.model[i] = None
else:
self._runner.switch_lora("", 0.0)
self._runner.set_config({
"lora_configs": None,
"lora_dynamic_apply": False,
})
self._applied_loras = []
@staticmethod
+28 -8
View File
@@ -9,25 +9,45 @@ python -m pytest tests/unit -v
```
These exercise pure logic: config parsing, prompt derivation, LoRA spec
parsing, frame-length fitting, library round-robin selection. They do not
touch CUDA, Wan2.2, MuseTalk, or ffmpeg. Safe to run on Windows, outside
Docker, without any models installed.
parsing, frame-length fitting, library round-robin selection, the
pipeline's video branch, and ffmpeg mux argument shaping. They do not
touch CUDA, Wan2.2, MuseTalk, or a real ffmpeg binary. Safe to run on
Windows, outside Docker, without any models installed.
Current unit files:
- `test_video_config.py``VideoConfig.from_dict` round-trip, LoRA target validation
- `test_video_engine_logic.py` — prompt derivation, library cursor, frame fitting
- `test_pipeline_video_branch.py` — pipeline takes the video path iff engine is ready
- `test_musetalk_fit_frames.py` — frame-length adjustment to match audio duration
- `test_muxer_ffmpeg.py` — ffmpeg command construction
## Component tests — slow, GPU-required, run inside Docker
Each script in `tests/component/` exercises one subsystem end-to-end against
the real models. They are ordered to match the implementation phases:
Each script in `tests/component/` exercises one subsystem end-to-end
against the real models. The numbered prefix reflects the implementation
phase each script gates, and also serves as a reasonable run order when
debugging a fresh environment:
| Script | Phase | Tests |
|---|---|---|
| `test_01_video_skeleton.py` | 1 | VideoEngine loads, config gate respected |
| `test_02_wan22_loras.py` | 2 | Wan2.2 pipeline loads, LoRA stack applies |
| `test_03_idle_clip.py` | 3 | set_avatar → idle MP4, written to disk for eyeballing |
| `test_03_idle_clip.py` | 3 | `set_avatar` → idle MP4, written to disk for eyeballing |
| `test_04_library_prebake.py` | 4 | library mode pre-bakes N base clips |
| `test_05_musetalk_lipsync.py` | 5 | MuseTalk lip-sync on library frames + ffmpeg mux |
| `test_06_reflective.py` | 6 | reflective mode: fresh Wan2.2 per reply |
| `test_07_endpoints.py` | 7 | HTTP endpoints return sane responses |
| `test_08_lora_reload.py` | 8 | /api/reload-loras swaps LoRAs live |
| `test_08_lora_reload.py` | 8 | `/api/reload-loras` swaps LoRAs live |
| `test_09_gguf_generate.py` | 9 | GGUF-quantised DIT end-to-end I2V generation |
| `test_10_t5_encode.py` | 10 | T5 encoder (optionally fp8-quantised) on CUDA |
| `test_11_image_encode.py` | 11 | Avatar image → VAE latent path |
| `test_12_dit_single_step.py` | 12 | Single DIT step on the loaded expert(s) |
| `test_13_vae_decode.py` | 13 | VAE decode back to RGB frames |
Tests 09-13 are focused on the GGUF + Blackwell (SM120) path and are how
new quant schemes / attention backends get validated before wiring them
into the full pipeline.
Run one:
@@ -36,7 +56,7 @@ Run one:
docker compose exec voice-chat python -m tests.component.test_03_idle_clip
```
Run all (slow, ~20+ minutes on 5090):
Run all (slow, ~20+ minutes on a 5090):
```
docker compose exec voice-chat python -m tests.component.run_all
+23 -39
View File
@@ -1,26 +1,26 @@
"""Phase 2 component test: Wan2.2 pipeline + LoRA stacking.
"""Phase 2 component test: dense Wan2.2-TI2V-5B-Turbo pipeline + LoRA stacking.
Verifies:
- ``Wan22Pipeline`` loads successfully (exercises the real LightX2V
set_config -> init_runner flow).
- ``load_loras`` / ``unload_loras`` survive with the two user LoRAs at
``/cache/loras/wan22-[HL]-e8.safetensors``.
- ``load_loras`` / ``unload_loras`` survive with any user LoRAs at
``/cache/loras/*.safetensors`` (target='both', dense single DIT).
Supports both fp8 and GGUF DIT quantisation. Set the ``DIT_QUANT``
environment variable to switch (default: ``fp8-sgl``).
Supports any GGUF quant published in hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF.
Set ``DIT_QUANT`` to switch (default: ``gguf-Q8_0``).
DIT_QUANT=gguf-Q4_K_M docker compose exec voice-chat \
python -m tests.component.test_02_wan22_loras
Requires GPU and a first-run download of both HF repos (base support files
~12 GB, DIT size depends on quant — fp8 ~30 GB, GGUF Q4_K_M ~19 GB).
Requires GPU and a first-run download of the base repo + GGUF DIT.
If LightX2V isn't installed the test is skipped.
Run (default fp8):
Run:
docker compose exec voice-chat python -m tests.component.test_02_wan22_loras
"""
from __future__ import annotations
import glob
import os
import sys
@@ -28,19 +28,9 @@ from tests.component._common import get_logger
log = get_logger("test_02")
# --- Quant-dependent defaults ------------------------------------------------
DIT_QUANT = os.environ.get("DIT_QUANT", "fp8-sgl")
if DIT_QUANT.startswith("gguf-"):
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
else:
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
LORA_HIGH = "/cache/loras/wan22-H-e8.safetensors"
LORA_LOW = "/cache/loras/wan22-L-e8.safetensors"
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
def run():
@@ -57,13 +47,14 @@ def run():
"(quant=%s, dit_repo=%s)...", DIT_QUANT, DIT_REPO)
try:
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-I2V-A14B",
base_repo="Wan-AI/Wan2.2-TI2V-5B",
dit_repo=DIT_REPO,
config_json=CONFIG_JSON,
model_cls="wan2.2_moe_distill",
model_cls="wan2.2",
resolution=480,
fps=16,
dit_quant_scheme=DIT_QUANT,
t5_quantized=True,
)
except Exception as e:
log.error("FAIL: Wan22Pipeline construction raised: %s", e)
@@ -78,34 +69,27 @@ def run():
pipe.load_loras([])
log.info(" PASS")
if not (os.path.isfile(LORA_HIGH) and os.path.isfile(LORA_LOW)):
log.warning("SKIP: expected LoRA files not found at %s / %s",
LORA_HIGH, LORA_LOW)
lora_files = sorted(glob.glob("/cache/loras/*.safetensors"))
if not lora_files:
log.warning("SKIP: no LoRA files found in /cache/loras/")
log.info("ALL PASSED (partial — LoRA cases skipped)")
return
log.info("[case 3] load_loras with the two MoE distill LoRAs")
lora_path = lora_files[0]
log.info("[case 3] load_loras with one 5B-compatible LoRA (%s)", lora_path)
specs = [
LoRASpec(
path=LORA_HIGH,
path=lora_path,
weight=1.0,
target="high_noise",
name="wan22-H-e8",
),
LoRASpec(
path=LORA_LOW,
weight=1.0,
target="low_noise",
name="wan22-L-e8",
target="both",
name=os.path.basename(lora_path),
),
]
try:
pipe.load_loras(specs)
except Exception as e:
log.error("FAIL: load_loras raised: %s", e)
log.error("Check: switch_lora support for wan2.2_moe_distill in the "
"installed LightX2V build. If it errors there, pre-declare "
"LoRAs in the config_json 'lora_configs' field instead.")
log.error("Check: LoRA checkpoint shape matches dense 5B DIT.")
sys.exit(3)
log.info(" PASS: LoRAs applied")
+2 -2
View File
@@ -81,9 +81,9 @@ def run():
body = {
"loras": [
{"path": "/cache/loras/a.safetensors", "weight": 0.8,
"target": "high_noise", "name": "test-a"},
"target": "both", "name": "test-a"},
{"path": "/cache/loras/b.safetensors", "weight": 0.4,
"target": "low_noise"},
"target": "both"},
]
}
resp = client.post("/api/reload-loras", json=body)
+9 -17
View File
@@ -32,28 +32,20 @@ def run():
write_bytes("phase8_idle_noloras.mp4", idle_a)
log.info("idle (no LoRAs) sha256=%s", hash_a[:16])
# Hot-reload with a distill LoRA
specs = [
LoRASpec(
path="lightx2v/Wan2.2-Distill-Loras:"
"wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step.safetensors",
weight=1.0,
target="high_noise",
name="distill-hi",
),
]
engine.load_loras(specs)
# Hot-reload flow: unload (no-op), reload empty list, verify clip still generates.
# There are no published 5B-Turbo-compatible LoRAs yet; when one exists,
# construct a LoRASpec(path=..., target="both", weight=1.0) and compare hashes.
engine.load_loras([])
engine.set_avatar(avatar_path)
idle_b = engine.get_idle_clip()
assert idle_b is not None
hash_b = hashlib.sha256(idle_b).hexdigest()
write_bytes("phase8_idle_withlora.mp4", idle_b)
log.info("idle (with LoRA) sha256=%s", hash_b[:16])
write_bytes("phase8_idle_reloaded.mp4", idle_b)
log.info("idle (post-reload) sha256=%s", hash_b[:16])
if hash_a != hash_b:
log.info("PASS: idle clip changed after LoRA reload")
else:
log.warning("clips identical — LoRA may not be applied; eyeball _out/*.mp4")
log.info("PASS: hot-reload round-trip completed "
"(hash match=%s — expected without a real LoRA applied).",
hash_a == hash_b)
if __name__ == "__main__":
+7 -12
View File
@@ -1,10 +1,10 @@
"""Quick smoke test: generate a video clip with the GGUF pipeline.
"""Quick smoke test: generate a video clip with the dense 5B Turbo GGUF pipeline.
Calls Wan22Pipeline.generate_i2v directly (no MuseTalk, no VideoEngine)
and writes the result to tests/component/_out/phase9_gguf.mp4.
Run:
docker compose exec -e DIT_QUANT=gguf-Q4_K_M voice-chat \
docker compose exec -e DIT_QUANT=gguf-Q8_0 voice-chat \
python -m tests.component.test_09_gguf_generate
"""
from __future__ import annotations
@@ -16,14 +16,9 @@ from tests.component._common import ensure_sample_avatar, get_logger, write_byte
log = get_logger("test_09")
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
if DIT_QUANT.startswith("gguf-"):
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
else:
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
def run():
@@ -38,10 +33,10 @@ def run():
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-I2V-A14B",
base_repo="Wan-AI/Wan2.2-TI2V-5B",
dit_repo=DIT_REPO,
config_json=CONFIG_JSON,
model_cls="wan2.2_moe_distill",
model_cls="wan2.2",
resolution=480,
fps=16,
dit_quant_scheme=DIT_QUANT,
+5 -10
View File
@@ -17,14 +17,9 @@ from tests.component._common import get_logger
log = get_logger("test_10")
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
if DIT_QUANT.startswith("gguf-"):
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
else:
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
def run():
@@ -36,10 +31,10 @@ def run():
log.info("Building pipeline (quant=%s) — this loads T5 + DIT weights...", DIT_QUANT)
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-I2V-A14B",
base_repo="Wan-AI/Wan2.2-TI2V-5B",
dit_repo=DIT_REPO,
config_json=CONFIG_JSON,
model_cls="wan2.2_moe_distill",
model_cls="wan2.2",
resolution=480,
fps=16,
dit_quant_scheme=DIT_QUANT,
+5 -10
View File
@@ -22,14 +22,9 @@ from tests.component._common import ensure_sample_avatar, get_logger
log = get_logger("test_11")
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
if DIT_QUANT.startswith("gguf-"):
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
else:
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
def run():
@@ -44,10 +39,10 @@ def run():
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-I2V-A14B",
base_repo="Wan-AI/Wan2.2-TI2V-5B",
dit_repo=DIT_REPO,
config_json=CONFIG_JSON,
model_cls="wan2.2_moe_distill",
model_cls="wan2.2",
resolution=480,
fps=16,
dit_quant_scheme=DIT_QUANT,
+5 -10
View File
@@ -20,14 +20,9 @@ from tests.component._common import ensure_sample_avatar, get_logger
log = get_logger("test_12")
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
if DIT_QUANT.startswith("gguf-"):
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
else:
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
def run():
@@ -42,10 +37,10 @@ def run():
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-I2V-A14B",
base_repo="Wan-AI/Wan2.2-TI2V-5B",
dit_repo=DIT_REPO,
config_json=CONFIG_JSON,
model_cls="wan2.2_moe_distill",
model_cls="wan2.2",
resolution=480,
fps=16,
dit_quant_scheme=DIT_QUANT,
+5 -10
View File
@@ -19,14 +19,9 @@ from tests.component._common import ensure_sample_avatar, get_logger, write_byte
log = get_logger("test_13")
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q4_K_M")
if DIT_QUANT.startswith("gguf-"):
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_distill.json"
DIT_REPO = "QuantStack/Wan2.2-I2V-A14B-GGUF"
else:
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_fp8_distill.json"
DIT_REPO = "lightx2v/Wan2.2-Distill-Models"
DIT_QUANT = os.environ.get("DIT_QUANT", "gguf-Q8_0")
CONFIG_JSON = "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"
DIT_REPO = "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
def run():
@@ -41,10 +36,10 @@ def run():
log.info("Building pipeline (quant=%s)...", DIT_QUANT)
pipe = Wan22Pipeline(
base_repo="Wan-AI/Wan2.2-I2V-A14B",
base_repo="Wan-AI/Wan2.2-TI2V-5B",
dit_repo=DIT_REPO,
config_json=CONFIG_JSON,
model_cls="wan2.2_moe_distill",
model_cls="wan2.2",
resolution=480,
fps=16,
dit_quant_scheme=DIT_QUANT,
+31 -26
View File
@@ -64,32 +64,39 @@ def test_lora_parse_full():
{
"loras": [
{
"path": "/tmp/hi.safetensors",
"path": "/tmp/a.safetensors",
"weight": 0.7,
"target": "high_noise",
"name": "hi-noise-style",
"target": "both",
"name": "style-a",
},
{
"path": "/tmp/lo.safetensors",
"path": "/tmp/b.safetensors",
"weight": 0.4,
"target": "low_noise",
"name": "lo-noise-style",
"target": "both",
"name": "style-b",
},
]
}
)
assert len(cfg.loras) == 2
assert cfg.loras[0].target == "high_noise"
assert cfg.loras[0].name == "hi-noise-style"
assert cfg.loras[1].target == "low_noise"
assert cfg.loras[0].target == "both"
assert cfg.loras[0].name == "style-a"
assert cfg.loras[1].target == "both"
assert cfg.loras[1].weight == 0.4
def test_lora_invalid_target_falls_back_to_both():
def test_lora_legacy_moe_target_coerced_to_both():
"""Legacy MoE configs with target='high_noise'/'low_noise' get coerced."""
cfg = VideoConfig.from_dict(
{"loras": [{"path": "/tmp/x.safetensors", "target": "bogus"}]}
{
"loras": [
{"path": "/tmp/hi.safetensors", "target": "high_noise"},
{"path": "/tmp/lo.safetensors", "target": "low_noise"},
{"path": "/tmp/x.safetensors", "target": "bogus"},
]
}
)
assert cfg.loras[0].target == "both"
assert all(l.target == "both" for l in cfg.loras)
def test_lora_entries_without_path_are_dropped():
@@ -107,8 +114,8 @@ def test_models_section_override():
"wan22_base_repo": "/local/weights/wan22",
"wan22_dit_repo": "/local/weights/wan22-dit",
"wan22_dit_quant_scheme": "gguf-Q4_K_M",
"wan22_config_json": "/local/cfg/fp8.json",
"wan22_model_cls": "wan2.2_moe",
"wan22_config_json": "/local/cfg/turbo.json",
"wan22_model_cls": "wan2.2",
"musetalk_path": "/local/weights/musetalk",
}
}
@@ -116,18 +123,16 @@ def test_models_section_override():
assert cfg.wan22_base_repo == "/local/weights/wan22"
assert cfg.wan22_dit_repo == "/local/weights/wan22-dit"
assert cfg.wan22_dit_quant_scheme == "gguf-Q4_K_M"
assert cfg.wan22_config_json == "/local/cfg/fp8.json"
assert cfg.wan22_model_cls == "wan2.2_moe"
assert cfg.wan22_config_json == "/local/cfg/turbo.json"
assert cfg.wan22_model_cls == "wan2.2"
assert cfg.musetalk_model_path == "/local/weights/musetalk"
def test_models_section_backwards_compat_fp8_repo():
"""Old config key wan22_fp8_repo still works via fallback."""
cfg = VideoConfig.from_dict(
{
"models": {
"wan22_fp8_repo": "/local/weights/wan22-fp8",
}
}
)
assert cfg.wan22_dit_repo == "/local/weights/wan22-fp8"
def test_models_section_defaults_to_5b_turbo():
cfg = VideoConfig.from_dict({})
assert cfg.wan22_base_repo == "Wan-AI/Wan2.2-TI2V-5B"
assert cfg.wan22_dit_repo == "hum-ma/Wan2.2-TI2V-5B-Turbo-GGUF"
assert cfg.wan22_dit_quant_scheme == "gguf-Q8_0"
assert cfg.wan22_t5_quantized is True
assert cfg.wan22_model_cls == "wan2.2"
assert cfg.wan22_config_json == "/app/configs/lightx2v/wan22_i2v_gguf_5b_turbo.json"